diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index f21ac96e..792a00f6 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -48,7 +48,7 @@ def log_image(self, key, images): def log_model(self, model): # Create model signature - #signature = infer_signature(X.numpy(), model(X).detach().numpy()) + #signature = infer_signature(train_dataset.numpy(), model(train_dataset).detach().numpy()) mlflow.pytorch.log_model(model, "model") @@ -361,7 +361,7 @@ def main(input_args=None): # Log the model training_logger.log_model(model) - + # data_module.train_dataloader().dataset.data if __name__ == "__main__": main()