Face identification and recognition scalable server with multiple face directories.
https://github.com/ehp/faceserver
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
309 lines
11 KiB
309 lines
11 KiB
# -*- coding: utf-8 -*- |
|
""" |
|
Copyright 2019 Petr Masopust, Aprar s.r.o. |
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); |
|
you may not use this file except in compliance with the License. |
|
You may obtain a copy of the License at |
|
|
|
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
|
Unless required by applicable law or agreed to in writing, software |
|
distributed under the License is distributed on an "AS IS" BASIS, |
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
See the License for the specific language governing permissions and |
|
limitations under the License. |
|
|
|
Adopted code from https://github.com/rainofmine/Face_Attention_Network |
|
""" |
|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
def memprint(a): |
|
print(a.shape) |
|
print(a.element_size() * a.nelement()) |
|
|
|
|
|
def calc_iou(a, b): |
|
step = 20 |
|
IoU = torch.zeros((len(a), len(b))).cuda() |
|
step_count = int(len(b) / step) |
|
if len(b) % step != 0: |
|
step_count += 1 |
|
|
|
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) |
|
|
|
for i in range(step_count): |
|
iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[i * step:(i + 1) * step, 2]) |
|
iw.sub_(torch.max(torch.unsqueeze(a[:, 0], 1), b[i * step:(i + 1) * step, 0])) |
|
|
|
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[i * step:(i + 1) * step, 3]) |
|
ih.sub_(torch.max(torch.unsqueeze(a[:, 1], 1), b[i * step:(i + 1) * step, 1])) |
|
|
|
iw.clamp_(min=0) |
|
ih.clamp_(min=0) |
|
|
|
iw.mul_(ih) |
|
del ih |
|
|
|
ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area[i * step:(i + 1) * step] - iw |
|
ua = torch.clamp(ua, min=1e-8) |
|
iw.div_(ua) |
|
del ua |
|
|
|
IoU[:, i * step:(i + 1) * step] = iw |
|
|
|
return IoU |
|
|
|
|
|
def calc_iou_vis(a, b): |
|
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) |
|
|
|
iw = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 0]) |
|
ih = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 1]) |
|
|
|
iw = torch.clamp(iw, min=0) |
|
ih = torch.clamp(ih, min=0) |
|
|
|
intersection = iw * ih |
|
|
|
IoU = intersection / area |
|
|
|
return IoU |
|
|
|
|
|
def IoG(box_a, box_b): |
|
inter_xmin = torch.max(box_a[:, 0], box_b[:, 0]) |
|
inter_ymin = torch.max(box_a[:, 1], box_b[:, 1]) |
|
inter_xmax = torch.min(box_a[:, 2], box_b[:, 2]) |
|
inter_ymax = torch.min(box_a[:, 3], box_b[:, 3]) |
|
Iw = torch.clamp(inter_xmax - inter_xmin, min=0) |
|
Ih = torch.clamp(inter_ymax - inter_ymin, min=0) |
|
I = Iw * Ih |
|
G = (box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1]) |
|
return I / G |
|
|
|
|
|
class FocalLoss(nn.Module): |
|
def __init__(self, is_cuda=True): |
|
super(FocalLoss, self).__init__() |
|
self.is_cuda = is_cuda |
|
|
|
def forward(self, classifications, regressions, anchors, annotations): |
|
alpha = 0.25 |
|
gamma = 2.0 |
|
batch_size = classifications.shape[0] |
|
classification_losses = [] |
|
regression_losses = [] |
|
|
|
anchor = anchors[0, :, :] |
|
|
|
anchor_widths = anchor[:, 2] - anchor[:, 0] |
|
anchor_heights = anchor[:, 3] - anchor[:, 1] |
|
anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths |
|
anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights |
|
|
|
for j in range(batch_size): |
|
|
|
classification = classifications[j, :, :] |
|
regression = regressions[j, :, :] |
|
|
|
bbox_annotation = annotations[j, :, :] |
|
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] |
|
|
|
if bbox_annotation.shape[0] == 0: |
|
if self.is_cuda: |
|
regression_losses.append(torch.tensor(0).float().cuda()) |
|
classification_losses.append(torch.tensor(0).float().cuda()) |
|
else: |
|
regression_losses.append(torch.tensor(0).float()) |
|
classification_losses.append(torch.tensor(0).float()) |
|
|
|
continue |
|
|
|
classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) |
|
|
|
IoU = calc_iou(anchor, bbox_annotation[:, :4]) # num_anchors x num_annotations |
|
|
|
IoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1 |
|
|
|
# compute the loss for classification |
|
targets = torch.ones(classification.shape) * -1 |
|
if self.is_cuda: |
|
targets = targets.cuda() |
|
|
|
targets[torch.lt(IoU_max, 0.4), :] = 0 |
|
|
|
positive_ful = torch.ge(IoU_max, 0.5) |
|
positive_indices = positive_ful |
|
|
|
num_positive_anchors = positive_indices.sum() |
|
|
|
assigned_annotations = bbox_annotation[IoU_argmax, :] |
|
|
|
targets[positive_indices, :] = 0 |
|
targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 |
|
try: |
|
alpha_factor = torch.ones(targets.shape) |
|
if self.is_cuda: |
|
alpha_factor = alpha_factor.cuda() |
|
alpha_factor *= alpha |
|
except: |
|
print(targets) |
|
print(targets.shape) |
|
|
|
alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) |
|
focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) |
|
focal_weight = alpha_factor * torch.pow(focal_weight, gamma) |
|
|
|
bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) |
|
|
|
# cls_loss = focal_weight * torch.pow(bce, gamma) |
|
cls_loss = focal_weight * bce |
|
|
|
cls_zeros = torch.zeros(cls_loss.shape) |
|
if self.is_cuda: |
|
cls_zeros = cls_zeros.cuda() |
|
cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, cls_zeros) |
|
|
|
classification_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0)) |
|
|
|
# compute the loss for regression |
|
|
|
if positive_indices.sum() > 0: |
|
assigned_annotations = assigned_annotations[positive_indices, :] |
|
|
|
anchor_widths_pi = anchor_widths[positive_indices] |
|
anchor_heights_pi = anchor_heights[positive_indices] |
|
anchor_ctr_x_pi = anchor_ctr_x[positive_indices] |
|
anchor_ctr_y_pi = anchor_ctr_y[positive_indices] |
|
|
|
gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] |
|
gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] |
|
gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths |
|
gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights |
|
|
|
# clip widths to 1 |
|
gt_widths = torch.clamp(gt_widths, min=1) |
|
gt_heights = torch.clamp(gt_heights, min=1) |
|
|
|
targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi |
|
targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi |
|
targets_dw = torch.log(gt_widths / anchor_widths_pi) |
|
targets_dh = torch.log(gt_heights / anchor_heights_pi) |
|
|
|
targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh)) |
|
targets = targets.t() |
|
|
|
if self.is_cuda: |
|
targets = targets.cuda() / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda() |
|
else: |
|
targets = targets / torch.Tensor([[0.1, 0.1, 0.2, 0.2]]) |
|
|
|
regression_diff = torch.abs(targets - regression[positive_indices, :]) |
|
|
|
regression_loss = torch.where( |
|
torch.le(regression_diff, 1.0 / 9.0), |
|
0.5 * 9.0 * torch.pow(regression_diff, 2), |
|
regression_diff - 0.5 / 9.0 |
|
) |
|
regression_losses.append(regression_loss.mean()) |
|
else: |
|
if self.is_cuda: |
|
regression_losses.append(torch.tensor(0).float().cuda()) |
|
else: |
|
regression_losses.append(torch.tensor(0).float()) |
|
|
|
return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses) \ |
|
.mean(dim=0, keepdim=True) |
|
|
|
|
|
class LevelAttentionLoss(nn.Module): |
|
def __init__(self, is_cuda=True): |
|
super(LevelAttentionLoss, self).__init__() |
|
self.is_cuda = is_cuda |
|
|
|
def forward(self, img_batch_shape, attention_mask, bboxs): |
|
h, w = img_batch_shape[2], img_batch_shape[3] |
|
|
|
mask_losses = [] |
|
|
|
batch_size = bboxs.shape[0] |
|
for j in range(batch_size): |
|
|
|
bbox_annotation = bboxs[j, :, :] |
|
bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] |
|
|
|
if bbox_annotation.shape[0] == 0: |
|
if self.is_cuda: |
|
mask_losses.append(torch.tensor(0).float().cuda()) |
|
else: |
|
mask_losses.append(torch.tensor(0).float()) |
|
continue |
|
|
|
cond1 = torch.le(bbox_annotation[:, 0], w) |
|
cond2 = torch.le(bbox_annotation[:, 1], h) |
|
cond3 = torch.le(bbox_annotation[:, 2], w) |
|
cond4 = torch.le(bbox_annotation[:, 3], h) |
|
cond = cond1 * cond2 * cond3 * cond4 |
|
|
|
bbox_annotation = bbox_annotation[cond, :] |
|
|
|
if bbox_annotation.shape[0] == 0: |
|
if self.is_cuda: |
|
mask_losses.append(torch.tensor(0).float().cuda()) |
|
else: |
|
mask_losses.append(torch.tensor(0).float()) |
|
continue |
|
|
|
bbox_area = (bbox_annotation[:, 2] - bbox_annotation[:, 0]) * ( |
|
bbox_annotation[:, 3] - bbox_annotation[:, 1]) |
|
|
|
mask_loss = [] |
|
for id in range(len(attention_mask)): |
|
|
|
attention_map = attention_mask[id][j, 0, :, :] |
|
|
|
min_area = (2 ** (id + 5)) ** 2 * 0.5 |
|
max_area = (2 ** (id + 5) * 1.58) ** 2 * 2 |
|
|
|
level_bbox_indice1 = torch.ge(bbox_area, min_area) |
|
level_bbox_indice2 = torch.le(bbox_area, max_area) |
|
|
|
level_bbox_indice = level_bbox_indice1 * level_bbox_indice2 |
|
|
|
level_bbox_annotation = bbox_annotation[level_bbox_indice, :].clone() |
|
|
|
# level_bbox_annotation = bbox_annotation.clone() |
|
|
|
attention_h, attention_w = attention_map.shape |
|
|
|
if level_bbox_annotation.shape[0]: |
|
level_bbox_annotation[:, 0] *= attention_w / w |
|
level_bbox_annotation[:, 1] *= attention_h / h |
|
level_bbox_annotation[:, 2] *= attention_w / w |
|
level_bbox_annotation[:, 3] *= attention_h / h |
|
|
|
mask_gt = torch.zeros(attention_map.shape) |
|
if self.is_cuda: |
|
mask_gt = mask_gt.cuda() |
|
|
|
for i in range(level_bbox_annotation.shape[0]): |
|
x1 = max(int(level_bbox_annotation[i, 0]), 0) |
|
y1 = max(int(level_bbox_annotation[i, 1]), 0) |
|
x2 = min(math.ceil(level_bbox_annotation[i, 2]) + 1, attention_w) |
|
y2 = min(math.ceil(level_bbox_annotation[i, 3]) + 1, attention_h) |
|
|
|
mask_gt[y1:y2, x1:x2] = 1 |
|
|
|
mask_gt = mask_gt[mask_gt >= 0] |
|
mask_predict = attention_map[attention_map >= 0] |
|
|
|
mask_loss.append(F.binary_cross_entropy(mask_predict, mask_gt)) |
|
mask_losses.append(torch.stack(mask_loss).mean()) |
|
|
|
return torch.stack(mask_losses).mean(dim=0, keepdim=True)
|
|
|