Skip to content

Commit

Permalink
chore: use built-in types
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaschhh committed Sep 24, 2024
1 parent b79d04d commit e93e412
Show file tree
Hide file tree
Showing 63 changed files with 302 additions and 330 deletions.
10 changes: 5 additions & 5 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shutil
from datetime import datetime
from pathlib import Path
from typing import List, Tuple, Type
from typing import Type

import click
import click_pathlib
Expand Down Expand Up @@ -198,7 +198,7 @@ def entry_point_pack_encoded_data(config_path: FilePath):
@data.command(name="merge_packed_data")
@click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True)
@click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path))
def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path):
def entry_point_merge_packed_data(src_paths: list[Path], target_path: Path):
"""Utility for merging different pbin-files into one.
This is especially useful, if different datasets were at different points in time or if one encoding takes so long,
that the overall process was done in chunks.
Expand All @@ -207,7 +207,7 @@ def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path):
Specify an arbitrary amount of pbin-files and/or directory containing such as input.
Args:
src_paths (List[Path]): List of paths to the pbin-files or directories containing such.
src_paths (list[Path]): List of paths to the pbin-files or directories containing such.
target_path (Path): The path to the merged pbin-file, that will be created.
"""
input_files = []
Expand Down Expand Up @@ -364,7 +364,7 @@ def get_logging_publishers(
results_subscriber: MessageSubscriberIF[EvaluationResultBatch],
global_rank: int,
local_rank: int,
) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]:
"""Returns the logging publishers for the training.
These publishers are used to pass the evaluation results and the progress updates to the message broker.
Expand All @@ -377,7 +377,7 @@ def get_logging_publishers(
local_rank (int): The local rank of the current process on the current node
Returns:
Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation
result publisher and the progress publisher
"""
message_broker = MessageBroker()
Expand Down
18 changes: 9 additions & 9 deletions src/modalities/batch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Dict, Optional
from typing import Optional

import torch

Expand Down Expand Up @@ -32,8 +32,8 @@ class Batch(ABC):
class DatasetBatch(Batch, TorchDeviceMixin):
"""A batch of samples and its targets. Used to batch train a model."""

samples: Dict[str, torch.Tensor]
targets: Dict[str, torch.Tensor]
samples: dict[str, torch.Tensor]
targets: dict[str, torch.Tensor]
batch_dim: int = 0

def to(self, device: torch.device):
Expand All @@ -58,8 +58,8 @@ def __len__(self) -> int:
class InferenceResultBatch(Batch, TorchDeviceMixin):
"""Stores targets and predictions of an entire batch."""

targets: Dict[str, torch.Tensor]
predictions: Dict[str, torch.Tensor]
targets: dict[str, torch.Tensor]
predictions: dict[str, torch.Tensor]
batch_dim: int = 0

def to_cpu(self):
Expand Down Expand Up @@ -106,12 +106,12 @@ class EvaluationResultBatch(Batch):

dataloader_tag: str
num_train_steps_done: int
losses: Dict[str, ResultItem] = field(default_factory=dict)
metrics: Dict[str, ResultItem] = field(default_factory=dict)
throughput_metrics: Dict[str, ResultItem] = field(default_factory=dict)
losses: dict[str, ResultItem] = field(default_factory=dict)
metrics: dict[str, ResultItem] = field(default_factory=dict)
throughput_metrics: dict[str, ResultItem] = field(default_factory=dict)

def __str__(self) -> str:
def _round_result_item_dict(result_item_dict: Dict[str, ResultItem]) -> Dict[str, ResultItem]:
def _round_result_item_dict(result_item_dict: dict[str, ResultItem]) -> dict[str, ResultItem]:
rounded_result_item_dict = {}
for k, item in result_item_dict.items():
if item.decimal_places is not None:
Expand Down
5 changes: 2 additions & 3 deletions src/modalities/checkpointing/checkpoint_saving.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import Dict

import torch.nn as nn
from torch.optim import Optimizer
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(
def save_checkpoint(
self,
training_progress: TrainingProgress,
evaluation_result: Dict[str, EvaluationResultBatch],
evaluation_result: dict[str, EvaluationResultBatch],
model: nn.Module,
optimizer: Optimizer,
early_stoppping_criterion_fulfilled: bool = False,
Expand All @@ -53,7 +52,7 @@ def save_checkpoint(
Args:
training_progress (TrainingProgress): The training progress.
evaluation_result (Dict[str, EvaluationResultBatch]): The evaluation result.
evaluation_result (dict[str, EvaluationResultBatch]): The evaluation result.
model (nn.Module): The model to be saved.
optimizer (Optimizer): The optimizer to be saved.
early_stoppping_criterion_fulfilled (bool, optional):
Expand Down
5 changes: 2 additions & 3 deletions src/modalities/checkpointing/checkpoint_saving_instruction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from typing import List

from modalities.training.training_progress import TrainingProgress

Expand All @@ -11,8 +10,8 @@ class CheckpointingInstruction:
Attributes:
save_current (bool): Indicates whether to save the current checkpoint.
checkpoints_to_delete (List[TrainingProgress]): List of checkpoint IDs to delete.
checkpoints_to_delete (list[TrainingProgress]): List of checkpoint IDs to delete.
"""

save_current: bool = False
checkpoints_to_delete: List[TrainingProgress] = field(default_factory=list)
checkpoints_to_delete: list[TrainingProgress] = field(default_factory=list)
18 changes: 9 additions & 9 deletions src/modalities/checkpointing/checkpoint_saving_strategies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional
from typing import Optional

from modalities.batch import EvaluationResultBatch
from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction
Expand All @@ -13,15 +13,15 @@ class CheckpointSavingStrategyIF(ABC):
def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: Optional[Dict[str, EvaluationResultBatch]] = None,
evaluation_result: Optional[dict[str, EvaluationResultBatch]] = None,
early_stoppping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Returns the checkpointing instruction.
Parameters:
training_progress (TrainingProgress): The training progress.
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.
Expand All @@ -45,29 +45,29 @@ def __init__(self, k: int = -1):
Set to a positive integer to save the specified number of
checkpointsStrategy for saving the k most recent checkpoints only.
"""
self.saved_step_checkpoints: List[TrainingProgress] = []
self.saved_step_checkpoints: list[TrainingProgress] = []
self.k = k

def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: Dict[str, EvaluationResultBatch] | None = None,
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
early_stoppping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Generates a checkpointing instruction based on the given parameters.
Args:
training_progress (TrainingProgress): The training progress.
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.
Returns:
CheckpointingInstruction: The generated checkpointing instruction.
"""
checkpoints_to_delete: List[TrainingProgress] = []
checkpoints_to_delete: list[TrainingProgress] = []
save_current = True

if self.k > 0:
Expand Down Expand Up @@ -100,15 +100,15 @@ def __init__(self, k: int):
def get_checkpoint_instruction(
self,
training_progress: TrainingProgress,
evaluation_result: Dict[str, EvaluationResultBatch] | None = None,
evaluation_result: dict[str, EvaluationResultBatch] | None = None,
early_stoppping_criterion_fulfilled: bool = False,
) -> CheckpointingInstruction:
"""
Returns a CheckpointingInstruction object.
Args:
training_progress (TrainingProgress): The training progress.
evaluation_result (Dict[str, EvaluationResultBatch] | None, optional):
evaluation_result (dict[str, EvaluationResultBatch] | None, optional):
The evaluation result. Defaults to None.
early_stoppping_criterion_fulfilled (bool, optional):
Whether the early stopping criterion is fulfilled. Defaults to False.
Expand Down
5 changes: 2 additions & 3 deletions src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import List

import torch
import torch.nn as nn
Expand All @@ -17,7 +16,7 @@ class FSDPCheckpointLoading(CheckpointLoadingIF):
def __init__(
self,
global_rank: int,
block_names: List[str],
block_names: list[str],
mixed_precision_settings: MixedPrecisionSettings,
sharding_strategy: ShardingStrategy,
):
Expand All @@ -26,7 +25,7 @@ def __init__(
Args:
global_rank (int): The global rank of the process.
block_names (List[str]): The names of the blocks.
block_names (list[str]): The names of the blocks.
mixed_precision_settings (MixedPrecisionSettings): The settings for mixed precision.
sharding_strategy (ShardingStrategy): The sharding strategy.
Expand Down
3 changes: 1 addition & 2 deletions src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum
from pathlib import Path
from typing import List

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -124,7 +123,7 @@ def _save_checkpoint(self, model: FSDP, optimizer: Optimizer, training_progress:
# leading to wrong throughput measurements.
dist.barrier()

def _get_paths_to_delete(self, training_progress: TrainingProgress) -> List[Path]:
def _get_paths_to_delete(self, training_progress: TrainingProgress) -> list[Path]:
return [
self._get_checkpointing_path(
experiment_id=self.experiment_id,
Expand Down
26 changes: 13 additions & 13 deletions src/modalities/config/component_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Any, Type, TypeVar

from pydantic import BaseModel

Expand All @@ -19,12 +19,12 @@ def __init__(self, registry: Registry) -> None:
"""
self.registry = registry

def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild:
def build_components(self, config_dict: dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild:
"""Builds the components from a config dictionary. All components specified in `components_model_type`
are built from the config dictionary in a recursive manner.
Args:
config_dict (Dict): Dictionary with the configuration of the components.
config_dict (dict[): Dictionary with the configuration of the components.
components_model_type (Type[BaseModelChild]): Base model type defining the components to be build.
Returns:
Expand All @@ -35,7 +35,7 @@ def build_components(self, config_dict: Dict, components_model_type: Type[BaseMo
components = components_model_type(**component_dict)
return components

def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]:
def _build_config(self, config_dict: dict, component_names: list[str]) -> dict[str, Any]:
component_dict_filtered = {name: config_dict[name] for name in component_names}
components, _ = self._build_component(
current_component_config=component_dict_filtered,
Expand All @@ -47,10 +47,10 @@ def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[s

def _build_component(
self,
current_component_config: Union[Dict, List, Any],
component_config: Union[Dict, List, Any],
top_level_components: Dict[str, Any],
traversal_path: List,
current_component_config: dict | list | Any,
component_config: dict | list | Any,
top_level_components: dict[str, Any],
traversal_path: list,
) -> Any:
# build sub components first via recursion
if isinstance(current_component_config, dict):
Expand Down Expand Up @@ -130,16 +130,16 @@ def _build_component(
return current_component_config, top_level_components

@staticmethod
def _is_component_config(config_dict: Dict) -> bool:
def _is_component_config(config_dict: dict) -> bool:
# TODO instead of field checks, we should introduce an enum for the config type.
return "component_key" in config_dict.keys()

@staticmethod
def _is_reference_config(config_dict: Dict) -> bool:
def _is_reference_config(config_dict: dict) -> bool:
# TODO instead of field checks, we should introduce an enum for the config type.
return {"instance_key", "pass_type"} == config_dict.keys()

def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel:
def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel:
component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key)
self._assert_valid_config_keys(
component_key=component_key,
Expand All @@ -151,7 +151,7 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co
return comp_config

def _assert_valid_config_keys(
self, component_key: str, variant_key: str, config_dict: Dict, component_config_type: Type[BaseModelChild]
self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild]
) -> None:
required_keys = []
optional_keys = []
Expand All @@ -178,7 +178,7 @@ def _instantiate_component(self, component_key: str, variant_key: str, component
return component

@staticmethod
def _base_model_to_dict(base_model: BaseModel) -> Dict:
def _base_model_to_dict(base_model: BaseModel) -> dict:
# converts top level structure of base_model into dictionary while maintaining substructure
output = {}
for name, _ in base_model.model_fields.items():
Expand Down
Loading

0 comments on commit e93e412

Please sign in to comment.