From f333fd85a18002ffbb01b8a3257f381c14bdd5ec Mon Sep 17 00:00:00 2001 From: leixy Date: Tue, 15 Jun 2021 17:51:01 +0800 Subject: [PATCH 1/2] support test on video --- test_video.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 test_video.py diff --git a/test_video.py b/test_video.py new file mode 100644 index 00000000..421231d7 --- /dev/null +++ b/test_video.py @@ -0,0 +1,95 @@ + +import cv2 +import torch +import fractions +import numpy as np +import torch.nn.functional as F +from torchvision import transforms +from models.models import create_model +from options.test_options import TestOptions + + +def lcm(a, b): return abs(a * b) / fractions.gcd(a, b) if a and b else 0 + +transformer = transforms.Compose([ + transforms.ToTensor(), + #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + +transformer_Arcface = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + +detransformer = transforms.Compose([ + transforms.Normalize([0, 0, 0], [1/0.229, 1/0.224, 1/0.225]), + transforms.Normalize([-0.485, -0.456, -0.406], [1, 1, 1]) + ]) + +opt = TestOptions().parse() + +start_epoch, epoch_iter = 1, 0 + +torch.nn.Module.dump_patches = True +model = create_model(opt) +model.eval() + +def img_b_atte(img_b): + img_b = transformer(img_b) + img_att = img_b.view(-1, img_b.shape[0], img_b.shape[1], img_b.shape[2]) + img_att = img_att.cuda() + return img_att + +def swap(img_id,img_att,latend_id): + img_fake = model(img_id, img_att, latend_id, latend_id, True) + for i in range(img_id.shape[0]): + if i == 0: + row1 = img_id[i] + row2 = img_att[i] + row3 = img_fake[i] + else: + row1 = torch.cat([row1, img_id[i]], dim=2) + row2 = torch.cat([row2, img_att[i]], dim=2) + row3 = torch.cat([row3, img_fake[i]], dim=2) + + full = row3.detach() + full = full.permute(1, 2, 0) + output = full.to('cpu') + output = np.array(output) + output = output[..., ::-1] + output = output*255 + output=output.astype(np.uint8) + return output + +pic_a = opt.pic_a_path +img_a=cv2.imread(pic_a) +img_a=cv2.cvtColor(img_a,cv2.COLOR_BGR2RGB) +h,w,_=img_a.shape +if w!=224 or h!=224: + img_a=cv2.resize(img_a,(224,224)) +img_a = transformer_Arcface(img_a) +img_id = img_a.view(-1, img_a.shape[0], img_a.shape[1], img_a.shape[2]) +img_id=img_id.cuda() + +#create latent id +img_id_downsample = F.interpolate(img_id, scale_factor=0.5) +latend_id = model.netArc(img_id_downsample) +latend_id = latend_id.detach().to('cpu') +latend_id = latend_id/np.linalg.norm(latend_id) +latend_id = latend_id.to('cuda') + +cap=cv2.VideoCapture(opt.video_path) + +while cap.isOpened(): + _,img_b=cap.read() + if img_b is None: + break + h,w,_=img_b.shape + if w!=224 or h!=224: + img_b=cv2.resize(img_b,(224,224)) + img_b=cv2.cvtColor(img_b,cv2.COLOR_BGR2RGB) + img_att=img_b_atte(img_b) + img_fake=swap(img_id,img_att,latend_id) + cv2.imshow("swap",img_fake) + if cv2.waitKey(1) & 0xFF == ord('q'): + break From afe0452b428f0da415c170fd9aa9a9dea0dab96e Mon Sep 17 00:00:00 2001 From: leixy Date: Tue, 15 Jun 2021 17:53:23 +0800 Subject: [PATCH 2/2] support test on video --- options/test_options.py | 45 +++++++++++++++++++++-------------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/options/test_options.py b/options/test_options.py index 0a4b4beb..0c3f367b 100644 --- a/options/test_options.py +++ b/options/test_options.py @@ -1,22 +1,23 @@ -from .base_options import BaseOptions - -class TestOptions(BaseOptions): - def initialize(self): - BaseOptions.initialize(self) - self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') - self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') - self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') - self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') - self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') - self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') - self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') - self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') - self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") - self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") - self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") - self.parser.add_argument("--Arc_path", type=str, default='models/BEST_checkpoint.tar', help="run ONNX model via TRT") - self.parser.add_argument("--pic_a_path", type=str, default='crop_224/gdg.jpg', help="people a") - self.parser.add_argument("--pic_b_path", type=str, default='crop_224/zrf.jpg', help="people b") - self.parser.add_argument("--output_path", type=str, default='output/', help="people b") - - self.isTrain = False +from .base_options import BaseOptions + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') + self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') + self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') + self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') + self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') + self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") + self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") + self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") + self.parser.add_argument("--Arc_path", type=str, default='models/BEST_checkpoint.tar', help="run ONNX model via TRT") + self.parser.add_argument("--pic_a_path", type=str, default='crop_224/gdg.jpg', help="people a") + self.parser.add_argument("--pic_b_path", type=str, default='crop_224/zrf.jpg', help="people b") + self.parser.add_argument("--output_path", type=str, default='output/', help="people b") + self.parser.add_argument("--video_path",type=str,help="people b video") + + self.isTrain = False