diff --git a/src/eva/core/models/modules/head.py b/src/eva/core/models/modules/head.py index 903a8100..8e09bc55 100644 --- a/src/eva/core/models/modules/head.py +++ b/src/eva/core/models/modules/head.py @@ -11,7 +11,7 @@ from eva.core.metrics import structs as metrics_lib from eva.core.models.modules import module -from eva.core.models.modules.typings import INPUT_BATCH, MODEL_TYPE +from eva.core.models.modules.typings import DATA_SAMPLE, MODEL_TYPE from eva.core.models.modules.utils import batch_postprocess, grad @@ -72,23 +72,23 @@ def forward(self, tensor: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tens return self.head(features).squeeze(-1) @override - def training_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + def training_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT: return self._batch_step(batch) @override - def validation_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + def validation_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT: return self._batch_step(batch) @override - def test_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + def test_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> STEP_OUTPUT: return self._batch_step(batch) @override - def predict_step(self, batch: INPUT_BATCH, *args: Any, **kwargs: Any) -> torch.Tensor: - tensor = INPUT_BATCH(*batch).data + def predict_step(self, batch: DATA_SAMPLE, *args: Any, **kwargs: Any) -> torch.Tensor: + tensor = DATA_SAMPLE(*batch).data return tensor if self.backbone is None else self.backbone(tensor) - def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT: + def _batch_step(self, batch: DATA_SAMPLE) -> STEP_OUTPUT: """Performs a model forward step and calculates the loss. Args: @@ -97,7 +97,7 @@ def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT: Returns: The batch step output. """ - data, targets, metadata = INPUT_BATCH(*batch) + data, targets, metadata = DATA_SAMPLE(*batch) predictions = self(data) loss = self.criterion(predictions, targets) return { @@ -105,4 +105,4 @@ def _batch_step(self, batch: INPUT_BATCH) -> STEP_OUTPUT: "targets": targets, "predictions": predictions, "metadata": metadata, - } \ No newline at end of file + }