Skip to content

Commit

Permalink
fixes to bugs final.
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Feb 9, 2021
1 parent 79b2cc6 commit 81184d8
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 9 deletions.
7 changes: 0 additions & 7 deletions flood_forecast/pytorch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,15 +231,10 @@ def compute_loss(labels, output, src, criterion, validation_dataset, probabilist
output = validation_dataset.inverse_scale(output.cpu())
labels = validation_dataset.inverse_scale(labels.cpu())
elif len(output.shape) == 3:
print("original output shape ")
print(output.shape)
output = output.cpu().numpy().transpose(0, 2, 1)
labels = labels.cpu().numpy().transpose(0, 2, 1)
output = validation_dataset.inverse_scale(torch.from_numpy(output))
labels = validation_dataset.inverse_scale(torch.from_numpy(labels))
print("Output shape is ")
print(output)
print(labels)
stuff = src.cpu().numpy().transpose(0, 2, 1)
src = validation_dataset.inverse_scale(torch.from_numpy(stuff))
else:
Expand All @@ -255,8 +250,6 @@ def compute_loss(labels, output, src, criterion, validation_dataset, probabilist
assert len(labels.shape) == len(output.shape)
loss = criterion(labels.float(), output, src, m)
else:
print(output.shape)
print(labels.shape)
assert len(labels.shape) == len(output.shape)
assert labels.shape[0] == output.shape[0]
loss = criterion(output, labels.float())
Expand Down
4 changes: 2 additions & 2 deletions tests/transformer_bottleneck.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"test_path": "tests/test_data/keag_small.csv",
"batch_size":4,
"forecast_history":5,
"forecast_length":5,
"forecast_length":1,
"train_end": 100,
"valid_start":301,
"valid_end": 401,
Expand Down Expand Up @@ -67,7 +67,7 @@
"dataset_params":{
"file_path": "tests/test_data/keag_small.csv",
"forecast_history":5,
"forecast_length":5,
"forecast_length":1,
"relevant_cols": ["cfs", "precip", "temp"],
"target_col": ["cfs"],
"scaling": "RobustScaler",
Expand Down

0 comments on commit 81184d8

Please sign in to comment.