From 7abff4947bb92b6450ce0e69deca3ee78303800e Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 07:40:13 -0700 Subject: [PATCH 01/32] refactor: ruff fixes and adding fractional coordinate check --- matsciml/datasets/utils.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index d172f0c0..89dc32fb 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -9,6 +9,7 @@ import lmdb import torch +import numpy as np from einops import einsum, rearrange from joblib import Parallel, delayed from pymatgen.core import Lattice, Structure @@ -302,11 +303,11 @@ def get_lmdb_keys( "Both `ignore_keys` and `_lambda` were passed; arguments are mutually exclusive.", ) if ignore_keys: - _lambda = lambda x: x not in ignore_keys + _lambda = lambda x: x not in ignore_keys # noqa: E731 else: if not _lambda: # escape case where we basically don't filter - _lambda = lambda x: x + _lambda = lambda x: x # noqa: E731 # convert to a sorted list of keys keys = sorted(list(filter(_lambda, keys))) return keys @@ -529,7 +530,7 @@ def divide_data_chunks( assert all( [length != 0 for length in lengths], ), "Too many processes specified and not enough data to split over multiple LMDB files. Decrease `num_procs!`" - p = Parallel(num_procs)( + _ = Parallel(num_procs)( delayed(write_chunk)(chunk, target_dir, index, metadata) for chunk, index in zip(chunks, lmdb_indices) ) @@ -693,6 +694,11 @@ def calculate_periodic_shifts( include_index=True, include_image=True, ) + # check to make sure the cell definition is valid + if np.any(structure.frac_coords > 1.0): + raise ValueError( + f"Structure has fractional coordinates greater than 1! Check structure:\n{structure}" + ) def _all_sites_have_neighbors(neighbors): return all([len(n) for n in neighbors]) From 0e4e8bc5f83b3a06cdc7c7ba69b5e939aa1cf692 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 11:19:57 -0700 Subject: [PATCH 02/32] feat: added embedding forward hook check Signed-off-by: Lee, Kin Long Kelvin --- matsciml/lightning/callbacks.py | 107 ++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 66277637..08357f23 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -20,10 +20,12 @@ from torch import distributed as dist from torch import nn from torch.optim import Optimizer +from dgl import DGLGraph from matsciml.common.packages import package_registry from matsciml.datasets.utils import concatenate_keys from matsciml.models.base import BaseTaskModule +from matsciml.common.types import Embeddings, BatchDict class LeaderboardWriter(BasePredictionWriter): @@ -845,3 +847,108 @@ def _second_step( if p.grad is None: continue p.data = org_weights[p] + + + +def embedding_magnitude_hook(module: nn.Module, input: BatchDict, output: Embeddings) -> None: + """ + Forward hook that will inspect an embedding output. + + This checks for two properties of graph-level and node-level embeddings: + the magnitude of the median tells us if the values are a lot larger than + what we might typically expect, and the variance tells us if the embeddings + are effectively collapsing. + + Parameters + ---------- + module : nn.Module + Nominally a PyTorch module, but we actually expect an encoder. + input : BatchDict + Batch of samples to process + output : Embeddings + Expected to be an embedding data structure. If not, we don't + fail the run, but posts a critical message. + """ + logger = getLogger("matsciml.helper") + logger.setLevel("INFO") + if isinstance(output, Embeddings): + # check the magnitude of both node and system level embeddings + if output.system_embedding is not None: + sys_z = output.system_embedding.detach().cpu() + # calculate representative statistics + sys_z_med = sys_z.median().item() + sys_z_var = sys_z.var().item() + if sys_z_med > 10.: + logger.warning( + f"Median system/graph embedding value is greater than 10 ({sys_z_med})" + ) + if sys_z_var <= 1e-2: + logger.warning( + f"Variance in system/graph embedding is quite small ({sys_z_var})" + ) + if output.point_embedding is not None: + node_z = output.point_embedding.detach().cpu() + # calculate representative statistics + node_z_med = node_z.median().item() + node_z_var = node_z.var().item() + if node_z_med > 10.: + logger.warning( + f"Median node embedding value is greater than 10 ({node_z_med})" + ) + if node_z_var <= 1e-2: + logger.warning( + f"Variance in node embedding is quite small ({node_z_var})" + ) + else: + logger.critical( + f"Hooked module does not produce an embedding data structure! {module}" + ) + + +class TrainingHelperCallback(Callback): + def __init__(self, + small_grad_thres: float = 1e-3, param_norm_thres: float = 10., update_freq: int = 50 + ) -> None: + super().__init__() + self.logger = getLogger("matsciml.helper") + self.logger.setLevel("INFO") + self.small_grad_thres = small_grad_thres + self.update_freq = update_freq + + def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + self.batch_idx = 0 + + def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: + self.batch_idx = batch_idx + if self.is_active: + # look at atom positions for irregularities + if "graph" in batch: + g = batch["graph"] + if isinstance(g, DGLGraph): + pos = g.ndata["pos"] + else: + pos = g.pos + else: + # we assume there are positions, otherwise there are bigger + # problems than running this check + pos = batch["pos"] + min_pos, max_pos = pos.min().item(), pos.max().item() + if min_pos >= 0. and max_pos <= 1.: + self.logger.warning("Coordinates are small and might be fractional, which may not be intended.") + + @property + def is_active(self) -> bool: + return (self.batch_idx % self.update_freq) == 0 + + def on_before_optimizer_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: + if self.is_active: + # loop through parameter related checks + for name, parameter in pl_module.named_parameters(): + if parameter.requires_grad: + if parameter.grad is None: + self.logger.warning(f"Parameter {name} has no gradients, but should!") + else: + grad_norm = parameter.grad.norm() + if grad_norm.abs() < self.small_grad_thres: + self.logger.warning(f"Parameter {name} has small gradient norm - {grad_norm}") + param_norm = parameter.norm() From a4336850679037d04a368c9f702e1441b9d4c9d1 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 12:06:52 -0700 Subject: [PATCH 03/32] feat: added encoder forward hook to helper --- matsciml/lightning/callbacks.py | 60 +++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 08357f23..a05dcc67 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -849,8 +849,9 @@ def _second_step( p.data = org_weights[p] - -def embedding_magnitude_hook(module: nn.Module, input: BatchDict, output: Embeddings) -> None: +def embedding_magnitude_hook( + module: nn.Module, input: BatchDict, output: Embeddings +) -> None: """ Forward hook that will inspect an embedding output. @@ -878,7 +879,7 @@ def embedding_magnitude_hook(module: nn.Module, input: BatchDict, output: Embedd # calculate representative statistics sys_z_med = sys_z.median().item() sys_z_var = sys_z.var().item() - if sys_z_med > 10.: + if sys_z_med > 10.0: logger.warning( f"Median system/graph embedding value is greater than 10 ({sys_z_med})" ) @@ -891,7 +892,7 @@ def embedding_magnitude_hook(module: nn.Module, input: BatchDict, output: Embedd # calculate representative statistics node_z_med = node_z.median().item() node_z_var = node_z.var().item() - if node_z_med > 10.: + if node_z_med > 10.0: logger.warning( f"Median node embedding value is greater than 10 ({node_z_med})" ) @@ -906,19 +907,39 @@ def embedding_magnitude_hook(module: nn.Module, input: BatchDict, output: Embedd class TrainingHelperCallback(Callback): - def __init__(self, - small_grad_thres: float = 1e-3, param_norm_thres: float = 10., update_freq: int = 50 + def __init__( + self, + small_grad_thres: float = 1e-3, + param_norm_thres: float = 10.0, + update_freq: int = 50, + encoder_hook: bool = True, ) -> None: super().__init__() self.logger = getLogger("matsciml.helper") self.logger.setLevel("INFO") self.small_grad_thres = small_grad_thres self.update_freq = update_freq + self.encoder_hook = encoder_hook - def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.encoder_hook: + pl_module.encoder.register_forward_hook(embedding_magnitude_hook) + self.logger.info("Registered embedding monitor") + + def on_train_epoch_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: self.batch_idx = 0 - def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int) -> None: + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + ) -> None: self.batch_idx = batch_idx if self.is_active: # look at atom positions for irregularities @@ -933,22 +954,33 @@ def on_train_batch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo # problems than running this check pos = batch["pos"] min_pos, max_pos = pos.min().item(), pos.max().item() - if min_pos >= 0. and max_pos <= 1.: - self.logger.warning("Coordinates are small and might be fractional, which may not be intended.") + if min_pos >= 0.0 and max_pos <= 1.0: + self.logger.warning( + "Coordinates are small and might be fractional, which may not be intended." + ) @property def is_active(self) -> bool: return (self.batch_idx % self.update_freq) == 0 - def on_before_optimizer_step(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: + def on_before_optimizer_step( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + optimizer: Optimizer, + ) -> None: if self.is_active: # loop through parameter related checks for name, parameter in pl_module.named_parameters(): if parameter.requires_grad: if parameter.grad is None: - self.logger.warning(f"Parameter {name} has no gradients, but should!") + self.logger.warning( + f"Parameter {name} has no gradients, but should!" + ) else: grad_norm = parameter.grad.norm() if grad_norm.abs() < self.small_grad_thres: - self.logger.warning(f"Parameter {name} has small gradient norm - {grad_norm}") - param_norm = parameter.norm() + self.logger.warning( + f"Parameter {name} has small gradient norm - {grad_norm}" + ) + _ = parameter.norm() From 7759f09d3e0f0d9b4daa71931c3c525ad3f9734d Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 12:43:12 -0700 Subject: [PATCH 04/32] feat: added encoder-outputhead compaison --- matsciml/lightning/callbacks.py | 74 ++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index a05dcc67..2a4e8b9f 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -913,6 +913,7 @@ def __init__( param_norm_thres: float = 10.0, update_freq: int = 50, encoder_hook: bool = True, + record_param_norm_history: bool = True, ) -> None: super().__init__() self.logger = getLogger("matsciml.helper") @@ -920,6 +921,7 @@ def __init__( self.small_grad_thres = small_grad_thres self.update_freq = update_freq self.encoder_hook = encoder_hook + self.record_param_norm_history = record_param_norm_history def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" @@ -963,6 +965,64 @@ def on_train_batch_start( def is_active(self) -> bool: return (self.batch_idx % self.update_freq) == 0 + @staticmethod + def encoder_head_comparison( + pl_module: pl.LightningModule, log_history, python_logger + ): + """ + Make a comparison of weight norms in the encoder and output head stack. + + The heuristic being checked here is if the encoder weights are a lot smaller + than the output head, the encoder may end up being ignored entirely and + the output heads are just overfitting to the data. This check doesn't prove + that is happening, but provides an indication of it. + + Parameters + ---------- + pl_module + Nominally a generic ``LightningModule``, but we expect the + model to have an encoder and an output head module dict. + log_history : bool + Default True, whether to log the weight norm values to an + experiment tracker. + python_logger : Logger + Logger for the Python side to raise the warning message. + """ + # compare encoder and output head weights + encoder_norm_vals = [] + output_norm_vals = [] + for parameter in pl_module.encoder.parameters(): + encoder_norm_vals.append(parameter.detach().norm().cpu().item()) + for head in pl_module.output_heads.values(): + for parameter in head.parameters(): + output_norm_vals.append(parameter.detach().norm().cpu().item()) + encoder_norm_vals = np.array(encoder_norm_vals) + output_norm_vals = np.array(output_norm_vals) + encoder_median = np.median(encoder_norm_vals) + output_median = np.median(output_norm_vals) + if encoder_median < (2.0 * output_median): + python_logger.warning( + "Median encoder weights are significantly smaller than output heads:" + " encoder median norm: {encoder_median:.3e}," + " output head: {output_median:.3e}" + ) + # optionally record to service as well + if log_history: + pl_module.log( + "encoder_weight_norm", + torch.from_numpy(encoder_norm_vals).float(), + prog_bar=False, + on_step=True, + on_epoch=False, + ) + pl_module.log( + "outputhead_weight_norm", + torch.from_numpy(output_norm_vals).float(), + prog_bar=False, + on_step=True, + on_epoch=False, + ) + def on_before_optimizer_step( self, trainer: "pl.Trainer", @@ -971,6 +1031,7 @@ def on_before_optimizer_step( ) -> None: if self.is_active: # loop through parameter related checks + grad_norm_vals = [] for name, parameter in pl_module.named_parameters(): if parameter.requires_grad: if parameter.grad is None: @@ -983,4 +1044,15 @@ def on_before_optimizer_step( self.logger.warning( f"Parameter {name} has small gradient norm - {grad_norm}" ) - _ = parameter.norm() + grad_norm_vals.append(grad_norm.detach().cpu().item()) + # track gradient norm for the whole model + pl_module.log( + "gradient_norms", + torch.FloatTensor(grad_norm_vals), + prog_bar=False, + on_step=True, + on_epoch=False, + ) + self.encoder_head_comparison( + pl_module, self.record_param_norm_history, self.logger + ) From 7221eff21c09cf71f37a85a6969925abf4e4282d Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 14:21:47 -0700 Subject: [PATCH 05/32] feat: working grad norm logging --- matsciml/lightning/callbacks.py | 57 +++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 2a4e8b9f..244ed6ea 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -17,6 +17,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import BasePredictionWriter, Callback from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning import loggers as pl_loggers from torch import distributed as dist from torch import nn from torch.optim import Optimizer @@ -967,7 +968,10 @@ def is_active(self) -> bool: @staticmethod def encoder_head_comparison( - pl_module: pl.LightningModule, log_history, python_logger + pl_module: pl.LightningModule, + log_history, + python_logger, + global_step: int | None = None, ): """ Make a comparison of weight norms in the encoder and output head stack. @@ -1007,21 +1011,20 @@ def encoder_head_comparison( " output head: {output_median:.3e}" ) # optionally record to service as well - if log_history: - pl_module.log( - "encoder_weight_norm", - torch.from_numpy(encoder_norm_vals).float(), - prog_bar=False, - on_step=True, - on_epoch=False, - ) - pl_module.log( - "outputhead_weight_norm", - torch.from_numpy(output_norm_vals).float(), - prog_bar=False, - on_step=True, - on_epoch=False, - ) + if log_history and pl_module.logger is not None: + log_service = pl_module.logger.experiment + encoder_norm_vals = torch.from_numpy(encoder_norm_vals).float() + output_norm_vals = torch.from_numpy(output_norm_vals).float() + if isinstance(log_service, pl_loggers.TensorBoardLogger): + log_service.add_histogram( + "encoder_weight_norm", encoder_norm_vals, global_step + ) + log_service.add_histogram( + "outputhead_weight_norm", output_norm_vals, global_step + ) + elif isinstance(log_service, pl_loggers.WandbLogger): + log_service.log({"encoder_weight_norm": encoder_norm_vals}) + log_service.log({"outputhead_weight_norm": output_norm_vals}) def on_before_optimizer_step( self, @@ -1030,6 +1033,7 @@ def on_before_optimizer_step( optimizer: Optimizer, ) -> None: if self.is_active: + log_service = pl_module.logger # loop through parameter related checks grad_norm_vals = [] for name, parameter in pl_module.named_parameters(): @@ -1046,13 +1050,18 @@ def on_before_optimizer_step( ) grad_norm_vals.append(grad_norm.detach().cpu().item()) # track gradient norm for the whole model - pl_module.log( - "gradient_norms", - torch.FloatTensor(grad_norm_vals), - prog_bar=False, - on_step=True, - on_epoch=False, - ) + grad_norm_vals = torch.FloatTensor(grad_norm_vals) + if isinstance(log_service, pl_loggers.TensorBoardLogger): + log_service.experiment.add_histogram( + "gradient_norms", grad_norm_vals, global_step=trainer.global_step + ) + elif isinstance(log_service, pl_loggers.WandbLogger): + log_service.experiment.log( + {"gradient_norms": torch.FloatTensor(grad_norm_vals)} + ) self.encoder_head_comparison( - pl_module, self.record_param_norm_history, self.logger + pl_module, + self.record_param_norm_history, + self.logger, + trainer.global_step, ) From 17297f258fe4b42933c1018566ac67403c05afae Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 14:26:20 -0700 Subject: [PATCH 06/32] refactor: changing variance value to a much smaller value --- matsciml/lightning/callbacks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 244ed6ea..94c6800a 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -884,7 +884,7 @@ def embedding_magnitude_hook( logger.warning( f"Median system/graph embedding value is greater than 10 ({sys_z_med})" ) - if sys_z_var <= 1e-2: + if sys_z_var <= 1e-5: logger.warning( f"Variance in system/graph embedding is quite small ({sys_z_var})" ) @@ -897,7 +897,7 @@ def embedding_magnitude_hook( logger.warning( f"Median node embedding value is greater than 10 ({node_z_med})" ) - if node_z_var <= 1e-2: + if node_z_var <= 1e-5: logger.warning( f"Variance in node embedding is quite small ({node_z_var})" ) From 4372d75972ead7b9bc37114371f5797124277481 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 14:40:30 -0700 Subject: [PATCH 07/32] docs: adding docstrings throughout helper Signed-off-by: Lee, Kin Long Kelvin --- matsciml/lightning/callbacks.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 94c6800a..0f63e72f 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -911,7 +911,6 @@ class TrainingHelperCallback(Callback): def __init__( self, small_grad_thres: float = 1e-3, - param_norm_thres: float = 10.0, update_freq: int = 50, encoder_hook: bool = True, record_param_norm_history: bool = True, @@ -927,6 +926,11 @@ def __init__( def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: + """ + This attaches the embedding hook, which inspects the embeddings + to make sure there is sufficient variance, or if the values are + too big. + """ if self.encoder_hook: pl_module.encoder.register_forward_hook(embedding_magnitude_hook) self.logger.info("Registered embedding monitor") @@ -934,6 +938,7 @@ def on_fit_start( def on_train_epoch_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: + """Sets an internal batch index tracker for activity.""" self.batch_idx = 0 def on_train_batch_start( @@ -943,6 +948,13 @@ def on_train_batch_start( batch: Any, batch_idx: int, ) -> None: + """ + Triggering at the beginning of a training batch, this is where all + the checks pertaining to input data should be made. For now, + we check whether or not the coordinates are bounded between 0,1 + which may indicate that the coordinates are fractional which may + not be intended. + """ self.batch_idx = batch_idx if self.is_active: # look at atom positions for irregularities @@ -964,6 +976,7 @@ def on_train_batch_start( @property def is_active(self) -> bool: + """Determines whether or not to perform an update.""" return (self.batch_idx % self.update_freq) == 0 @staticmethod @@ -1032,6 +1045,13 @@ def on_before_optimizer_step( pl_module: "pl.LightningModule", optimizer: Optimizer, ) -> None: + """ + This stage checks for problems pertaining to parameter weights + and gradients, triggering before the optimizer is stepped. + We check to make sure the gradient norm is reasonably sized, + as well as making sure that the output head weigts don't get + significantly larger than the encoder. + """ if self.is_active: log_service = pl_module.logger # loop through parameter related checks From 1a84d162f7bb29c9c1afc19297a749c41b192dc6 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 16:25:53 -0700 Subject: [PATCH 08/32] feat: added function to log embeddings --- matsciml/models/base.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 1cfa605b..c4029f45 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -11,6 +11,7 @@ from logging import getLogger import pytorch_lightning as pl +from pytorch_lightning import loggers as pl_loggers import torch from einops import reduce from torch import Tensor, nn @@ -676,6 +677,7 @@ def __init__( embedding_reduction_type: str = "mean", normalize_kwargs: dict[str, float] | None = None, scheduler_kwargs: dict[str, dict[str, Any]] | None = None, + log_embeddings: bool = False, **kwargs, ) -> None: super().__init__() @@ -706,6 +708,7 @@ def __init__( if len(self.task_keys) > 0: self.task_loss_scaling = self._task_loss_scaling self.embedding_reduction_type = embedding_reduction_type + self.log_embeddings = log_embeddings self.save_hyperparameters(ignore=["encoder", "loss_func"]) @property @@ -873,6 +876,40 @@ def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: results[key] = output return results + def _log_embedding(self, embeddings: Embeddings) -> None: + """ + This maps the appropriate logging function depending on what + logger was used, and saves the graph and node level embeddings. + + Some services like ``wandb`` are able to do some nifty embedding + analyses online using these embeddings. + + Parameters + ---------- + embeddings : Embeddings + Data structure containing embeddings from the encoder. + """ + if self.logger is not None: + exp = self.logger.experiment + if isinstance(self.logger, pl_loggers.WandbLogger): + exp.log( + {"graph_embeddings": embeddings.system_embedding.detach().cpu()} + ) + if isinstance(embeddings.point_embedding, torch.Tensor): + exp.log( + {"node_embeddings": embeddings.point_embedding.detach().cpu()} + ) + elif isinstance(self.logger, pl_loggers.TensorBoardLogger): + exp.add_embedding( + embeddings.system_embedding.detach().cpu(), tag="graph_embeddings" + ) + if isinstance(embeddings.point_embedding, torch.Tensor): + exp.add_embedding( + embeddings.point_embedding.detach().cpu(), tag="node_embeddings" + ) + else: + pass + def _get_targets( self, batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]], From 552cf1b99ffd13ef99116514dc799e82b5ff9568 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 16:32:22 -0700 Subject: [PATCH 09/32] feat: adding embedding logging call --- matsciml/models/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index c4029f45..c95876d0 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -874,6 +874,8 @@ def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: reduction=self.embedding_reduction_type, ) results[key] = output + if self.log_embeddings: + self._log_embedding(embeddings) return results def _log_embedding(self, embeddings: Embeddings) -> None: @@ -1366,6 +1368,8 @@ def process_embedding(self, embeddings: Embeddings) -> Dict[str, torch.Tensor]: output = head(embeddings.system_embedding[key]) output = reduce(output, "b ... d -> b d", reduction="mean") results[key] = output + if self.log_embeddings: + self._log_embedding(embeddings) return results def _compute_losses( @@ -1827,6 +1831,8 @@ def energy_and_force( # this ensures that we get a scalar value for every node # representing the energy contribution outputs["node_energies"] = node_energies + if self.log_embeddings: + self._log_embedding(embeddings) return outputs def predict(self, batch: BatchDict) -> dict[str, torch.Tensor]: From 9922415911b7c63dff8446ee6de90038fef95339 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 16:36:03 -0700 Subject: [PATCH 10/32] refactor: using hparams for log embedding kwarg --- matsciml/models/base.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index c95876d0..ea7617cd 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -708,7 +708,6 @@ def __init__( if len(self.task_keys) > 0: self.task_loss_scaling = self._task_loss_scaling self.embedding_reduction_type = embedding_reduction_type - self.log_embeddings = log_embeddings self.save_hyperparameters(ignore=["encoder", "loss_func"]) @property @@ -874,7 +873,7 @@ def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: reduction=self.embedding_reduction_type, ) results[key] = output - if self.log_embeddings: + if self.hparams.log_embeddings: self._log_embedding(embeddings) return results @@ -1368,7 +1367,7 @@ def process_embedding(self, embeddings: Embeddings) -> Dict[str, torch.Tensor]: output = head(embeddings.system_embedding[key]) output = reduce(output, "b ... d -> b d", reduction="mean") results[key] = output - if self.log_embeddings: + if self.hparams.log_embeddings: self._log_embedding(embeddings) return results @@ -1831,7 +1830,7 @@ def energy_and_force( # this ensures that we get a scalar value for every node # representing the energy contribution outputs["node_energies"] = node_energies - if self.log_embeddings: + if self.hparams.log_embeddings: self._log_embedding(embeddings) return outputs From 23493054cc0cb566c6a9205b08035d8cc3660fbe Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 16:37:33 -0700 Subject: [PATCH 11/32] fix: adding global step specification in tensorboard embedding log --- matsciml/models/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ea7617cd..4b05b609 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -902,11 +902,15 @@ def _log_embedding(self, embeddings: Embeddings) -> None: ) elif isinstance(self.logger, pl_loggers.TensorBoardLogger): exp.add_embedding( - embeddings.system_embedding.detach().cpu(), tag="graph_embeddings" + embeddings.system_embedding.detach().cpu(), + tag="graph_embeddings", + global_step=self.trainer.global_step, ) if isinstance(embeddings.point_embedding, torch.Tensor): exp.add_embedding( - embeddings.point_embedding.detach().cpu(), tag="node_embeddings" + embeddings.point_embedding.detach().cpu(), + tag="node_embeddings", + global_step=self.trainer.global_step, ) else: pass From 1a749beaadac97a755b8dfbc09ba233aa2f68bf7 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Thu, 6 Jun 2024 16:44:30 -0700 Subject: [PATCH 12/32] refactor: adding global step to add embedding --- matsciml/models/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 4b05b609..fb3d3109 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -903,14 +903,12 @@ def _log_embedding(self, embeddings: Embeddings) -> None: elif isinstance(self.logger, pl_loggers.TensorBoardLogger): exp.add_embedding( embeddings.system_embedding.detach().cpu(), - tag="graph_embeddings", - global_step=self.trainer.global_step, + tag=f"graph_embeddings_{self.trainer.global_step}", ) if isinstance(embeddings.point_embedding, torch.Tensor): exp.add_embedding( embeddings.point_embedding.detach().cpu(), - tag="node_embeddings", - global_step=self.trainer.global_step, + tag=f"node_embeddings_{self.trainer.global_step}", ) else: pass From cac78fe5e4e922a426ba4e3bf24691f9a01c93e8 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 08:04:34 -0700 Subject: [PATCH 13/32] refactor: making forward generically stash embeddings --- matsciml/models/base.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index fb3d3109..45c54330 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -842,10 +842,11 @@ def forward( batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]], ) -> dict[str, torch.Tensor]: if "embeddings" in batch: - embedding = batch.get("embeddings") + embeddings = batch.get("embeddings") else: - embedding = self.encoder(batch) - outputs = self.process_embedding(embedding) + embeddings = self.encoder(batch) + batch["embeddings"] = embeddings + outputs = self.process_embedding(embeddings) return outputs def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: @@ -873,8 +874,6 @@ def process_embedding(self, embeddings: Embeddings) -> dict[str, torch.Tensor]: reduction=self.embedding_reduction_type, ) results[key] = output - if self.hparams.log_embeddings: - self._log_embedding(embeddings) return results def _log_embedding(self, embeddings: Embeddings) -> None: From 53d640bc09d8dbd953d7d7ff73d15b98dc7105ff Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 08:05:47 -0700 Subject: [PATCH 14/32] refactor: putting embedding logging in train step That way we don't do a double log as forward might be called multiple times --- matsciml/models/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 45c54330..68c6e9ba 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1039,6 +1039,8 @@ def training_step( ) batch_size = None self.log_dict(metrics, on_step=True, prog_bar=True, batch_size=batch_size) + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return loss_dict def validation_step( From d7e1f075f1e9d0176aa45bd91d72d4eeb24fd13b Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 08:06:24 -0700 Subject: [PATCH 15/32] refactor: adding embedding logging to independent steps --- matsciml/models/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 68c6e9ba..98de5598 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -1061,6 +1061,8 @@ def validation_step( ) batch_size = None self.log_dict(metrics, batch_size=batch_size) + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return loss_dict def test_step( @@ -1081,6 +1083,8 @@ def test_step( ) batch_size = None self.log_dict(metrics, batch_size=batch_size) + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return loss_dict def _make_normalizers(self) -> dict[str, Normalizer]: From 997a1be6596e4549455847111a32bacd6d973e8c Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 10:50:18 -0700 Subject: [PATCH 16/32] chore: rebasing main to finalize PR --- matsciml/models/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 98de5598..d7beb634 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2004,6 +2004,8 @@ def training_step( s.step(loss, self.current_epoch) else: s.step(epoch=self.current_epoch) + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return loss_dict def _make_normalizers(self) -> dict[str, Normalizer]: From 1640ef884818675a830ab3b6a23ce3c5536521e3 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 09:02:44 -0700 Subject: [PATCH 17/32] refactor: added log embedding frequency control --- matsciml/models/base.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index d7beb634..3b5edd27 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -678,6 +678,7 @@ def __init__( normalize_kwargs: dict[str, float] | None = None, scheduler_kwargs: dict[str, dict[str, Any]] | None = None, log_embeddings: bool = False, + log_embeddings_every_n_steps: int = 50, **kwargs, ) -> None: super().__init__() @@ -889,7 +890,9 @@ def _log_embedding(self, embeddings: Embeddings) -> None: embeddings : Embeddings Data structure containing embeddings from the encoder. """ - if self.logger is not None: + log_freq = self.hparams.log_embeddings_every_n_steps + global_step = self.trainer.global_step + if self.logger is not None and (global_step % log_freq) == 0: exp = self.logger.experiment if isinstance(self.logger, pl_loggers.WandbLogger): exp.log( From a9dd1255f097394c77fccb161ef9a053c5430071 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 09:30:48 -0700 Subject: [PATCH 18/32] refactor: cleaned up log_embeddings function --- matsciml/models/base.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 3b5edd27..9ad4dc9e 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -892,24 +892,37 @@ def _log_embedding(self, embeddings: Embeddings) -> None: """ log_freq = self.hparams.log_embeddings_every_n_steps global_step = self.trainer.global_step + # only log embeddings at the same cadence as everything else if self.logger is not None and (global_step % log_freq) == 0: exp = self.logger.experiment + sys_z = embeddings.system_embedding.detach().cpu() + node_z = embeddings.point_embedding.detach().cpu() if isinstance(self.logger, pl_loggers.WandbLogger): + # this import is okay here since we need it for the logger anyway + import wandb + + cols = [f"D{i}" for i in range(sys_z.size(-1))] exp.log( - {"graph_embeddings": embeddings.system_embedding.detach().cpu()} + {"graph_embeddings": wandb.Table(columns=cols, data=sys_z.tolist())} ) if isinstance(embeddings.point_embedding, torch.Tensor): + # TODO: should add labels to the nodes based on graph index exp.log( - {"node_embeddings": embeddings.point_embedding.detach().cpu()} + { + "node_embeddings": wandb.Table( + columns=cols, data=node_z.tolist() + ) + } ) elif isinstance(self.logger, pl_loggers.TensorBoardLogger): exp.add_embedding( - embeddings.system_embedding.detach().cpu(), + sys_z, tag=f"graph_embeddings_{self.trainer.global_step}", ) if isinstance(embeddings.point_embedding, torch.Tensor): + # TODO: should add labels to the nodes based on graph index exp.add_embedding( - embeddings.point_embedding.detach().cpu(), + node_z, tag=f"node_embeddings_{self.trainer.global_step}", ) else: From 7186df6da98254fac5fa74464b250ed0167597df Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 09:35:32 -0700 Subject: [PATCH 19/32] test: added simple unit test for forward hook Signed-off-by: Lee, Kin Long Kelvin --- matsciml/lightning/tests/test_helper.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 matsciml/lightning/tests/test_helper.py diff --git a/matsciml/lightning/tests/test_helper.py b/matsciml/lightning/tests/test_helper.py new file mode 100644 index 00000000..5ba2c728 --- /dev/null +++ b/matsciml/lightning/tests/test_helper.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + +from matsciml.lightning.callbacks import embedding_magnitude_hook +from matsciml.common.types import Embeddings + + +class DummyEncoder(nn.Module): + def forward(self, g_z, n_z) -> Embeddings: + embeddings = Embeddings(g_z, n_z) + return embeddings + + +def test_hook_manual(caplog): + g_z = torch.rand(8, 64) * 30 + n_z = torch.rand(340, 64) * 30 + encoder = DummyEncoder() + encoder.register_forward_hook(embedding_magnitude_hook) + _ = encoder(g_z, n_z) + assert "WARNING" in caplog.text + assert "embedding value is greater" in caplog.text From 44150f571945f132e9d879e4089a24ddcf984d91 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 12:11:12 -0700 Subject: [PATCH 20/32] feat: implemented working autocorrelation callback --- matsciml/lightning/callbacks.py | 157 ++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 0f63e72f..30f67642 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -10,6 +10,7 @@ from time import time from copy import copy from typing import Any, Callable, Dict, Iterator, Optional +from queue import Queue import numpy as np import pytorch_lightning as pl @@ -22,6 +23,7 @@ from torch import nn from torch.optim import Optimizer from dgl import DGLGraph +from scipy.signal import correlate from matsciml.common.packages import package_registry from matsciml.datasets.utils import concatenate_keys @@ -1085,3 +1087,158 @@ def on_before_optimizer_step( self.logger, trainer.global_step, ) + + +class ModelAutocorrelation(Callback): + def __init__( + self, + buffer_size: int = 100, + sampled: bool = True, + sample_frac: float = 0.05, + analyze_grads: bool = True, + analyze_every_n_steps: int = 50, + ) -> None: + super().__init__() + self.buffer_size = buffer_size + if not sampled: + raise NotImplementedError( + "Only sampled analysis mode is currently supported." + ) + self.sampled = sampled + self.sample_frac = sample_frac + self.analyze_grads = analyze_grads + self.analyze_every_n_steps = analyze_every_n_steps + + @staticmethod + def sample_parameters( + model: nn.Module, indices: dict[str, torch.Tensor], collect_grads: bool + ) -> tuple[np.ndarray, np.ndarray | None]: + collected_params = [] + collected_grads = [] + for name, parameter in model.named_parameters(): + idx = indices.get(name, None) + if idx is not None: + elements = parameter.flatten()[idx].detach().cpu().numpy() + collected_params.append(elements) + if collect_grads and parameter.grad is not None: + collected_grads.append( + parameter.grad.flatten()[idx].detach().cpu().numpy() + ) + if collect_grads and len(collected_grads) > 0: + return np.hstack(collected_params), np.hstack(collected_grads) + else: + return np.hstack(collected_params), None + + def run_analysis(self, logger): + param_history = np.vstack(self.history["params"].queue) + param_corr = self._calculate_autocorrelation(param_history) + # now log the spectrum + if isinstance(logger, pl_loggers.WandbLogger): + from wandb.plot import line_series + + logger.experiment.log( + { + "param_autocorrelation": line_series( + xs=[i for i in range(param_history.shape[0])], + ys=param_corr.tolist(), + title="Parameter autocorrelation", + xname="Steps", + ) + } + ) + elif isinstance(logger, pl_loggers.TensorBoardLogger): + logger.experiment.add_image( + "param_autocorrelation", + param_corr, + global_step=self.global_step, + dataformats="WH", + ) + + if self.analyze_grads: + grad_history = np.vstack(self.history["grads"].queue) + grad_corr = self._calculate_autocorrelation(grad_history) + if isinstance(logger, pl_loggers.WandbLogger): + from wandb.plot import line_series + + logger.experiment.log( + { + "grad_autocorrelation": line_series( + xs=[i for i in range(grad_history.shape[0])], + ys=grad_corr.tolist(), + title="Gradient autocorrelation", + xname="Steps", + ) + } + ) + elif isinstance(logger, pl_loggers.TensorBoardLogger): + logger.experiment.add_image( + "grad_autocorrelation", + grad_corr, + global_step=self.global_step, + dataformats="WH", + ) + + @staticmethod + def _calculate_autocorrelation(history: np.ndarray) -> np.ndarray: + assert history.ndim == 2, "Expected history to be 2D!" + # normalizing by variance explodes, so just make it relative + corr = correlate(history, history, mode="same") + corr = (corr - corr.min(axis=0)) / (corr.max(axis=0) - corr.min(axis=0)) + return corr + + def on_fit_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + if self.sampled: + indices = {} + for name, parameter in pl_module.named_parameters(): + if not isinstance(parameter, torch.nn.UninitializedParameter): + numel = parameter.numel() + indices[name] = torch.randperm(numel)[ + : int(numel * self.sample_frac) + ] + self.indices = indices + self.history = { + "params": Queue(self.buffer_size), + "grads": Queue(self.buffer_size), + } + + @property + def global_step(self) -> int: + return self._global_step + + @global_step.setter + def global_step(self, value: int) -> None: + self._global_step = value + + @property + def is_active(self) -> bool: + return ( + self.global_step % self.analyze_every_n_steps + ) == 0 and self.global_step != 0 + + def on_train_batch_start( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + ) -> None: + self.global_step = trainer.global_step + + def on_before_optimizer_step( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor + ) -> None: + params, grads = self.sample_parameters( + pl_module, self.indices, self.analyze_grads + ) + self.history["params"].put(params) + if self.analyze_grads: + self.history["grads"].put(grads) + # remove the oldest part of history first if we're full + if self.history["params"].full(): + _ = self.history["params"].get() + if self.history["grads"].full(): + _ = self.history["grads"].get() + if self.is_active: + self.run_analysis(pl_module.logger) From 8eb876f46e126c155486e7d61fc9d653792bff0a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 13:00:27 -0700 Subject: [PATCH 21/32] docs: added a variety of docstrings for model autocorrelation --- matsciml/lightning/callbacks.py | 139 +++++++++++++++++++++++++++++++- 1 file changed, 138 insertions(+), 1 deletion(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 30f67642..ec6488ae 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1098,6 +1098,45 @@ def __init__( analyze_grads: bool = True, analyze_every_n_steps: int = 50, ) -> None: + """ + Initializes a ``ModelAutocorrelation`` callback. + + The purpose of this callback is to track parameters and optionally + gradients over time, and periodically calculate the autocorrelation + spectrum to see how correlated parameters and gradients are throughout + the training process. + + Parameters + ---------- + buffer_size : int, default 100 + Number of steps worth of parameters/gradients to keep in + the correlation window. If the buffer is too small, the + autocorrelation might not be particularly meaningful; if + it's too big, it may impact training throughput. + sampled : bool, default True + If True, we ``sample_frac`` worth of elements from every + parameter tensor. The False case has not yet been implemented, + but is intended to track the whole model. + sample_frac : float, default 0.05 + Fraction of a given parameter/gradient tensor to track. + Larger values give a better picture for how the whole + model is behaving, while fewer samples mean less impact + but a poorer description. + analyze_grads : bool, default True + If True, perform the autocorrelation procedure for gradients + as well as parameters. This may give a better indication of + dynamics over parameters alone. + analyze_every_n_steps : int, default 50 + Frequency to carry out the autocorrelation analysis. Note + that sampling is done at every training step, regardless + of this value. Instead, this determines how often we do the + autocorrelation calculation and logging. + + Raises + ------ + NotImplementedError + If ``sampled=False``, which has not yet been implemented. + """ super().__init__() self.buffer_size = buffer_size if not sampled: @@ -1113,6 +1152,32 @@ def __init__( def sample_parameters( model: nn.Module, indices: dict[str, torch.Tensor], collect_grads: bool ) -> tuple[np.ndarray, np.ndarray | None]: + """ + Collect elements from parameter and gradient tensors of the + target model based on a dictionary of indices. + + Indices are expected to run over the number elements flattened + tensors (i.e. ``Tensor.numel``). + + Parameters + ---------- + model : nn.Module + PyTorch model to track + indices : dict[str, torch.Tensor] + Dictionary mapping for layer name and corresponding + parameter tensor + collect_grads + If True, gradients will also be recorded. Evidently + this means twice the storage requirement. + + Returns + ------- + tuple[np.ndarray, np.ndarray | None] + If ``collect_grads`` is True, a 2-tuple of arrays + will be returned, corresponding to the sampled + parameters and gradients. If False, the latter will + just be None. + """ collected_params = [] collected_grads = [] for name, parameter in model.named_parameters(): @@ -1129,7 +1194,22 @@ def sample_parameters( else: return np.hstack(collected_params), None - def run_analysis(self, logger): + def run_analysis(self, logger: pl_loggers.Logger): + """ + Perform the autocorrelation analysis. + + This function will convert the history buffer into arrays + and pass them to ``_calculate_autocorrelation``. If we have + a logger (either ``wandb`` or ``tensorboard``), we will + log the correlation spectra to these services as well. + + Parameters + ---------- + logger : pl_loggers.Logger + Abstract PyTorch Lightning logger instance. While it is + technically abstract, only ``WandbLogger`` and ``TensorBoardLogger`` + are supported right now + """ param_history = np.vstack(self.history["params"].queue) param_corr = self._calculate_autocorrelation(param_history) # now log the spectrum @@ -1153,6 +1233,10 @@ def run_analysis(self, logger): global_step=self.global_step, dataformats="WH", ) + else: + raise NotImplementedError( + "Only WandbLogger and TensorBoardLogger are currently supported." + ) if self.analyze_grads: grad_history = np.vstack(self.history["grads"].queue) @@ -1180,6 +1264,27 @@ def run_analysis(self, logger): @staticmethod def _calculate_autocorrelation(history: np.ndarray) -> np.ndarray: + """ + Use ``scipy.signal.correlate`` to calculate the autocorrelation + for parameters and optionally gradients. + + This spectrum tells you the degree of correlation between training + steps in the recent history for every parameter/gradient element + being tracked. The rescaling is done unintelligently, and for + purely aesthetic reasons. + + Parameters + ---------- + history : np.ndarray + NumPy 2D array; the first dimension is time step, and the + second is parameter/gradient element. + + Returns + ------- + np.ndarray + NumPy 2D array; the first dimension is time step, and the + second corresponds to autocorrelation power/signal. + """ assert history.ndim == 2, "Expected history to be 2D!" # normalizing by variance explodes, so just make it relative corr = correlate(history, history, mode="same") @@ -1189,6 +1294,19 @@ def _calculate_autocorrelation(history: np.ndarray) -> np.ndarray: def on_fit_start( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: + """ + Setup the callback tracking. For sampling mode, we generate random + indices for every parameter in the model (that isn't lazy) that + corresponds to parameters/gradients we will consistently track + throughout training. + + Parameters + ---------- + trainer : pl.Trainer + PyTorch Lightning trainer instance + pl_module : pl.LightningModule + PyTorch Lightning module to track + """ if self.sampled: indices = {} for name, parameter in pl_module.named_parameters(): @@ -1198,6 +1316,8 @@ def on_fit_start( : int(numel * self.sample_frac) ] self.indices = indices + # queue structure is used to manage the history with a finite + # number of elements self.history = { "params": Queue(self.buffer_size), "grads": Queue(self.buffer_size), @@ -1213,6 +1333,7 @@ def global_step(self, value: int) -> None: @property def is_active(self) -> bool: + """Used to determine whether the correlation analysis will be carried out.""" return ( self.global_step % self.analyze_every_n_steps ) == 0 and self.global_step != 0 @@ -1229,6 +1350,22 @@ def on_train_batch_start( def on_before_optimizer_step( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", loss: torch.Tensor ) -> None: + """ + Triggers before the optimizer is stepped, adding the parameters and + optionally gradients to the history. + + If the current step matches the analysis frequency, carry out the + autocorrelation analysis and log the spectrum. + + Parameters + ---------- + trainer : pl.Trainer + PyTorch Lightning trainer instance + pl_module : pl.LightningModule + PyTorch Lightning module to track + loss : torch.Tensor + Loss value; unused + """ params, grads = self.sample_parameters( pl_module, self.indices, self.analyze_grads ) From 9c48a5c5b3d5e54b6138a9d0bc1348a6dbd8901a Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 13:10:55 -0700 Subject: [PATCH 22/32] docs: added docstring for helper callback --- matsciml/lightning/callbacks.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index ec6488ae..76ba6608 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -917,6 +917,37 @@ def __init__( encoder_hook: bool = True, record_param_norm_history: bool = True, ) -> None: + """ + Initializes a ``TrainingHelperCallback``. + + The purpose of this callback is to provide some typical + heuristics that are useful for diagnosing how training + is progressing. The behavior of this callback is twofold: + (1) emit warning messages to the user, indicating that + there are irregularities like missing gradients, and low + variance in embeddings; (2) send some of these observations + to loggers like ``TensorBoardLogger`` and ``WandbLogger`` + for asynchronous viewing. + + Parameters + ---------- + small_grad_thres : float, default 1e-3 + Threshold for detecting when gradients for particular + parameters are considered small. This helps identify + layers that could benefit with some residual connections. + update_freq : int, default 50 + Frequency of which to run checks with this callback. + This can be increased to make messages less spammy. + encoder_hook : bool, default True + If True, we register a forward hook with the model's + encoder that is specifically designed for ``matsciml`` + usage. This hook will inspect graph and node level + embeddings, particularly variance in dimensions, to + identify feature collapse. + record_param_norm_history : bool, default True + If True, will log tensor norms to ``tensorboard`` or + ``wandb`` services. + """ super().__init__() self.logger = getLogger("matsciml.helper") self.logger.setLevel("INFO") From a8ccae8fb68b0a9dd34722ed204e9da979651d31 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 14:20:01 -0700 Subject: [PATCH 23/32] scripts: added scripts to demonstrate callback usage Signed-off-by: Lee, Kin Long Kelvin --- examples/callbacks/autocorrelation.py | 65 +++++++++++++++++++++++++++ examples/callbacks/helper.py | 57 +++++++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 examples/callbacks/autocorrelation.py create mode 100644 examples/callbacks/helper.py diff --git a/examples/callbacks/autocorrelation.py b/examples/callbacks/autocorrelation.py new file mode 100644 index 00000000..c1566b2d --- /dev/null +++ b/examples/callbacks/autocorrelation.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger + +from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform +from matsciml.lightning.data_utils import MatSciMLDataModule +from matsciml.lightning.callbacks import ModelAutocorrelation +from matsciml.models import SchNet +from matsciml.models.base import ScalarRegressionTask + +""" +This script demonstrates the use of the `ModelAutocorrelation` callback. + +The main utility of this callback is to monitor the degree of correlation +in model parameters and optionally gradients over a time span. The idea +is that for optimization trajectories, steps are ideally as de-correlated +as possible (at least within reason), and indeed is actually a major +assumption of Adam-like optimizers. + +There is no hard coded heuristic for identifying "too much correlation" +yet, however this callback can help do the data collection for you to +develop a sense for yourself. One method for trying this out is to +set varying learning rates, and seeing how the autocorrelation spectra +are different. +""" + +# construct a scalar regression task with SchNet encoder +task = ScalarRegressionTask( + encoder_class=SchNet, + # kwargs to be passed into the creation of SchNet model + encoder_kwargs={ + "encoder_only": True, + "hidden_feats": [128, 128, 128], + "atom_embedding_dim": 128, + }, + # which keys to use as targets + task_keys=["energy_relaxed"], + log_embeddings=False, +) +# Use IS2RE devset to test workflow +# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances +dm = MatSciMLDataModule.from_devset( + "IS2REDataset", + dset_kwargs={ + "transforms": [ + PointCloudToGraphTransform( + "dgl", + cutoff_dist=20.0, + node_keys=["pos", "atomic_numbers"], + ), + DistancesTransform(), + ], + }, +) + +# tensorboard logging if working purely locally, otherwise wandb +logger = WandbLogger( + name="helper-callback", offline=False, project="matsciml", log_model="all" +) +logger = TensorBoardLogger("./") + +# run a quick training loop +trainer = pl.Trainer(max_epochs=30, logger=logger, callbacks=[ModelAutocorrelation()]) +trainer.fit(task, datamodule=dm) diff --git a/examples/callbacks/helper.py b/examples/callbacks/helper.py new file mode 100644 index 00000000..28c6fab5 --- /dev/null +++ b/examples/callbacks/helper.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger + +from matsciml.datasets.transforms import DistancesTransform, PointCloudToGraphTransform +from matsciml.lightning.data_utils import MatSciMLDataModule +from matsciml.lightning.callbacks import TrainingHelperCallback +from matsciml.models import SchNet +from matsciml.models.base import ScalarRegressionTask + +""" +This script demonstrates the use of the ``TrainingHelperCallback`` +callback. The purpose of this callback is to provide some +helpful heuristics into the training process by identifying +some common issues like unused weights, small gradients, +and oversmoothed embeddings. +""" + +# construct a scalar regression task with SchNet encoder +task = ScalarRegressionTask( + encoder_class=SchNet, + # kwargs to be passed into the creation of SchNet model + encoder_kwargs={ + "encoder_only": True, + "hidden_feats": [128, 128, 128], + "atom_embedding_dim": 128, + }, + # which keys to use as targets + task_keys=["energy_relaxed"], + log_embeddings=True, +) +# Use IS2RE devset to test workflow +# SchNet uses RBFs, and expects edge features corresponding to atom-atom distances +dm = MatSciMLDataModule.from_devset( + "IS2REDataset", + dset_kwargs={ + "transforms": [ + PointCloudToGraphTransform( + "dgl", + cutoff_dist=20.0, + node_keys=["pos", "atomic_numbers"], + ), + DistancesTransform(), + ], + }, +) + +# tensorboard logging if working purely locally +# logger = TensorBoardLogger("./") +logger = WandbLogger( + name="helper-callback", offline=False, project="matsciml", log_model="all" +) + +# run a quick training loop +trainer = pl.Trainer(max_epochs=10, logger=logger, callbacks=[TrainingHelperCallback()]) +trainer.fit(task, datamodule=dm) From 28e1a33ad5fce3a375465796aecfe3f09b6f2ffe Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Fri, 7 Jun 2024 14:34:01 -0700 Subject: [PATCH 24/32] refactor: using cartesian coordinates as regular inputs --- matsciml/datasets/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/matsciml/datasets/utils.py b/matsciml/datasets/utils.py index 89dc32fb..b0410573 100644 --- a/matsciml/datasets/utils.py +++ b/matsciml/datasets/utils.py @@ -735,12 +735,13 @@ def _all_sites_have_neighbors(neighbors): cell = torch.from_numpy(cell.copy()).float() # get coordinates as well, for standardization frac_coords = torch.from_numpy(structure.frac_coords).float() + coords = torch.from_numpy(structure.cart_coords).float() return_dict = { "src_nodes": torch.LongTensor(all_src), "dst_nodes": torch.LongTensor(all_dst), "images": torch.FloatTensor(all_images), "cell": cell, - "pos": frac_coords, + "pos": coords, } # now calculate offsets based on each image for a lattice return_dict["offsets"] = einsum(return_dict["images"], cell, "v i, n i j -> v j") From af745c0fd3c6d95630e8be18a994197d60f9cc99 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Wed, 12 Jun 2024 10:51:26 -0700 Subject: [PATCH 25/32] refactor: taking absolute value of the median for comparison --- matsciml/lightning/callbacks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 76ba6608..847b7afe 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -880,7 +880,7 @@ def embedding_magnitude_hook( if output.system_embedding is not None: sys_z = output.system_embedding.detach().cpu() # calculate representative statistics - sys_z_med = sys_z.median().item() + sys_z_med = sys_z.median().abs().item() sys_z_var = sys_z.var().item() if sys_z_med > 10.0: logger.warning( @@ -893,7 +893,7 @@ def embedding_magnitude_hook( if output.point_embedding is not None: node_z = output.point_embedding.detach().cpu() # calculate representative statistics - node_z_med = node_z.median().item() + node_z_med = node_z.median().abs().item() node_z_var = node_z.var().item() if node_z_med > 10.0: logger.warning( From 515e4b88c2a370243ff7f03ae97cbf885b225452 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Wed, 12 Jun 2024 10:58:38 -0700 Subject: [PATCH 26/32] refactor: now looping over multiple loggers, if any are supplied This will only still function for wandb/tensorboard, but supports multiple --- matsciml/lightning/callbacks.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/matsciml/lightning/callbacks.py b/matsciml/lightning/callbacks.py index 847b7afe..586c7903 100644 --- a/matsciml/lightning/callbacks.py +++ b/matsciml/lightning/callbacks.py @@ -1056,21 +1056,23 @@ def encoder_head_comparison( " encoder median norm: {encoder_median:.3e}," " output head: {output_median:.3e}" ) - # optionally record to service as well - if log_history and pl_module.logger is not None: - log_service = pl_module.logger.experiment - encoder_norm_vals = torch.from_numpy(encoder_norm_vals).float() - output_norm_vals = torch.from_numpy(output_norm_vals).float() - if isinstance(log_service, pl_loggers.TensorBoardLogger): - log_service.add_histogram( - "encoder_weight_norm", encoder_norm_vals, global_step - ) - log_service.add_histogram( - "outputhead_weight_norm", output_norm_vals, global_step - ) - elif isinstance(log_service, pl_loggers.WandbLogger): - log_service.log({"encoder_weight_norm": encoder_norm_vals}) - log_service.log({"outputhead_weight_norm": output_norm_vals}) + # optionally record to a supported service as well + # this nominally should work for multiple loggers + if log_history and len(pl_module.loggers) > 0: + for pl_logger in pl_module.loggers: + log_service = pl_logger.experiment + encoder_norm_vals = torch.from_numpy(encoder_norm_vals).float() + output_norm_vals = torch.from_numpy(output_norm_vals).float() + if isinstance(log_service, pl_loggers.TensorBoardLogger): + log_service.add_histogram( + "encoder_weight_norm", encoder_norm_vals, global_step + ) + log_service.add_histogram( + "outputhead_weight_norm", output_norm_vals, global_step + ) + elif isinstance(log_service, pl_loggers.WandbLogger): + log_service.log({"encoder_weight_norm": encoder_norm_vals}) + log_service.log({"outputhead_weight_norm": output_norm_vals}) def on_before_optimizer_step( self, From 7f726b91fc95ddc051469b4f336eb90de4a3cc99 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Wed, 12 Jun 2024 12:30:58 -0700 Subject: [PATCH 27/32] refactor: making forward pass only set embeddings to batch, not reusing them This addresses the issue of computational graphs breaking --- matsciml/models/base.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 9ad4dc9e..3cc47f51 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -842,11 +842,8 @@ def forward( self, batch: dict[str, torch.Tensor | dgl.DGLGraph | dict[str, torch.Tensor]], ) -> dict[str, torch.Tensor]: - if "embeddings" in batch: - embeddings = batch.get("embeddings") - else: - embeddings = self.encoder(batch) - batch["embeddings"] = embeddings + embeddings = self.encoder(batch) + batch["embeddings"] = embeddings outputs = self.process_embedding(embeddings) return outputs From b888ee8434709c1acf2be148a6ebf9f4d13b68ce Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 10:55:22 -0700 Subject: [PATCH 28/32] fix: nesting scheduler stepping mechanism only if something is passed --- matsciml/models/base.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 3cc47f51..4efb1c7b 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2009,14 +2009,15 @@ def training_step( # step learning rate schedulers at the end of epochs if self.trainer.is_last_batch: schedulers = self.lr_schedulers() - if not isinstance(schedulers, list): - schedulers = [schedulers] - for s in schedulers: - # for schedulers that need a metric - if isinstance(s, lr_scheduler.ReduceLROnPlateau): - s.step(loss, self.current_epoch) - else: - s.step(epoch=self.current_epoch) + if schedulers is not None: + if not isinstance(schedulers, list): + schedulers = [schedulers] + for s in schedulers: + # for schedulers that need a metric + if isinstance(s, lr_scheduler.ReduceLROnPlateau): + s.step(loss, self.current_epoch) + else: + s.step(epoch=self.current_epoch) if self.hparams.log_embeddings and "embeddings" in batch: self._log_embedding(batch["embeddings"]) return loss_dict From c182d909594701229ceeff295d22f0fc618f3c22 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 11:04:40 -0700 Subject: [PATCH 29/32] refactor: allowing multi task subtasks to reuse shared embedding --- matsciml/models/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 4efb1c7b..24436935 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2689,8 +2689,10 @@ def forward( if self.is_multidata: for key, data in batch.items(): data["embeddings"] = self.encoder(data) + embeddings = data["embeddings"] else: batch["embeddings"] = self.encoder(batch) + embeddings = batch["embeddings"] # for single dataset usage, we assume the nested structure isn't used if self.is_multidata: for key, data in batch.items(): @@ -2699,13 +2701,13 @@ def forward( results[key] = {} # finally call the task with the data for task_type, subtask in subtasks.items(): - results[key][task_type] = subtask(data) + results[key][task_type] = subtask.process_embedding(embeddings) else: # in the single dataset case, we can skip the outer loop # and just pass the batch into the subtask tasks = list(self.task_map.values()).pop(0) for task_type, subtask in tasks.items(): - results[task_type] = subtask(batch) + results[task_type] = subtask.process_embedding(embeddings) return results def predict(self, batch: BatchDict) -> dict[str, dict[str, torch.Tensor]]: From 2ca13670b67e206036ad2168090eae3f38a0a8dd Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 11:07:23 -0700 Subject: [PATCH 30/32] refactor: adding log embeddings kwargs to multi task litmodule --- matsciml/models/base.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index 24436935..b6f95c96 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -2311,6 +2311,8 @@ def __init__( *tasks: tuple[str, BaseTaskModule], task_scaling: Iterable[float] | None = None, task_keys: dict[str, list[str]] | None = None, + log_embeddings: bool = False, + log_embeddings_every_n_steps: int = 50, **encoder_opt_kwargs, ) -> None: """ @@ -2354,6 +2356,8 @@ def __init__( "subtask_hparams": subtask_hparams, "task_scaling": task_scaling, "encoder_opt_kwargs": encoder_opt_kwargs, + "log_embeddings": log_embeddings, + "log_embeddings_every_n_steps": log_embeddings_every_n_steps, }, ) self.task_map = task_map From 43175b4ec51455a4925c69574dffdb0d41cda755 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 11:08:46 -0700 Subject: [PATCH 31/32] feat: added log embeddings method to multitask litmodule --- matsciml/models/base.py | 51 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index b6f95c96..ba548fd1 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -3181,6 +3181,57 @@ def from_pretrained_encoder(cls, task_ckpt_path: str | Path, **kwargs): task.encoder.load_state_dict(encoder_weights) return task + def _log_embedding(self, embeddings: Embeddings) -> None: + """ + This maps the appropriate logging function depending on what + logger was used, and saves the graph and node level embeddings. + + Some services like ``wandb`` are able to do some nifty embedding + analyses online using these embeddings. + + Parameters + ---------- + embeddings : Embeddings + Data structure containing embeddings from the encoder. + """ + log_freq = self.hparams.log_embeddings_every_n_steps + global_step = self.trainer.global_step + # only log embeddings at the same cadence as everything else + if self.logger is not None and (global_step % log_freq) == 0: + exp = self.logger.experiment + sys_z = embeddings.system_embedding.detach().cpu() + node_z = embeddings.point_embedding.detach().cpu() + if isinstance(self.logger, pl_loggers.WandbLogger): + # this import is okay here since we need it for the logger anyway + import wandb + + cols = [f"D{i}" for i in range(sys_z.size(-1))] + exp.log( + {"graph_embeddings": wandb.Table(columns=cols, data=sys_z.tolist())} + ) + if isinstance(embeddings.point_embedding, torch.Tensor): + # TODO: should add labels to the nodes based on graph index + exp.log( + { + "node_embeddings": wandb.Table( + columns=cols, data=node_z.tolist() + ) + } + ) + elif isinstance(self.logger, pl_loggers.TensorBoardLogger): + exp.add_embedding( + sys_z, + tag=f"graph_embeddings_{self.trainer.global_step}", + ) + if isinstance(embeddings.point_embedding, torch.Tensor): + # TODO: should add labels to the nodes based on graph index + exp.add_embedding( + node_z, + tag=f"node_embeddings_{self.trainer.global_step}", + ) + else: + pass + @registry.register_task("OpenCatalystInference") class OpenCatalystInference(ABC, pl.LightningModule): From d53cde41b383f122a3810f1034538b406a5d6aec Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Mon, 1 Jul 2024 11:14:05 -0700 Subject: [PATCH 32/32] refactor: logging embeddings for both training and validation steps in multitas --- matsciml/models/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/matsciml/models/base.py b/matsciml/models/base.py index ba548fd1..c19b8e4d 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -3051,6 +3051,9 @@ def training_step( prog_bar=True, batch_size=batch_info["batch_size"], ) + # optionally log embeddings + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return losses def validation_step( @@ -3105,6 +3108,9 @@ def validation_step( prog_bar=True, batch_size=batch_info["batch_size"], ) + # optionally log embeddings + if self.hparams.log_embeddings and "embeddings" in batch: + self._log_embedding(batch["embeddings"]) return losses @classmethod