1
0
Fork 0
Face identification and recognition scalable server with multiple face directories. https://github.com/ehp/faceserver
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

108 lines
3.5 KiB

# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
class AngleLinear(nn.Module):
def __init__(self, in_features, out_features):
super(AngleLinear, self).__init__()
self.W = Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.W)
def forward(self, input):
x = F.normalize(input)
W = F.normalize(self.W)
return F.linear(x, W)
class AdaCos(nn.Module):
def __init__(self, num_classes, m=0.50, is_cuda=True):
super(AdaCos, self).__init__()
self.n_classes = num_classes
self.s = math.sqrt(2) * math.log(num_classes - 1)
self.base_s = self.s
self.m = m
self.criterion = nn.CrossEntropyLoss()
if is_cuda:
self.criterion = self.criterion.cuda()
def forward(self, input, label):
# changed to fixed adacos
# theta = torch.acos(torch.clamp(input, -1.0 + 1e-7, 1.0 - 1e-7))
# one_hot = torch.zeros_like(input)
# one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# with torch.no_grad():
# B_avg = torch.where(one_hot < 1, torch.exp(self.s * input), torch.zeros_like(input))
# B_avg = torch.sum(B_avg) / input.size(0)
# theta_med = torch.median(theta)
# self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
# # TODO why converge to infinity ?
# self.s = torch.clamp(self.s, self.base_s / 2, self.base_s * 2)
# print(self.s)
output = self.s * input
return self.criterion(output, label)
class ArcFace(nn.Module):
def __init__(self, s=30.0, m=0.50, is_cuda=True):
super(ArcFace, self).__init__()
self.s = s
self.m = m
self.criterion = nn.CrossEntropyLoss()
if is_cuda:
self.criterion = self.criterion.cuda()
def forward(self, input, label):
theta = torch.acos(torch.clamp(input, -1.0 + 1e-7, 1.0 - 1e-7))
target_logits = torch.cos(theta + self.m)
one_hot = torch.zeros_like(input)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = input * (1 - one_hot) + target_logits * one_hot
output *= self.s
return self.criterion(output, label)
class SphereFace(nn.Module):
def __init__(self, s=30.0, m=1.35, is_cuda=True):
super(SphereFace, self).__init__()
self.s = s
self.m = m
self.criterion = nn.CrossEntropyLoss()
if is_cuda:
self.criterion = self.criterion.cuda()
def forward(self, input, label):
theta = torch.acos(torch.clamp(input, -1.0 + 1e-7, 1.0 - 1e-7))
target_logits = torch.cos(self.m * theta)
one_hot = torch.zeros_like(input)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = input * (1 - one_hot) + target_logits * one_hot
output *= self.s
return self.criterion(output, label)
class CosFace(nn.Module):
def __init__(self, s=30.0, m=0.35, is_cuda=True):
super(CosFace, self).__init__()
self.s = s
self.m = m
self.criterion = nn.CrossEntropyLoss()
if is_cuda:
self.criterion = self.criterion.cuda()
def forward(self, input, label):
target_logits = input - self.m
one_hot = torch.zeros_like(input)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
output = input * (1 - one_hot) + target_logits * one_hot
output *= self.s
return self.criterion(output, label)