Skip to content

Commit

Permalink
Only forward pass from trained model file.
Browse files Browse the repository at this point in the history
  • Loading branch information
madhavkhoslaa committed Jul 15, 2019
1 parent afe34b5 commit 1dae603
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import torchvision
from model.Unet import UNeT
import sys
import torchvision

to_pred= sys.argv[1:]
model= UNeT(n_channels= 3, n_classes=30)
model.load_state_dict('./model.pt')
model.eval()
for _ in to_pred:
output= model(_)
torchvision.utils.save_image(output, _+'prediction.jpeg')
print("Output saved at {}+prediction.jpeg".format(_))

0 comments on commit 1dae603

Please sign in to comment.