Skip to content

Commit

Permalink
feat: add mtcf validation and test ops
Browse files Browse the repository at this point in the history
Ref.: #49
  • Loading branch information
vitalwarley committed Dec 23, 2023
1 parent 6871b35 commit 238ddac
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 72 deletions.
17 changes: 13 additions & 4 deletions ours/datasets/mtcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,26 @@ class MTCFDataset(Dataset):
Similar to Zhang et al. (2021)
"""

def __init__(self, root_dir: Path, sample_path: Path, negatives_per_sample: int = 1, transform=None):
def __init__(
self,
root_dir: Path,
sample_path: Path,
negatives_per_sample: int = 1,
extend_with_same_gen: bool = True,
transform=None,
):
self.root_dir = root_dir
self.sample_path = sample_path
self.extend_with_same_gen = extend_with_same_gen
self.transform = transform
self.samples = self.load_sample()
print(
f"Loaded {len(self.samples)} samples from {sample_path} "
+ "(with duplicated samples for same generation bb, ss, sibs)."
)
self.add_negative_samples(negatives_per_sample)
print(f"Added negative samples, now we have {len(self.samples)} samples.")
if negatives_per_sample:
self.add_negative_samples(negatives_per_sample)
print(f"Added negative samples, now we have {len(self.samples)} samples.")

def load_sample(self):
sample_list = []
Expand All @@ -36,7 +45,7 @@ def load_sample(self):
tmp = line.split(" ")
sample = Sample(tmp[0], tmp[1], tmp[2], tmp[-2], tmp[-1])
sample_list.append(sample)
if sample.is_same_generation:
if sample.is_same_generation and self.extend_with_same_gen:
# Create new sample swapping f1 and f2
sample_list.append(Sample(tmp[0], tmp[2], tmp[1], tmp[-2], tmp[-1]))
return sample_list
Expand Down
166 changes: 107 additions & 59 deletions ours/guild.yml
Original file line number Diff line number Diff line change
@@ -1,61 +1,109 @@
sota2020:
description: Sample training script
main: train_fc
sourcecode:
- dataset.py
- model.py
- mytypes.py
- train_fc.py
- transforms.py
- utils.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: ../fitw2020/
target-type: link
- file: insightface
- file: models/
target-type: link
- model: sota
operations:
sota2020:
description: Sample training script
main: train_fc
sourcecode:
- dataset.py
- model.py
- mytypes.py
- train_fc.py
- transforms.py
- utils.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: ../fitw2020/
target-type: link
- file: insightface
- file: models/
target-type: link

sota2021:
description: Sample training script
main: train_kv
sourcecode:
- dataset.py
- model.py
- mytypes.py
- train_kv.py
- utils.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: models
target-type: link
- file: runs
target-type: link
- file: ../rfiw2021/
target-type: link
- file: insightface
sota2021:
description: Sample training script
main: train_kv
sourcecode:
- dataset.py
- model.py
- mytypes.py
- train_kv.py
- utils.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: models
target-type: link
- file: runs
target-type: link
- file: ../rfiw2021/
target-type: link
- file: insightface

mtcf:
description: Reproduction of Hörmann et al. (2020)
main: tasks.mtcf
sourcecode:
- utils.py
- models/mtcf.py
- datasets/mtcf.py
- datasets/utils.py
- tasks/mtcf.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../rfiw2021/
target-type: link
- file: models/insightface
rename: models/insightface
target-type: link
- model: mtcf
operations:
train:
description: Reproduction of Hörmann et al. (2020)
main: tasks.mtcf train
sourcecode:
- utils.py
- models/mtcf.py
- datasets/mtcf.py
- datasets/utils.py
- tasks/mtcf.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../rfiw2021/
target-type: link
- file: models/insightface
rename: models/insightface
target-type: link
val:
description: Reproduction of Hörmann et al. (2020)
main: tasks.mtcf val
sourcecode:
- utils.py
- models/mtcf.py
- datasets/mtcf.py
- datasets/utils.py
- tasks/mtcf.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../rfiw2021/
target-type: link
- file: models/insightface
rename: models/insightface
target-type: link
- operation: mtcf
select: exp
test:
description: Reproduction of Hörmann et al. (2020)
main: tasks.mtcf test
sourcecode:
- utils.py
- models/mtcf.py
- datasets/mtcf.py
- datasets/utils.py
- tasks/mtcf.py
flags-import: all
flags-dest: args
output-scalars: '(\key):\s+(\value)'
requires:
- file: weights
target-type: link
- file: ../rfiw2021/
target-type: link
- file: models/insightface
rename: models/insightface
target-type: link
- operation: mtcf
select: exp
117 changes: 108 additions & 9 deletions ours/tasks/mtcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from utils import predict_kinship_mtcf, update_lr_mtcf, validate_pairs
from utils import predict_kinship_mtcf, test_pairs, update_lr_mtcf, validate_pairs


def log(loss, metric, epoch, step, global_step, cur_lr):
Expand All @@ -23,9 +23,6 @@ def log(loss, metric, epoch, step, global_step, cur_lr):


def train(args):
# Set random seed to 100
torch.manual_seed(100)

# Define transformations for training and validation sets
transform_img_train = transforms.Compose(
[
Expand Down Expand Up @@ -144,9 +141,80 @@ def train(args):
)


def create_parser():
parser = ArgumentParser(description="Configuration for the training script")
def val(args):
transform_img_val = transforms.Compose(
[
transforms.ToTensor(),
]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the training dataset
val_dataset = MTCFDataset(Path(args.root_dir), Path(args.dataset_path), transform=transform_img_val)

# Define the model
model = MTCFNet()
# Load the model weights
model.load_state_dict(torch.load(args.weights))
model.to(device)

# Define the DataLoader for the training set
val_dataloader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=4, # Assuming 12 workers for loading data
pin_memory=True,
)

auc, thresh = validate_pairs(
model, val_dataloader, device=args.device, return_thresh=True, predict=predict_kinship_mtcf
)

print(f"auc: {auc:.3f} | thresh: {thresh:.3f}")


def test(args):
transform_img_test = transforms.Compose(
[
transforms.ToTensor(),
]
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the training dataset
test_dataset = MTCFDataset(
Path(args.root_dir),
Path(args.dataset_path),
negatives_per_sample=0,
extend_with_same_gen=False,
transform=transform_img_test,
)

# Define the model
model = MTCFNet()
# Load the model weights
model.load_state_dict(torch.load(args.weights))
model.to(device)

# Define the DataLoader for the training set
test_dataloader = DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=12, # Assuming 12 workers for loading data
pin_memory=True,
)

acc = test_pairs(model, test_dataloader, device=args.device, thresh=args.threshold, predict=predict_kinship_mtcf)

print(f"acc: {acc:.3f}")


def create_parser_train(subparsers):
parser = subparsers.add_parser("train", help="Train the model")
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--train-dataset-path", type=str, required=True)
parser.add_argument("--val-dataset-path", type=str, required=True)
Expand All @@ -159,12 +227,43 @@ def create_parser():
parser.add_argument("--batch-size", type=int, default=200, help="Batch size")
parser.add_argument("--loss-log-step", type=int, default=100, help="Steps for logging loss")
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.set_defaults(func=train)

return parser

def create_parser_val(subparsers):
parser = subparsers.add_parser("val", help="Train the model")
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--dataset-path", type=str, required=True)
parser.add_argument("--weights", type=str, required=True)
parser.add_argument("--output-dir", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=1024, help="Batch size")
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.set_defaults(func=val)


def create_parser_test(subparsers):
parser = subparsers.add_parser("test", help="Train the model")
parser.add_argument("--root-dir", type=str, required=True)
parser.add_argument("--dataset-path", type=str, required=True)
parser.add_argument("--weights", type=str, required=True)
parser.add_argument("--threshold", type=float, required=True)
parser.add_argument("--output-dir", type=str, required=True)
parser.add_argument("--batch-size", type=int, default=1024, help="Batch size")
parser.add_argument("--device", type=str, default="0", help="Device to use for training")
parser.set_defaults(func=test)


if __name__ == "__main__":
parser = create_parser()
# Set random seed to 100
torch.manual_seed(100)

parser = ArgumentParser(description="Configuration for the MTCF strategy")
subparsers = parser.add_subparsers()

create_parser_train(subparsers)
create_parser_val(subparsers)
create_parser_test(subparsers)

args = parser.parse_args()

args.output_dir = Path(args.output_dir)
Expand All @@ -187,4 +286,4 @@ def create_parser():
print("CUDA is not available.")
args.device = torch.device("cpu")

train(args)
args.func(args)
12 changes: 12 additions & 0 deletions ours/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,18 @@ def validate_pairs(model, dataloader, device, return_thresh=False, predict=predi
return auc


def test_pairs(model, dataloader, device, thresh, predict=predict_kinship):
model.eval()
predictions, y_true = predict(model, dataloader)
# move all to device
predictions = predictions.to(device)
y_true = y_true.to(device)
# compute metrics
acc = tm.Accuracy(task="binary", threshold=thresh).to(device)
acc = acc(predictions, y_true)
return acc


def test(model, dataloader, threshold):
model.eval()
predictions, y_true = predict(model, dataloader)
Expand Down

0 comments on commit 238ddac

Please sign in to comment.