Skip to content

Commit

Permalink
enh: enable loading model weights from training checkpoint (#3969)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
geoffreyangus and pre-commit-ci[bot] authored Mar 20, 2024
1 parent c09d5dc commit 25e4ac1
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 9 deletions.
31 changes: 26 additions & 5 deletions ludwig/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
MODEL_HYPERPARAMETERS_FILE_NAME,
set_disable_progressbar,
TRAIN_SET_METADATA_FILE_NAME,
TRAINING_CHECKPOINTS_DIR_PATH,
)
from ludwig.models.base import BaseModel
from ludwig.models.calibrator import Calibrator
Expand Down Expand Up @@ -1282,9 +1283,12 @@ def evaluate(
self.model.output_features, predictions, dataset, training_set_metadata
)
eval_stats = {
of_name: {**eval_stats[of_name], **overall_stats[of_name]}
# account for presence of 'combined' key
if of_name in overall_stats else {**eval_stats[of_name]}
of_name: (
{**eval_stats[of_name], **overall_stats[of_name]}
# account for presence of 'combined' key
if of_name in overall_stats
else {**eval_stats[of_name]}
)
for of_name in eval_stats
}

Expand Down Expand Up @@ -1765,6 +1769,7 @@ def load(
gpu_memory_limit: Optional[float] = None,
allow_parallel_threads: bool = True,
callbacks: List[Callback] = None,
from_checkpoint: bool = False,
) -> "LudwigModel": # return is an instance of ludwig.api.LudwigModel class
"""This function allows for loading pretrained models.
Expand All @@ -1788,6 +1793,9 @@ def load(
:param callbacks: (list, default: `None`) a list of
`ludwig.callbacks.Callback` objects that provide hooks into the
Ludwig pipeline.
:param from_checkpoint: (bool, default: `False`) if `True`, the model
will be loaded from the latest checkpoint (training_checkpoints/)
instead of the final model weights.
# Return
Expand Down Expand Up @@ -1834,7 +1842,7 @@ def load(
ludwig_model.model = LudwigModel.create_model(config_obj)

# load model weights
ludwig_model.load_weights(model_dir)
ludwig_model.load_weights(model_dir, from_checkpoint)

# The LoRA layers appear to be loaded again (perhaps due to a potential bug); hence, we merge and unload again.
if ludwig_model.is_merge_and_unload_set():
Expand All @@ -1851,12 +1859,16 @@ def load(
def load_weights(
self,
model_dir: str,
from_checkpoint: bool = False,
) -> None:
"""Loads weights from a pre-trained model.
# Inputs
:param model_dir: (str) filepath string to location of a pre-trained
model
:param from_checkpoint: (bool, default: `False`) if `True`, the model
will be loaded from the latest checkpoint (training_checkpoints/)
instead of the final model weights.
# Return
:return: `None`
Expand All @@ -1868,7 +1880,16 @@ def load_weights(
```
"""
if self.backend.is_coordinator():
self.model.load(model_dir)
if from_checkpoint:
with self.backend.create_trainer(
model=self.model,
config=self.config_obj.trainer,
) as trainer:
checkpoint = trainer.create_checkpoint_handle()
training_checkpoints_path = os.path.join(model_dir, TRAINING_CHECKPOINTS_DIR_PATH)
trainer.resume_weights_and_optimizer(training_checkpoints_path, checkpoint)
else:
self.model.load(model_dir)

self.backend.sync_model(self.model)

Expand Down
9 changes: 6 additions & 3 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,11 @@ def save_checkpoint(self, progress_tracker: ProgressTracker, save_path: str, che
# Callback that the checkpoint was reached, regardless of whether the model was evaluated.
self.callback(lambda c: c.on_checkpoint(self, progress_tracker))

def create_checkpoint_handle(self):
return self.distributed.create_checkpoint_handle(
dist_model=self.dist_model, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler
)

def train(
self,
training_set,
Expand Down Expand Up @@ -873,9 +878,7 @@ def train(
)

# ====== Setup session =======
checkpoint = self.distributed.create_checkpoint_handle(
dist_model=self.dist_model, model=self.model, optimizer=self.optimizer, scheduler=self.scheduler
)
checkpoint = self.create_checkpoint_handle()
checkpoint_manager = CheckpointManager(checkpoint, training_checkpoints_path, device=self.device)

# ====== Setup Tensorboard writers =======
Expand Down
2 changes: 1 addition & 1 deletion ludwig/utils/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,4 @@ def load_latest_checkpoint(checkpoint: Checkpoint, directory: str, device: torch
if last_ckpt:
checkpoint.load(last_ckpt, device)
else:
logger.error(f"No checkpoints found in {directory}.")
raise FileNotFoundError(f"No checkpoints found in {directory}.")
73 changes: 73 additions & 0 deletions tests/integration_tests/test_model_save_and_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,79 @@
)


def test_model_load_from_checkpoint(tmpdir, csv_filename, tmp_path):
torch.manual_seed(1)
random.seed(1)
np.random.seed(1)

input_features = [
binary_feature(),
number_feature(),
]

output_features = [
binary_feature(),
]

data_csv_path = generate_data(input_features, output_features, csv_filename, num_examples=50)

config = {
"input_features": input_features,
"output_features": output_features,
TRAINER: {"epochs": 1, BATCH_SIZE: 2},
}
backend = LocalTestBackend()

# create sub-directory to store results
results_dir = tmp_path / "results"
results_dir.mkdir()

data_df = read_csv(data_csv_path)
splitter = get_splitter("random")
training_set, validation_set, test_set = splitter.split(data_df, backend)
ludwig_model1 = LudwigModel(config, backend=backend)
_, _, output_dir = ludwig_model1.train(
training_set=training_set,
validation_set=validation_set,
test_set=test_set,
output_directory="results", # results_dir
)

model_dir = os.path.join(output_dir, "model")
ludwig_model_loaded = LudwigModel.load(model_dir, backend=backend, from_checkpoint=True)
preds_1, _ = ludwig_model1.predict(dataset=validation_set)

def check_model_equal(ludwig_model2):
# Compare model predictions
preds_2, _ = ludwig_model2.predict(dataset=validation_set)
assert set(preds_1.keys()) == set(preds_2.keys())
for key in preds_1:
assert preds_1[key].dtype == preds_2[key].dtype, key
assert np.all(a == b for a, b in zip(preds_1[key], preds_2[key])), key
# assert preds_2[key].dtype == preds_3[key].dtype, key
# assert list(preds_2[key]) == list(preds_3[key]), key

# Compare model weights
for if_name in ludwig_model1.model.input_features:
if1 = ludwig_model1.model.input_features.get(if_name)
if2 = ludwig_model2.model.input_features.get(if_name)
for if1_w, if2_w in zip(if1.encoder_obj.parameters(), if2.encoder_obj.parameters()):
assert torch.allclose(if1_w, if2_w)

c1 = ludwig_model1.model.combiner
c2 = ludwig_model2.model.combiner
for c1_w, c2_w in zip(c1.parameters(), c2.parameters()):
assert torch.allclose(c1_w, c2_w)

for of_name in ludwig_model1.model.output_features:
of1 = ludwig_model1.model.output_features.get(of_name)
of2 = ludwig_model2.model.output_features.get(of_name)
for of1_w, of2_w in zip(of1.decoder_obj.parameters(), of2.decoder_obj.parameters()):
assert torch.allclose(of1_w, of2_w)

check_model_equal(ludwig_model_loaded)


def test_model_save_reload_api(tmpdir, csv_filename, tmp_path):
torch.manual_seed(1)
random.seed(1)
Expand Down

0 comments on commit 25e4ac1

Please sign in to comment.