From 48f75e88f4a1d640e48f76ef8646a693a7a9757b Mon Sep 17 00:00:00 2001 From: harborn Date: Tue, 9 Apr 2024 15:38:16 +0800 Subject: [PATCH] upgrade Checkpoint usage to fit ray 3.0 (#388) * upgrade Checkpoint and session usage to fit ray 3.0 * upgrade ray version from 2.4.0 to 2.8.0 * ray 2.8.0 not available, use ray 2.7.0 * fix lint * fix lint * update * update --- python/raydp/tf/estimator.py | 11 +++-------- python/raydp/torch/estimator.py | 19 +++---------------- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/python/raydp/tf/estimator.py b/python/raydp/tf/estimator.py index 88916da2..5cd714f2 100644 --- a/python/raydp/tf/estimator.py +++ b/python/raydp/tf/estimator.py @@ -22,10 +22,9 @@ from tensorflow import DType, TensorShape from tensorflow.keras.callbacks import Callback -from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard +from ray.train.tensorflow import TensorflowTrainer, TensorflowCheckpoint, prepare_dataset_shard from ray.air import session from ray.air.config import ScalingConfig, RunConfig, FailureConfig -from ray.air.checkpoint import Checkpoint from ray.data import read_parquet from ray.data.dataset import Dataset from ray.data.preprocessors import Concatenator @@ -185,9 +184,7 @@ def train_func(config): if config["evaluate"]: test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks) results.append(test_history) - session.report({}, checkpoint=Checkpoint.from_dict({ - "model_weights": multi_worker_model.get_weights() - })) + session.report({}, checkpoint=TensorflowCheckpoint.from_model(multi_worker_model)) def fit(self, train_ds: Dataset, @@ -271,6 +268,4 @@ def fit_on_spark(self, def get_model(self) -> Any: assert self._trainer, "Trainer has not been created" - model = keras.models.model_from_json(self._serialized_model) - model.set_weights(self._results.checkpoint.to_dict()["model_weights"]) - return model + return self._results.checkpoint.get_model() diff --git a/python/raydp/torch/estimator.py b/python/raydp/torch/estimator.py index 20803e7d..a9546837 100644 --- a/python/raydp/torch/estimator.py +++ b/python/raydp/torch/estimator.py @@ -30,9 +30,8 @@ import ray from ray import train -from ray.train.torch import TorchTrainer +from ray.train.torch import TorchTrainer, TorchCheckpoint from ray.air.config import ScalingConfig, RunConfig, FailureConfig -from ray.air.checkpoint import Checkpoint from ray.air import session from ray.data.dataset import Dataset from ray.tune.search.sample import Domain @@ -255,9 +254,7 @@ def train_func(config): else: # if num_workers = 1, model is not wrapped states = model.state_dict() - session.report({}, checkpoint=Checkpoint.from_dict({ - "state_dict": states - })) + session.report({}, checkpoint=TorchCheckpoint.from_state_dict(states)) @staticmethod def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None): @@ -381,14 +378,4 @@ def fit_on_spark(self, def get_model(self): assert self._trainer is not None, "Must call fit first" - states = self._trained_results.checkpoint.to_dict()["state_dict"] - if isinstance(self._model, torch.nn.Module): - model = self._model - elif callable(self._model): - model = self._model() - else: - raise Exception( - "Unsupported parameter, we only support torch.nn.Model instance " - "or a function(dict -> model)") - model.load_state_dict(states) - return model + return self._trained_results.checkpoint.get_model()