diff --git a/datasets/vggface2.py b/datasets/vggface2.py index a36d258c1..f86e3a4b0 100644 --- a/datasets/vggface2.py +++ b/datasets/vggface2.py @@ -23,8 +23,8 @@ from torchvision import transforms import cv2 -import face_detection import kornia.geometry.transform as GT +from batch_face import RetinaFace from PIL import Image from skimage import transform as trans from tqdm import tqdm @@ -99,8 +99,11 @@ def __extract_gt(self): """ Extracts the ground truth from the dataset """ - detector = face_detection.build_detector("RetinaNetResNet50", confidence_threshold=.5, - nms_iou_threshold=.4) + if torch.cuda.is_available(): + detector = RetinaFace(gpu_id=torch.cuda.current_device(), network="resnet50") + else: + detector = RetinaFace(gpu_id=-1, network="resnet50") + img_paths = list(glob.glob(os.path.join(self.d_path + '/**/', '*.jpg'), recursive=True)) nf_number = 0 words_count = 0 @@ -111,22 +114,17 @@ def __extract_gt(self): boxes = [] image = cv2.imread(jpg) - img_max = max(image.shape[0], image.shape[1]) - if img_max > 1320: - continue - bboxes, lndmrks = detector.batched_detect_with_landmarks(np.expand_dims(image, 0)) - bboxes = bboxes[0] - lndmrks = lndmrks[0] + faces = detector(image) - if (bboxes.shape[0] == 0) or (lndmrks.shape[0] == 0): + if len(faces) == 0: nf_number += 1 continue - for box in bboxes: + for face in faces: + box = face[0] box = np.clip(box[:4], 0, None) boxes.append(box) - - lndmrks = lndmrks[0] + lndmrks = faces[0][1] dir_name = os.path.dirname(jpg) lbl = os.path.relpath(dir_name, self.d_path) diff --git a/requirements.txt b/requirements.txt index 95203ea3f..f553bacf0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ Pillow==10.3.0 PyYAML==6.0.1 albumentations==1.4.10 faiss-cpu==1.8.0 -face-detection==0.2.2 +batch-face>=1.4.0 h5py==3.11.0 kornia==0.7.2 librosa==0.10.2.post1 diff --git a/train.py b/train.py index f44e1ba53..bd5e49139 100755 --- a/train.py +++ b/train.py @@ -1233,8 +1233,14 @@ def _validate(data_loader, model, criterion, loggers, args, epoch=-1, tflogger=N _, class_preds_batch = torch.max(output, 1) class_probs.append(class_probs_batch) class_preds.append(class_preds_batch) + if args.kd_relationbased: + stats = ( + '', + OrderedDict([('Loss', losses[OBJECTIVE_LOSS_KEY].mean), + ('Overall Loss', losses[OVERALL_LOSS_KEY].mean)]) + ) - if args.obj_detection: + elif args.obj_detection: # Only run compute() if there is at least one new update() if have_mAP: mAP = map_calculator.compute()['map_50'] @@ -1343,12 +1349,8 @@ def update_training_scores_history(perf_scores_history, model, top1, top5, mAP, if args.kd_relationbased: # Keep perf_scores_history sorted from best to worst based on overall loss # overall_loss = student_loss*student_weight + distillation_loss*distillation_weight - if not args.sparsity_perf: - perf_scores_history.sort(key=operator.attrgetter('vloss', 'epoch'), - reverse=True) - else: - perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'vloss', 'epoch'), - reverse=True) + perf_scores_history.sort(key=operator.attrgetter('params_nnz_cnt', 'vloss', 'epoch'), + reverse=True) for score in perf_scores_history[:args.num_best_scores]: msglogger.info('==> Best [Overall Loss: %f on epoch: %d]', -score.vloss, score.epoch)