diff --git a/dataset/videomatte.py b/dataset/videomatte.py index 555911b..2d86849 100644 --- a/dataset/videomatte.py +++ b/dataset/videomatte.py @@ -9,14 +9,14 @@ class VideoMatteDataset(Dataset): def __init__(self, videomatte_dir, - background_image_dir, + # background_image_dir, background_video_dir, size, seq_length, seq_sampler, transform=None): - self.background_image_dir = background_image_dir - self.background_image_files = os.listdir(background_image_dir) + # self.background_image_dir = background_image_dir + # self.background_image_files = os.listdir(background_image_dir) self.background_video_dir = background_video_dir self.background_video_clips = sorted(os.listdir(background_video_dir)) self.background_video_frames = [sorted(os.listdir(os.path.join(background_video_dir, clip))) @@ -38,10 +38,10 @@ def __len__(self): return len(self.videomatte_idx) def __getitem__(self, idx): - if random.random() < 0.5: - bgrs = self._get_random_image_background() - else: - bgrs = self._get_random_video_background() + # if random.random() < 0.5: + # bgrs = self._get_random_image_background() + # else: + bgrs = self._get_random_video_background() fgrs, phas = self._get_videomatte(idx) @@ -50,11 +50,11 @@ def __getitem__(self, idx): return fgrs, phas, bgrs - def _get_random_image_background(self): - with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: - bgr = self._downsample_if_needed(bgr.convert('RGB')) - bgrs = [bgr] * self.seq_length - return bgrs + # def _get_random_image_background(self): + # with Image.open(os.path.join(self.background_image_dir, random.choice(self.background_image_files))) as bgr: + # bgr = self._downsample_if_needed(bgr.convert('RGB')) + # bgrs = [bgr] * self.seq_length + # return bgrs def _get_random_video_background(self): clip_idx = random.choice(range(len(self.background_video_clips))) diff --git a/inference.py b/inference.py index a116754..3a1d1b3 100644 --- a/inference.py +++ b/inference.py @@ -120,6 +120,10 @@ def convert_video(model, rec = [None] * 4 for src in reader: + if src.shape[-1] %2 == 1: + src = src[:, :, :, :-1] + if src.shape[-2] %2 == 1: + src = src[:, :, :-1, :] if downsample_ratio is None: downsample_ratio = auto_downsample_ratio(*src.shape[2:]) diff --git a/inference_utils.py b/inference_utils.py index d651dc0..c6b4111 100644 --- a/inference_utils.py +++ b/inference_utils.py @@ -5,7 +5,7 @@ from torch.utils.data import Dataset from torchvision.transforms.functional import to_pil_image from PIL import Image - +import torch class VideoReader(Dataset): def __init__(self, path, transform=None): @@ -55,18 +55,23 @@ def close(self): class ImageSequenceReader(Dataset): def __init__(self, path, transform=None): self.path = path - self.files = sorted(os.listdir(path)) + self.files_fgr = sorted(os.listdir(path + "fgr/")) + self.files_bgr = sorted(os.listdir(path + "bgr/")) self.transform = transform def __len__(self): - return len(self.files) + return len(self.files_fgr) def __getitem__(self, idx): - with Image.open(os.path.join(self.path, self.files[idx])) as img: - img.load() + with Image.open(os.path.join(self.path + "fgr/", self.files_fgr[idx])) as fgr_img: + fgr_img.load() + + with Image.open(os.path.join(self.path + "bgr/", self.files_bgr[idx])) as bgr_img: + bgr_img.load() + if self.transform is not None: - return self.transform(img) - return img + return torch.cat([self.transform(fgr_img), self.transform(bgr_img)], dim = 0) + return fgr_img class ImageSequenceWriter: diff --git a/model/decoder.py b/model/decoder.py index 7307435..7569429 100644 --- a/model/decoder.py +++ b/model/decoder.py @@ -1,7 +1,7 @@ import torch from torch import Tensor from torch import nn -from torch.nn import functional as F +# from torch.nn import functional as F from typing import Tuple, Optional class RecurrentDecoder(nn.Module): @@ -9,10 +9,10 @@ def __init__(self, feature_channels, decoder_channels): super().__init__() self.avgpool = AvgPool() self.decode4 = BottleneckBlock(feature_channels[3]) - self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 3, decoder_channels[0]) - self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 3, decoder_channels[1]) - self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 3, decoder_channels[2]) - self.decode0 = OutputBlock(decoder_channels[2], 3, decoder_channels[3]) + self.decode3 = UpsamplingBlock(feature_channels[3], feature_channels[2], 6, decoder_channels[0]) + self.decode2 = UpsamplingBlock(decoder_channels[0], feature_channels[1], 6, decoder_channels[1]) + self.decode1 = UpsamplingBlock(decoder_channels[1], feature_channels[0], 6, decoder_channels[2]) + self.decode0 = OutputBlock(decoder_channels[2], 6, decoder_channels[3]) def forward(self, s0: Tensor, f1: Tensor, f2: Tensor, f3: Tensor, f4: Tensor, diff --git a/model/mobilenetv3.py b/model/mobilenetv3.py index 712a298..271a93e 100644 --- a/model/mobilenetv3.py +++ b/model/mobilenetv3.py @@ -3,6 +3,21 @@ from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig from torchvision.transforms.functional import normalize +def load_matched_state_dict(model, state_dict, print_stats=True): + """ + Only loads weights that matched in key and shape. Ignore other weights. + """ + num_matched, num_total = 0, 0 + curr_state_dict = model.state_dict() + for key in curr_state_dict.keys(): + num_total += 1 + if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape: + curr_state_dict[key] = state_dict[key] + num_matched += 1 + model.load_state_dict(curr_state_dict) + if print_stats: + print(f'Loaded state_dict: {num_matched}/{num_total} matched') + class MobileNetV3LargeEncoder(MobileNetV3): def __init__(self, pretrained: bool = False): super().__init__( @@ -27,14 +42,24 @@ def __init__(self, pretrained: bool = False): ) if pretrained: - self.load_state_dict(torch.hub.load_state_dict_from_url( - 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) + pretrained_state_dict = torch.hub.load_state_dict_from_url( + 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth') + + # print("pretrained_state_dict keys \n \n ", pretrained_state_dict.keys()) + + # print("\n\ncurrent model state dict keys \n\n", self.state_dict().keys()) + + load_matched_state_dict(self, pretrained_state_dict) + + # self.load_state_dict(torch.hub.load_state_dict_from_url( + # 'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth')) del self.avgpool del self.classifier def forward_single_frame(self, x): - x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + # print(x.shape) + x = torch.cat((normalize(x[:, :3, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), normalize(x[:, 3:, ...], [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])), dim = -3) x = self.features[0](x) x = self.features[1](x) diff --git a/model/model.py b/model/model.py index 71fc684..bd47f0c 100644 --- a/model/model.py +++ b/model/model.py @@ -1,6 +1,7 @@ import torch from torch import Tensor from torch import nn +from torchsummary import summary from torch.nn import functional as F from typing import Optional, List @@ -58,8 +59,8 @@ def forward(self, if not segmentation_pass: fgr_residual, pha = self.project_mat(hid).split([3, 1], dim=-3) if downsample_ratio != 1: - fgr_residual, pha = self.refiner(src, src_sm, fgr_residual, pha, hid) - fgr = fgr_residual + src + fgr_residual, pha = self.refiner(src[:, :, :3, ...], src_sm[:, :, :3, ...], fgr_residual, pha, hid) + fgr = fgr_residual + src[:, :, :3, ...] fgr = fgr.clamp(0., 1.) pha = pha.clamp(0., 1.) return [fgr, pha, *rec] diff --git a/requirements_training.txt b/requirements_training.txt index 70fd4b1..0ac0666 100644 --- a/requirements_training.txt +++ b/requirements_training.txt @@ -1,5 +1,7 @@ easing_functions==1.0.4 -tensorboard==2.5.0 -torch==1.9.0 -torchvision==0.10.0 -tqdm==4.61.1 \ No newline at end of file +tensorboard +torch +torchvision +tqdm==4.61.1 +opencv-python==4.6.0.66 +torchsummary \ No newline at end of file diff --git a/train.py b/train.py index 462bd1f..a542add 100644 --- a/train.py +++ b/train.py @@ -121,7 +121,8 @@ from model import MattingNetwork from train_config import DATA_PATHS from train_loss import matting_loss, segmentation_loss - +import kornia +from torchvision import transforms as T class Trainer: def __init__(self, rank, world_size): @@ -189,7 +190,7 @@ def init_datasets(self): if self.args.dataset == 'videomatte': self.dataset_lr_train = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['train'], - background_image_dir=DATA_PATHS['background_images']['train'], + # background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_lr, seq_length=self.args.seq_length_lr, @@ -198,7 +199,7 @@ def init_datasets(self): if self.args.train_hr: self.dataset_hr_train = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['train'], - background_image_dir=DATA_PATHS['background_images']['train'], + # background_image_dir=DATA_PATHS['background_images']['train'], background_video_dir=DATA_PATHS['background_videos']['train'], size=self.args.resolution_hr, seq_length=self.args.seq_length_hr, @@ -206,38 +207,38 @@ def init_datasets(self): transform=VideoMatteTrainAugmentation(size_hr)) self.dataset_valid = VideoMatteDataset( videomatte_dir=DATA_PATHS['videomatte']['valid'], - background_image_dir=DATA_PATHS['background_images']['valid'], + # background_image_dir=DATA_PATHS['background_images']['valid'], background_video_dir=DATA_PATHS['background_videos']['valid'], size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, seq_sampler=ValidFrameSampler(), transform=VideoMatteValidAugmentation(size_hr if self.args.train_hr else size_lr)) - else: - self.dataset_lr_train = ImageMatteDataset( - imagematte_dir=DATA_PATHS['imagematte']['train'], - background_image_dir=DATA_PATHS['background_images']['train'], - background_video_dir=DATA_PATHS['background_videos']['train'], - size=self.args.resolution_lr, - seq_length=self.args.seq_length_lr, - seq_sampler=TrainFrameSampler(), - transform=ImageMatteAugmentation(size_lr)) - if self.args.train_hr: - self.dataset_hr_train = ImageMatteDataset( - imagematte_dir=DATA_PATHS['imagematte']['train'], - background_image_dir=DATA_PATHS['background_images']['train'], - background_video_dir=DATA_PATHS['background_videos']['train'], - size=self.args.resolution_hr, - seq_length=self.args.seq_length_hr, - seq_sampler=TrainFrameSampler(), - transform=ImageMatteAugmentation(size_hr)) - self.dataset_valid = ImageMatteDataset( - imagematte_dir=DATA_PATHS['imagematte']['valid'], - background_image_dir=DATA_PATHS['background_images']['valid'], - background_video_dir=DATA_PATHS['background_videos']['valid'], - size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, - seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, - seq_sampler=ValidFrameSampler(), - transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) + # else: + # self.dataset_lr_train = ImageMatteDataset( + # imagematte_dir=DATA_PATHS['imagematte']['train'], + # background_image_dir=DATA_PATHS['background_images']['train'], + # background_video_dir=DATA_PATHS['background_videos']['train'], + # size=self.args.resolution_lr, + # seq_length=self.args.seq_length_lr, + # seq_sampler=TrainFrameSampler(), + # transform=ImageMatteAugmentation(size_lr)) + # if self.args.train_hr: + # self.dataset_hr_train = ImageMatteDataset( + # imagematte_dir=DATA_PATHS['imagematte']['train'], + # background_image_dir=DATA_PATHS['background_images']['train'], + # background_video_dir=DATA_PATHS['background_videos']['train'], + # size=self.args.resolution_hr, + # seq_length=self.args.seq_length_hr, + # seq_sampler=TrainFrameSampler(), + # transform=ImageMatteAugmentation(size_hr)) + # self.dataset_valid = ImageMatteDataset( + # imagematte_dir=DATA_PATHS['imagematte']['valid'], + # background_image_dir=DATA_PATHS['background_images']['valid'], + # background_video_dir=DATA_PATHS['background_videos']['valid'], + # size=self.args.resolution_hr if self.args.train_hr else self.args.resolution_lr, + # seq_length=self.args.seq_length_hr if self.args.train_hr else self.args.seq_length_lr, + # seq_sampler=ValidFrameSampler(), + # transform=ImageMatteAugmentation(size_hr if self.args.train_hr else size_lr)) # Matting dataloaders: self.datasampler_lr_train = DistributedSampler( @@ -270,49 +271,49 @@ def init_datasets(self): pin_memory=True) # Segementation datasets - self.log('Initializing image segmentation datasets') - self.dataset_seg_image = ConcatDataset([ - CocoPanopticDataset( - imgdir=DATA_PATHS['coco_panoptic']['imgdir'], - anndir=DATA_PATHS['coco_panoptic']['anndir'], - annfile=DATA_PATHS['coco_panoptic']['annfile'], - transform=CocoPanopticTrainAugmentation(size_lr)), - SuperviselyPersonDataset( - imgdir=DATA_PATHS['spd']['imgdir'], - segdir=DATA_PATHS['spd']['segdir'], - transform=CocoPanopticTrainAugmentation(size_lr)) - ]) - self.datasampler_seg_image = DistributedSampler( - dataset=self.dataset_seg_image, - rank=self.rank, - num_replicas=self.world_size, - shuffle=True) - self.dataloader_seg_image = DataLoader( - dataset=self.dataset_seg_image, - batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, - num_workers=self.args.num_workers, - sampler=self.datasampler_seg_image, - pin_memory=True) + # self.log('Initializing image segmentation datasets') + # self.dataset_seg_image = ConcatDataset([ + # CocoPanopticDataset( + # imgdir=DATA_PATHS['coco_panoptic']['imgdir'], + # anndir=DATA_PATHS['coco_panoptic']['anndir'], + # annfile=DATA_PATHS['coco_panoptic']['annfile'], + # transform=CocoPanopticTrainAugmentation(size_lr)), + # SuperviselyPersonDataset( + # imgdir=DATA_PATHS['spd']['imgdir'], + # segdir=DATA_PATHS['spd']['segdir'], + # transform=CocoPanopticTrainAugmentation(size_lr)) + # ]) + # self.datasampler_seg_image = DistributedSampler( + # dataset=self.dataset_seg_image, + # rank=self.rank, + # num_replicas=self.world_size, + # shuffle=True) + # self.dataloader_seg_image = DataLoader( + # dataset=self.dataset_seg_image, + # batch_size=self.args.batch_size_per_gpu * self.args.seq_length_lr, + # num_workers=self.args.num_workers, + # sampler=self.datasampler_seg_image, + # pin_memory=True) - self.log('Initializing video segmentation datasets') - self.dataset_seg_video = YouTubeVISDataset( - videodir=DATA_PATHS['youtubevis']['videodir'], - annfile=DATA_PATHS['youtubevis']['annfile'], - size=self.args.resolution_lr, - seq_length=self.args.seq_length_lr, - seq_sampler=TrainFrameSampler(speed=[1]), - transform=YouTubeVISAugmentation(size_lr)) - self.datasampler_seg_video = DistributedSampler( - dataset=self.dataset_seg_video, - rank=self.rank, - num_replicas=self.world_size, - shuffle=True) - self.dataloader_seg_video = DataLoader( - dataset=self.dataset_seg_video, - batch_size=self.args.batch_size_per_gpu, - num_workers=self.args.num_workers, - sampler=self.datasampler_seg_video, - pin_memory=True) + # self.log('Initializing video segmentation datasets') + # self.dataset_seg_video = YouTubeVISDataset( + # videodir=DATA_PATHS['youtubevis']['videodir'], + # annfile=DATA_PATHS['youtubevis']['annfile'], + # size=self.args.resolution_lr, + # seq_length=self.args.seq_length_lr, + # seq_sampler=TrainFrameSampler(speed=[1]), + # transform=YouTubeVISAugmentation(size_lr)) + # self.datasampler_seg_video = DistributedSampler( + # dataset=self.dataset_seg_video, + # rank=self.rank, + # num_replicas=self.world_size, + # shuffle=True) + # self.dataloader_seg_video = DataLoader( + # dataset=self.dataset_seg_video, + # batch_size=self.args.batch_size_per_gpu, + # num_workers=self.args.num_workers, + # sampler=self.datasampler_seg_video, + # pin_memory=True) def init_model(self): self.log('Initializing model') @@ -359,12 +360,12 @@ def train(self): self.train_mat(true_fgr, true_pha, true_bgr, downsample_ratio=self.args.downsample_ratio, tag='hr') # Segmentation pass - if self.step % 2 == 0: - true_img, true_seg = self.load_next_seg_video_sample() - self.train_seg(true_img, true_seg, log_label='seg_video') - else: - true_img, true_seg = self.load_next_seg_image_sample() - self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') + # if self.step % 2 == 0: + # true_img, true_seg = self.load_next_seg_video_sample() + # self.train_seg(true_img, true_seg, log_label='seg_video') + # else: + # true_img, true_seg = self.load_next_seg_image_sample() + # self.train_seg(true_img.unsqueeze(1), true_seg.unsqueeze(1), log_label='seg_image') if self.step % self.args.checkpoint_save_interval == 0: self.save() @@ -376,10 +377,47 @@ def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): true_pha = true_pha.to(self.rank, non_blocking=True) true_bgr = true_bgr.to(self.rank, non_blocking=True) true_fgr, true_pha, true_bgr = self.random_crop(true_fgr, true_pha, true_bgr) - true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) + true_src = true_bgr.clone() + + # Augment bgr with shadow + aug_shadow_idx = torch.rand(len(true_src)) < 0.3 + if aug_shadow_idx.any(): + aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()).flatten(start_dim = 0, end_dim = 1) + aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) + aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) + expected_shape = torch.tensor(true_src[aug_shadow_idx].shape) + expected_shape[2] = -1 + true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow.reshape(expected_shape.tolist())).clamp_(0, 1) + del aug_shadow + del aug_shadow_idx + + # Composite foreground onto source + true_src = true_fgr * true_pha + true_src * (1 - true_pha) + + # Augment with noise + aug_noise_idx = torch.rand(len(true_src)) < 0.4 + if aug_noise_idx.any(): + true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) + true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) + del aug_noise_idx + + # Augment background with jitter + aug_jitter_idx = torch.rand(len(true_src)) < 0.8 + if aug_jitter_idx.any(): + true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_jitter_idx].shape) + del aug_jitter_idx + + # Augment background with affine + aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 + if aug_affine_idx.any(): + true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx].flatten(start_dim = 0, end_dim = 1)).reshape(true_bgr[aug_affine_idx].shape) + del aug_affine_idx + + fg_bg_input = torch.cat((true_src, true_bgr), dim = -3) + with autocast(enabled=not self.args.disable_mixed_precision): - pred_fgr, pred_pha = self.model_ddp(true_src, downsample_ratio=downsample_ratio)[:2] + pred_fgr, pred_pha = self.model_ddp(fg_bg_input, downsample_ratio=downsample_ratio)[:2] loss = matting_loss(pred_fgr, pred_pha, true_fgr, true_pha) self.scaler.scale(loss['total']).backward() @@ -397,29 +435,30 @@ def train_mat(self, true_fgr, true_pha, true_bgr, downsample_ratio, tag): self.writer.add_image(f'train_{tag}_true_fgr', make_grid(true_fgr.flatten(0, 1), nrow=true_fgr.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_pha', make_grid(true_pha.flatten(0, 1), nrow=true_pha.size(1)), self.step) self.writer.add_image(f'train_{tag}_true_src', make_grid(true_src.flatten(0, 1), nrow=true_src.size(1)), self.step) - - def train_seg(self, true_img, true_seg, log_label): - true_img = true_img.to(self.rank, non_blocking=True) - true_seg = true_seg.to(self.rank, non_blocking=True) + + # does not get called + # def train_seg(self, true_img, true_seg, log_label): + # true_img = true_img.to(self.rank, non_blocking=True) + # true_seg = true_seg.to(self.rank, non_blocking=True) - true_img, true_seg = self.random_crop(true_img, true_seg) + # true_img, true_seg = self.random_crop(true_img, true_seg) - with autocast(enabled=not self.args.disable_mixed_precision): - pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] - loss = segmentation_loss(pred_seg, true_seg) + # with autocast(enabled=not self.args.disable_mixed_precision): + # pred_seg = self.model_ddp(true_img, segmentation_pass=True)[0] + # loss = segmentation_loss(pred_seg, true_seg) - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - self.optimizer.zero_grad() + # self.scaler.scale(loss).backward() + # self.scaler.step(self.optimizer) + # self.scaler.update() + # self.optimizer.zero_grad() - if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: - self.writer.add_scalar(f'{log_label}_loss', loss, self.step) + # if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_loss_interval == 0: + # self.writer.add_scalar(f'{log_label}_loss', loss, self.step) - if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: - self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) - self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) - self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) + # if self.rank == 0 and (self.step - self.step % 2) % self.args.log_train_images_interval == 0: + # self.writer.add_image(f'{log_label}_pred_seg', make_grid(pred_seg.flatten(0, 1).float().sigmoid(), nrow=self.args.seq_length_lr), self.step) + # self.writer.add_image(f'{log_label}_true_seg', make_grid(true_seg.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) + # self.writer.add_image(f'{log_label}_true_img', make_grid(true_img.flatten(0, 1), nrow=self.args.seq_length_lr), self.step) def load_next_mat_hr_sample(self): try: @@ -430,23 +469,23 @@ def load_next_mat_hr_sample(self): sample = next(self.dataiterator_mat_hr) return sample - def load_next_seg_video_sample(self): - try: - sample = next(self.dataiterator_seg_video) - except: - self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) - self.dataiterator_seg_video = iter(self.dataloader_seg_video) - sample = next(self.dataiterator_seg_video) - return sample + # def load_next_seg_video_sample(self): + # try: + # sample = next(self.dataiterator_seg_video) + # except: + # self.datasampler_seg_video.set_epoch(self.datasampler_seg_video.epoch + 1) + # self.dataiterator_seg_video = iter(self.dataloader_seg_video) + # sample = next(self.dataiterator_seg_video) + # return sample - def load_next_seg_image_sample(self): - try: - sample = next(self.dataiterator_seg_image) - except: - self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) - self.dataiterator_seg_image = iter(self.dataloader_seg_image) - sample = next(self.dataiterator_seg_image) - return sample + # def load_next_seg_image_sample(self): + # try: + # sample = next(self.dataiterator_seg_image) + # except: + # self.datasampler_seg_image.set_epoch(self.datasampler_seg_image.epoch + 1) + # self.dataiterator_seg_image = iter(self.dataloader_seg_image) + # sample = next(self.dataiterator_seg_image) + # return sample def validate(self): if self.rank == 0: @@ -461,7 +500,9 @@ def validate(self): true_bgr = true_bgr.to(self.rank, non_blocking=True) true_src = true_fgr * true_pha + true_bgr * (1 - true_pha) batch_size = true_src.size(0) - pred_fgr, pred_pha = self.model(true_src)[:2] + + fg_bg_input = torch.cat((true_src, true_bgr), dim = -3) + pred_fgr, pred_pha = self.model(fg_bg_input)[:2] total_loss += matting_loss(pred_fgr, pred_pha, true_fgr, true_pha)['total'].item() * batch_size total_count += batch_size avg_loss = total_loss / total_count diff --git a/train_config.py b/train_config.py index 0792696..c122a55 100644 --- a/train_config.py +++ b/train_config.py @@ -37,32 +37,32 @@ 'train': '../matting-data/VideoMatte240K_JPEG_SD/train', 'valid': '../matting-data/VideoMatte240K_JPEG_SD/valid', }, - 'imagematte': { - 'train': '../matting-data/ImageMatte/train', - 'valid': '../matting-data/ImageMatte/valid', - }, - 'background_images': { - 'train': '../matting-data/Backgrounds/train', - 'valid': '../matting-data/Backgrounds/valid', - }, + # 'imagematte': { + # 'train': '../matting-data/ImageMatte/train', + # 'valid': '../matting-data/ImageMatte/valid', + # }, + # 'background_images': { + # 'train': '../matting-data/Backgrounds/train', + # 'valid': '../matting-data/Backgrounds/valid', + # }, 'background_videos': { 'train': '../matting-data/BackgroundVideos/train', 'valid': '../matting-data/BackgroundVideos/valid', }, - 'coco_panoptic': { - 'imgdir': '../matting-data/coco/train2017/', - 'anndir': '../matting-data/coco/panoptic_train2017/', - 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json', - }, - 'spd': { - 'imgdir': '../matting-data/SuperviselyPersonDataset/img', - 'segdir': '../matting-data/SuperviselyPersonDataset/seg', - }, - 'youtubevis': { - 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages', - 'annfile': '../matting-data/YouTubeVIS/train/instances.json', - } + # 'coco_panoptic': { + # 'imgdir': '../matting-data/coco/train2017/', + # 'anndir': '../matting-data/coco/panoptic_train2017/', + # 'annfile': '../matting-data/coco/annotations/panoptic_train2017.json', + # }, + # 'spd': { + # 'imgdir': '../matting-data/SuperviselyPersonDataset/img', + # 'segdir': '../matting-data/SuperviselyPersonDataset/seg', + # }, + # 'youtubevis': { + # 'videodir': '../matting-data/YouTubeVIS/train/JPEGImages', + # 'annfile': '../matting-data/YouTubeVIS/train/instances.json', + # } }