Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
Galaxies99 committed Aug 27, 2021
1 parent e7b2edb commit 8cd470e
Show file tree
Hide file tree
Showing 11 changed files with 416 additions and 42 deletions.
2 changes: 2 additions & 0 deletions configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

"dataset":
"data_dir": "data"
"use_augmentation": True
"augmentation_probability": 0.8

"trainer":
"batch_size": 4
Expand Down
19 changes: 11 additions & 8 deletions dataset/transparent_grasp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
"""
Transparent Grasp Dataset.
Author: Hongjie Fang.
"""
import os
import json
import torch
import numpy as np
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset
from utils.data_preparation import process_data


class TransparentGrasp(Dataset):
Expand All @@ -24,6 +30,7 @@ def __init__(self, data_dir, split = 'train', **kwargs):
if split not in ['train', 'test']:
raise AttributeError('Invalid split option.')
self.data_dir = data_dir
self.split = split
with open(os.path.join(self.data_dir, 'metadata.json'), 'r') as fp:
self.dataset_metadata = json.load(fp)
self.scene_num = self.dataset_metadata['total_scenes']
Expand Down Expand Up @@ -52,20 +59,16 @@ def __init__(self, data_dir, split = 'train', **kwargs):
])
# Integrity double-check
assert len(self.sample_info) == self.total_samples, "Error in total samples, expect {} samples, found {} samples.".format(self.total_samples, len(self.sample_info))
self.use_aug = kwargs.get('use_augmentation', True)
self.aug_prob = kwargs.get('augmentation_probability', 0.8)

def __getitem__(self, id):
img_path, camera_type, scene_type = self.sample_info[id]
rgb = np.array(Image.open(os.path.join(img_path, 'rgb{}.png'.format(camera_type))), dtype = np.float32) / 255.0
rgb = rgb.transpose(2, 0, 1) # HWC -> CHW
rgb = np.array(Image.open(os.path.join(img_path, 'rgb{}.png'.format(camera_type))), dtype = np.float32)
depth = np.array(Image.open(os.path.join(img_path, 'depth{}.png'.format(camera_type))), dtype = np.float32)
depth = depth / (1000 if camera_type == 1 else 4000) # depth sensor scaling
depth = np.where(depth > 10, 1, depth / 10)
depth_gt = np.array(Image.open(os.path.join(img_path, 'depth{}-gt.png'.format(camera_type))), dtype = np.float32)
depth_gt = depth_gt / (1000 if camera_type == 1 else 4000) # depth sensor scaling
depth_gt = np.where(depth_gt > 10, 1, depth_gt / 10)
depth_gt_mask = np.array(Image.open(os.path.join(img_path, 'depth{}-gt-mask.png'.format(camera_type))), dtype = np.bool)
scene_mask = np.array([1 if scene_type == 'cluttered' else 0], dtype = np.bool)
return torch.FloatTensor(rgb), torch.FloatTensor(depth), torch.FloatTensor(depth_gt), torch.BoolTensor(depth_gt_mask), torch.BoolTensor(scene_mask)
return process_data(rgb, depth, depth_gt, depth_gt_mask, scene_type, camera_type, split = self.split, use_aug = self.use_aug, aug_prob = self.aug_prob)

def __len__(self):
return self.total_samples
8 changes: 8 additions & 0 deletions models/DFNet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Depth Filler Network.
Author: Hongjie Fang.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -7,6 +12,9 @@


class DFNet(nn.Module):
"""
Depth Filler Network (DFNet).
"""
def __init__(self, in_channels = 4, hidden_channels = 64, L = 5, k = 12, **kwargs):
super(DFNet, self).__init__()
self.in_channels = in_channels
Expand Down
5 changes: 5 additions & 0 deletions models/dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Dense Block.
Author: Hongjie Fang.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
9 changes: 9 additions & 0 deletions models/duc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
"""
Dense upsampling convolution layer.
Author: Hongjie Fang.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F


class DenseUpsamplingConvolution(nn.Module):
"""
Dense upsampling convolution module.
"""
def __init__(self, inplanes, planes, upscale_factor = 2):
super(DenseUpsamplingConvolution, self).__init__()
self.layer = nn.Sequential(
Expand Down
6 changes: 5 additions & 1 deletion models/weight_init.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
'''
From mmcv lib (https://github.com/open-mmlab/mmcv)
Weight Initialization.
Author: Authors from [mmcv] repository.
Ref:
1. [mmcv] repository: https://github.com/open-mmlab/mmcv.
'''
import torch.nn as nn

Expand Down
24 changes: 16 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
Training scripts.
Authors: Hongjie Fang.
"""
import os
import yaml
import torch
Expand Down Expand Up @@ -128,7 +133,7 @@ def test_one_epoch(epoch):
logger.info('Threshold 1.05 (w/o mask): {:.6f}, {:.6f}'.format(metrics_result[9], metrics_result[8]))
logger.info('Threshold 1.10 (w/o mask): {:.6f}, {:.6f}'.format(metrics_result[11], metrics_result[10]))
logger.info('Threshold 1.25 (w/o mask): {:.6f}, {:.6f}'.format(metrics_result[13], metrics_result[12]))
return mean_loss
return mean_loss, metrics_result


def train(start_epoch):
Expand All @@ -137,18 +142,21 @@ def train(start_epoch):
for epoch in range(start_epoch, max_epoch):
logger.info('--> Epoch {}/{}'.format(epoch + 1, max_epoch))
train_one_epoch(epoch)
loss = test_one_epoch(epoch)
loss, metrics_result = test_one_epoch(epoch)
if lr_scheduler is not None:
lr_scheduler.step()
save_dict = {
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict() if builder.multigpu() else model.state_dict(),
'mean_loss': loss,
'metrics': list(metrics_result)
}
torch.save(save_dict, os.path.join(stats_dir, 'checkpoint-ep{}.tar'.format(epoch)))
if loss < min_loss:
min_loss = loss
min_loss_epoch = epoch + 1
save_dict = {
'epoch': epoch + 1,
'model_state_dict': model.module.state_dict() if builder.multigpu() else model.state_dict(),
}
torch.save(save_dict, os.path.join(stats_dir, 'checkpoint.tar'))
logger.info('Training Finished. Max accuracy: {:.6f}, in epoch {}'.format(min_loss, min_loss_epoch))
torch.save(save_dict, os.path.join(stats_dir, 'checkpoint.tar'.format(epoch)))
logger.info('Training Finished. Min testing loss: {:.6f}, in epoch {}'.format(min_loss, min_loss_epoch))


if __name__ == '__main__':
Expand Down
59 changes: 40 additions & 19 deletions utils/builder.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
"""
Configuration builder.
Authors: Hongjie Fang.
"""
import os
from dataset.transparent_grasp import TransparentGrasp
from models.DFNet import DFNet


class ConfigBuilder(object):
'''
Configuration Builder
'''
"""
Configuration Builder.
Features includes:
- build model from configuration;
- build optimizer from configuration;
- build learning rate scheduler from configuration;
- build dataset & dataloader from configuration;
- build statistics directory from configuration;
- fetch training parameters (e.g., max_epoch, multigpu) from configuration.
"""
def __init__(self, **params):
'''
"""
Set the default configuration for the configuration builder.
Parameters
----------
params: the configuration parameters.
'''
"""
super(ConfigBuilder, self).__init__()
self.params = params

def get_model(self, model_params = None):
'''
"""
Get the model from configuration.
Parameters
Expand All @@ -28,7 +42,7 @@ def get_model(self, model_params = None):
Returns
-------
A model, which is usually a torch.nn.Module object.
'''
"""
if model_params is None:
model_params = self.params.get('model', {})
type = model_params.get('type', 'DFNet')
Expand All @@ -40,7 +54,7 @@ def get_model(self, model_params = None):
return model

def get_optimizer(self, model, optimizer_params = None):
'''
"""
Get the optimizer from configuration.
Parameters
Expand All @@ -51,7 +65,7 @@ def get_optimizer(self, model, optimizer_params = None):
Returns
-------
An optimizer for the given model.
'''
"""
from torch.optim import SGD, ASGD, Adagrad, Adamax, Adadelta, Adam, AdamW, RMSprop
if optimizer_params is None:
optimizer_params = self.params.get('optimizer', {})
Expand All @@ -78,7 +92,7 @@ def get_optimizer(self, model, optimizer_params = None):
return optimizer

def get_lr_scheduler(self, optimizer, lr_scheduler_params = None):
'''
"""
Get the learning rate scheduler from configuration.
Parameters
Expand All @@ -89,7 +103,7 @@ def get_lr_scheduler(self, optimizer, lr_scheduler_params = None):
Returns
-------
A learning rate scheduler for the given optimizer.
'''
"""
from torch.optim.lr_scheduler import MultiStepLR, ExponentialLR, CyclicLR, CosineAnnealingLR, LambdaLR, StepLR
if lr_scheduler_params is None:
lr_scheduler_params = self.params.get('lr_scheduler', {})
Expand All @@ -114,7 +128,7 @@ def get_lr_scheduler(self, optimizer, lr_scheduler_params = None):
return scheduler

def get_dataset(self, dataset_params = None, split = 'train'):
'''
"""
Get the dataset from configuration.
Parameters
Expand All @@ -125,13 +139,13 @@ def get_dataset(self, dataset_params = None, split = 'train'):
Returns
-------
A torch.utils.data.Dataset item.
'''
"""
if dataset_params is None:
dataset_params = self.params.get('dataset', {"data_dir": "data"})
return TransparentGrasp(split = split, **dataset_params)

def get_dataloader(self, dataset_params = None, split = 'train', batch_size = None, num_workers = None, shuffle = None):
'''
"""
Get the dataloader from configuration.
Parameters
Expand All @@ -145,7 +159,7 @@ def get_dataloader(self, dataset_params = None, split = 'train', batch_size = No
Returns
-------
A torch.utils.data.DataLoader item.
'''
"""
from torch.utils.data import DataLoader
if batch_size is None:
batch_size = self.params.get('trainer', {}).get('batch_size', 4)
Expand All @@ -162,17 +176,17 @@ def get_dataloader(self, dataset_params = None, split = 'train', batch_size = No
)

def get_max_epoch(self):
'''
"""
Get the max epoch from configuration.
Returns
-------
An integer, which is the max epoch (default: 50).
'''
"""
return self.params.get('trainer', {}).get('max_epoch', 50)

def get_stats_dir(self, stats_params = None):
'''
"""
Get the statistics directory from configuration.
Parameters
Expand All @@ -182,7 +196,7 @@ def get_stats_dir(self, stats_params = None):
Returns
-------
A string, the statistics directory.
'''
"""
if stats_params is None:
stats_params = self.params.get('stats', {})
stats_dir = stats_params.get('stats_dir', 'stats')
Expand All @@ -193,4 +207,11 @@ def get_stats_dir(self, stats_params = None):
return stats_res_dir

def multigpu(self):
"""
Get the multigpu settings from configuration.
Returns
-------
A boolean value, whether to use the multigpu training/testing (default: False).
"""
return self.params.get('trainer', {}).get('multigpu', False)
Loading

0 comments on commit 8cd470e

Please sign in to comment.