1
0
Fork 0
Face identification and recognition scalable server with multiple face directories. https://github.com/ehp/faceserver
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

210 lines
8.3 KiB

# -*- coding: utf-8 -*-
"""
Copyright 2019 Petr Masopust, Aprar s.r.o.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Adopted code from https://github.com/ronghuaiyang/arcface-pytorch
"""
import argparse
import os
import time
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms as T
from recognition.angle import AngleLinear, CosFace, SphereFace, ArcFace, AdaCos
from recognition.focal_loss import FocalLoss
from recognition.nets import get_net_by_name
from recognition.test import lfw_test2, get_pair_list, load_img_data
class Dataset(torch.utils.data.Dataset):
def __init__(self, root, data_list_file, imagesize):
with open(os.path.join(data_list_file), 'r') as fd:
imgs = fd.readlines()
imgs = [os.path.join(root, img[:-1]) for img in imgs]
self.labels = list(set([img.split()[1] for img in imgs]))
self.imgs = np.random.permutation(imgs)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
self.transforms = T.Compose([
T.RandomResizedCrop(imagesize),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize,
T.RandomErasing()
])
def __getitem__(self, index):
sample = self.imgs[index]
splits = sample.split()
img_path = splits[0]
data = Image.open(img_path)
data = data.convert(mode="RGB")
data = self.transforms(data)
cls = self.label_to_class(splits[1])
return data.float(), cls
def __len__(self):
return len(self.imgs)
def label_to_class(self, label):
for idx, v in enumerate(self.labels):
if v == label:
return idx
raise Exception("Unknown label %s" % label)
def num_labels(self):
return len(self.labels)
def main(args=None):
parser = argparse.ArgumentParser(description='Training script for face identification.')
parser.add_argument('--print_freq', help='Print every N batch (default 100)', type=int, default=100)
parser.add_argument('--epochs', help='Number of epochs', type=int, default=50)
parser.add_argument('--net', help='Net name, must be one of resnet18, resnet34, resnet50, resnet101, resnet152, resnext50, resnext101 or spherenet',
default='resnet50')
parser.add_argument('--lr_step', help='Learning rate step (default 10)', type=int, default=10)
parser.add_argument('--lr', help='Learning rate (default 0.1)', type=float, default=0.1)
parser.add_argument('--weight_decay', help='Weight decay (default 0.0005)', type=float, default=0.0005)
parser.add_argument('--easy_margin', help='Use easy margin (default false)', dest='easy_margin', default=False,
action='store_true')
parser.add_argument('--parallel', help='Run training with DataParallel', dest='parallel',
default=False, action='store_true')
parser.add_argument('--loss',
help='One of focal_loss. cross_entropy, arcface, cosface, sphereface, adacos (default cross_entropy)',
type=str, default='cross_entropy')
parser.add_argument('--optimizer', help='One of sgd, adam (default sgd)',
type=str, default='sgd')
parser.add_argument('--batch_size', help='Batch size (default 16)', type=int, default=16)
parser.add_argument('--casia_list', help='Path to CASIA dataset file list (training)')
parser.add_argument('--casia_root', help='Path to CASIA images (training)')
parser.add_argument('--lfw_root', help='Path to LFW dataset (testing)')
parser.add_argument('--lfw_pair_list', help='Path to LFW pair list file (testing)')
parser.add_argument('--model_name', help='Name of the model to save')
parser = parser.parse_args(args)
is_cuda = torch.cuda.is_available()
print('CUDA available: {}'.format(is_cuda))
imagesize = 224
model = get_net_by_name(parser.net)
# TODO split training dataset to train/validation and stop using test dataset for acc
train_dataset = Dataset(parser.casia_root, parser.casia_list, imagesize)
trainloader = torch.utils.data.DataLoader(train_dataset,
batch_size=parser.batch_size,
shuffle=True,
# pin_memory=True,
num_workers=0)
num_classes = train_dataset.num_labels()
if parser.loss == 'focal_loss':
metric_fc = nn.Linear(512, num_classes)
criterion = FocalLoss(gamma=2, is_cuda=is_cuda)
elif parser.loss == 'cross_entropy':
metric_fc = nn.Linear(512, num_classes)
criterion = torch.nn.CrossEntropyLoss()
if is_cuda:
criterion = criterion.cuda()
elif parser.loss == 'cosface':
metric_fc = AngleLinear(512, num_classes)
criterion = CosFace(is_cuda=is_cuda)
elif parser.loss == 'arcface':
metric_fc = AngleLinear(512, num_classes)
criterion = ArcFace(is_cuda=is_cuda)
elif parser.loss == 'sphereface':
metric_fc = AngleLinear(512, num_classes)
criterion = SphereFace(is_cuda=is_cuda)
elif parser.loss == 'adacos':
metric_fc = AngleLinear(512, num_classes)
criterion = AdaCos(num_classes, is_cuda=is_cuda)
else:
raise ValueError('Unknown loss %s' % parser.loss)
if parser.optimizer == 'sgd':
optimizer = torch.optim.SGD([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
lr=parser.lr, weight_decay=parser.weight_decay)
elif parser.optimizer == 'adam':
optimizer = torch.optim.Adam([{'params': model.parameters()}, {'params': metric_fc.parameters()}],
lr=parser.lr, weight_decay=parser.weight_decay)
else:
raise ValueError('Unknown optimizer %s' % parser.optimizer)
scheduler = StepLR(optimizer, step_size=parser.lr_step, gamma=0.1)
if parser.parallel:
model = nn.DataParallel(model)
metric_fc = nn.DataParallel(metric_fc)
if is_cuda:
model.cuda()
metric_fc.cuda()
print(model)
print(metric_fc)
identity_list = get_pair_list(parser.lfw_pair_list)
img_data = load_img_data(identity_list, parser.lfw_root)
print('{} train iters per epoch:'.format(len(trainloader)))
start = time.time()
last_acc = 0.0
for i in range(parser.epochs):
model.train()
for ii, data in enumerate(trainloader):
data_input, label = data
if is_cuda:
data_input = data_input.cuda()
label = label.cuda().long()
feature = model(data_input)
output = metric_fc(feature)
loss = criterion(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iters = i * len(trainloader) + ii
if iters % parser.print_freq == 0:
speed = parser.print_freq / (time.time() - start)
time_str = time.asctime(time.localtime(time.time()))
print('{} train epoch {} iter {} {} iters/s loss {}'.format(time_str, i, ii, speed, loss.item()))
start = time.time()
scheduler.step()
model.eval()
acc = lfw_test2(model, identity_list, img_data, is_cuda=is_cuda)
print('Accuracy: %f' % acc)
if last_acc < acc:
# TODO remove makedir
os.makedirs('./ckpt', exist_ok=True)
torch.save(model.state_dict(), './ckpt/' + parser.model_name + '_{}.pt'.format(i))
torch.save(metric_fc.state_dict(), './ckpt/' + parser.model_name + '_metric_{}.pt'.format(i))
if __name__ == '__main__':
main()