Skip to content

Commit

Permalink
upgrade Checkpoint usage to fit ray 3.0 (#388)
Browse files Browse the repository at this point in the history
* upgrade Checkpoint and session usage to fit ray 3.0

* upgrade ray version from 2.4.0 to 2.8.0

* ray 2.8.0 not available, use ray 2.7.0

* fix lint

* fix lint

* update

* update
  • Loading branch information
harborn authored Apr 9, 2024
1 parent 01e851f commit 48f75e8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 24 deletions.
11 changes: 3 additions & 8 deletions python/raydp/tf/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
from tensorflow import DType, TensorShape
from tensorflow.keras.callbacks import Callback

from ray.train.tensorflow import TensorflowTrainer, prepare_dataset_shard
from ray.train.tensorflow import TensorflowTrainer, TensorflowCheckpoint, prepare_dataset_shard
from ray.air import session
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
from ray.air.checkpoint import Checkpoint
from ray.data import read_parquet
from ray.data.dataset import Dataset
from ray.data.preprocessors import Concatenator
Expand Down Expand Up @@ -185,9 +184,7 @@ def train_func(config):
if config["evaluate"]:
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
results.append(test_history)
session.report({}, checkpoint=Checkpoint.from_dict({
"model_weights": multi_worker_model.get_weights()
}))
session.report({}, checkpoint=TensorflowCheckpoint.from_model(multi_worker_model))

def fit(self,
train_ds: Dataset,
Expand Down Expand Up @@ -271,6 +268,4 @@ def fit_on_spark(self,

def get_model(self) -> Any:
assert self._trainer, "Trainer has not been created"
model = keras.models.model_from_json(self._serialized_model)
model.set_weights(self._results.checkpoint.to_dict()["model_weights"])
return model
return self._results.checkpoint.get_model()
19 changes: 3 additions & 16 deletions python/raydp/torch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@

import ray
from ray import train
from ray.train.torch import TorchTrainer
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
from ray.air.checkpoint import Checkpoint
from ray.air import session
from ray.data.dataset import Dataset
from ray.tune.search.sample import Domain
Expand Down Expand Up @@ -255,9 +254,7 @@ def train_func(config):
else:
# if num_workers = 1, model is not wrapped
states = model.state_dict()
session.report({}, checkpoint=Checkpoint.from_dict({
"state_dict": states
}))
session.report({}, checkpoint=TorchCheckpoint.from_state_dict(states))

@staticmethod
def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):
Expand Down Expand Up @@ -381,14 +378,4 @@ def fit_on_spark(self,

def get_model(self):
assert self._trainer is not None, "Must call fit first"
states = self._trained_results.checkpoint.to_dict()["state_dict"]
if isinstance(self._model, torch.nn.Module):
model = self._model
elif callable(self._model):
model = self._model()
else:
raise Exception(
"Unsupported parameter, we only support torch.nn.Model instance "
"or a function(dict -> model)")
model.load_state_dict(states)
return model
return self._trained_results.checkpoint.get_model()

0 comments on commit 48f75e8

Please sign in to comment.