diff --git a/didactic/data/cardinal/utils.py b/didactic/data/cardinal/utils.py index 67c95b01..1db71b70 100644 --- a/didactic/data/cardinal/utils.py +++ b/didactic/data/cardinal/utils.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd -from vital.data.cardinal.config import CardinalTag, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.utils.attributes import build_attributes_dataframe from vital.data.cardinal.utils.data_struct import Patient @@ -107,25 +107,25 @@ def _to_num(attr_data: pd.Series) -> pd.Series: return neigh_data -def build_img_attr_by_patient_group_dataframe( +def build_time_series_attr_by_patient_group_dataframe( patients_groups: Mapping[Hashable, Iterable[Patient]], - attr: Tuple[ViewEnum, ImageAttribute], + attr: Tuple[ViewEnum, TimeSeriesAttribute], group_desc: str = "group", mask_tag: str = CardinalTag.mask, resampling_rate: int = 128, ) -> pd.DataFrame: - """Builds a dataframe with the average curve of an image attribute by patient group. + """Builds a dataframe with the average curve of a time-series attribute by patient group. Args: patients_groups: Mapping between group ID and the patients in that group. - attr: A pair of view and image attribute to compute the average curve of. + attr: A pair of view and time-series attribute to compute the average curve of. group_desc: Description of the semantic meaning of the groups. - mask_tag: Tag of the segmentation mask for which to extract the image attribute data. - resampling_rate: Number of points at which to resample the image attribute curves from each patient, so that - they can be easily compared and aggregated together. + mask_tag: Tag of the segmentation mask for which to extract the time-series attribute data. + resampling_rate: Number of points at which to resample the time-series attribute curves from each patient, so + that they can be easily compared and aggregated together. Returns: - Dataframe with the average curve of an image attribute by patient group, in long format. + Dataframe with the average curve of an time-series attribute by patient group, in long format. """ resampling_fn = Interp1d(resampling_rate) diff --git a/didactic/tasks/cardiac_multimodal_representation.py b/didactic/tasks/cardiac_multimodal_representation.py index fa929167..7fdb53e2 100644 --- a/didactic/tasks/cardiac_multimodal_representation.py +++ b/didactic/tasks/cardiac_multimodal_representation.py @@ -14,10 +14,10 @@ from torch.nn import Parameter, ParameterDict, init from torchmetrics.functional import accuracy, mean_absolute_error from vital.data.augmentation.base import mask_tokens, random_masking -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_image_attributes -from vital.data.cardinal.utils.attributes import CLINICAL_CAT_ATTR_LABELS +from vital.data.cardinal.datapipes import MISSING_CAT_ATTR, PatientData, filter_time_series_attributes +from vital.data.cardinal.utils.attributes import TABULAR_CAT_ATTR_LABELS from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data @@ -25,7 +25,7 @@ from didactic.models.time_series import TimeSeriesEmbedding logger = logging.getLogger(__name__) -CardiacAttribute = ClinicalAttribute | Tuple[ViewEnum, ImageAttribute] +CardiacAttribute = TabularAttribute | Tuple[ViewEnum, TimeSeriesAttribute] class CardiacMultimodalRepresentationTask(SharedStepsTask): @@ -34,15 +34,15 @@ class CardiacMultimodalRepresentationTask(SharedStepsTask): def __init__( self, embed_dim: int, - clinical_attrs: Sequence[ClinicalAttribute | str], - img_attrs: Sequence[ImageAttribute], + tabular_attrs: Sequence[TabularAttribute | str], + time_series_attrs: Sequence[TimeSeriesAttribute], views: Sequence[ViewEnum] = tuple(ViewEnum), - predict_losses: Dict[ClinicalAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, + predict_losses: Dict[TabularAttribute | str, Callable[[Tensor, Tensor], Tensor]] | DictConfig = None, ordinal_mode: bool = True, contrastive_loss: Callable[[Tensor, Tensor], Tensor] | DictConfig = None, contrastive_loss_weight: float = 0, - clinical_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, - img_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None, + tabular_tokenizer: Optional[FeatureTokenizer | DictConfig] = None, + time_series_tokenizer: Optional[TimeSeriesEmbedding | DictConfig] = None, cls_token: bool = True, sequential_pooling: bool = False, mtr_p: float | Tuple[float, float] = 0, @@ -54,6 +54,9 @@ def __init__( Args: embed_dim: Size of the tokens/embedding for all the modalities. + tabular_attrs: Tabular attributes to provide to the model. + time_series_attrs: Time-series attributes to provide to the model. + views: Views from which to include time-series attributes. predict_losses: Supervised criteria to measure the error between the predicted attributes and their real value. ordinal_mode: Whether to consider applicable targets as ordinal variables, which means: @@ -63,11 +66,8 @@ def __init__( of feature vectors, in a contrastive learning step that follows the SCARF pretraining. (see ref: https://arxiv.org/abs/2106.15147) contrastive_loss_weight: Factor by which to weight the `contrastive_loss` in the overall loss. - clinical_attrs: Clinical attributes to provide to the model. - img_attrs: Image attributes to provide to the model. - views: Views from which to include image attributes. - clinical_tokenizer: Tokenizer that can process clinical, i.e. patient records, data. - img_tokenizer: Tokenizer that can process imaging data. + tabular_tokenizer: Tokenizer that can process tabular, i.e. patient records, data. + time_series_tokenizer: Tokenizer that can process time-series data. cls_token: Whether to add a CLS token to use as the encoder's output token. sequential_pooling: Whether to perform sequential pooling on the encoder's output tokens. Otherwise, the full sequence of tokens is concatenated before being fed to the prediction head. @@ -82,9 +82,9 @@ def __init__( """ # Ensure string tags are converted to their appropriate enum types # And to it before call to the parent's `init` so that the converted values are saved in `hparams` - clinical_attrs = tuple(ClinicalAttribute[e] for e in clinical_attrs) + tabular_attrs = tuple(TabularAttribute[e] for e in tabular_attrs) views = tuple(ViewEnum[e] for e in views) - img_attrs = tuple(ImageAttribute[e] for e in img_attrs) + time_series_attrs = tuple(TimeSeriesAttribute[e] for e in time_series_attrs) # If dropout/masking are not single numbers, make sure they are tuples (and not another container type) if not isinstance(mtr_p, (int, float)): @@ -104,66 +104,71 @@ def __init__( "dimensionality of the encoder's output from a sequence of tokens to only one token." ) - if not clinical_tokenizer and clinical_attrs: + if not tabular_tokenizer and tabular_attrs: raise ValueError( - f"You have requested the following attributes derived from clinical data: " - f"{[str(attr) for attr in clinical_attrs]}, but have not configured a tokenizer for clinical-based " - f"attributes. Either provide this tokenizer (through the `clinical_tokenizer` parameter) or remove any " - f"clinical-based attributes (by setting the `clinical_attrs` to be an empty list)." + f"You have requested the following tabular attributes: " + f"{[str(attr) for attr in tabular_attrs]}, but have not configured a tokenizer for tabular attributes. " + f"Either provide this tokenizer (through the `tabular_tokenizer` parameter) or remove any tabular " + f"attributes (by setting the `tabular_attrs` to be an empty list)." ) - if img_attrs: - if not img_tokenizer: + if time_series_attrs: + if not time_series_tokenizer: raise ValueError( - f"You have requested the following attributes derived from imaging data: " - f"{[str(attr) for attr in img_attrs]}, but have not configured a tokenizer for image-based " - f"attributes. Either provide this tokenizer (through the `img_tokenizer` parameter) or remove any " - f"image-based attributes (by setting the `img_attrs` to be an empty list)." + f"You have requested the following time-series attributes: " + f"{[str(attr) for attr in time_series_attrs]}, but have not configured a tokenizer for time-series " + f"attributes. Either provide this tokenizer (through the `time_series_tokenizer` parameter) or " + f"remove any time-series attributes (by setting the `time_series_attrs` to be an empty list)." ) if ( - img_tokenizer.model if isinstance(img_tokenizer, TimeSeriesEmbedding) else img_tokenizer.get("model") + time_series_tokenizer.model + if isinstance(time_series_tokenizer, TimeSeriesEmbedding) + else time_series_tokenizer.get("model") ) is None: logger.warning( - f"You have requested the following attributes derived from imaging data: " - f"{[str(attr) for attr in img_attrs]}, but have not configured a model for the tokenizer for " - f"image-based attributes. The tokenizer's model is optional, but highly recommended, so this is " - f"likely an oversight. You can provide this model through the `img_tokenizer.model` parameter." + f"You have requested the following time-series attributes: " + f"{[str(attr) for attr in time_series_attrs]}, but have not configured a model for the tokenizer " + f"for time-series attributes. The tokenizer's model is optional, but highly recommended, so this " + f"is likely an oversight. You can provide this model through the `time_series_tokenizer.model` " + f"parameter." ) - if not (clinical_attrs or img_attrs): + if not (tabular_attrs or time_series_attrs): raise ValueError( - "You configured neither clinical attributes nor image attributes as input variables to the model, but " - "the model requires at least one input. Set non-empty values for either or both `clinical_attrs` and " - "`img_attrs`." + "You configured neither tabular attributes nor time-series attributes as input variables to the model, " + "but the model requires at least one input. Set non-empty values for either or both `tabular_attrs` " + "and `time_series_attrs`." ) super().__init__(*args, **kwargs) - # TOFIX: Hack to log image tokenizer model's hparams when it's a config for a `torch.nn.Sequential` object + # TOFIX: Hack to log time-series tokenizer model's hparams when it's a config for a `torch.nn.Sequential` object # In that case, we have to use a `ListConfig` for the reserved `_args_` key. However, the automatic # serialization of `_args_` fails (w/ a '`DictConfig' not JSON serializable' error). Therefore, we fix it by # manually unpacking and logging the first and only element in the `_args_` `ListConfig` - if isinstance(img_tokenizer, DictConfig): - if img_tokenizer.get("model", {}).get("_target_") == "torch.nn.Sequential": - self.save_hyperparameters({"img_tokenizer/model/_args_/0": img_tokenizer.model._args_[0]}) + if isinstance(time_series_tokenizer, DictConfig): + if time_series_tokenizer.get("model", {}).get("_target_") == "torch.nn.Sequential": + self.save_hyperparameters( + {"time_series_tokenizer/model/_args_/0": time_series_tokenizer.model._args_[0]} + ) # Add shortcut to lr to work with Lightning's learning rate finder self.hparams.lr = None # Add shortcut to token labels to avoid downstream applications having to determine them from hyperparameters - self.token_tags = clinical_attrs + tuple( - "/".join([view, attr]) for view, attr in itertools.product(views, img_attrs) + self.token_tags = tabular_attrs + tuple( + "/".join([view, attr]) for view, attr in itertools.product(views, time_series_attrs) ) if cls_token: self.token_tags = self.token_tags + ("LAT",) - # Categorise the clinical attributes (tabular data) in terms of their type (numerical vs categorical) - self.clinical_num_attrs = [ - attr for attr in self.hparams.clinical_attrs if attr in ClinicalAttribute.numerical_attrs() + # Categorise the tabular attributes in terms of their type (numerical vs categorical) + self.tabular_num_attrs = [ + attr for attr in self.hparams.tabular_attrs if attr in TabularAttribute.numerical_attrs() ] - self.clinical_cat_attrs = [ - attr for attr in self.hparams.clinical_attrs if attr in ClinicalAttribute.categorical_attrs() + self.tabular_cat_attrs = [ + attr for attr in self.hparams.tabular_attrs if attr in TabularAttribute.categorical_attrs() ] - self.clinical_cat_attrs_cardinalities = [ - len(CLINICAL_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.clinical_cat_attrs + self.tabular_cat_attrs_cardinalities = [ + len(TABULAR_CAT_ATTR_LABELS[cat_attr]) for cat_attr in self.tabular_cat_attrs ] # Extract train/test masking probabilities from their configs @@ -180,23 +185,23 @@ def __init__( self.predict_losses = {} if predict_losses: self.predict_losses = { - ClinicalAttribute[attr]: hydra.utils.instantiate(attr_loss) + TabularAttribute[attr]: hydra.utils.instantiate(attr_loss) if isinstance(attr_loss, DictConfig) else attr_loss for attr, attr_loss in predict_losses.items() } - self.hparams.target_clinical_attrs = tuple( + self.hparams.target_tabular_attrs = tuple( self.predict_losses ) # Hyperparameter to easily access target attributes for attr in self.predict_losses: - if attr in ClinicalAttribute.numerical_attrs(): + if attr in TabularAttribute.numerical_attrs(): self.metrics[attr] = {"mae": mean_absolute_error} - elif attr in ClinicalAttribute.binary_attrs(): + elif attr in TabularAttribute.binary_attrs(): self.metrics[attr] = {"acc": functools.partial(accuracy, task="binary")} - else: # attr in ClinicalAttribute.categorical_attrs() + else: # attr in TabularAttribute.categorical_attrs() self.metrics[attr] = { "acc": functools.partial( - accuracy, task="multiclass", num_classes=len(CLINICAL_CAT_ATTR_LABELS[attr]) + accuracy, task="multiclass", num_classes=len(TABULAR_CAT_ATTR_LABELS[attr]) ) } @@ -211,8 +216,8 @@ def __init__( # Compute shapes relevant for defining the models' architectures self.sequence_length = ( - len(self.hparams.clinical_attrs) - + (len(self.hparams.img_attrs) * len(self.hparams.views)) + len(self.hparams.tabular_attrs) + + (len(self.hparams.time_series_attrs) * len(self.hparams.views)) + self.hparams.cls_token ) @@ -232,25 +237,25 @@ def __init__( f"of the number of attention heads for your configuration above this warning." ) - if clinical_attrs: - if isinstance(clinical_tokenizer, DictConfig): - clinical_tokenizer = hydra.utils.instantiate( - clinical_tokenizer, - n_num_features=len(self.clinical_num_attrs), - cat_cardinalities=self.clinical_cat_attrs_cardinalities, + if tabular_attrs: + if isinstance(tabular_tokenizer, DictConfig): + tabular_tokenizer = hydra.utils.instantiate( + tabular_tokenizer, + n_num_features=len(self.tabular_num_attrs), + cat_cardinalities=self.tabular_cat_attrs_cardinalities, ) else: # Set tokenizer to `None` if it's not going to be used - clinical_tokenizer = None - self.clinical_tokenizer = clinical_tokenizer + tabular_tokenizer = None + self.tabular_tokenizer = tabular_tokenizer - if img_attrs: - if isinstance(img_tokenizer, DictConfig): - img_tokenizer = hydra.utils.instantiate(img_tokenizer) + if time_series_attrs: + if isinstance(time_series_tokenizer, DictConfig): + time_series_tokenizer = hydra.utils.instantiate(time_series_tokenizer) else: # Set tokenizer to `None` if it's not going to be used - img_tokenizer = None - self.img_tokenizer = img_tokenizer + time_series_tokenizer = None + self.time_series_tokenizer = time_series_tokenizer # Initialize modules/parameters dependent on the encoder's configuration @@ -287,18 +292,18 @@ def _init_mask_token() -> Parameter: @property def example_input_array( self, - ) -> Tuple[Dict[ClinicalAttribute, Tensor], Dict[Tuple[ViewEnum, ImageAttribute], Tensor]]: + ) -> Tuple[Dict[TabularAttribute, Tensor], Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]]: """Redefine example input array based on the cardiac attributes provided to the model.""" # 2 is the size of the batch in the example - clinical_attrs = {attr: torch.randn(2) for attr in self.clinical_num_attrs} + tab_attrs = {attr: torch.randn(2) for attr in self.tabular_num_attrs} # Only generate 0/1 labels, to avoid generating labels bigger than the number of classes, which would lead to # an index out of range error when looking up the embedding of the class in the categorical feature tokenizer - clinical_attrs.update({attr: torch.randint(2, (2,)) for attr in self.clinical_cat_attrs}) - img_attrs = { - (view, attr): torch.randn(2, self.hparams.data_params.in_shape[CardinalTag.image_attrs][1]) - for view, attr in itertools.product(self.hparams.views, self.hparams.img_attrs) + tab_attrs.update({attr: torch.randint(2, (2,)) for attr in self.tabular_cat_attrs}) + time_series_attrs = { + (view, attr): torch.randn(2, self.hparams.data_params.in_shape[CardinalTag.time_series_attrs][1]) + for view, attr in itertools.product(self.hparams.views, self.hparams.time_series_attrs) } - return clinical_attrs, img_attrs + return tab_attrs, time_series_attrs def configure_model( self, @@ -312,50 +317,52 @@ def configure_model( if self.contrastive_loss: contrastive_head = hydra.utils.instantiate(self.hparams.model.contrastive_head) - # Build the prediction heads (one by clinical attribute to predict) following the architecture proposed in + # Build the prediction heads (one by tabular attribute to predict) following the architecture proposed in # https://arxiv.org/pdf/2106.11959 prediction_heads = None if self.predict_losses: prediction_heads = nn.ModuleDict() - for target_clinical_attr in self.predict_losses: + for target_tab_attr in self.predict_losses: if ( - target_clinical_attr in ClinicalAttribute.categorical_attrs() - and target_clinical_attr not in ClinicalAttribute.binary_attrs() + target_tab_attr in TabularAttribute.categorical_attrs() + and target_tab_attr not in TabularAttribute.binary_attrs() ): # Multi-class classification target - output_size = len(CLINICAL_CAT_ATTR_LABELS[target_clinical_attr]) + output_size = len(TABULAR_CAT_ATTR_LABELS[target_tab_attr]) else: # Binary classification or regression target output_size = 1 - if self.hparams.ordinal_mode and target_clinical_attr in ClinicalAttribute.ordinal_attrs(): + if self.hparams.ordinal_mode and target_tab_attr in TabularAttribute.ordinal_attrs(): # For ordinal targets, use a separate prediction head config - prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + prediction_heads[target_tab_attr] = hydra.utils.instantiate( self.hparams.model.ordinal_head, num_logits=output_size ) else: - prediction_heads[target_clinical_attr] = hydra.utils.instantiate( + prediction_heads[target_tab_attr] = hydra.utils.instantiate( self.hparams.model.prediction_head, out_features=output_size ) return encoder, contrastive_head, prediction_heads def configure_optimizers(self) -> Dict[Literal["optimizer", "lr_scheduler"], Any]: - """Configure optimizer to ignore parameters that should remain frozen (e.g. image tokenizer).""" + """Configure optimizer to ignore parameters that should remain frozen (e.g. tokenizers).""" return super().configure_optimizers(params=filter(lambda p: p.requires_grad, self.parameters())) @auto_move_data def tokenize( - self, clinical_attrs: Dict[ClinicalAttribute, Tensor], img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] + self, + tabular_attrs: Dict[TabularAttribute, Tensor], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor], ) -> Tuple[Tensor, Tensor]: - """Tokenizes the input clinical and image attributes, providing a mask of non-missing attributes. + """Tokenizes the input tabular and time-series attributes, providing a mask of non-missing attributes. Args: - clinical_attrs: (K: S, V: N), Sequence of batches of clinical attributes. To indicate an item is missing an + tabular_attrs: (K: S, V: N), Sequence of batches of tabular attributes. To indicate an item is missing an attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical attributes, respectively. - img_attrs: (K: S, V: (N, ?)), Sequence of batches of image attributes, where the dimensionality of each - attribute can vary. + time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the + dimensionality of each attribute can vary. Returns: Batch of i) (N, S, E) tokens for each attribute, and ii) (N, S) mask of non-missing attributes. @@ -364,17 +371,17 @@ def tokenize( tokens, notna_mask = [], [] # Tokenize the attributes - if clinical_attrs: - clinical_num_attrs, clinical_cat_attrs = None, None - if self.clinical_num_attrs: - # Group the numerical attributes from the `clinical_attrs` input in a single tensor - clinical_num_attrs = torch.hstack( - [clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_num_attrs] + if tabular_attrs: + num_attrs, cat_attrs = None, None + if self.tabular_num_attrs: + # Group the numerical attributes from the `tabular_attrs` input in a single tensor + num_attrs = torch.hstack( + [tabular_attrs[attr].unsqueeze(1) for attr in self.tabular_num_attrs] ) # (N, S_num) - if self.clinical_cat_attrs: - # Group the categorical attributes from the `clinical_attrs` input in a single tensor - clinical_cat_attrs = torch.hstack( - [clinical_attrs[attr].unsqueeze(1) for attr in self.clinical_cat_attrs] + if self.tabular_cat_attrs: + # Group the categorical attributes from the `tabular_attrs` input in a single tensor + cat_attrs = torch.hstack( + [tabular_attrs[attr].unsqueeze(1) for attr in self.tabular_cat_attrs] ) # (N, S_cat) # Use "sanitized" version of the inputs, where invalid values are replaced by null/default values, for the # tokenization process. This is done to avoid propagating NaNs to available/valid values. @@ -383,30 +390,32 @@ def tokenize( # instead of their current null/default values. # 1) Convert missing numerical attributes (NaNs) to numbers to avoid propagating NaNs # 2) Clip categorical labels to convert indicators of missing data (-1) into valid indices (0) - clinical_attrs_tokens = self.clinical_tokenizer( - x_num=torch.nan_to_num(clinical_num_attrs) if clinical_num_attrs is not None else None, - x_cat=clinical_cat_attrs.clip(0) if clinical_cat_attrs is not None else None, - ) # (N, S_clinical, E) - tokens.append(clinical_attrs_tokens) - - # Identify missing data in clinical attributes - if self.clinical_num_attrs: - notna_mask.append(~(clinical_num_attrs.isnan())) - if self.clinical_cat_attrs: - notna_mask.append(clinical_cat_attrs != MISSING_CAT_ATTR) - - if img_attrs: - img_attrs_tokens = self.img_tokenizer(img_attrs) # S * (N, ?) -> (N, S_img, E) - tokens.append(img_attrs_tokens) - - # Indicate that, when image tokens are requested, they are always available - image_notna_mask = torch.full(img_attrs_tokens.shape[:2], True, device=img_attrs_tokens.device) - notna_mask.append(image_notna_mask) + tab_attrs_tokens = self.tabular_tokenizer( + x_num=torch.nan_to_num(num_attrs) if num_attrs is not None else None, + x_cat=cat_attrs.clip(0) if cat_attrs is not None else None, + ) # (N, S_tab, E) + tokens.append(tab_attrs_tokens) + + # Identify missing data in tabular attributes + if self.tabular_num_attrs: + notna_mask.append(~(num_attrs.isnan())) + if self.tabular_cat_attrs: + notna_mask.append(cat_attrs != MISSING_CAT_ATTR) + + if time_series_attrs: + time_series_attrs_tokens = self.time_series_tokenizer(time_series_attrs) # S * (N, ?) -> (N, S_ts, E) + tokens.append(time_series_attrs_tokens) + + # Indicate that, when time-series tokens are requested, they are always available + time_series_notna_mask = torch.full( + time_series_attrs_tokens.shape[:2], True, device=time_series_attrs_tokens.device + ) + notna_mask.append(time_series_notna_mask) # Cast to float to make sure tokens are not represented using double - tokens = torch.cat(tokens, dim=1).float() # (N, S_clinical + S_img, E) + tokens = torch.cat(tokens, dim=1).float() # (N, S_tab + S_ts, E) # Cast to bool to make sure attention mask is represented by bool - notna_mask = torch.cat(notna_mask, dim=1).bool() # (N, S_clinical + S_img) + notna_mask = torch.cat(notna_mask, dim=1).bool() # (N, S_tab + S_ts) return tokens, notna_mask @@ -464,18 +473,18 @@ def encode(self, tokens: Tensor, avail_mask: Tensor, disable_augments: bool = Fa @auto_move_data def forward( self, - clinical_attrs: Dict[ClinicalAttribute, Tensor], - img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor], + tabular_attrs: Dict[TabularAttribute, Tensor], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor], task: Literal["encode", "predict", "unimodal_param", "unimodal_tau"] = "encode", - ) -> Tensor | Dict[ClinicalAttribute, Tensor]: + ) -> Tensor | Dict[TabularAttribute, Tensor]: """Performs a forward pass through i) the tokenizer, ii) the transformer encoder and iii) the prediction head. Args: - clinical_attrs: (K: S, V: N) Sequence of batches of clinical attributes. To indicate an item is missing an + tabular_attrs: (K: S, V: N) Sequence of batches of tabular attributes. To indicate an item is missing an attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical attributes, respectively. - img_attrs: (K: S, V: (N, ?)), Sequence of batches of image attributes, where the dimensionality of each - attribute can vary. + time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the + dimensionality of each attribute can vary. task: Flag indicating which type of inference task to perform. Returns: @@ -500,7 +509,7 @@ def forward( "the requested inference task." ) - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) + in_tokens, avail_mask = self.tokenize(tabular_attrs, time_series_attrs) # (N, S, E), (N, S) out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) # Early return if requested task requires no prediction heads @@ -515,7 +524,7 @@ def forward( case "predict": if self.hparams.ordinal_mode: predictions = { - attr: pred[0] if attr in ClinicalAttribute.ordinal_attrs() else pred + attr: pred[0] if attr in TabularAttribute.ordinal_attrs() else pred for attr, pred in predictions.items() } case "unimodal_param": @@ -530,11 +539,13 @@ def forward( return predictions def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: - # Extract clinical and image attributes from the batch - clinical_attrs = {attr: attr_data for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs} - img_attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.img_attrs) + # Extract tabular and time-series attributes from the batch + tabular_attrs = {attr: attr_data for attr, attr_data in batch.items() if attr in self.hparams.tabular_attrs} + time_series_attrs = filter_time_series_attributes( + batch, views=self.hparams.views, attributes=self.hparams.time_series_attrs + ) - in_tokens, avail_mask = self.tokenize(clinical_attrs, img_attrs) # (N, S, E), (N, S) + in_tokens, avail_mask = self.tokenize(tabular_attrs, time_series_attrs) # (N, S, E), (N, S) out_features = self.encode(in_tokens, avail_mask) # (N, S, E) -> (N, E) metrics = {} @@ -557,7 +568,7 @@ def _prediction_shared_step( predictions = {} for attr, prediction_head in self.prediction_heads.items(): pred = prediction_head(out_features) - if self.hparams.ordinal_mode and attr in ClinicalAttribute.ordinal_attrs(): + if self.hparams.ordinal_mode and attr in TabularAttribute.ordinal_attrs(): # For ordinal targets, extract the logits from the multiple outputs of unimodal logits head pred = pred[0] predictions[attr] = pred.squeeze(dim=1) @@ -567,16 +578,16 @@ def _prediction_shared_step( for attr, loss in self.predict_losses.items(): target, y_hat = batch[attr], predictions[attr] - if attr in ClinicalAttribute.categorical_attrs(): + if attr in TabularAttribute.categorical_attrs(): notna_mask = target != MISSING_CAT_ATTR - else: # attr in ClinicalAttribute.numerical_attrs(): + else: # attr in TabularAttribute.numerical_attrs(): notna_mask = ~target.isnan() losses[f"{loss.__class__.__name__.lower().replace('loss', '')}/{attr}"] = loss( y_hat[notna_mask], # For BCE losses (e.g. `BCELoss`, BCEWithLogitsLoss`, etc.), the targets have to be floats, # so convert them from long to float - target[notna_mask] if attr not in ClinicalAttribute.binary_attrs() else target[notna_mask].float(), + target[notna_mask] if attr not in TabularAttribute.binary_attrs() else target[notna_mask].float(), ) for metric_tag, metric in self.metrics[attr].items(): @@ -608,34 +619,34 @@ def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 ) -> Tuple[ Tensor, - Optional[Dict[ClinicalAttribute, Tensor]], - Optional[Dict[ClinicalAttribute, Tensor]], - Optional[Dict[ClinicalAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], + Optional[Dict[TabularAttribute, Tensor]], ]: - # Extract clinical and image attributes from the patient and add batch dimension - clinical_attrs = { - attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.clinical_attrs + # Extract tabular and time-series attributes from the patient and add batch dimension + tabular_attrs = { + attr: attr_data[None, ...] for attr, attr_data in batch.items() if attr in self.hparams.tabular_attrs } - img_attrs = { + time_series_attrs = { attr: attr_data[None, ...] - for attr, attr_data in filter_image_attributes( - batch, views=self.hparams.views, attributes=self.hparams.img_attrs + for attr, attr_data in filter_time_series_attributes( + batch, views=self.hparams.views, attributes=self.hparams.time_series_attrs ).items() } # Encoder's output - out_features = self(clinical_attrs, img_attrs) + out_features = self(tabular_attrs, time_series_attrs) # If the model has targets to predict, output the predictions predictions = None if self.prediction_heads: - predictions = self(clinical_attrs, img_attrs, task="predict") + predictions = self(tabular_attrs, time_series_attrs, task="predict") # If the model enforces unimodal constraint on ordinal targets, output the unimodal parametrization unimodal_params, unimodal_taus = None, None if self.hparams.ordinal_mode: - unimodal_params = self(clinical_attrs, img_attrs, task="unimodal_param") - unimodal_taus = self(clinical_attrs, img_attrs, task="unimodal_tau") + unimodal_params = self(tabular_attrs, time_series_attrs, task="unimodal_param") + unimodal_taus = self(tabular_attrs, time_series_attrs, task="unimodal_tau") # Remove unnecessary batch dimension from the different outputs # (only do this once all downstream inferences have been performed) diff --git a/didactic/tasks/cardiac_representation_clustering.py b/didactic/tasks/cardiac_representation_clustering.py index d9f052eb..e6e55091 100644 --- a/didactic/tasks/cardiac_representation_clustering.py +++ b/didactic/tasks/cardiac_representation_clustering.py @@ -12,7 +12,7 @@ from sklearn.mixture import GaussianMixture from sklearn.model_selection import GridSearchCV from tqdm.auto import tqdm -from vital.data.cardinal.config import ClinicalAttribute +from vital.data.cardinal.config import TabularAttribute class GridSearchEnsembleClustering: @@ -266,7 +266,7 @@ def main(): "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--output_dir", @@ -311,9 +311,9 @@ def main(): ) parser.add_argument( "--order_clusters_by", - type=ClinicalAttribute, - choices=list(ClinicalAttribute), - default=ClinicalAttribute.ht_grade, + type=TabularAttribute, + choices=list(TabularAttribute), + default=TabularAttribute.ht_grade, help="Attribute used to order the clusters by ascending mean value by cluster", ) args = parser.parse_args() diff --git a/didactic/tasks/cardiac_sequence_attrs_ae.py b/didactic/tasks/cardiac_sequence_attrs_ae.py index 828ecead..7b2c1347 100644 --- a/didactic/tasks/cardiac_sequence_attrs_ae.py +++ b/didactic/tasks/cardiac_sequence_attrs_ae.py @@ -10,9 +10,9 @@ from strenum import SnakeCaseStrEnum from torch import Tensor, nn from torch.nn import functional as F -from vital.data.cardinal.config import CardinalTag, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum -from vital.data.cardinal.datapipes import PatientData, filter_image_attributes +from vital.data.cardinal.datapipes import PatientData, filter_time_series_attributes from vital.tasks.generic import SharedStepsTask from vital.utils.decorators import auto_move_data from vital.utils.norm import minmax_scaling, scale @@ -23,29 +23,29 @@ @unique class _AttributeNormalization(SnakeCaseStrEnum): - """Names of the available strategies for normalizing image attributes values.""" + """Names of the available strategies for normalizing time-series attributes values.""" data = auto() """Normalize the attributes' values w.r.t. statistics computed on the whole training dataset.""" metrics = auto() - """Normalize the metrics computed on the image attributes w.r.t. statistics computed on each mini-batch.""" + """Normalize the metrics computed on the time-series attributes w.r.t. statistics computed on each mini-batch.""" @unique class _AttributeStatistic(SnakeCaseStrEnum): - """Statistics about the image attributes that are computed on the dataset and stored inside the model.""" + """Statistics about the time-series attributes that are computed on the dataset and stored inside the model.""" min = auto() max = auto() class CardiacSequenceAttributesAutoencoder(SharedStepsTask): - """Autoencoder pipeline specialized for cardiac sequences image attributes.""" + """Autoencoder pipeline specialized for cardiac sequences time-series attributes.""" def __init__( self, views: Sequence[ViewEnum] = tuple(ViewEnum), - attrs: Sequence[ImageAttribute] = tuple(ImageAttribute), + attrs: Sequence[TimeSeriesAttribute] = tuple(TimeSeriesAttribute), normalization: _AttributeNormalization = _AttributeNormalization.data, reconstruction_loss: nn.Module | DictConfig = nn.L1Loss(), *args, @@ -55,10 +55,10 @@ def __init__( Args: views: Views to train the model on. - attrs: Image attributes to train the model on. - normalization: Strategy to use to normalize image attributes values. - reconstruction_loss: Criterion to measure the reconstruction error on the image attribute curves, or Hydra - config object describing how to instantiate such criterion. + attrs: Time-series attributes to train the model on. + normalization: Strategy to use to normalize time-series attributes values. + reconstruction_loss: Criterion to measure the reconstruction error on the time-series attribute curves, or + Hydra config object describing how to instantiate such criterion. *args: Positional arguments to pass to the parent's constructor. **kwargs: Keyword arguments to pass to the parent's constructor. """ @@ -76,7 +76,7 @@ def __init__( self.reconstruction_loss = reconstruction_loss self._reconstruction_loss_name = self.reconstruction_loss.__class__.__name__.lower().replace("loss", "") - # Register buffers for image attributes (needs to be in `__init__`) + # Register buffers for time-series attributes (needs to be in `__init__`) attrs_stats_defaults = { _AttributeStatistic.min: torch.finfo().max, _AttributeStatistic.max: torch.finfo().min, @@ -88,9 +88,9 @@ def __init__( @property def example_input_array(self) -> Tensor: - """Redefine example input array based only on the image attributes modality.""" - img_attrs_shape = self.hparams.data_params.in_shape[CardinalTag.image_attrs] - return torch.randn((2, 1, img_attrs_shape[1])) + """Redefine example input array based only on the time-series attributes modality.""" + attrs_shape = self.hparams.data_params.in_shape[CardinalTag.time_series_attrs] + return torch.randn((2, 1, attrs_shape[1])) @property def latent_dim(self) -> int: @@ -117,31 +117,33 @@ def reconstruction_loss_scale(self) -> int: ) return reconstruction_loss_scale - def _get_attr_bounds(self, attr: Tuple[ViewEnum, ImageAttribute]) -> Tuple[Tensor, Tensor]: - """Access the stored min/max bounds related to an image attribute. + def _get_attr_bounds(self, attr: Tuple[ViewEnum, TimeSeriesAttribute]) -> Tuple[Tensor, Tensor]: + """Access the stored min/max bounds related to a time-series attribute. Args: attr: Key identifying the attribute for which to look up the bounds. Returns: - Min/max bounds for the requested image attribute. + Min/max bounds for the requested time-series attribute. """ return self._get_attr_stat(attr, _AttributeStatistic.min), self._get_attr_stat(attr, _AttributeStatistic.max) - def _get_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _AttributeStatistic) -> Tensor: - """Access a statistic related to an image attribute, saved as a torch buffer inside the model. + def _get_attr_stat(self, attr: Tuple[ViewEnum, TimeSeriesAttribute], stat: _AttributeStatistic) -> Tensor: + """Access a statistic related to a time-series attribute, saved as a torch buffer inside the model. Args: attr: Key identifying the attribute for which to look up the statistic. stat: Statistic to look up. Returns: - Statistic for the requested image attribute. + Statistic for the requested time-series attribute. """ return getattr(self, "_".join((*attr, stat))) - def _set_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _AttributeStatistic, val: Tensor) -> None: - """Sets the value of a statistic related to an image attribute, saved as a torch buffer inside the model. + def _set_attr_stat( + self, attr: Tuple[ViewEnum, TimeSeriesAttribute], stat: _AttributeStatistic, val: Tensor + ) -> None: + """Sets the value of a statistic related to a time-series attribute, saved as a torch buffer inside the model. Args: attr: Key identifying the attribute for which to set the statistic. @@ -152,12 +154,12 @@ def _set_attr_stat(self, attr: Tuple[ViewEnum, ImageAttribute], stat: _Attribute def configure_model(self) -> nn.Module: """Configure the network architecture used by the system.""" - attrs_shape = self.hparams.data_params.in_shape[CardinalTag.image_attrs] + attrs_shape = self.hparams.data_params.in_shape[CardinalTag.time_series_attrs] model = hydra.utils.instantiate(self.hparams.model, input_shape=(1, attrs_shape[-1])) return model def on_fit_start(self) -> None: - """Computes global statistics for the image attributes on the training subset. + """Computes global statistics for the time-series attributes on the training subset. These stats will be used during training and inference to normalize attributes values or metrics """ @@ -169,7 +171,7 @@ def on_fit_start(self) -> None: for batch in train_dl: for (stat, update_fn), (attr, attr_data) in itertools.product( attrs_stats_update_fn.items(), - filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs).items(), + filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs).items(), ): self._set_attr_stat(attr, stat, update_fn(self._get_attr_stat(attr, stat), attr_data)) @@ -222,7 +224,7 @@ def forward( self, x: Tensor, task: Literal["encode", "decode", "reconstruct"] = "reconstruct", - attr: Tuple[ViewEnum, ImageAttribute] = None, + attr: Tuple[ViewEnum, TimeSeriesAttribute] = None, out_shape: Tuple[int, ...] = None, ) -> Tensor: """Performs test-time inference on the input. @@ -280,12 +282,12 @@ def forward( return x def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: # noqa: D102 - attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) + attrs = filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) if self.hparams.normalization == _AttributeNormalization.data: attrs = self._normalize_attrs(attrs) - # Forward on image attributes + # Forward on time-series attributes attrs_x_hat, attrs_z = {}, {} for attr_key, attr_data in attrs.items(): attrs_z[attr_key] = self.encoder(attr_data.unsqueeze(1)) @@ -312,8 +314,8 @@ def _shared_step(self, batch: PatientData, batch_idx: int) -> Dict[str, Tensor]: return metrics def _normalize_attrs( - self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tensor]: + self, attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor] + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]: """Normalizes attributes with different range of values. Args: @@ -328,8 +330,8 @@ def _normalize_attrs( } def _normalize_attrs_reconstruction_metrics( - self, attrs_metrics: Dict[Tuple[ViewEnum, ImageAttribute], Tensor] - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tensor]: + self, attrs_metrics: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor] + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]: """Normalizes reconstruction metrics computed on attributes with different range of values. Args: @@ -353,9 +355,9 @@ def _normalize_attrs_reconstruction_metrics( @torch.inference_mode() def predict_step( # noqa: D102 self, batch: PatientData, batch_idx: int, dataloader_idx: int = 0 - ) -> Dict[Tuple[ViewEnum, ImageAttribute], Tuple[Tensor, Tensor]]: - # Reconstruct the image attributes - attrs = filter_image_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) + ) -> Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tuple[Tensor, Tensor]]: + # Reconstruct the time-series attributes + attrs = filter_time_series_attributes(batch, views=self.hparams.views, attributes=self.hparams.attrs) prediction = {} for attr_key, attr_data in attrs.items(): if not (is_batch := attr_data.ndim == 2): @@ -383,12 +385,12 @@ def __init__(self, cardiac_sequence_attrs_model: str | Path | CardiacSequenceAtt """Initializes class instance. Args: - cardiac_sequence_attrs_model: Pretrained image attributes autoencoder model used to compress the attributes - into tokens. Mutually exclusive parameter with `embed_dim`. + cardiac_sequence_attrs_model: Pretrained time-series attributes autoencoder model used to compress the + attributes into tokens. Mutually exclusive parameter with `embed_dim`. """ super().__init__() - # If the image attributes encoder is a checkpoint rather than an instantiated network, load the model from + # If the time-series attributes encoder is a checkpoint rather than an instantiated network, load the model from # the checkpoint if isinstance(cardiac_sequence_attrs_model, (str, Path)): cardiac_sequence_attrs_model = load_from_checkpoint(cardiac_sequence_attrs_model) @@ -396,11 +398,11 @@ def __init__(self, cardiac_sequence_attrs_model: str | Path | CardiacSequenceAtt # Also, the backend model needs to be saved as a class member even if it's not necessary so that the # tokenizer as a whole can behave as expected of a module # (e.g. moving it across devices is applied recursively to the backend model, etc.) - self.img_attrs_ae = cardiac_sequence_attrs_model.eval().requires_grad_(False) + self.autoencoder = cardiac_sequence_attrs_model.eval().requires_grad_(False) @torch.inference_mode() - def forward(self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor]) -> Tensor: - """Encodes image attributes using the autoencoder. + def forward(self, attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor]) -> Tensor: + """Encodes time-series attributes using the autoencoder. Args: attrs: (K: S, V: (N, ?)): Attributes to tokenize, where the dimensionality of each attribute can vary. @@ -408,4 +410,4 @@ def forward(self, attrs: Dict[Tuple[ViewEnum, ImageAttribute], Tensor]) -> Tenso Returns: (N, S, E), Tokenized version of the attributes. """ - return torch.stack([self.img_attrs_ae(x, task="encode", attr=attr) for attr, x in attrs.items()], dim=1) + return torch.stack([self.autoencoder(x, task="encode", attr=attr) for attr, x in attrs.items()], dim=1) diff --git a/didactic/tasks/cardiac_sequence_attrs_pca.py b/didactic/tasks/cardiac_sequence_attrs_pca.py index 8d0a6159..76c34c58 100644 --- a/didactic/tasks/cardiac_sequence_attrs_pca.py +++ b/didactic/tasks/cardiac_sequence_attrs_pca.py @@ -5,7 +5,7 @@ import numpy as np from sklearn.decomposition import PCA -from vital.data.cardinal.config import ImageAttribute +from vital.data.cardinal.config import TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.transforms import Interp1d from vital.utils.decorators import auto_cast_data @@ -13,7 +13,7 @@ class CardiacSequenceAttributesPCA: - """PCA model specialized for cardiac sequences image attributes, where attributes have different range of values.""" + """PCA model specialized for cardiac sequences time-series attrs, where attrs have different range of values.""" def __init__( self, @@ -26,7 +26,7 @@ def __init__( Args: n_components: Dimensionality of the PCA model's latent space. pca_kwargs: Parameters that will be passed along to the `PCA`'s init. - strategy: Strategies available for handling multi-domain image attributes values. + strategy: Strategies available for handling multi-domain time-series attributes values. 'global_pca': one global PCA model is trained on the normalized attributes. 'attr_pca': one PCA model is fitted to each attribute, side-stepping the need for normalization. """ @@ -117,7 +117,7 @@ def __call__( self, x: np.ndarray, task: Literal["encode", "decode", "reconstruct"] = "reconstruct", - attr: Tuple[ViewEnum, ImageAttribute] = None, + attr: Tuple[ViewEnum, TimeSeriesAttribute] = None, out_shape: Tuple[int, ...] = None, ) -> np.ndarray: """Performs test-time inference on the input. @@ -170,11 +170,12 @@ def __call__( x = x.squeeze(axis=1) return x - def fit(self, samples: Dict[Tuple[ViewEnum, ImageAttribute], np.ndarray]) -> "CardiacSequenceAttributesPCA": + def fit(self, samples: Dict[Tuple[ViewEnum, TimeSeriesAttribute], np.ndarray]) -> "CardiacSequenceAttributesPCA": """Fits one or multiple PCA model(s) to the attributes samples, depending on the chosen strategy. Args: - samples: Mapping between attributes and their samples, of shape + samples: Mapping between attributes and their samples, of shape (N, L), where N is the number of samples and + L is the length of the time-series. The samples are divided by attributes to allow the model to learn models/statistics on each attribute independently. @@ -219,7 +220,7 @@ def main(): from matplotlib import pyplot as plt from tqdm.auto import tqdm from vital.data.cardinal.config import CardinalTag - from vital.data.cardinal.utils.attributes import IMAGE_ATTR_LABELS + from vital.data.cardinal.utils.attributes import TIME_SERIES_ATTR_LABELS from vital.data.cardinal.utils.itertools import Patients from vital.utils.logging import configure_logging from vital.utils.parsing import yaml_flow_collection @@ -239,7 +240,7 @@ def main(): "--mask_tag", type=str, default=CardinalTag.mask, - help="Tag of the segmentation mask for which to extract the image attributes", + help="Tag of the segmentation mask for which to extract the time-series attributes", ) parser.add_argument( "--n_features", @@ -260,7 +261,7 @@ def main(): type=str, choices=["global_pca", "attr_pca"], default="global_pca", - help="Strategies available for handling multi-domain image attributes values. \n" + help="Strategies available for handling multi-domain time-series attributes values. \n" "'global_pca': one global PCA model is trained on the normalized attributes. \n" "'attr_pca': one PCA model is fitted to each attribute, side-stepping the need for normalization.", ) @@ -357,7 +358,7 @@ def _save_cur_fig(title: str, folder: Path) -> None: lambda x: model(x, task="decode", attr=attr_key), attr_samples_embedding, sweep_coeffs, - plots_kwargs={"ylabel": IMAGE_ATTR_LABELS[attr_key[1]]}, + plots_kwargs={"ylabel": TIME_SERIES_ATTR_LABELS[attr_key[1]]}, ): _save_cur_fig(title, output_dir / "sweep" / "_".join(attr_key)) diff --git a/didactic/tasks/utils.py b/didactic/tasks/utils.py index 22321c52..c3d6dd06 100644 --- a/didactic/tasks/utils.py +++ b/didactic/tasks/utils.py @@ -5,7 +5,7 @@ import numpy as np import torch from tqdm.auto import tqdm -from vital.data.cardinal.config import CardinalTag, ClinicalAttribute, ImageAttribute +from vital.data.cardinal.config import CardinalTag, TabularAttribute, TimeSeriesAttribute from vital.data.cardinal.config import View as ViewEnum from vital.data.cardinal.datapipes import process_patient from vital.data.cardinal.utils.data_dis import check_subsets @@ -31,7 +31,7 @@ def encode_patients( Args: model: Transformer encoder model to use for inference. patients: (N) Patients to encode. - mask_tag: Tag of the segmentation mask for which to extract the image attributes. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. progress_bar: If ``True``, enables progress bars detailing the progress of the processing and encoding patients data. **forward_kwargs: Keyword arguments to pass along to the encoder's inference method. @@ -39,16 +39,16 @@ def encode_patients( Returns: (N, E), encodings of the patients. """ - clinical_attrs, img_attrs = model.hparams.clinical_attrs, model.hparams.img_attrs + tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs if progress_bar: patients = tqdm(patients, desc="Processing patients' data to prepare it for inference", unit="patient") patients_attrs = [ - process_patient(patient, clinical_attributes=clinical_attrs, image_attributes=img_attrs, mask_tag=mask_tag) + process_patient(patient, tabular_attrs=tab_attrs, time_series_attrs=time_series_attrs, mask_tag=mask_tag) for patient in patients ] # Run inference on one patient at a time, instead of in batches, to avoid a few possible issues: - # i) having to resample image attributes to be of the same constant shape to be able to stack them + # i) having to resample time-series attributes to be of the same constant shape to be able to stack them # ii) out of memory errors, in case of very large collections of patients if progress_bar: patients_attrs = tqdm( @@ -60,8 +60,8 @@ def encode_patients( [ encode_patients_attrs( model, - {attr: patient_attrs[attr] for attr in clinical_attrs}, - {(view, attr): patient_attrs[view][attr] for view in model.hparams.views for attr in img_attrs}, + {attr: patient_attrs[attr] for attr in tab_attrs}, + {(view, attr): patient_attrs[view][attr] for view in model.hparams.views for attr in time_series_attrs}, **forward_kwargs, ) for patient_attrs in patients_attrs @@ -73,30 +73,30 @@ def encode_patients( def encode_patients_attrs( model: CardiacMultimodalRepresentationTask, - clinical_attrs: Dict[ClinicalAttribute, np.ndarray], - img_attrs: Dict[Tuple[ViewEnum, ImageAttribute], np.ndarray], + tabular_attrs: Dict[TabularAttribute, np.ndarray], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], np.ndarray], **forward_kwargs: Dict[str, Any], ) -> np.ndarray: """Wrapper around encoder inference to handle boilerplate code (e.g. numpy to torch, batching/unbatching, etc.). Args: model: Transformer encoder model to use for inference. - clinical_attrs: (K: S, V: [N]) Sequence of (batch of) clinical attributes. - img_attrs: (K: S, V: ([N,] L)), Sequence of (batch of) image attributes, where L is the dimensionality of each - attribute. + tabular_attrs: (K: S, V: [N]) Sequence of (batch of) tabular attributes. + time_series_attrs: (K: S, V: ([N,] L)), Sequence of (batch of) time-series attributes, where L is the + dimensionality of each attribute. **forward_kwargs: Keyword arguments to pass along to the encoder's inference method. Returns: ([N,], E), encoding(s) of the patient/batch of patients. """ - is_batch = list(clinical_attrs.values())[0].ndim == 1 + is_batch = list(tabular_attrs.values())[0].ndim == 1 # If the input isn't a batch of data, add the batch dimension - clinical_attrs = {k: v if is_batch else v[None, ...] for k, v in clinical_attrs.items()} - img_attrs = {k: v if is_batch else v[None, ...] for k, v in img_attrs.items()} + tabular_attrs = {k: v if is_batch else v[None, ...] for k, v in tabular_attrs.items()} + time_series_attrs = {k: v if is_batch else v[None, ...] for k, v in time_series_attrs.items()} with torch.inference_mode(): - out_features = model(numpy_to_torch(clinical_attrs), numpy_to_torch(img_attrs), **forward_kwargs) + out_features = model(numpy_to_torch(tabular_attrs), numpy_to_torch(time_series_attrs), **forward_kwargs) # Squeeze to remove batch dimension, if it wasn't there in the input if not is_batch: @@ -119,7 +119,7 @@ def summarize_patient_attn( Args: model: Transformer encoder model for which we want to analyze the attention. patient: Patient for which to summarize the model's attention. - mask_tag: Tag of the segmentation mask for which to extract the image attributes. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. use_attention_rollout: Whether to use attention rollout to compute the summary of the attention. attention_rollout_kwargs: When using attention rollout (`use_attention_rollout` is True), parameters to forward to `didactic.models.explain.attention_rollout`.