diff --git a/rayTrain_Tune.py b/rayTrain_Tune.py index d9884666..7ec4a861 100644 --- a/rayTrain_Tune.py +++ b/rayTrain_Tune.py @@ -58,7 +58,7 @@ def train_func(config): torch.save(model.state_dict(), checkpoint_path) checkpoint = Checkpoint.from_directory(checkpoint_dir) - ray.train.report({"loss": loss.item()}, checkpoint=checkpoint) + ray.train.report({"loss": loss}, checkpoint=checkpoint) if __name__ == "__main__":