diff --git a/didactic/tasks/utils.py b/didactic/tasks/utils.py index c3d6dd0..ea8e1c9 100644 --- a/didactic/tasks/utils.py +++ b/didactic/tasks/utils.py @@ -24,7 +24,7 @@ def encode_patients( patients: Iterable[Patient], mask_tag: str = CardinalTag.mask, progress_bar: bool = False, - **forward_kwargs: Dict[str, Any], + **forward_kwargs, ) -> np.ndarray: """Wrapper around encoder inference to handle boilerplate code (e.g. extracting attributes from patients, etc.). @@ -75,7 +75,7 @@ def encode_patients_attrs( model: CardiacMultimodalRepresentationTask, tabular_attrs: Dict[TabularAttribute, np.ndarray], time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], np.ndarray], - **forward_kwargs: Dict[str, Any], + **forward_kwargs, ) -> np.ndarray: """Wrapper around encoder inference to handle boilerplate code (e.g. numpy to torch, batching/unbatching, etc.).