diff --git a/recommenders/models/wide_deep/wide_deep_utils.py b/recommenders/models/wide_deep/wide_deep_utils.py index c989e3d8a..7544472b7 100644 --- a/recommenders/models/wide_deep/wide_deep_utils.py +++ b/recommenders/models/wide_deep/wide_deep_utils.py @@ -1,213 +1,406 @@ # Copyright (c) Recommenders contributors. # Licensed under the MIT License. +from typing import Tuple, Dict, Optional, Any, Union +from dataclasses import dataclass, field +from pathlib import Path -import tensorflow as tf - -from recommenders.utils.constants import DEFAULT_USER_COL, DEFAULT_ITEM_COL -from recommenders.utils.tf_utils import MODEL_DIR - - -def build_feature_columns( - users, - items, - user_col=DEFAULT_USER_COL, - item_col=DEFAULT_ITEM_COL, - item_feat_col=None, - crossed_feat_dim=1000, - user_dim=8, - item_dim=8, - item_feat_shape=None, - model_type="wide_deep", -): - """Build wide and/or deep feature columns for TensorFlow high-level API Estimator. - - Args: - users (iterable): Distinct user ids. - items (iterable): Distinct item ids. - user_col (str): User column name. - item_col (str): Item column name. - item_feat_col (str): Item feature column name for 'deep' or 'wide_deep' model. - crossed_feat_dim (int): Crossed feature dimension for 'wide' or 'wide_deep' model. - user_dim (int): User embedding dimension for 'deep' or 'wide_deep' model. - item_dim (int): Item embedding dimension for 'deep' or 'wide_deep' model. - item_feat_shape (int or an iterable of integers): Item feature array shape for 'deep' or 'wide_deep' model. - model_type (str): Model type, either - 'wide' for a linear model, - 'deep' for a deep neural networks, or - 'wide_deep' for a combination of linear model and neural networks. - - Returns: - list, list: - - The wide feature columns - - The deep feature columns. If only the wide model is selected, the deep column list is empty and viceversa. - """ - if model_type not in ["wide", "deep", "wide_deep"]: - raise ValueError("Model type should be either 'wide', 'deep', or 'wide_deep'") - - user_ids = tf.feature_column.categorical_column_with_vocabulary_list( - user_col, users - ) - item_ids = tf.feature_column.categorical_column_with_vocabulary_list( - item_col, items - ) - - if model_type == "wide": - return _build_wide_columns(user_ids, item_ids, crossed_feat_dim), [] - elif model_type == "deep": - return ( - [], - _build_deep_columns( - user_ids, item_ids, user_dim, item_dim, item_feat_col, item_feat_shape - ), - ) - elif model_type == "wide_deep": - return ( - _build_wide_columns(user_ids, item_ids, crossed_feat_dim), - _build_deep_columns( - user_ids, item_ids, user_dim, item_dim, item_feat_col, item_feat_shape - ), +import numpy as np +import pandas as pd +from tqdm.auto import tqdm + +import torch +from torch.utils.data import DataLoader +from torch import nn + +from recommenders.utils.constants import DEFAULT_USER_COL, DEFAULT_ITEM_COL, DEFAULT_RATING_COL, DEFAULT_PREDICTION_COL +import recommenders.utils.python_utils as pu +import recommenders.utils.torch_utils as tu + +@dataclass(frozen=True) +class WideAndDeepHyperParams: + user_dim: int = 32 + item_dim: int = 32 + crossed_feat_dim: int = 1000 + dnn_hidden_units: Tuple[int, ...] = (128, 128) + dnn_dropout: float = 0.0 + dnn_additional_embeddings_sizes: dict[str, Tuple[int, int]] = field(default_factory=dict) + dnn_cont_features: int = 0 + +class WideAndDeepModel(nn.Module): + def __init__( + self, + num_users: int, + num_items: int, + hparams: WideAndDeepHyperParams = WideAndDeepHyperParams(), + binary_output: bool = False, + ): + super().__init__() + + self.hparams = hparams + self.n_users = num_users + self.n_items = num_items + + self.users_emb = nn.Embedding(num_users, hparams.user_dim) + self.items_emb = nn.Embedding(num_items, hparams.item_dim) + self.additional_embs = nn.ModuleDict({ + k: nn.Embedding(num, dim) for k, (num, dim) in hparams.dnn_additional_embeddings_sizes.items() + }) + + # Randomly initialize embeddings + total_emb_dim = hparams.user_dim + hparams.item_dim + for _, emb in self.additional_embs.items(): + total_emb_dim += emb.embedding_dim + nn.init.uniform_(emb.weight, -1, 1) + + layers = [] + prev_output = hparams.dnn_cont_features + total_emb_dim + for hu in hparams.dnn_hidden_units: + layers.append(nn.Linear(prev_output, hu)) + layers.append(nn.Dropout(hparams.dnn_dropout)) + layers.append(nn.ReLU()) + prev_output = hu + + self.deep = nn.Sequential(*layers) + + # Cross product of users-items + exclusive_wide_input = hparams.crossed_feat_dim + + self.head = nn.Sequential( + nn.Linear(exclusive_wide_input+prev_output, 1), ) + if binary_output: + self.head.append(nn.Sigmoid()) + + def forward( + self, + interactions: torch.Tensor, + additional_embeddings: Dict[str, torch.Tensor] = {}, + continuous_features: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + users, items = interactions.T + + all_embed = torch.cat([ + self.users_emb(users), # Receives the indices + self.items_emb(items), + *[ emb(additional_embeddings[k]) for k, emb in self.additional_embs.items() ] + ], dim=1) + + # TODO: Use hashing to avoid problems with biased distributions + cross_product_idx = (users*self.n_items + items) % self.hparams.crossed_feat_dim + cross_product = nn.functional.one_hot(cross_product_idx, self.hparams.crossed_feat_dim) + + if self.hparams.dnn_cont_features > 0: + deep_input = torch.cat([continuous_features, all_embed], dim=1) + else: + deep_input = all_embed + + return self.head(torch.cat([ + cross_product, # wide input + self.deep(deep_input), # deep output + ], dim=1)) + + +class WideAndDeepDataset(torch.utils.data.Dataset): + def __init__( + self, + ratings: pd.DataFrame, + user_col: str = DEFAULT_USER_COL, + item_col: str = DEFAULT_ITEM_COL, + rating_col: str = DEFAULT_RATING_COL, + n_users: Optional[int] = None, + n_items: Optional[int] = None, + item_feat: Optional[pd.DataFrame] = None, + user_feat: Optional[pd.DataFrame] = None, + n_cont_features: Optional[int] = None, + ): + self._check_cols_df('ratings', ratings, [user_col, item_col, rating_col]) + self._check_cols_df('item_feat', item_feat, [item_col]) + self._check_cols_df('user_feat', user_feat, [user_col]) + self._check_cont_features('item_feat', item_feat, item_col) + self._check_cont_features('user_feat', user_feat, user_col) + + self.user_col = user_col + self.item_col = item_col + self.rating_col = rating_col + self.ratings = ratings.copy() + self.item_feat = item_feat.set_index(item_col).copy() if item_feat is not None else pd.DataFrame() + self.user_feat = user_feat.set_index(user_col).copy() if user_feat is not None else pd.DataFrame() + self.n_cont_features = n_cont_features or self._get_continuous_features([self.item_feat.index.min()], [self.user_feat.index.min()]).shape[1] + + self.n_users = n_users or ratings[user_col].max()+1 + self.n_items = n_items or ratings[item_col].max()+1 + + self.ratings[rating_col] = self.ratings[rating_col].astype('float32') + + @staticmethod + def _check_cols_df(df_name: str, df: Optional[pd.DataFrame], cols: list[str]) -> bool: + if df is None or df.empty: + return True + + for c in cols: + if c not in df.columns: + raise ValueError(f"Column '{c}' is not present on {df_name}") + + return True + + @staticmethod + def _check_cont_features(df_name: str, df: Optional[pd.DataFrame], col: str) -> bool: + if df is None or df.empty: + return True + + for c in df.columns: + if c == col: + continue + + # Check that dtype is float + if not pd.api.types.is_numeric_dtype(df[c].dtype): + raise ValueError(f"Column '{c}' from {df_name} has dtype {df[c].dtype}. Only numeric dtypes are allowed.") + + def __len__(self): + return len(self.ratings) + + def _get_continuous_features(self, item_ids, user_ids) -> np.array: + if len(item_ids) != len(user_ids): + raise ValueError('item_ids and user_ids should have same length') + + if self.item_feat.empty and self.user_feat.empty: + return np.array([]) + + # Put empty array so concat doesn't fail + continuous_features = [] + + if not self.item_feat.empty: + feats = self.item_feat.loc[item_ids] + continuous_features.extend(feats.values) + + if not self.user_feat.empty: + feats = self.user_feat.loc[user_ids] + continuous_features.extend(feats.values) + + return np.stack(continuous_features) + + def __getitem__(self, idx): + # TODO: Get additional embeddings too (e.g: user demographics) + item = self.ratings.iloc[idx] + + ret = { + 'interactions': self.ratings[[self.user_col, self.item_col]].iloc[idx].values, + } + + if self.n_cont_features: + # Reshape because it is only one item + ret['continuous_features'] = self._get_continuous_features([item[self.item_col]], [item[self.user_col]]).reshape(-1) -def _build_wide_columns(user_ids, item_ids, hash_bucket_size=1000): - """Build wide feature (crossed) columns. `user_ids` * `item_ids` are hashed into `hash_bucket_size` - - Args: - user_ids (tf.feature_column.categorical_column_with_vocabulary_list): User ids. - item_ids (tf.feature_column.categorical_column_with_vocabulary_list): Item ids. - hash_bucket_size (int): Hash bucket size. - - Returns: - list: Wide feature columns. - """ - # Including the original features in addition to the crossed one is recommended to address hash collision problem. - return [ - user_ids, - item_ids, - tf.feature_column.crossed_column( - [user_ids, item_ids], hash_bucket_size=hash_bucket_size - ), - ] - - -def _build_deep_columns( - user_ids, item_ids, user_dim, item_dim, item_feat_col=None, item_feat_shape=1 -): - """Build deep feature columns - - Args: - user_ids (tf.feature_column.categorical_column_with_vocabulary_list): User ids. - item_ids (tf.feature_column.categorical_column_with_vocabulary_list): Item ids. - user_dim (int): User embedding dimension. - item_dim (int): Item embedding dimension. - item_feat_col (str): Item feature column name. - item_feat_shape (int or an iterable of integers): Item feature array shape. - - Returns: - list: Deep feature columns. - """ - deep_columns = [ - # User embedding - tf.feature_column.embedding_column( - categorical_column=user_ids, dimension=user_dim, max_norm=user_dim**0.5 - ), - # Item embedding - tf.feature_column.embedding_column( - categorical_column=item_ids, dimension=item_dim, max_norm=item_dim**0.5 - ), - ] - # Item feature - if item_feat_col is not None: - deep_columns.append( - tf.feature_column.numeric_column( - item_feat_col, shape=item_feat_shape, dtype=tf.float32 + return ret, self.ratings[self.rating_col].iloc[idx] + + +class WideAndDeep(object): + def __init__( + self, + train: WideAndDeepDataset, + test: WideAndDeepDataset, + hparams: WideAndDeepHyperParams = WideAndDeepHyperParams(), + *, + n_users: Optional[int] = None, + n_items: Optional[int] = None, + epochs: int = 100, + batch_size: int = 128, + loss_fn: Union[str, nn.Module] = 'mse', + optimizer: str = 'sgd', + l1: float = 0.0001, + optimizer_params: dict[str, Any] = dict(), + disable_batch_progress: bool = False, + disable_iter_progress: bool = False, + eval_epoch: int = 1, + model_dir: Optional[Union[str, Path]] = None, + save_model_iter: int = -1, + prediction_col: str = DEFAULT_PREDICTION_COL, + ): + self.n_users = n_users or max(train.n_users, test.n_users) + self.n_items = n_items or max(train.n_items, test.n_items) + + if train.n_cont_features != test.n_cont_features: + raise ValueError(f'The number of cont. features on the train dataset is not the same as in test') + if train.n_cont_features != hparams.dnn_cont_features: + raise ValueError( + f"The number of cont. features on the dataset ({train.n_cont_features}) " + f"is not the same as in the hparams ({hparams.dnn_cont_features})" ) + + self.train = train + self.test = test + self.train_dataloader = DataLoader(train, batch_size, shuffle=True) + self.test_dataloader = DataLoader(test, len(test)) + + self.model = WideAndDeepModel( + num_users=self.n_users, + num_items=self.n_items, + hparams=hparams, ) - return deep_columns - - -def build_model( - model_dir=MODEL_DIR, - wide_columns=(), - deep_columns=(), - linear_optimizer="Ftrl", - dnn_optimizer="Adagrad", - dnn_hidden_units=(128, 128), - dnn_dropout=0.0, - dnn_batch_norm=True, - log_every_n_iter=1000, - save_checkpoints_steps=10000, - seed=None, -): - """Build wide-deep model. - - To generate wide model, pass wide_columns only. - To generate deep model, pass deep_columns only. - To generate wide_deep model, pass both wide_columns and deep_columns. - - Args: - model_dir (str): Model checkpoint directory. - wide_columns (list of tf.feature_column): Wide model feature columns. - deep_columns (list of tf.feature_column): Deep model feature columns. - linear_optimizer (str or tf.train.Optimizer): Wide model optimizer name or object. - dnn_optimizer (str or tf.train.Optimizer): Deep model optimizer name or object. - dnn_hidden_units (list of int): Deep model hidden units. E.g., [10, 10, 10] is three layers of 10 nodes each. - dnn_dropout (float): Deep model's dropout rate. - dnn_batch_norm (bool): Deep model's batch normalization flag. - log_every_n_iter (int): Log the training loss for every n steps. - save_checkpoints_steps (int): Model checkpoint frequency. - seed (int): Random seed. - - Returns: - tf.estimator.Estimator: Model - """ - gpu_config = tf.compat.v1.ConfigProto() - gpu_config.gpu_options.allow_growth = True # dynamic memory allocation - - # TensorFlow training setup - config = tf.estimator.RunConfig( - tf_random_seed=seed, - log_step_count_steps=log_every_n_iter, - save_checkpoints_steps=save_checkpoints_steps, - session_config=gpu_config, - ) - - if len(wide_columns) > 0 and len(deep_columns) == 0: - model = tf.compat.v1.estimator.LinearRegressor( - model_dir=model_dir, - config=config, - feature_columns=wide_columns, - optimizer=linear_optimizer, - ) - elif len(wide_columns) == 0 and len(deep_columns) > 0: - model = tf.compat.v1.estimator.DNNRegressor( - model_dir=model_dir, - config=config, - feature_columns=deep_columns, - hidden_units=dnn_hidden_units, - optimizer=dnn_optimizer, - dropout=dnn_dropout, - batch_norm=dnn_batch_norm, - ) - elif len(wide_columns) > 0 and len(deep_columns) > 0: - model = tf.compat.v1.estimator.DNNLinearCombinedRegressor( - model_dir=model_dir, - config=config, - # wide settings - linear_feature_columns=wide_columns, - linear_optimizer=linear_optimizer, - # deep settings - dnn_feature_columns=deep_columns, - dnn_hidden_units=dnn_hidden_units, - dnn_optimizer=dnn_optimizer, - dnn_dropout=dnn_dropout, - batch_norm=dnn_batch_norm, + + if isinstance(loss_fn, nn.Module): + self.loss_fn = loss_fn + else: + self.loss_fn = tu.LOSS_DICT[loss_fn]() + + self.optimizer = tu.OPTIM_DICT[optimizer]( + self.model.parameters(), + lr=l1, + **optimizer_params, ) - else: - raise ValueError( - "To generate wide model, set wide_columns.\n" - "To generate deep model, set deep_columns.\n" - "To generate wide_deep model, set both wide_columns and deep_columns." + + self.disable_batch_progress = disable_batch_progress + self.disable_iter_progress = disable_iter_progress + self.prediction_col = prediction_col + + self.current_epoch = 0 + self.epochs = epochs + self.eval_epoch = eval_epoch + + self.model_dir = Path(model_dir) if model_dir else None + self.save_model_iter = save_model_iter + self._check_save_model() + + self.train_loss_history = list() + self.test_loss_history = list() + + @property + def user_col(self) -> str: + return self.train.user_col + + @property + def model_path(self) -> Path: + return self.model_dir / f'wide_deep_state_{self.current_epoch:05d}.pth' + + @property + def item_col(self) -> str: + return self.train.item_col + + def _check_save_model(self) -> bool: + # The two conditions should be True/False at the same time + if (self.save_model_iter == -1) != (self.model_dir is None): + raise ValueError('You should set both save_model_iter and model_dir at the same time') + + if self.model_dir is not None: + # Check that save works + self.save() + + return True + + def fit(self): + if self.current_epoch >= self.epochs: + print(f"Model is already trained with {self.epochs} epochs. Increment the number of epochs.") + + with tqdm(total=self.epochs, leave=True, disable=self.disable_iter_progress) as pbar: + pbar.update(self.current_epoch) + for _ in range(self.current_epoch, self.epochs): + self.fit_step() + pbar.update() + pbar.set_postfix( + train_loss=self.train_loss_history[-1], + test_loss=self.test_loss_history[-1][1], + ) + + if self.save_model_iter != -1 and self.current_epoch % self.save_model_iter == 0: + self.save() + + def save(self, model_path=None): + model_path = Path(model_path) if model_path else self.model_path + model_path.parent.mkdir(exist_ok=True) + + torch.save(self.model.state_dict(), model_path) + + def load(self, model_path=None): + if model_path is None: + print('Model path not specified, automatically loading from model dir') + model_path = max(self.model_dir.glob('*.pth'), key=lambda f: int(f.stem.split('_')[-1])) + print(' Loading', model_path) + else: + model_path = Path(model_path) + + self.model.load_state_dict(torch.load(model_path)) + self.current_epoch = int(model_path.stem.split('_')[-1]) + + def fit_step(self): + self.model.train() + + train_loss = 0.0 + for X,y in tqdm(self.train_dataloader, 'batch', leave=False, disable=self.disable_batch_progress): + pred = self.model( + X['interactions'], + continuous_features=X.get('continuous_features', None), + ) + loss = self.loss_fn(pred, y) + # TODO: Can we use this loss? Or should I calculate it again with no_grad? + train_loss += loss.item() + + # Propagate error + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + self.train_loss_history.append(train_loss / len(self.train_dataloader)) + self.model.eval() + + if self.eval_epoch != -1 and self.current_epoch%self.eval_epoch == 0: + num_batches = len(self.test_dataloader) + test_loss = 0 + + with torch.no_grad(): + for X, y in self.test_dataloader: + pred = self.model( + X['interactions'], + continuous_features=X.get('continuous_features', None), + ) + test_loss += self.loss_fn(pred, y).item() + + test_loss /= num_batches + self.test_loss_history.append((self.current_epoch, test_loss)) + + self.current_epoch += 1 + + def _get_uip_cont(self, user_ids, item_ids, remove_seen: bool): + if user_ids is None: + user_ids = np.arange(1, self.n_users) + if item_ids is None: + item_ids = np.arange(1, self.n_items) + + uip = pd.MultiIndex.from_product( + [user_ids, item_ids], + names=[self.user_col, self.item_col], ) - return model + if remove_seen: + uip = uip.difference( + self.train.ratings.set_index([self.user_col, self.item_col]).index + ) + + uip = uip.to_frame(index=False) + + cont_features = None + if self.train.n_cont_features > 0: + cont_features = torch.from_numpy( + self.train._get_continuous_features(uip.values[:,0], uip.values[:,1]) + ) + + return uip, cont_features + + def recommend_k_items( + self, user_ids=None, item_ids=None, top_k=10, remove_seen=True, + ): + uip, cont_features = self._get_uip_cont(user_ids, item_ids, remove_seen) + + with torch.no_grad(): + uip[self.prediction_col] = self.model( + torch.from_numpy(uip[[self.user_col, self.item_col]].values), + continuous_features=cont_features, + ) + + return ( + uip + .sort_values([self.user_col, self.prediction_col], ascending=[True, False]) + .groupby(self.user_col) + .head(top_k) + .reset_index(drop=True) + ) \ No newline at end of file diff --git a/recommenders/utils/torch_utils.py b/recommenders/utils/torch_utils.py new file mode 100644 index 000000000..cbd2e42ca --- /dev/null +++ b/recommenders/utils/torch_utils.py @@ -0,0 +1,37 @@ +from typing import Type + +import torch +from torch import nn, optim + +OPTIM_DICT: dict[str, Type[optim.Optimizer]] = { + 'adadelta': optim.Adadelta, + 'adagrad': optim.Adagrad, + 'adam': optim.Adam, + 'adamw': optim.AdamW, + 'adamax': optim.Adamax, + 'asgd': optim.ASGD, + 'lbfgs': optim.LBFGS, + 'rmsprop': optim.RMSprop, + 'rprop': optim.Rprop, + 'sgd': optim.SGD, + 'sparseadam': optim.SparseAdam, +} + +LOSS_DICT: dict[str, Type[nn.Module]] = { + 'l1': nn.L1Loss, + 'mse': nn.MSELoss, + 'cross_entropy': nn.CrossEntropyLoss, + 'nll': nn.NLLLoss, + 'bce': nn.BCELoss, + 'bce_with_logits': nn.BCEWithLogitsLoss, + 'hinge': nn.HingeEmbeddingLoss, + 'kl_div': nn.KLDivLoss, + 'huber': nn.HuberLoss, + 'smooth_l1': nn.SmoothL1Loss, + 'soft_margin': nn.SoftMarginLoss, + 'multi_margin': nn.MultiMarginLoss, + 'cosine_embedding': nn.CosineEmbeddingLoss, + 'margin_ranking': nn.MarginRankingLoss, + 'triplet_margin': nn.TripletMarginLoss, + 'ctc': nn.CTCLoss, +} \ No newline at end of file diff --git a/tests/unit/recommenders/models/test_wide_deep_utils.py b/tests/unit/recommenders/models/test_wide_deep_utils.py index 2d4168e99..49bd81b04 100644 --- a/tests/unit/recommenders/models/test_wide_deep_utils.py +++ b/tests/unit/recommenders/models/test_wide_deep_utils.py @@ -13,12 +13,11 @@ ) try: - from recommenders.utils.tf_utils import pandas_input_fn, MODEL_DIR from recommenders.models.wide_deep.wide_deep_utils import ( - build_model, - build_feature_columns, + WideAndDeep, + WideAndDeepDataset, + WideAndDeepHyperParams, ) - import tensorflow as tf except ImportError: pass # skip this import if we are in cpu environment @@ -32,109 +31,87 @@ def pd_df(): { DEFAULT_USER_COL: [1, 1, 1, 2, 2, 2], DEFAULT_ITEM_COL: [1, 2, 3, 1, 4, 5], - ITEM_FEAT_COL: [ - [1, 1, 1], - [2, 2, 2], - [3, 3, 3], - [1, 1, 1], - [4, 4, 4], - [5, 5, 5], - ], DEFAULT_RATING_COL: [5, 4, 3, 5, 5, 3], } ) + item_feat = pd.DataFrame({ + DEFAULT_ITEM_COL: [1, 2, 3, 4, 5], + ITEM_FEAT_COL: [ + [1, 1, 1], + [2, 2, 2], + [3, 3, 3], + [4, 4, 4], + [5, 5, 5], + ], + }) users = df.drop_duplicates(DEFAULT_USER_COL)[DEFAULT_USER_COL].values items = df.drop_duplicates(DEFAULT_ITEM_COL)[DEFAULT_ITEM_COL].values - return df, users, items + return df, users, items, item_feat +@pytest.mark.gpu +def test_wide_deep_dataset(pd_df): + data, users, items, item_feat = pd_df + dataset = WideAndDeepDataset(data) + assert len(dataset) == len(data) + # Add +1 because user 0 does count for `dataset` + assert dataset.n_users == len(users)+1 + assert dataset.n_items == len(items)+1 + assert dataset.n_cont_features == 0 + item, rating = dataset[0] + assert list(item['interactions']) == [1,1] + assert 'continuous_features' not in item + assert rating == 5 + + # Test using the item features + dataset = WideAndDeepDataset(data, item_feat=item_feat) + assert len(dataset) == len(data) + # Add +1 because user 0 does count for `dataset` + assert dataset.n_users == len(users)+1 + assert dataset.n_items == len(items)+1 + assert dataset.n_cont_features == 3 + item, rating = dataset[0] + assert list(item['interactions']) == [1,1] + assert list(item['continuous_features']) == [1,1,1] + assert rating == 5 @pytest.mark.gpu -def test_wide_model(pd_df, tmp): - data, users, items = pd_df +def test_wide_deep_model(pd_df, tmp): + data, users, items, item_feat = pd_df - # Test wide model - # Test if wide column has two original features and one crossed feature - wide_columns, _ = build_feature_columns( - users, items, model_type="wide", crossed_feat_dim=10 - ) - assert len(wide_columns) == 3 - # Check crossed feature dimension - assert wide_columns[2].hash_bucket_size == 10 - # Check model type - model = build_model( - os.path.join(tmp, "wide_" + MODEL_DIR), wide_columns=wide_columns - ) - assert isinstance(model, tf.compat.v1.estimator.LinearRegressor) - # Test if model train works - model.train( - input_fn=pandas_input_fn( - df=data, - y_col=DEFAULT_RATING_COL, - batch_size=1, - num_epochs=None, - shuffle=True, - ), - steps=1, + dataset = WideAndDeepDataset(data) + default_hparams = WideAndDeepHyperParams() + model = WideAndDeep( + dataset, + dataset, ) + + assert model.model.deep[0].in_features == default_hparams.item_dim + default_hparams.user_dim + assert model.model.head[-1].out_features == 1 - # Close the event file so that the model folder can be cleaned up. - summary_writer = tf.compat.v1.summary.FileWriterCache.get(model.model_dir) - summary_writer.close() - + # Test if the model train works + model.fit_step() + assert model.current_epoch == len(model.train_loss_history) == len(model.test_loss_history) == 1 @pytest.mark.gpu -def test_deep_model(pd_df, tmp): - data, users, items = pd_df - - # Test if deep columns have user and item features - _, deep_columns = build_feature_columns(users, items, model_type="deep") - assert len(deep_columns) == 2 - # Check model type - model = build_model( - os.path.join(tmp, "deep_" + MODEL_DIR), deep_columns=deep_columns - ) - assert isinstance(model, tf.compat.v1.estimator.DNNRegressor) - # Test if model train works - model.train( - input_fn=pandas_input_fn( - df=data, y_col=DEFAULT_RATING_COL, batch_size=1, num_epochs=1, shuffle=False - ) - ) - - # Close the event file so that the model folder can be cleaned up. - summary_writer = tf.compat.v1.summary.FileWriterCache.get(model.model_dir) - summary_writer.close() +def test_wide_deep_recs(pd_df, tmp): + data, users, items, item_feat = pd_df + dataset = WideAndDeepDataset(data) + model = WideAndDeep( + dataset, + dataset, + ) -@pytest.mark.gpu -def test_wide_deep_model(pd_df, tmp): - data, users, items = pd_df + recs = model.recommend_k_items(users, items, top_k=4, remove_seen=False) - # Test if wide and deep columns have correct features - wide_columns, deep_columns = build_feature_columns( - users, items, model_type="wide_deep" - ) - assert len(wide_columns) == 3 - assert len(deep_columns) == 2 - # Check model type - model = build_model( - os.path.join(tmp, "wide_deep_" + MODEL_DIR), - wide_columns=wide_columns, - deep_columns=deep_columns, - ) - assert isinstance(model, tf.compat.v1.estimator.DNNLinearCombinedRegressor) - # Test if model train works - model.train( - input_fn=pandas_input_fn( - df=data, - y_col=DEFAULT_RATING_COL, - batch_size=1, - num_epochs=None, - shuffle=True, - ), - steps=1, - ) + assert len(recs) == len(users)*4 + assert set(recs[DEFAULT_USER_COL].unique()) == set(users) + assert set(recs[DEFAULT_ITEM_COL].unique()).issubset(items) - # Close the event file so that the model folder can be cleaned up. - summary_writer = tf.compat.v1.summary.FileWriterCache.get(model.model_dir) - summary_writer.close() + # Each user has voted in 3 items, therefore + # only two items remain to be recommended per user + # even if we specify top_k>2 + recs = model.recommend_k_items(users, items, top_k=4) + assert len(recs) == 2*2 + assert set(recs[DEFAULT_USER_COL].unique()).issubset(users) + assert set(recs[DEFAULT_ITEM_COL].unique()).issubset(items)