Skip to content

Commit

Permalink
Add metadata file when saving model (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqtg authored Mar 28, 2024
1 parent 93a3852 commit c1f3383
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions cornac/models/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import warnings
from datetime import datetime
from glob import glob
import json

import numpy as np

Expand Down Expand Up @@ -219,7 +220,7 @@ def clone(self, new_params=None):

return self.__class__(**init_params)

def save(self, save_dir=None, save_trainset=False):
def save(self, save_dir=None, save_trainset=False, metadata=None):
"""Save a recommender model to the filesystem.
Parameters
Expand All @@ -232,6 +233,10 @@ def save(self, save_dir=None, save_trainset=False):
if we want to deploy model later because train_set is
required for certain evaluation steps.
metadata: dict, default: None
Metadata to be saved with the model. This is useful
to store model details.
Returns
-------
model_file : str
Expand All @@ -246,16 +251,27 @@ def save(self, save_dir=None, save_trainset=False):
model_file = os.path.join(model_dir, "{}.pkl".format(timestamp))

saved_model = copy.deepcopy(self)
pickle.dump(saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL)
pickle.dump(
saved_model, open(model_file, "wb"), protocol=pickle.HIGHEST_PROTOCOL
)
if self.verbose:
print("{} model is saved to {}".format(self.name, model_file))

metadata = {} if metadata is None else metadata
metadata["model_classname"] = type(saved_model).__name__
metadata["model_file"] = os.path.basename(model_file)

if save_trainset:
trainset_file = model_file + ".trainset"
pickle.dump(
self.train_set,
open(model_file + ".trainset", "wb"),
open(trainset_file, "wb"),
protocol=pickle.HIGHEST_PROTOCOL,
)
metadata["trainset_file"] = os.path.basename(trainset_file)

with open(model_file + ".meta", "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)

return model_file

Expand Down Expand Up @@ -502,9 +518,7 @@ def rank(self, user_idx, item_indices=None, k=-1, **kwargs):
)
item_scores = all_item_scores[item_indices]

if (
k != -1
): # O(n + k log k), faster for small k which is usually the case
if k != -1: # O(n + k log k), faster for small k which is usually the case
partitioned_idx = np.argpartition(item_scores, -k)
top_k_idx = partitioned_idx[-k:]
sorted_top_k_idx = top_k_idx[np.argsort(item_scores[top_k_idx])]
Expand Down Expand Up @@ -545,7 +559,9 @@ def recommend(self, user_id, k=-1, remove_seen=False, train_set=None):
raise ValueError(f"{user_id} is unknown to the model.")

if k < -1 or k > self.total_items:
raise ValueError(f"k={k} is invalid, there are {self.total_users} users in total.")
raise ValueError(
f"k={k} is invalid, there are {self.total_users} users in total."
)

item_indices = np.arange(self.total_items)
if remove_seen:
Expand Down Expand Up @@ -622,7 +638,11 @@ def early_stop(self, train_set, val_set, min_delta=0.0, patience=0):

if self.stopped_epoch > 0:
print("Early stopping:")
print("- best epoch = {}, stopped epoch = {}".format(self.best_epoch, self.stopped_epoch))
print(
"- best epoch = {}, stopped epoch = {}".format(
self.best_epoch, self.stopped_epoch
)
)
print(
"- best monitored value = {:.6f} (delta = {:.6f})".format(
self.best_value, current_value - self.best_value
Expand Down

0 comments on commit c1f3383

Please sign in to comment.