Skip to content

Commit

Permalink
Training
Browse files Browse the repository at this point in the history
  • Loading branch information
madhavkhoslaa committed Aug 26, 2019
1 parent b0cb46e commit 081844e
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,22 @@
transforms_compose = transforms.Compose(
[transforms.ToTensor()])
params = hyperparameters(
train_percentage=0.6,
train_percentage=1,
batch_size=1,
epoch=200,
n_classes=1)
epoch=50,
n_classes=2)

if torch.cuda.is_available():
net = UNeT(n_classes=29, n_channels=3).cuda()
net = UNeT(n_classes=2, n_channels=3).cuda()
else:
net = UNeT(n_classes=29, n_channels=3)
net = UNeT(n_classes=2, n_channels=3)
encoder= HotEncoder(is_binary= False, dir= ANNOTATIONS_DIR, extension="png")
color_dict= encoder.gen_colors()
Images = ImageList(
Images=IMAGE_DIR,
Annotations=ANNOTATIONS_DIR,
train_percentage=1,
extension="png")
loss_val = Loss()
Train = ImageLoader(
encoder_obj= encoder,
data=Images.train_set,
Expand All @@ -58,7 +58,6 @@
Test,
batch_size=params.hyperparameters["batch_size"],
shuffle=True)
color_dict= encoder.gen_colors()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for i, data in enumerate(TrainLoader, 0):
Expand All @@ -68,6 +67,7 @@
metrics = defaultdict()
running_loss = 0.0
for i, data in enumerate(TrainLoader, 0):
net.train()
if torch.cuda.is_available():
inputs, labels = data["Image"].cuda(), data["Label"].cuda()
else:
Expand All @@ -80,11 +80,7 @@
loss.backward()
optimizer.step()
running_loss += loss.item()
print(
"Epoch: {} | Loss: {} | Instance: {}".format(
int(epoch),
loss.item(),
i))
print("Running loss|", running_loss)
torch.save(net.state_dict(), MODEL_SAVE + "/model.pt")
print("Model saved")
print(loss)
torch.save(net.state_dict(), MODEL_SAVE + "/model_roof_epoch50_batch_5.pt")

print("Model saved")

0 comments on commit 081844e

Please sign in to comment.