Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed Jan 3, 2024
1 parent 7ddf019 commit ddbce90
Show file tree
Hide file tree
Showing 5 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion flood_forecast/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def handle_later_ev(model, df_train_and_test, end_tensor, params, csv_test_loade
model.model.train() # sets mode to train so the dropout layers will be touched
assert num_prediction_samples > 0
if csv_test_loader.__class__.__name__ == "SeriesIDTestLoader":
raise NotImplementedError("SeriesIDTestLoader not yet supported for predictions")
raise NotImplementedError("SeriesIDTestLoader not yet supported for predictions.")
prediction_samples = generate_prediction_samples(
model,
df_train_and_test,
Expand Down
2 changes: 1 addition & 1 deletion flood_forecast/temporal_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def decoding_function(model, src: torch.Tensor, trg: torch.Tensor, forecast_leng
"""This function is responsible for decoding models that use `TemporalLoader` data. The basic logic of this
function is as follows. The data to the encoder (e.g. src) is not modified at each step of the decoding process.
Instead only the data to the decoder (e.g. the masked trg) is changed when forecasting max_len > forecast_length.
New data is appended (forecast_len == 2) (decoder_seq==10) (max==20) (20 (8)->2 First 8 should
New data is appended (forecast_len == 2) (decoder_seq==10) (max==20) (20 (8)->2 First 8 should be the same).
:param model: The PyTorch time series forecasting model that you want to use forecasting on.
:type model: `torch.nn.Module`
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name='flood_forecast',
version='1.0dev',
version='1.00dev',
packages=[
'flood_forecast',
'flood_forecast.transformer_xl',
Expand Down
2 changes: 1 addition & 1 deletion tests/test_series_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_series_test_loader(self):
print(all_rows_orig)
# self.assertIsInstance(all_rows_orig, pd.DataFrame)
self.assertGreater(forecast_start, 0)
# self.assertIsInstance(df_train_test, pd.DataFrame)..
# self.assertIsInstance(df_train_test, pd.DataFrame)..l

def test_eval_series_loader(self):
# infer_on_torch_model("s") # to-do fill in
Expand Down
1 change: 0 additions & 1 deletion tests/transformer_b_series.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
"inference_params":
{
"datetime_start":"2020-06-06",
"num_prediction_samples": 11,
"hours_to_forecast":4,
"test_csv_path":"tests/test_data/solar_small.csv",
"decoder_params":{
Expand Down

0 comments on commit ddbce90

Please sign in to comment.