Skip to content

Commit

Permalink
Merge pull request #710 from 1Pravi/master
Browse files Browse the repository at this point in the history
Add error handling for brain age model
  • Loading branch information
sarthakpati authored Nov 14, 2023
2 parents e22f8c4 + b9913d6 commit 87af776
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions GANDLF/models/brain_age.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch.nn as nn
import sys
import torchvision

import traceback

def brainage(parameters):
"""
Expand All @@ -18,11 +18,13 @@ def brainage(parameters):
"""

# Check that the input data is 2D
if parameters["model"]["dimension"] != 2:
sys.exit("Brain Age predictions only works on 2D data")
assert parameters["model"]["dimension"] == 2, "Brain Age predictions only work on 2D data"

# Load the pretrained VGG16 model
model = torchvision.models.vgg16(pretrained=True)
try:
# Load the pretrained VGG16 model
model = torchvision.models.vgg16(pretrained=True)
except Exception:
sys.exit("Error: Failed to load VGG16 model: " + traceback.format_exc())

# Remove the final convolutional layer
model.final_convolution_layer = None
Expand All @@ -36,19 +38,13 @@ def brainage(parameters):
features = list(model.classifier.children())[:-1] # Remove the last layer
features.extend(
[
nn.Linear(
num_features, 1024
), # Add a linear layer with 1024 output features
nn.Linear(num_features, 1024), # Add a linear layer with 1024 output features
nn.ReLU(True), # Add a ReLU activation function
nn.Dropout2d(0.8), # Add a 2D dropout layer with a probability of 0.8
nn.Linear(
1024, 1
), # Add a linear layer with 1 output feature (for brain age prediction)
nn.Linear(1024, 1), # Add a linear layer with 1 output feature (for brain age prediction)
]
)
model.classifier = nn.Sequential(
*features
) # Replace the model classifier with the modified one
model.classifier = nn.Sequential(*features) # Replace the model classifier with the modified one

# Set the "amp" parameter to False (not yet implemented for VGG)
parameters["model"]["amp"] = False
Expand Down

0 comments on commit 87af776

Please sign in to comment.