diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index fd139fef..3ad8469b 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -5,7 +5,7 @@ import shutil from datetime import datetime from pathlib import Path -from typing import List, Tuple, Type +from typing import Type import click import click_pathlib @@ -198,7 +198,7 @@ def entry_point_pack_encoded_data(config_path: FilePath): @data.command(name="merge_packed_data") @click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True) @click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path)) -def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path): +def entry_point_merge_packed_data(src_paths: list[Path], target_path: Path): """Utility for merging different pbin-files into one. This is especially useful, if different datasets were at different points in time or if one encoding takes so long, that the overall process was done in chunks. @@ -207,7 +207,7 @@ def entry_point_merge_packed_data(src_paths: List[Path], target_path: Path): Specify an arbitrary amount of pbin-files and/or directory containing such as input. Args: - src_paths (List[Path]): List of paths to the pbin-files or directories containing such. + src_paths (list[Path]): List of paths to the pbin-files or directories containing such. target_path (Path): The path to the merged pbin-file, that will be created. """ input_files = [] @@ -364,7 +364,7 @@ def get_logging_publishers( results_subscriber: MessageSubscriberIF[EvaluationResultBatch], global_rank: int, local_rank: int, - ) -> Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: + ) -> tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: """Returns the logging publishers for the training. These publishers are used to pass the evaluation results and the progress updates to the message broker. @@ -377,7 +377,7 @@ def get_logging_publishers( local_rank (int): The local rank of the current process on the current node Returns: - Tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation + tuple[MessagePublisher[EvaluationResultBatch], MessagePublisher[ProgressUpdate]]: The evaluation result publisher and the progress publisher """ message_broker = MessageBroker() diff --git a/src/modalities/batch.py b/src/modalities/batch.py index 6a8b12e6..19aa673e 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Optional import torch @@ -32,8 +32,8 @@ class Batch(ABC): class DatasetBatch(Batch, TorchDeviceMixin): """A batch of samples and its targets. Used to batch train a model.""" - samples: Dict[str, torch.Tensor] - targets: Dict[str, torch.Tensor] + samples: dict[str, torch.Tensor] + targets: dict[str, torch.Tensor] batch_dim: int = 0 def to(self, device: torch.device): @@ -58,8 +58,8 @@ def __len__(self) -> int: class InferenceResultBatch(Batch, TorchDeviceMixin): """Stores targets and predictions of an entire batch.""" - targets: Dict[str, torch.Tensor] - predictions: Dict[str, torch.Tensor] + targets: dict[str, torch.Tensor] + predictions: dict[str, torch.Tensor] batch_dim: int = 0 def to_cpu(self): @@ -106,12 +106,12 @@ class EvaluationResultBatch(Batch): dataloader_tag: str num_train_steps_done: int - losses: Dict[str, ResultItem] = field(default_factory=dict) - metrics: Dict[str, ResultItem] = field(default_factory=dict) - throughput_metrics: Dict[str, ResultItem] = field(default_factory=dict) + losses: dict[str, ResultItem] = field(default_factory=dict) + metrics: dict[str, ResultItem] = field(default_factory=dict) + throughput_metrics: dict[str, ResultItem] = field(default_factory=dict) def __str__(self) -> str: - def _round_result_item_dict(result_item_dict: Dict[str, ResultItem]) -> Dict[str, ResultItem]: + def _round_result_item_dict(result_item_dict: dict[str, ResultItem]) -> dict[str, ResultItem]: rounded_result_item_dict = {} for k, item in result_item_dict.items(): if item.decimal_places is not None: diff --git a/src/modalities/checkpointing/checkpoint_saving.py b/src/modalities/checkpointing/checkpoint_saving.py index 986e71f7..2a47648e 100644 --- a/src/modalities/checkpointing/checkpoint_saving.py +++ b/src/modalities/checkpointing/checkpoint_saving.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict import torch.nn as nn from torch.optim import Optimizer @@ -43,7 +42,7 @@ def __init__( def save_checkpoint( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch], + evaluation_result: dict[str, EvaluationResultBatch], model: nn.Module, optimizer: Optimizer, early_stoppping_criterion_fulfilled: bool = False, @@ -53,7 +52,7 @@ def save_checkpoint( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch]): The evaluation result. + evaluation_result (dict[str, EvaluationResultBatch]): The evaluation result. model (nn.Module): The model to be saved. optimizer (Optimizer): The optimizer to be saved. early_stoppping_criterion_fulfilled (bool, optional): diff --git a/src/modalities/checkpointing/checkpoint_saving_instruction.py b/src/modalities/checkpointing/checkpoint_saving_instruction.py index 1bd42470..c9a80b70 100644 --- a/src/modalities/checkpointing/checkpoint_saving_instruction.py +++ b/src/modalities/checkpointing/checkpoint_saving_instruction.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from typing import List from modalities.training.training_progress import TrainingProgress @@ -11,8 +10,8 @@ class CheckpointingInstruction: Attributes: save_current (bool): Indicates whether to save the current checkpoint. - checkpoints_to_delete (List[TrainingProgress]): List of checkpoint IDs to delete. + checkpoints_to_delete (list[TrainingProgress]): List of checkpoint IDs to delete. """ save_current: bool = False - checkpoints_to_delete: List[TrainingProgress] = field(default_factory=list) + checkpoints_to_delete: list[TrainingProgress] = field(default_factory=list) diff --git a/src/modalities/checkpointing/checkpoint_saving_strategies.py b/src/modalities/checkpointing/checkpoint_saving_strategies.py index e4902600..50c72bee 100644 --- a/src/modalities/checkpointing/checkpoint_saving_strategies.py +++ b/src/modalities/checkpointing/checkpoint_saving_strategies.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Optional from modalities.batch import EvaluationResultBatch from modalities.checkpointing.checkpoint_saving_instruction import CheckpointingInstruction @@ -13,7 +13,7 @@ class CheckpointSavingStrategyIF(ABC): def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Optional[Dict[str, EvaluationResultBatch]] = None, + evaluation_result: Optional[dict[str, EvaluationResultBatch]] = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -21,7 +21,7 @@ def get_checkpoint_instruction( Parameters: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. @@ -45,13 +45,13 @@ def __init__(self, k: int = -1): Set to a positive integer to save the specified number of checkpointsStrategy for saving the k most recent checkpoints only. """ - self.saved_step_checkpoints: List[TrainingProgress] = [] + self.saved_step_checkpoints: list[TrainingProgress] = [] self.k = k def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch] | None = None, + evaluation_result: dict[str, EvaluationResultBatch] | None = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -59,7 +59,7 @@ def get_checkpoint_instruction( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. @@ -67,7 +67,7 @@ def get_checkpoint_instruction( Returns: CheckpointingInstruction: The generated checkpointing instruction. """ - checkpoints_to_delete: List[TrainingProgress] = [] + checkpoints_to_delete: list[TrainingProgress] = [] save_current = True if self.k > 0: @@ -100,7 +100,7 @@ def __init__(self, k: int): def get_checkpoint_instruction( self, training_progress: TrainingProgress, - evaluation_result: Dict[str, EvaluationResultBatch] | None = None, + evaluation_result: dict[str, EvaluationResultBatch] | None = None, early_stoppping_criterion_fulfilled: bool = False, ) -> CheckpointingInstruction: """ @@ -108,7 +108,7 @@ def get_checkpoint_instruction( Args: training_progress (TrainingProgress): The training progress. - evaluation_result (Dict[str, EvaluationResultBatch] | None, optional): + evaluation_result (dict[str, EvaluationResultBatch] | None, optional): The evaluation result. Defaults to None. early_stoppping_criterion_fulfilled (bool, optional): Whether the early stopping criterion is fulfilled. Defaults to False. diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py index dc3b9de0..556ba5b8 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_loading.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch import torch.nn as nn @@ -17,7 +16,7 @@ class FSDPCheckpointLoading(CheckpointLoadingIF): def __init__( self, global_rank: int, - block_names: List[str], + block_names: list[str], mixed_precision_settings: MixedPrecisionSettings, sharding_strategy: ShardingStrategy, ): @@ -26,7 +25,7 @@ def __init__( Args: global_rank (int): The global rank of the process. - block_names (List[str]): The names of the blocks. + block_names (list[str]): The names of the blocks. mixed_precision_settings (MixedPrecisionSettings): The settings for mixed precision. sharding_strategy (ShardingStrategy): The sharding strategy. diff --git a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py index 684847f0..0c5757a9 100644 --- a/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py +++ b/src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py @@ -1,6 +1,5 @@ from enum import Enum from pathlib import Path -from typing import List import torch import torch.distributed as dist @@ -124,7 +123,7 @@ def _save_checkpoint(self, model: FSDP, optimizer: Optimizer, training_progress: # leading to wrong throughput measurements. dist.barrier() - def _get_paths_to_delete(self, training_progress: TrainingProgress) -> List[Path]: + def _get_paths_to_delete(self, training_progress: TrainingProgress) -> list[Path]: return [ self._get_checkpointing_path( experiment_id=self.experiment_id, diff --git a/src/modalities/config/component_factory.py b/src/modalities/config/component_factory.py index 94ed71e6..17a64dcc 100644 --- a/src/modalities/config/component_factory.py +++ b/src/modalities/config/component_factory.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Any, Type, TypeVar from pydantic import BaseModel @@ -19,12 +19,12 @@ def __init__(self, registry: Registry) -> None: """ self.registry = registry - def build_components(self, config_dict: Dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: + def build_components(self, config_dict: dict, components_model_type: Type[BaseModelChild]) -> BaseModelChild: """Builds the components from a config dictionary. All components specified in `components_model_type` are built from the config dictionary in a recursive manner. Args: - config_dict (Dict): Dictionary with the configuration of the components. + config_dict (dict[): Dictionary with the configuration of the components. components_model_type (Type[BaseModelChild]): Base model type defining the components to be build. Returns: @@ -35,7 +35,7 @@ def build_components(self, config_dict: Dict, components_model_type: Type[BaseMo components = components_model_type(**component_dict) return components - def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[str, Any]: + def _build_config(self, config_dict: dict, component_names: list[str]) -> dict[str, Any]: component_dict_filtered = {name: config_dict[name] for name in component_names} components, _ = self._build_component( current_component_config=component_dict_filtered, @@ -47,10 +47,10 @@ def _build_config(self, config_dict: Dict, component_names: List[str]) -> Dict[s def _build_component( self, - current_component_config: Union[Dict, List, Any], - component_config: Union[Dict, List, Any], - top_level_components: Dict[str, Any], - traversal_path: List, + current_component_config: dict | list | Any, + component_config: dict | list | Any, + top_level_components: dict[str, Any], + traversal_path: list, ) -> Any: # build sub components first via recursion if isinstance(current_component_config, dict): @@ -130,16 +130,16 @@ def _build_component( return current_component_config, top_level_components @staticmethod - def _is_component_config(config_dict: Dict) -> bool: + def _is_component_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return "component_key" in config_dict.keys() @staticmethod - def _is_reference_config(config_dict: Dict) -> bool: + def _is_reference_config(config_dict: dict) -> bool: # TODO instead of field checks, we should introduce an enum for the config type. return {"instance_key", "pass_type"} == config_dict.keys() - def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: Dict) -> BaseModel: + def _instantiate_component_config(self, component_key: str, variant_key: str, config_dict: dict) -> BaseModel: component_config_type: Type[BaseModel] = self.registry.get_config(component_key, variant_key) self._assert_valid_config_keys( component_key=component_key, @@ -151,7 +151,7 @@ def _instantiate_component_config(self, component_key: str, variant_key: str, co return comp_config def _assert_valid_config_keys( - self, component_key: str, variant_key: str, config_dict: Dict, component_config_type: Type[BaseModelChild] + self, component_key: str, variant_key: str, config_dict: dict, component_config_type: Type[BaseModelChild] ) -> None: required_keys = [] optional_keys = [] @@ -178,7 +178,7 @@ def _instantiate_component(self, component_key: str, variant_key: str, component return component @staticmethod - def _base_model_to_dict(base_model: BaseModel) -> Dict: + def _base_model_to_dict(base_model: BaseModel) -> dict: # converts top level structure of base_model into dictionary while maintaining substructure output = {} for name, _ in base_model.model_fields.items(): diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 96da831a..80cce3b9 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -1,7 +1,7 @@ import os from functools import partial from pathlib import Path -from typing import Annotated, Dict, List, Literal, Optional, Tuple +from typing import Annotated, Literal, Optional import torch from omegaconf import OmegaConf @@ -87,7 +87,7 @@ def parse_device(cls, device) -> PydanticPytorchDeviceType: class FSDPCheckpointLoadingConfig(BaseModel): global_rank: Annotated[int, Field(strict=True, ge=0)] - block_names: List[str] + block_names: list[str] mixed_precision_settings: MixedPrecisionSettings sharding_strategy: ShardingStrategy @@ -122,19 +122,19 @@ class CheckpointSavingConfig(BaseModel): class AdamOptimizerConfig(BaseModel): lr: float wrapped_model: PydanticPytorchModuleType - betas: Tuple[float, float] + betas: tuple[float, float] eps: float weight_decay: float - weight_decay_groups_excluded: List[str] + weight_decay_groups_excluded: list[str] class AdamWOptimizerConfig(BaseModel): lr: float wrapped_model: PydanticPytorchModuleType - betas: Tuple[float, float] + betas: tuple[float, float] eps: float weight_decay: float - weight_decay_groups_excluded: List[str] + weight_decay_groups_excluded: list[str] class DummyLRSchedulerConfig(BaseModel): @@ -151,17 +151,17 @@ class StepLRSchedulerConfig(BaseModel): class OneCycleLRSchedulerConfig(BaseModel): optimizer: PydanticOptimizerIFType - max_lr: Annotated[float, Field(strict=True, gt=0.0)] | List[Annotated[float, Field(strict=True, gt=0.0)]] + max_lr: Annotated[float, Field(strict=True, gt=0.0)] | list[Annotated[float, Field(strict=True, gt=0.0)]] total_steps: Optional[Annotated[int, Field(strict=True, gt=0)]] = None epochs: Optional[Annotated[int, Field(strict=True, gt=0)]] = None steps_per_epoch: Optional[Annotated[int, Field(strict=True, gt=0)]] = None pct_start: Annotated[float, Field(strict=True, gt=0.0, le=1.0)] anneal_strategy: str cycle_momentum: bool = True - base_momentum: Annotated[float, Field(strict=True, gt=0)] | List[ + base_momentum: Annotated[float, Field(strict=True, gt=0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.85 - max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | List[ + max_momentum: Annotated[float, Field(strict=True, gt=0.0)] | list[ Annotated[float, Field(strict=True, gt=0.0)] ] = 0.95 div_factor: Annotated[float, Field(strict=True, gt=0.0)] @@ -211,7 +211,7 @@ class FSDPWrappedModelConfig(BaseModel): sync_module_states: bool mixed_precision_settings: MixedPrecisionSettings sharding_strategy: ShardingStrategy - block_names: List[str] + block_names: list[str] @field_validator("mixed_precision_settings", mode="before") def parse_mixed_precision_setting_by_name(cls, name): @@ -241,7 +241,7 @@ class WeightInitializedModelConfig(BaseModel): class ActivationCheckpointedModelConfig(BaseModel): model: PydanticFSDPModuleType - activation_checkpointing_modules: Optional[List[str]] = Field(default_factory=list) + activation_checkpointing_modules: Optional[list[str]] = Field(default_factory=list) class PreTrainedHFTokenizerConfig(BaseModel): @@ -249,7 +249,7 @@ class PreTrainedHFTokenizerConfig(BaseModel): max_length: Optional[Annotated[int, Field(strict=True, ge=0)]] = None truncation: bool = False padding: bool | str = False - special_tokens: Optional[Dict[str, str]] = None + special_tokens: Optional[dict[str, str]] = None class PreTrainedSPTokenizerConfig(BaseModel): @@ -323,7 +323,7 @@ class DummyProgressSubscriberConfig(BaseModel): class RichProgressSubscriberConfig(BaseModel): - eval_dataloaders: Optional[List[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) train_dataloader_tag: str num_seen_steps: Annotated[int, Field(strict=True, ge=0)] num_target_steps: Annotated[int, Field(strict=True, gt=0)] @@ -348,7 +348,7 @@ class RichResultSubscriberConfig(BaseModel): global_rank: int -def load_app_config_dict(config_file_path: Path) -> Dict: +def load_app_config_dict(config_file_path: Path) -> dict: """Load the application configuration from the given YAML file. The function defines custom resolvers for the OmegaConf library to resolve environment variables and Modalities-specific variables. @@ -357,7 +357,7 @@ def load_app_config_dict(config_file_path: Path) -> Dict: config_file_path (Path): YAML config file. Returns: - Dict: Dictionary representation of the config file. + dict: Dictionary representation of the config file. """ def cuda_env_resolver_fun(var_name: str) -> int: diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 690559b3..bd203c9c 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Annotated, Any, Dict, List, Optional +from typing import Annotated, Any, Optional from pydantic import BaseModel, ConfigDict, Field, FilePath, field_validator, model_validator, root_validator @@ -69,7 +69,7 @@ class Config: extra = "allow" @root_validator(pre=True) - def _validate_all_paths(cls, values: Dict[str, Any]) -> Dict[str, Any]: + def _validate_all_paths(cls, values: dict[str, Any]) -> dict[str, Any]: for field_name, value in values.items(): if isinstance(value, str): # If a value is a string, convert it to Path values[field_name] = Path(value) @@ -83,7 +83,7 @@ class WarmstartCheckpointPaths(BaseModel): experiment_id: str config_file_path: FilePath - referencing_keys: Dict[str, str] + referencing_keys: dict[str, str] cuda_env: CudaEnvSettings paths: Paths intervals: Intervals @@ -171,7 +171,7 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel loss_fn: PydanticLossIFType train_dataset: PydanticDatasetIFType train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: List[PydanticLLMDataLoaderIFType] + eval_dataloaders: list[PydanticLLMDataLoaderIFType] progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType @@ -212,7 +212,7 @@ class TextGenerationSettings(BaseModel): model_path: FilePath sequence_length: int device: PydanticPytorchDeviceType - referencing_keys: Dict[str, str] + referencing_keys: dict[str, str] # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -246,10 +246,10 @@ def __init__( self.training_progress = training_progress def get_report(self) -> str: - def _get_formatted_dict_str(d: Dict[str, Any]) -> str: + def _get_formatted_dict_str(d: dict[str, Any]) -> str: return "\n\t".join([f"{k}: {v}" for k, v in d.items()]) - def _get_formatted_list_str(lst: List[str]) -> str: + def _get_formatted_list_str(lst: list[str]) -> str: return "\n\t".join(lst) training_target_str = _get_formatted_dict_str(dict(self.training_target)) @@ -273,7 +273,7 @@ def _get_formatted_list_str(lst: List[str]) -> str: ) return report - def _get_issue_warnings(self) -> List[str]: + def _get_issue_warnings(self) -> list[str]: issue_warnings = [] num_tokens = ( self.step_profile.local_train_micro_batch_size diff --git a/src/modalities/config/utils.py b/src/modalities/config/utils.py index fe47cfaf..a1d414fe 100644 --- a/src/modalities/config/utils.py +++ b/src/modalities/config/utils.py @@ -1,10 +1,10 @@ -from typing import Any, Dict +from typing import Any import torch from pydantic import BaseModel -def convert_base_model_config_to_dict(config: BaseModel) -> Dict[Any, Any]: +def convert_base_model_config_to_dict(config: BaseModel) -> dict[Any, Any]: """ "Converts non-recursively a Pydantic BaseModel to a dictionary.""" return {key: getattr(config, key) for key in config.model_dump().keys()} diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 33775fcd..c695cc16 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -6,7 +6,7 @@ import warnings from io import BufferedWriter from pathlib import Path -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Callable, Iterator, Optional import jq import numpy as np @@ -197,8 +197,8 @@ def _writer_thread(self, dst_path: Path) -> Callable: def writer(): # writes a batch received from the processed_samples_queue to the destination file def _write_batch( - batch: List[Tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: List, f: BufferedWriter - ) -> Tuple[int, int]: + batch: list[tuple[int, bytes]], prev_line_id: int, curr_offset: int, index_list: list, f: BufferedWriter + ) -> tuple[int, int]: # write the tokens for each document for line_id, tokens_as_bytes in batch: if prev_line_id + 1 != line_id: @@ -293,7 +293,7 @@ def _process_thread(self, process_id: int): f"Raised the following error: {exception=}" ) - def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]): + def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: list[tuple[int, int]]): # Update the length of the data section in the pre-allocated header of the destination file. # The data segment length is sum of the starting position and the length of the last document. length_of_byte_encoded_data_section = index_list[-1][0] + index_list[-1][1] @@ -356,18 +356,18 @@ def __init__(self, data_path: Path): pkl_encoded_index = f.read() # contains the start offset and length of each segment # as byte positions in the data section - self.index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index) + self.index_base: list[tuple[int, int]] = pickle.loads(pkl_encoded_index) # initialize memmapped data section self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,)) -def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): +def join_embedded_stream_data(stream_data: list[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048): """ Joins the embedded stream data into a single file. Args: - stream_data (List[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. + stream_data (list[EmbeddedStreamData]): A list of EmbeddedStreamData objects representing the stream data. target_file (Path): The target file to write the joined data to. chunk_size (int, optional): The size of each data chunk. Defaults to 2048. @@ -391,7 +391,7 @@ def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file num_entries = sum(len(d.index_base) for d in stream_data) - def index_stream_generator() -> Iterator[Tuple[int, int]]: + def index_stream_generator() -> Iterator[tuple[int, int]]: # generates a stream of index offsets and segment lengths. curr_offset = 0 for embedded_stream_data in stream_data: diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index fbf5cc36..808e9b58 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Union +from typing import Iterable, Optional from torch.utils.data import Dataset, DistributedSampler, Sampler from torch.utils.data.dataloader import DataLoader, T_co, _collate_fn_t, _worker_init_fn_t @@ -15,7 +15,7 @@ def __init__( batch_sampler: ResumableBatchSampler, dataset: Dataset[T_co], batch_size: Optional[int] = 1, - sampler: Union[Sampler, Iterable, None] = None, + sampler: Sampler | Iterable | None = None, num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None, pin_memory: bool = False, @@ -37,7 +37,7 @@ def __init__( batch_sampler (ResumableBatchSampler): The batch sampler used for sampling batches. dataset (Dataset[T_co]): The dataset to load the data from. batch_size (Optional[int], optional): The number of samples per batch. Defaults to 1. - sampler (Union[Sampler, Iterable, None], optional): The sampler used for sampling data. Defaults to None. + sampler (Sampler | Iterable | None, optional): The sampler used for sampling data. Defaults to None. num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0. collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples. Defaults to None. diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index fa855e5e..67c3585a 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -2,7 +2,7 @@ from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Optional import jq import numpy as np @@ -56,13 +56,13 @@ class DummySampleConfig(BaseModel): Attributes: sample_key (str): The key of the sample. - sample_shape (Tuple[int, ...]): The shape of the sample. + sample_shape (tuple[int, ...]): The shape of the sample. sample_type (DummySampleDataType): The type of the sample. """ sample_key: str - sample_shape: Tuple[int, ...] + sample_shape: tuple[int, ...] sample_type: DummySampleDataType @@ -72,24 +72,24 @@ class DummyDatasetConfig(BaseModel): Attributes: num_samples (int): The number of samples in the dataset. - sample_definition (List[DummySampleConfig]): The list of sample definitions in the dataset. + sample_definition (list[DummySampleConfig]): The list of sample definitions in the dataset. """ num_samples: int - sample_definition: List[DummySampleConfig] + sample_definition: list[DummySampleConfig] class DummyDataset(Dataset): """DummyDataset class.""" - def __init__(self, num_samples: int, sample_definition: Tuple[DummySampleConfig]): + def __init__(self, num_samples: int, sample_definition: tuple[DummySampleConfig]): """ Initializes a DummyDataset object with the given number of samples and sample definition. When calling the __getitem__ method, the dataset will return a random sample based on the sample definition. Args: num_samples (int): The number of samples in the dataset. - sample_definition (Tuple[DummySampleConfig]): A list of tuples defining the dataset output. + sample_definition (tuple[DummySampleConfig]): A list of tuples defining the dataset output. Each touple contains the sample key, shape and data type. Returns: @@ -108,7 +108,7 @@ def __len__(self) -> int: """ return self.num_samples - def __getitem__(self, idx: int) -> Dict: + def __getitem__(self, idx: int) -> dict: """ Retrieves an item from the dataset at the specified index. @@ -123,7 +123,7 @@ def __getitem__(self, idx: int) -> Dict: """ return self._create_random_sample() - def _create_random_sample(self) -> Dict: + def _create_random_sample(self) -> dict: # creates a random sample based on the sample definition sample = dict() for s in self.sample_definition: @@ -238,7 +238,7 @@ def __init__(self, raw_data_path: Path, sample_key: str): ) self._index = self._generate_packing_index() - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # The index is list of tuples, where each tuple contains the offset and length in bytes. @@ -308,7 +308,7 @@ def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): self.block_size = block_size super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: # Generates the packing index for the dataset. # A list of tuples representing the index, where each tuple contains the offset and length in bytes. @@ -339,7 +339,7 @@ def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): self.block_size = block_size super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) - def _generate_packing_index(self) -> List[Tuple[int, int]]: + def _generate_packing_index(self) -> list[tuple[int, int]]: index = [] curr_offset = self.HEADER_SIZE_IN_BYTES curr_len = 0 diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 7ce204c2..d5580b8c 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -1,6 +1,6 @@ import pickle from pathlib import Path -from typing import List, Optional, Tuple +from typing import Optional from transformers import PreTrainedTokenizer @@ -17,13 +17,13 @@ class DatasetFactory: """DatasetFactory for building the different dataset types.""" @staticmethod - def get_dummy_dataset(num_samples: int, sample_definition: Tuple[DummySampleConfig]) -> DummyDataset: + def get_dummy_dataset(num_samples: int, sample_definition: tuple[DummySampleConfig]) -> DummyDataset: """ Returns a DummyDataset object. Args: num_samples (int): The number of samples the dataset should generate. - sample_definition (Tuple[DummySampleConfig]): A list of tuples defining the dataset output. + sample_definition (tuple[DummySampleConfig]): A list of tuples defining the dataset output. Each tuple contains the sample key, shape and data type. Returns: @@ -64,7 +64,7 @@ def get_mem_map_dataset( return dataset @staticmethod - def get_raw_index(raw_index_path: Path) -> List[Tuple[int, int]]: + def get_raw_index(raw_index_path: Path) -> list[tuple[int, int]]: with raw_index_path.open("rb") as f: index = pickle.load(f) return index diff --git a/src/modalities/dataloader/large_file_lines_reader.py b/src/modalities/dataloader/large_file_lines_reader.py index 220f95bb..3d896dcd 100644 --- a/src/modalities/dataloader/large_file_lines_reader.py +++ b/src/modalities/dataloader/large_file_lines_reader.py @@ -1,7 +1,7 @@ import pickle from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional +from typing import Optional class BaseReader(ABC): @@ -10,7 +10,7 @@ def __len__(self) -> int: raise NotImplementedError @abstractmethod - def __getitem__(self, key: int | slice) -> str | List[str]: + def __getitem__(self, key: int | slice) -> str | list[str]: raise NotImplementedError @@ -72,7 +72,7 @@ def __len__(self) -> int: """ return len(self.index) - def __getitem__(self, key: int | slice) -> str | List[str]: + def __getitem__(self, key: int | slice) -> str | list[str]: """ Retrieves an item from the LargeFileLinesReader. @@ -80,7 +80,7 @@ def __getitem__(self, key: int | slice) -> str | List[str]: key (int | slice): The index or slice used to retrieve the item(s). Returns: - str | List[str]: The item(s) retrieved from the LargeFileLinesReader. + str | list[str]: The item(s) retrieved from the LargeFileLinesReader. Raises: IndexError: If the key is out of range. diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 0db60af7..456fcb47 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Callable import torch import torch.distributed as dist @@ -55,22 +55,22 @@ def evaluate_batch( def evaluate( self, model: nn.Module, - data_loaders: List[LLMDataLoader], + data_loaders: list[LLMDataLoader], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, - ) -> Dict[str, EvaluationResultBatch]: + ) -> dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. Args: model (nn.Module): The model to evaluate - data_loaders (List[LLMDataLoader]): List of dataloaders to evaluate the model on + data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss num_train_steps_done (int): The number of training steps done so far for logging purposes Returns: - Dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader + dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader """ - result_dict: Dict[str, EvaluationResultBatch] = {} + result_dict: dict[str, EvaluationResultBatch] = {} model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 8f69082d..270e1dae 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -1,6 +1,6 @@ from datetime import datetime from functools import partial -from typing import Callable, List +from typing import Callable import torch.nn as nn from torch.optim import Optimizer @@ -41,7 +41,7 @@ def run( checkpointing_interval_in_steps: int, evaluation_interval_in_steps: int, train_data_loader: LLMDataLoader, - evaluation_data_loaders: List[LLMDataLoader], + evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, ): """Runs the model training, including evaluation and checkpointing. @@ -54,7 +54,7 @@ def run( checkpointing_interval_in_steps (int): Interval in steps to save checkpoints. evaluation_interval_in_steps (int): Interval in steps to perform evaluation. train_data_loader (LLMDataLoader): Data loader with the training data. - evaluation_data_loaders (List[LLMDataLoader]): List of data loaders with the evaluation data. + evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data. checkpoint_saving (CheckpointSaving): Routine for saving checkpoints. """ evaluation_callback: Callable[[int], None] = partial( @@ -109,7 +109,7 @@ def _run_evaluation( self, model: nn.Module, num_train_steps_done: int, - evaluation_data_loaders: List[LLMDataLoader], + evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, ): if num_train_steps_done % evaluation_interval_in_steps == 0: diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index 7b38e58f..d81f86b7 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import Dict, List from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF @@ -22,7 +21,7 @@ class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" def __init__(self) -> None: - self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) + self.subscriptions: dict[MessageTypes, list[MessageSubscriberIF]] = defaultdict(list) def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): """Adds a single subscriber.""" diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 9d62c17a..5bdc885a 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Generic, TypeVar +from typing import Any, Generic, TypeVar from modalities.logging_broker.messages import Message @@ -14,5 +14,5 @@ def consume_message(self, message: Message[T]): raise NotImplementedError @abstractmethod - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError diff --git a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py index 26475b92..0c5e7f07 100644 --- a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Tuple +from typing import Any from rich.console import Group from rich.live import Live @@ -14,7 +14,7 @@ class DummyProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def consume_message(self, message: Message[ProgressUpdate]): pass - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass @@ -25,8 +25,8 @@ class RichProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def __init__( self, - train_split_num_steps: Dict[str, Tuple[int, int]], - eval_splits_num_steps: Dict[str, int], + train_split_num_steps: dict[str, tuple[int, int]], + eval_splits_num_steps: dict[str, int], ) -> None: # train split progress bar self.train_splits_progress = Progress( @@ -96,5 +96,5 @@ def consume_message(self, message: Message[ProgressUpdate]): completed=batch_progress.num_steps_done, ) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index d8c81ae6..8086f9a1 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict +from typing import Any import rich import wandb @@ -18,7 +18,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" pass - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass @@ -50,7 +50,7 @@ def consume_message(self, message: Message[EvaluationResultBatch]): if losses or metrics: rich.print(Panel(Group(*group_content))) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): raise NotImplementedError @@ -75,7 +75,7 @@ def __init__( self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config") - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): for k, v in mesasge_dict.items(): self.run.config[k] = v diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index c55f4ce0..c5ae3c4f 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List, Optional +from typing import Optional from modalities.config.config import WandbMode from modalities.dataloader.dataloader import LLMDataLoader @@ -18,7 +18,7 @@ class ProgressSubscriberFactory: @staticmethod def get_rich_progress_subscriber( - eval_dataloaders: List[LLMDataLoader], + eval_dataloaders: list[LLMDataLoader], train_dataloader_tag: str, num_seen_steps: int, num_target_steps: int, diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index cf6b96eb..e299d768 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,4 +1,4 @@ -from typing import Annotated, Dict, Tuple +from typing import Annotated import torch from einops import repeat @@ -181,7 +181,7 @@ def __init__( attention_config=text_decoder_config.attention_config, ) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the CoCa model. @@ -200,7 +200,7 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: self.text_cls_prediction_key: text_cls_token, } - def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_vision(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """ Encodes the input image using the vision encoder. @@ -208,7 +208,7 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch inputs (dict[str, torch.Tensor]): Dictionary containing vision inputs. Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. + tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. """ vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) @@ -216,7 +216,7 @@ def _forward_encode_vision(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1:, :] return vision_embd, vision_cls_token - def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: """ Encodes the input text using the text decoder. @@ -224,7 +224,7 @@ def _forward_encode_text(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.T inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor + tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor and the classification token tensor. """ text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index c6044782..437db1ec 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -1,5 +1,4 @@ from dataclasses import field -from typing import Dict, List import torch from pydantic import BaseModel @@ -13,14 +12,14 @@ class CoCaCollateFnConfig(BaseModel): Configuration class for CoCaCollateFn. Args: - sample_keys (List[str]): List of samples keys. - target_keys (List[str]): List of target keys. + sample_keys (list[str]): List of samples keys. + target_keys (list[str]): List of target keys. text_sample_key (str): Key for the text samples. text_target_key (str): Key for the text targets. """ - sample_keys: List[str] - target_keys: List[str] + sample_keys: list[str] + target_keys: list[str] text_sample_key: str text_target_key: str @@ -28,13 +27,13 @@ class CoCaCollateFnConfig(BaseModel): class CoCaCollatorFn(CollateFnIF): """Collator function for CoCa model.""" - def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_key: str, text_target_key: str): + def __init__(self, sample_keys: list[str], target_keys: list[str], text_sample_key: str, text_target_key: str): """ Initializes the CoCaCollatorFn object. Args: - sample_keys (List[str]): List of samples keys. - target_keys (List[str]): List of target keys. + sample_keys (list[str]): List of samples keys. + target_keys (list[str]): List of target keys. text_sample_key (str): Key for the text samples. text_target_key (str): Key for the text targets. @@ -58,12 +57,12 @@ def __init__(self, sample_keys: List[str], target_keys: List[str], text_sample_k self.text_sample_key = text_sample_key self.text_target_key = text_target_key - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors representing the batch data. Returns: diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 09997c2f..6c616523 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -1,5 +1,4 @@ from functools import partial -from typing import Dict import torch from torch import nn @@ -166,7 +165,7 @@ def __init__( ) self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size, bias=False) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the MultiModalTextDecoder module. diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index 4e69bed2..39204d18 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from torch import nn @@ -77,7 +75,7 @@ def __init__( ) ) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the TextDecoder module. diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py index 7211dd25..4e0256cb 100644 --- a/src/modalities/models/gpt2/collator.py +++ b/src/modalities/models/gpt2/collator.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Dict, List import torch @@ -10,12 +9,12 @@ class CollateFnIF(ABC): """CollateFnIF class to define a collate function interface.""" @abstractmethod - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors. + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors. Returns: DatasetBatch: The processed batch of data. @@ -40,12 +39,12 @@ def __init__(self, sample_key: str, target_key: str): self.sample_key = sample_key self.target_key = target_key - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: """ Process a batch of data. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries containing tensors. + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors. Returns: DatasetBatch: A processed batch of data where sample and target sequences are created. diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index ab43dfd1..fc54f512 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -1,7 +1,7 @@ import math from copy import deepcopy from enum import Enum -from typing import Annotated, Dict, List, Tuple +from typing import Annotated import torch import torch.nn as nn @@ -42,7 +42,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Perform forward pass for transforming queries/keys/values. @@ -52,7 +52,7 @@ def forward( v (torch.Tensor): The value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the output tensors. """ pass @@ -65,7 +65,7 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the IdentityTransform which does not apply any transform. @@ -75,7 +75,7 @@ def forward( v (torch.Tensor): The value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The tensors q, k, and v. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The tensors q, k, and v. """ return q, k, v @@ -160,7 +160,7 @@ def apply_rotary_pos_emb(self, x, cos, sin): def forward( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass of the RotaryTransform module. @@ -170,7 +170,7 @@ def forward( v (torch.Tensor): Value tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple containing the modified query tensor, key tensor, and value tensor. """ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k) @@ -213,7 +213,7 @@ class AttentionConfig(BaseModel): Configuration class for attention mechanism. Attributes: - qkv_transforms (List[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms. + qkv_transforms (list[QueryKeyValueTransformConfig]): List of configurations for query-key-value transforms. """ class QueryKeyValueTransformConfig(BaseModel): @@ -222,7 +222,7 @@ class QueryKeyValueTransformConfig(BaseModel): Attributes: type_hint (QueryKeyValueTransformType): The type hint for the transform. - config (Union[RotaryTransformConfig, IdentityTransformConfig]): The configuration for the transform. + config (RotaryTransformConfig | IdentityTransformConfig): The configuration for the transform. """ class IdentityTransformConfig(BaseModel): @@ -262,7 +262,7 @@ def parse_sharding_strategy_by_name(cls, name): type_hint: QueryKeyValueTransformType config: RotaryTransformConfig | IdentityTransformConfig - qkv_transforms: List[QueryKeyValueTransformConfig] + qkv_transforms: list[QueryKeyValueTransformConfig] class GPT2LLMConfig(BaseModel): @@ -422,7 +422,7 @@ def __init__( for transform_config in attention_config.qkv_transforms ) - def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def projection(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies projections to the input tensor to get queries, keys, and values. @@ -430,7 +430,7 @@ def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch x (torch.Tensor): The input tensor. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors. + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the query, key, and value tensors. """ # calculate query, key, values for all heads in batch and move head forward to be the batch dim return self.q_attn(x), self.k_attn(x), self.v_attn(x) @@ -438,7 +438,7 @@ def projection(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch @staticmethod def execute_qkv_transforms( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qkv_transforms: nn.ModuleList, n_head_q: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Applies a series of transformations to the query, key, and value tensors. @@ -450,7 +450,7 @@ def execute_qkv_transforms( n_head_q (int): The number of heads for the query tensors. Returns: - Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing the transformed query, key, and value tensors. """ batch_size, sequence_length, embedding_dim = q.size() @@ -826,16 +826,16 @@ def __init__( # not 100% sure what this is, so far seems to be harmless. TODO investigate self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying - def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass implementation of the GPT2LLM module. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - sample_key (str): Key for the input tensor containing token ids. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ input_ids = inputs[self.sample_key] @@ -861,16 +861,16 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso logits = self.lm_head(x) return {self.prediction_key: logits} - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the GPT2LLM module. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - sample_key (str): Key for the input tensor containing token ids. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ return self.forward_impl(inputs) diff --git a/src/modalities/models/gpt2/pretrained_gpt_model.py b/src/modalities/models/gpt2/pretrained_gpt_model.py index 7251229a..ea896624 100644 --- a/src/modalities/models/gpt2/pretrained_gpt_model.py +++ b/src/modalities/models/gpt2/pretrained_gpt_model.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch from transformers import PreTrainedModel @@ -38,7 +36,7 @@ def forward(self, tensor): """ model_input = {"input_ids": tensor} - model_forward_output: Dict[str, torch.Tensor] = self.model.forward(model_input) + model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) return model_forward_output[self.config.config.prediction_key] diff --git a/src/modalities/models/huggingface/huggingface_model.py b/src/modalities/models/huggingface/huggingface_model.py index 2d1bfa30..9d9a7414 100644 --- a/src/modalities/models/huggingface/huggingface_model.py +++ b/src/modalities/models/huggingface/huggingface_model.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from pydantic import BaseModel, ConfigDict @@ -102,26 +102,26 @@ def __init__( model_name, local_files_only=False, *model_args, **kwargs ) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. """ output = self.huggingface_model.forward(inputs[self.sample_key]) return {self.prediction_key: output[self.huggingface_prediction_subscription_key]} @property - def fsdp_block_names(self) -> List[str]: + def fsdp_block_names(self) -> list[str]: """ Returns a list of FSDP block names. Returns: - List[str]: A list of FSDP block names. + list[str]: A list of FSDP block names. """ return self.huggingface_model._no_split_modules diff --git a/src/modalities/models/huggingface_adapters/hf_adapter.py b/src/modalities/models/huggingface_adapters/hf_adapter.py index a111e1f8..09c75183 100644 --- a/src/modalities/models/huggingface_adapters/hf_adapter.py +++ b/src/modalities/models/huggingface_adapters/hf_adapter.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from pathlib import PosixPath -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Optional import torch from transformers import PretrainedConfig, PreTrainedModel @@ -49,8 +49,8 @@ def to_json_string(self, use_diff: bool = True) -> str: return json.dumps(json_dict) def _convert_posixpath_to_str( - self, data_to_be_formatted: Union[Dict[str, Any], List[Any], PosixPath, Any] - ) -> Union[Dict[str, Any], List[Any], PosixPath, Any]: + self, data_to_be_formatted: dict[str, Any] | list[Any] | PosixPath | Any + ) -> dict[str, Any] | list[Any] | PosixPath | Any: # Recursively converts any PosixPath objects within a nested data structure to strings. if isinstance(data_to_be_formatted, dict): @@ -108,13 +108,13 @@ def forward( output_hidden_states (bool, optional): Whether to output hidden states. Defaults to False. Returns: - Union[ModalitiesModelOutput, torch.Tensor]: The output of the forward pass. + ModalitiesModelOutput | torch.Tensor: The output of the forward pass. """ # These parameters are required by HuggingFace. We do not use them and hence don't implement them. if output_attentions or output_hidden_states: raise NotImplementedError model_input = {"input_ids": input_ids, "attention_mask": attention_mask} - model_forward_output: Dict[str, torch.Tensor] = self.model.forward(model_input) + model_forward_output: dict[str, torch.Tensor] = self.model.forward(model_input) if return_dict: return ModalitiesModelOutput(**model_forward_output) else: @@ -122,7 +122,7 @@ def forward( def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor = None, **kwargs - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Prepares the inputs for generation. @@ -132,7 +132,7 @@ def prepare_inputs_for_generation( **kwargs: Additional keyword arguments. Returns: - Dict[str, Any]: A dictionary containing the prepared inputs for generation. + dict[str, Any]: A dictionary containing the prepared inputs for generation. Note: Implement in subclasses of :class:`~transformers.PreTrainedModel` @@ -151,10 +151,10 @@ class ModalitiesModelOutput(ModelOutput): Args: logits (torch.FloatTensor, optional): The logits output of the model. Defaults to None. - hidden_states (Tuple[torch.FloatTensor], optional): The hidden states output of the model. Defaults to None. - attentions (Tuple[torch.FloatTensor], optional): The attentions output of the model. Defaults to None. + hidden_states (tuple[torch.FloatTensor], optional): The hidden states output of the model. Defaults to None. + attentions (tuple[torch.FloatTensor], optional): The attentions output of the model. Defaults to None. """ logits: Optional[torch.FloatTensor] = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states: Optional[tuple[torch.FloatTensor]] = None + attentions: Optional[tuple[torch.FloatTensor]] = None diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index 0bbaccc3..fd7703e5 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,13 +1,13 @@ from abc import abstractmethod from enum import Enum -from typing import Dict, List, Optional +from typing import Optional import torch import torch.nn as nn from modalities.batch import DatasetBatch, InferenceResultBatch -WeightDecayGroups = Dict[str, List[str]] +WeightDecayGroups = dict[str, list[str]] class ActivationType(str, Enum): @@ -50,19 +50,19 @@ def weight_decay_groups(self) -> WeightDecayGroups: return self._weight_decay_groups @abstractmethod - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the model. Args: - inputs (Dict[str, torch.Tensor]): A dictionary containing input tensors. + inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. Returns: - Dict[str, torch.Tensor]: A dictionary containing output tensors. + dict[str, torch.Tensor]: A dictionary containing output tensors. """ raise NotImplementedError - def get_parameters(self) -> Dict[str, torch.Tensor]: + def get_parameters(self) -> dict[str, torch.Tensor]: """ Returns a dictionary of the model's parameters. diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 96ea93a5..0a02fd8a 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch import torch.distributed as dist @@ -46,7 +45,7 @@ def get_checkpointed_model( def get_fsdp_wrapped_model( model: nn.Module, sync_module_states: bool, - block_names: List[str], + block_names: list[str], mixed_precision_settings: MixedPrecisionSettings, sharding_strategy: ShardingStrategy, ) -> FSDP: @@ -56,7 +55,7 @@ def get_fsdp_wrapped_model( Args: model (nn.Module): The original model to be wrapped. sync_module_states (bool): Whether to synchronize module states across ranks. - block_names (List[str]): List of block names. + block_names (list[str]): List of block names. mixed_precision_settings (MixedPrecisionSettings): Mixed precision settings. sharding_strategy (ShardingStrategy): Sharding strategy. @@ -108,12 +107,12 @@ def get_weight_initalized_model(model: nn.Module, model_initializer: ModelInitia return model @staticmethod - def get_activation_checkpointed_model(model: FSDP, activation_checkpointing_modules: List[str]) -> FSDP: + def get_activation_checkpointed_model(model: FSDP, activation_checkpointing_modules: list[str]) -> FSDP: """Apply activation checkpointing to the given model (in-place operation). Args: model (FSDP): The FSDP-wrapped model to apply activation checkpointing to. - activation_checkpointing_modules (List[str]): List of module names to apply activation checkpointing to. + activation_checkpointing_modules (list[str]): List of module names to apply activation checkpointing to. Raises: ValueError: Activation checkpointing can only be applied to FSDP-wrapped models! diff --git a/src/modalities/models/utils.py b/src/modalities/models/utils.py index debf0cab..77c75b6a 100644 --- a/src/modalities/models/utils.py +++ b/src/modalities/models/utils.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import Dict from pydantic import BaseModel @@ -22,12 +21,12 @@ class ModelTypeEnum(Enum): CHECKPOINTED_MODEL = "checkpointed_model" -def get_model_from_config(config: Dict, model_type: ModelTypeEnum): +def get_model_from_config(config: dict, model_type: ModelTypeEnum): """ Retrieves a model from the given configuration based on the specified model type. Args: - config (Dict): The configuration dictionary. + config (dict): The configuration dictionary. model_type (ModelTypeEnum): The type of the model to retrieve. Returns: diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index fd954574..0b504bd3 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -1,5 +1,5 @@ from math import floor -from typing import Annotated, Dict, Optional, Tuple, Union +from typing import Annotated, Optional import torch from einops.layers.torch import Rearrange @@ -18,7 +18,7 @@ class VisionTransformerConfig(BaseModel): Args: sample_key (str): The key for the input sample. prediction_key (str): The key for the model prediction. - img_size (Union[Tuple[int, int], int], optional): The size of the input image. Defaults to 224. + img_size (tuple[int, int] | int, optional): The size of the input image. Defaults to 224. n_classes (int, optional): The number of output classes. Defaults to 1000. n_layer (int): The number of layers in the model. Defaults to 12. attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. @@ -34,7 +34,7 @@ class VisionTransformerConfig(BaseModel): sample_key: str prediction_key: str - img_size: Annotated[Union[Tuple[int, int], int], Field(ge=1)] = 224 + img_size: Annotated[tuple[int, int] | int, Field(ge=1)] = 224 n_classes: Optional[Annotated[int, Field(ge=1)]] = 1000 n_layer: Annotated[int, Field(ge=1)] = 12 attention_config: AttentionConfig = None @@ -176,7 +176,7 @@ def __init__( self, sample_key: str, prediction_key: str, - img_size: Union[Tuple[int, int], int] = 224, + img_size: tuple[int, int] | int = 224, n_classes: int = 1000, n_layer: int = 12, attention_config: AttentionConfig = None, @@ -196,7 +196,7 @@ def __init__( Args: sample_key (str): The key for the samples. prediction_key (str): The key for the predictions. - img_size (Union[Tuple[int, int], int], optional): The size of the input image. Defaults to 224. + img_size (tuple[int, int] | int, optional): The size of the input image. Defaults to 224. n_classes (int, optional): The number of classes. Defaults to 1000. n_layer (int, optional): The number of layers. Defaults to 12. attention_config (AttentionConfig, optional): The attention configuration. Defaults to None. @@ -257,7 +257,7 @@ def forward_images(self, x: torch.Tensor) -> torch.Tensor: x = block(x) return x - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the VisionTransformer module. @@ -279,12 +279,12 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return {self.prediction_key: x} @staticmethod - def _calculate_block_size(img_size: Tuple[int, int], patch_size: int, patch_stride: int, add_cls_token: bool): + def _calculate_block_size(img_size: tuple[int, int], patch_size: int, patch_stride: int, add_cls_token: bool): """ Calculates the block size. Args: - img_size (Tuple[int, int]): The size of the input image. + img_size (tuple[int, int]): The size of the input image. patch_size (int): The size of each patch. patch_stride (int): The stride of each patch. add_cls_token (bool): Flag indicating whether to add a classification token. diff --git a/src/modalities/nn/attention.py b/src/modalities/nn/attention.py index dd8b5db5..789602a1 100644 --- a/src/modalities/nn/attention.py +++ b/src/modalities/nn/attention.py @@ -1,6 +1,6 @@ import math from enum import Enum -from typing import Optional, Tuple +from typing import Optional import torch import torch.nn.functional as F @@ -78,7 +78,7 @@ def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: y = self.resid_dropout(self.c_proj(y)) return y - def _forward_input_projection(self, x: Tensor, context: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + def _forward_input_projection(self, x: Tensor, context: Tensor) -> tuple[Tensor, Tensor, Tensor]: B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) _, Tc, Cc = context.shape # batch size, context length, context embedding dimensionality # Note that the context length (Tc), sequence length (T) and embedding dimensionalities (C and Cc) diff --git a/src/modalities/nn/model_initialization/composed_initialization.py b/src/modalities/nn/model_initialization/composed_initialization.py index 3b02d0bd..4a4d9f33 100644 --- a/src/modalities/nn/model_initialization/composed_initialization.py +++ b/src/modalities/nn/model_initialization/composed_initialization.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional import torch.nn as nn from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -15,7 +15,7 @@ class ModelInitializerWrapperConfig(BaseModel): - model_initializers: List[PydanticModelInitializationIFType] + model_initializers: list[PydanticModelInitializationIFType] # avoid warning about protected namespace 'model_', see # https://docs.pydantic.dev/2.7/api/config/#pydantic.config.ConfigDict.protected_namespaces @@ -78,7 +78,7 @@ def _check_values(self): class ModelInitializerWrapper(ModelInitializationIF): - def __init__(self, model_initializers: List[ModelInitializationIF]): + def __init__(self, model_initializers: list[ModelInitializationIF]): self.model_initializers = model_initializers def initialize_in_place(self, model: nn.Module): @@ -88,7 +88,7 @@ def initialize_in_place(self, model: nn.Module): class ComposedInitializationRoutines: @staticmethod - def get_model_initializer_wrapper(model_initializers: List[ModelInitializationIF]) -> ModelInitializationIF: + def get_model_initializer_wrapper(model_initializers: list[ModelInitializationIF]) -> ModelInitializationIF: initializer_wrapper = ModelInitializerWrapper(model_initializers) return initializer_wrapper diff --git a/src/modalities/nn/model_initialization/initialization_routines.py b/src/modalities/nn/model_initialization/initialization_routines.py index aa14ebf8..743297ec 100644 --- a/src/modalities/nn/model_initialization/initialization_routines.py +++ b/src/modalities/nn/model_initialization/initialization_routines.py @@ -1,6 +1,6 @@ import math import re -from typing import Annotated, List, Optional +from typing import Annotated, Optional import torch.nn as nn from pydantic import BaseModel, Field, model_validator @@ -12,7 +12,7 @@ class PlainInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] | str # can be float or "auto" - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" hidden_dim: Optional[int] = None @model_validator(mode="after") @@ -30,12 +30,12 @@ class ScaledInitializationConfig(BaseModel): mean: float std: Annotated[float, Field(strict=True, ge=0.0)] num_layers: Annotated[int, Field(strict=True, gt=0)] - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" class ScaledEmbedInitializationConfig(BaseModel): mean: float - parameter_name_regexes: List[str] # here we filter for the parameter names, e.g., "c_proj.weight" + parameter_name_regexes: list[str] # here we filter for the parameter names, e.g., "c_proj.weight" class NamedParameterwiseNormalInitialization(ModelInitializationIF): @@ -59,7 +59,7 @@ def initialize_in_place(self, model: nn.Module): class InitializationRoutines: @staticmethod def get_plain_initialization( - mean: float, std: float | str, parameter_name_regexes: List[str], hidden_dim: Optional[int] = None + mean: float, std: float | str, parameter_name_regexes: list[str], hidden_dim: Optional[int] = None ) -> NamedParameterwiseNormalInitialization: """Initializes the weights of a model by sampling from a normal distribution. NOTE: This class supports the initialization of nn.Linear and nn.Embedding layers. @@ -86,7 +86,7 @@ def get_plain_initialization( @staticmethod def get_scaled_initialization( - mean: float, std: float, num_layers: int, parameter_name_regexes: List[str] + mean: float, std: float, num_layers: int, parameter_name_regexes: list[str] ) -> ModelInitializationIF: """Implementation of scaled weight initialization. As defined in https://arxiv.org/abs/2312.16903 @@ -94,7 +94,7 @@ def get_scaled_initialization( mean (float): Mean of the normal distribution std (float): Standard deviation of the normal distribution used to initialize the other weights num_layers (int): Number of layers in the model which we use to downscale std with - parameter_name_regexes (List[str]): List of parameter name regexes to which the initialization + parameter_name_regexes (list[str]): List of parameter name regexes to which the initialization should be applied Returns: @@ -109,13 +109,13 @@ def get_scaled_initialization( return initialization @staticmethod - def get_scaled_embed_initialization(mean: float, parameter_name_regexes: List[str]) -> ModelInitializationIF: + def get_scaled_embed_initialization(mean: float, parameter_name_regexes: list[str]) -> ModelInitializationIF: """Implementation of scaled weight initialization for embeddings, see https://arxiv.org/abs/2312.16903 We fix the standard deviation to sqrt(0.4). Args: mean (float): Mean of the normal distribution - parameter_name_regexes (List[str], optional): List of parameter name regexes to which the initialization + parameter_name_regexes (list[str], optional): List of parameter name regexes to which the initialization should be applied Defaults to None. Returns: diff --git a/src/modalities/nn/model_initialization/parameter_name_filters.py b/src/modalities/nn/model_initialization/parameter_name_filters.py index 4f24c5aa..ff4edede 100644 --- a/src/modalities/nn/model_initialization/parameter_name_filters.py +++ b/src/modalities/nn/model_initialization/parameter_name_filters.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Optional +from typing import Optional from pydantic import BaseModel, Field @@ -16,8 +16,8 @@ class SupportWeightInitModels(Enum): class RegexFilter(BaseModel): - weights: List[str] - biases: Optional[List[str]] = Field(default_factory=list) + weights: list[str] + biases: Optional[list[str]] = Field(default_factory=list) NAMED_PARAMETER_INIT_GROUPS = { diff --git a/src/modalities/optimizers/lr_schedulers.py b/src/modalities/optimizers/lr_schedulers.py index bf53fc18..5e0e6f5b 100644 --- a/src/modalities/optimizers/lr_schedulers.py +++ b/src/modalities/optimizers/lr_schedulers.py @@ -1,5 +1,4 @@ import warnings -from typing import List from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler @@ -9,7 +8,7 @@ class DummyLRScheduler(LRScheduler): def __init__(self, optimizer: Optimizer, last_epoch=-1, verbose=False): super().__init__(optimizer, last_epoch, verbose) - def get_lr(self) -> List[float]: + def get_lr(self) -> list[float]: if not self._get_lr_called_within_step: # type error expected due to internal pytorch implementation warnings.warn( "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", UserWarning @@ -17,5 +16,5 @@ def get_lr(self) -> List[float]: return [group["lr"] for group in self.optimizer.param_groups] - def _get_closed_form_lr(self) -> List[float]: + def _get_closed_form_lr(self) -> list[float]: return self.base_lrs diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index 371c9c48..fdff5d0f 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -1,6 +1,5 @@ import re from pathlib import Path -from typing import Dict, List, Tuple import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -11,16 +10,16 @@ from modalities.models.model import NNModel from modalities.util import get_local_number_of_trainable_parameters, print_rank_0 -OptimizerGroups = List[Dict[str, List[nn.Parameter] | float]] +OptimizerGroups = list[dict[str, list[nn.Parameter] | float]] class OptimizerFactory: def get_adam( lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, - weight_decay_groups_excluded: List[str], + weight_decay_groups_excluded: list[str], wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) @@ -29,10 +28,10 @@ def get_adam( def get_adam_w( lr: float, - betas: Tuple[float, float], + betas: tuple[float, float], eps: float, weight_decay: float, - weight_decay_groups_excluded: List[str], + weight_decay_groups_excluded: list[str], wrapped_model: nn.Module, ) -> Optimizer: optimizer_groups = get_optimizer_groups(wrapped_model, weight_decay, weight_decay_groups_excluded) @@ -49,7 +48,7 @@ def get_checkpointed_optimizer( return wrapped_optimizer -def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_excluded: List[str]) -> OptimizerGroups: +def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_excluded: list[str]) -> OptimizerGroups: """ divide model parameters into optimizer groups (with or without weight decay) @@ -73,7 +72,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e return optimizer_groups -def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_groups_excluded: List[str]) -> None: +def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_groups_excluded: list[str]) -> None: """ checks the existence of all groups that are to be excluded from weight decay @@ -93,8 +92,8 @@ def _assert_existence_of_weight_decay_groups_excluded(model: FSDP, weight_decay_ def _create_optimizer_groups( - model: FSDP, weight_decay: float, weight_decay_groups_excluded: List[str] -) -> Tuple[OptimizerGroups, List[str]]: + model: FSDP, weight_decay: float, weight_decay_groups_excluded: list[str] +) -> tuple[OptimizerGroups, list[str]]: """ create optimizer groups of parameters with different weight decays that are to be used in Adam or AdamW """ @@ -118,8 +117,8 @@ def _create_optimizer_groups( def _filter_params_for_weight_decay_group( - params: Dict[str, List[nn.Parameter]], regex_expressions: List[str] -) -> List[nn.Parameter]: + params: dict[str, list[nn.Parameter]], regex_expressions: list[str] +) -> list[nn.Parameter]: """ filter parameters by their name. a parameter is kept if and only if it contains at least one of the regex expressions. @@ -139,7 +138,7 @@ def _print_params(params) -> None: print_rank_0(f"{i + 1} {name}") -def _print_optimizer_groups_overview(optimizer_groups: OptimizerGroups, optimizer_groups_names: List[str]) -> None: +def _print_optimizer_groups_overview(optimizer_groups: OptimizerGroups, optimizer_groups_names: list[str]) -> None: """ for each optimizer group, the following is printed: - the number of modules diff --git a/src/modalities/registry/registry.py b/src/modalities/registry/registry.py index aebd8cea..02010700 100644 --- a/src/modalities/registry/registry.py +++ b/src/modalities/registry/registry.py @@ -1,25 +1,25 @@ from dataclasses import asdict -from typing import Dict, List, Optional, Tuple, Type +from typing import Optional, Type from pydantic import BaseModel from modalities.registry.components import ComponentEntity -Entity = Tuple[Type, Type[BaseModel]] +Entity = tuple[Type, Type[BaseModel]] class Registry: """Registry class to store the components and their config classes.""" - def __init__(self, components: Optional[List[ComponentEntity]] = None) -> None: + def __init__(self, components: Optional[list[ComponentEntity]] = None) -> None: """Initializes the Registry class with an optional list of components. Args: - components (List[ComponentEntity], optional): List of components to + components (list[ComponentEntity], optional): List of components to intialize the registry with . Defaults to None. """ # maps component_key -> variant_key -> entity = (component, config) - self._registry_dict: Dict[str, Dict[str, Entity]] = {} + self._registry_dict: dict[str, dict[str, Entity]] = {} if components is not None: for component in components: self.add_entity(**asdict(component)) diff --git a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py index 22d463f3..97ad6976 100644 --- a/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py +++ b/src/modalities/running_env/fsdp/fsdp_auto_wrapper.py @@ -1,7 +1,7 @@ import functools import logging from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable import torch.nn as nn from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy @@ -17,13 +17,13 @@ def get_auto_wrap_policy(self) -> Callable: class FSDPTransformerAutoWrapPolicyFactory(FSDPAutoWrapFactoryIF): - def __init__(self, model: nn.Module, block_names: List[str]) -> None: + def __init__(self, model: nn.Module, block_names: list[str]) -> None: # TODO it's problematic that we store the model in-memory here. Might get too large in RAM... self.model = model self.block_names = block_names @staticmethod - def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str]) -> List[nn.Module]: + def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: list[str]) -> list[nn.Module]: fsdp_block_types = [] for cls_block_name in block_names: # TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct diff --git a/src/modalities/tokenization/tokenizer_wrapper.py b/src/modalities/tokenization/tokenizer_wrapper.py index 479c7954..e9e778fc 100644 --- a/src/modalities/tokenization/tokenizer_wrapper.py +++ b/src/modalities/tokenization/tokenizer_wrapper.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Dict, List, Optional +from typing import Optional import sentencepiece as spm from transformers import AutoTokenizer @@ -8,7 +8,7 @@ class TokenizerWrapper(ABC): """Abstract interface for tokenizers.""" - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: @@ -18,15 +18,15 @@ def tokenize(self, text: str) -> List[int]: NotImplementedError: Must be implemented by a subclass. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ raise NotImplementedError - def decode(self, input_ids: List[int]) -> str: + def decode(self, input_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Raises: NotImplementedError: Must be implemented by a subclass. @@ -72,7 +72,7 @@ def __init__( truncation: Optional[bool] = False, padding: Optional[bool | str] = False, max_length: Optional[int] = None, - special_tokens: Optional[Dict[str, str]] = None, + special_tokens: Optional[dict[str, str]] = None, ) -> None: """Initializes the PreTrainedHFTokenizer. @@ -81,7 +81,7 @@ def __init__( truncation (bool, optional): Flag whether to apply truncation. Defaults to False. padding (bool | str, optional): Defines the padding strategy. Defaults to False. max_length (int, optional): Maximum length of the tokenization output. Defaults to None. - special_tokens (Dict[str, str], optional): Added token keys should be in the list + special_tokens (dict[str, str], optional): Added token keys should be in the list of predefined special attributes: [bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token, additional_special_tokens]. Example: {"pad_token": "[PAD]"} @@ -113,22 +113,22 @@ def vocab_size(self) -> int: return self.tokenizer.vocab_size @property - def special_tokens(self) -> Dict[str, str | List[str]]: + def special_tokens(self) -> dict[str, str | list[str]]: """Returns the special tokens of the tokenizer. Returns: - Dict[str, str | List[str]]: Special tokens dictionary. + dict[str, str | list[str]]: Special tokens dictionary. """ return self.tokenizer.special_tokens_map - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: text (str): Text to be tokenized. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ tokens = self.tokenizer.__call__( text, @@ -138,11 +138,11 @@ def tokenize(self, text: str) -> List[int]: )["input_ids"] return tokens - def decode(self, token_ids: List[int]) -> str: + def decode(self, token_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Returns: str: Decoded text. @@ -180,23 +180,23 @@ def __init__(self, tokenizer_model_file: str): self.tokenizer = spm.SentencePieceProcessor() self.tokenizer.Load(tokenizer_model_file) - def tokenize(self, text: str) -> List[int]: + def tokenize(self, text: str) -> list[int]: """Tokenizes a text into a list of token IDs. Args: text (str): Text to be tokenized. Returns: - List[int]: List of token IDs. + list[int]: List of token IDs. """ tokens = self.tokenizer.encode(text) return tokens - def decode(self, token_ids: List[int]) -> str: + def decode(self, token_ids: list[int]) -> str: """Decodes a list of token IDs into the original text. Args: - input_ids (List[int]): List of token IDs. + input_ids (list[int]): List of token IDs. Returns: str: Decoded text. diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index e15e47c0..195d050a 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Callable, Optional, Tuple +from typing import Callable, Optional import torch import torch.distributed as dist @@ -89,7 +89,7 @@ def _train_batch( scheduler: LRScheduler, loss_fun: Loss, micro_batch_id: int, - ) -> Tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: """ Conducts a training step on batch of data. @@ -102,7 +102,7 @@ def _train_batch( micro_batch_id (int): The ID of the micro batch. Returns: - Tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. diff --git a/src/modalities/training/activation_checkpointing.py b/src/modalities/training/activation_checkpointing.py index 4da52687..ee7c7152 100644 --- a/src/modalities/training/activation_checkpointing.py +++ b/src/modalities/training/activation_checkpointing.py @@ -1,5 +1,4 @@ from functools import partial -from typing import List import torch from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -13,12 +12,12 @@ def is_module_to_apply_activation_checkpointing( - submodule: torch.nn.Module, activation_checkpointing_modules: List[type] + submodule: torch.nn.Module, activation_checkpointing_modules: list[type] ) -> bool: return isinstance(submodule, tuple(activation_checkpointing_modules)) -def apply_activation_checkpointing_inplace(model: torch.nn.Module, activation_checkpointing_modules: List[str]): +def apply_activation_checkpointing_inplace(model: torch.nn.Module, activation_checkpointing_modules: list[str]): activation_checkpointing_module_types = [ get_module_class_from_name(model, m) for m in activation_checkpointing_modules ] diff --git a/src/modalities/util.py b/src/modalities/util.py index 25d033e6..1bbe3ff4 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -5,7 +5,7 @@ from enum import Enum from pathlib import Path from types import TracebackType -from typing import Callable, Dict, Generic, Optional, Type, TypeVar +from typing import Callable, Generic, Optional, Type, TypeVar import torch import torch.distributed as dist @@ -151,7 +151,7 @@ def __repr__(self) -> str: class Aggregator(Generic[T]): def __init__(self): - self.key_to_value: Dict[T, torch.Tensor] = {} + self.key_to_value: dict[T, torch.Tensor] = {} def add_value(self, key: T, value: torch.Tensor): if key not in self.key_to_value: diff --git a/src/modalities/utils/mfu.py b/src/modalities/utils/mfu.py index d33e165e..48d9a604 100644 --- a/src/modalities/utils/mfu.py +++ b/src/modalities/utils/mfu.py @@ -1,5 +1,5 @@ import warnings -from typing import Optional, Tuple +from typing import Optional import torch from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -76,7 +76,7 @@ def get_theoretical_gpu_peak_performance(model: FSDP, world_size: int) -> Option return None -def get_theoretical_flops_per_token(model: FSDP) -> Tuple[Optional[int], Optional[int]]: +def get_theoretical_flops_per_token(model: FSDP) -> tuple[Optional[int], Optional[int]]: """ Calculates the theoretical number of floating point operations (FLOPs) per token for a given model. compute theoretical_flops_per_token = 6*N + 12*L*T*H @@ -86,7 +86,7 @@ def get_theoretical_flops_per_token(model: FSDP) -> Tuple[Optional[int], Optiona model (FSDP): The model for which to calculate the FLOPs per token. Returns: - Tuple[(int, optional), (int, optional)]: A tuple containing the theoretical FLOPs per token + tuple[(int, optional), (int, optional)]: A tuple containing the theoretical FLOPs per token and the sequence length. - Theoretical FLOPs per token: The estimated number of FLOPs required to process each token in the model. - Sequence length: The length of the input sequence. Needed to convert samples to tokens in compute_mfu. diff --git a/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py b/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py index 6c661511..080bcfc9 100644 --- a/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py +++ b/tests/checkpointing/pytorch/test_torch_checkpoint_loading.py @@ -1,5 +1,3 @@ -from typing import Dict - import pytest import torch import torch.nn as nn @@ -13,7 +11,7 @@ def __init__(self): super().__init__() self._weights = nn.Linear(2, 3) - def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: output = self._weights(**inputs) return {"output": output} diff --git a/tests/checkpointing/test_checkpoint_strategies.py b/tests/checkpointing/test_checkpoint_strategies.py index 9aef57df..9fd34580 100644 --- a/tests/checkpointing/test_checkpoint_strategies.py +++ b/tests/checkpointing/test_checkpoint_strategies.py @@ -1,5 +1,3 @@ -from typing import List - import pytest from modalities.checkpointing.checkpoint_saving_strategies import SaveKMostRecentCheckpointsStrategy @@ -25,7 +23,7 @@ ], ) def test_checkpoint_strategy_k( - k: int, saved_instances: List[TrainingProgress], checkpoints_to_delete: List[int], save_current: bool + k: int, saved_instances: list[TrainingProgress], checkpoints_to_delete: list[int], save_current: bool ) -> None: training_progress = TrainingProgress( num_seen_steps_current_run=10, num_seen_tokens_current_run=10, num_target_steps=20, num_target_tokens=40 diff --git a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py index 9a56d274..cdcbecb7 100644 --- a/tests/checkpointing/test_fsdp_to_disc_checkpointing.py +++ b/tests/checkpointing/test_fsdp_to_disc_checkpointing.py @@ -2,7 +2,6 @@ import tempfile from copy import deepcopy from pathlib import Path -from typing import Dict import pytest import torch @@ -42,7 +41,7 @@ reason="This e2e test requires 2 GPUs and a torchrun distributed environment.", ) class TestFSDPToDiscCheckpointing: - def get_gpt2_model_from_config(self, gpt2_model_config_dict: Dict) -> GPT2LLM: + def get_gpt2_model_from_config(self, gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): model: PydanticPytorchModuleType @@ -57,7 +56,7 @@ class GPT2InstantationModel(BaseModel): return model @pytest.fixture(scope="function") - def gpt2_model_config_dict(self) -> Dict: + def gpt2_model_config_dict(self) -> dict: config_file_path = working_dir / "gpt2_config.yaml" config_dict = load_app_config_dict(config_file_path=config_file_path) return config_dict @@ -108,7 +107,7 @@ def _clone_parameters(fsdp_wrapped_model): return [p.clone() for p in fsdp_wrapped_model.parameters() if p.requires_grad and p.numel() > 0] @staticmethod - def _generate_batch(gpt2_model_config: Dict): + def _generate_batch(gpt2_model_config: dict): # prepare input and targets data = torch.randint( 0, # lowest token_id @@ -122,10 +121,10 @@ def _generate_batch(gpt2_model_config: Dict): @staticmethod def _forward_backward_pass( - gpt2_model_config: Dict, + gpt2_model_config: dict, model: FSDP, optimizer: Optimizer, - batch_input_ids_dict: Dict, + batch_input_ids_dict: dict, batch_target_ids: torch.Tensor, ): ce_loss = CrossEntropyLoss() @@ -148,7 +147,7 @@ def _forward_backward_pass( @staticmethod def _assert_equality_optimizer_param_group( - optimizer_1_state_dict: Dict, optimizer_2_state_dict: Dict, must_be_equal: bool + optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool ): if must_be_equal: assert ( @@ -161,7 +160,7 @@ def _assert_equality_optimizer_param_group( @staticmethod def _assert_equality_optimizer_state( - optimizer_1_state_dict: Dict, optimizer_2_state_dict: Dict, must_be_equal: bool + optimizer_1_state_dict: dict, optimizer_2_state_dict: dict, must_be_equal: bool ): optimizer_1_state = optimizer_1_state_dict["state"] optimizer_2_state = optimizer_2_state_dict["state"] @@ -195,7 +194,7 @@ def test_save_checkpoint_after_backward_pass( optimizer: Optimizer, temporary_checkpoint_folder_path: Path, gpt2_model_2: GPT2LLM, - gpt2_model_config_dict: Dict, + gpt2_model_config_dict: dict, ): experiment_id = "0" num_train_steps_done = 1 diff --git a/tests/config/components.py b/tests/config/components.py index 67c9e9a3..f67c5531 100644 --- a/tests/config/components.py +++ b/tests/config/components.py @@ -1,5 +1,4 @@ from enum import Enum -from typing import List class Component_V_W_X_IF: @@ -30,7 +29,7 @@ def __init__(self, val_x: str, single_dependency: Component_V_W_X_IF) -> None: class ComponentY: - def __init__(self, val_y: str, multi_dependency: List[Component_V_W_X_IF]) -> None: + def __init__(self, val_y: str, multi_dependency: list[Component_V_W_X_IF]) -> None: self.val_y = val_y self.multi_dependency = multi_dependency diff --git a/tests/config/configs.py b/tests/config/configs.py index 569c4d11..2ed597eb 100644 --- a/tests/config/configs.py +++ b/tests/config/configs.py @@ -1,4 +1,4 @@ -from typing import Annotated, List +from typing import Annotated from pydantic import BaseModel @@ -23,7 +23,7 @@ class CompXConfig(BaseModel): class CompYConfig(BaseModel): val_y: str - multi_dependency: List[PydanticComponent_V_W_X_IF_Type] + multi_dependency: list[PydanticComponent_V_W_X_IF_Type] class CompZConfig(BaseModel): diff --git a/tests/conftest.py b/tests/conftest.py index 1aa87f87..be2fa571 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,6 @@ import os import pickle from pathlib import Path -from typing import Dict from unittest.mock import MagicMock import pytest @@ -51,7 +50,7 @@ def dummy_config_path() -> Path: @pytest.fixture -def dummy_config(monkeypatch, dummy_config_path) -> Dict: +def dummy_config(monkeypatch, dummy_config_path) -> dict: monkeypatch.setenv("RANK", "0") monkeypatch.setenv("LOCAL_RANK", "0") monkeypatch.setenv("WORLD_SIZE", "1") diff --git a/tests/dataloader/dummy_sequential_dataset.py b/tests/dataloader/dummy_sequential_dataset.py index 8eb412a4..0d3a8dee 100644 --- a/tests/dataloader/dummy_sequential_dataset.py +++ b/tests/dataloader/dummy_sequential_dataset.py @@ -1,5 +1,3 @@ -from typing import Dict - from pydantic import BaseModel from torch.utils.data.dataset import Dataset as TorchdataSet @@ -11,7 +9,7 @@ def __init__(self, num_samples: int): def __len__(self) -> int: return len(self.samples) - def __getitem__(self, idx: int) -> Dict: + def __getitem__(self, idx: int) -> dict: return self.samples[idx] diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 3b0ef7be..65139ce6 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from pathlib import Path -from typing import Any, Dict, List +from typing import Any import numpy as np import pytest @@ -44,7 +44,7 @@ def test_resumable_dataloader(): assert (flat_samples == original_samples).all() -def test_dataloader_from_config(dummy_config: Dict): +def test_dataloader_from_config(dummy_config: dict): start_index = 2 dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index @@ -248,7 +248,7 @@ class DataloaderTestModel(BaseModel): fixed_num_batches: int class IdentityCollateFn(CollateFnIF): - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> List[Dict[str, torch.Tensor]]: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> list[dict[str, torch.Tensor]]: return batch root_dir = Path(__file__).parents[0] diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index f41233ff..3261eb4b 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -2,7 +2,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Dict, List +from typing import Any import pytest import torch @@ -31,13 +31,13 @@ class SaveAllResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def __init__(self): - self.message_list: List[Message[EvaluationResultBatch]] = [] + self.message_list: list[Message[EvaluationResultBatch]] = [] def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" self.message_list.append(message) - def consume_dict(self, mesasge_dict: Dict[str, Any]): + def consume_dict(self, mesasge_dict: dict[str, Any]): pass @@ -55,7 +55,7 @@ class TrainDataloaderInstantiationModel(BaseModel): ) class TestWarmstart: @staticmethod - def get_loss_scores(messages: List[Message[EvaluationResultBatch]], loss_key: str) -> List[float]: + def get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: return [message.payload.losses[loss_key].value.item() for message in messages] def test_warm_start(self): @@ -130,7 +130,7 @@ def test_warm_start(self): # we collect the loss values from rank 0 and store them in the temporary experiment folder if dist.get_rank() == 0: - messages_0: List[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list + messages_0: list[Message[EvaluationResultBatch]] = components_0.evaluation_subscriber.message_list loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "train loss avg") with open(loss_values_experiment_0_path, "w") as f: json.dump(loss_scores_0, f) @@ -156,7 +156,7 @@ def test_warm_start(self): # we collect the loss values from rank 0 for the warmstart model # and store them in the temporary experiment folder if dist.get_rank() == 0: - messages_1: List[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list + messages_1: list[Message[EvaluationResultBatch]] = components_1.evaluation_subscriber.message_list loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "train loss avg") with open(loss_values_experiment_1_path, "w") as f: json.dump(loss_scores_1, f) diff --git a/tests/test_initialization.py b/tests/test_initialization.py index 9edbb3c0..a169ade8 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -2,7 +2,7 @@ import os import re from pathlib import Path -from typing import Dict, Optional +from typing import Optional import pytest import torch @@ -27,7 +27,7 @@ # $(which pytest) path/to/test_initialization.py -def get_model_from_config(model_config_dict: Dict) -> GPT2LLM | CoCa: +def get_model_from_config(model_config_dict: dict) -> GPT2LLM | CoCa: """get gpt2 or coca model from config_dict""" class InstantationModel(BaseModel): @@ -44,7 +44,7 @@ class InstantationModel(BaseModel): return model -def _replace_config_dict(_config_dict: Dict, _initialization_type: str, _std: str) -> Dict: +def _replace_config_dict(_config_dict: dict, _initialization_type: str, _std: str) -> dict: """dynamically replace initialization_type, std and dependent fields in config_dict""" _config_dict["model"]["config"]["model_initializer"]["config"]["weight_init_type"] = _initialization_type # replace _config_dict["model"]["config"]["model_initializer"]["config"]["std"] = _std # replace @@ -120,7 +120,7 @@ def _load_model(model_name: str, initialization: str = "plain", std: float | str } -def get_group_params(model: FSDP, model_name: str) -> Dict[str, Optional[torch.Tensor]]: +def get_group_params(model: FSDP, model_name: str) -> dict[str, Optional[torch.Tensor]]: """ divide all model parameters into initialization groups """ diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index 840003d3..4f273ad0 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -1,6 +1,5 @@ import os from pathlib import Path -from typing import Dict import pytest import torch @@ -26,7 +25,7 @@ # $(which pytest) path/to/test_optimizer_factory.py -def get_gpt2_model_from_config(gpt2_model_config_dict: Dict) -> GPT2LLM: +def get_gpt2_model_from_config(gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): model: PydanticPytorchModuleType diff --git a/tests/utils/test_mfu.py b/tests/utils/test_mfu.py index e0672004..ea1d424a 100644 --- a/tests/utils/test_mfu.py +++ b/tests/utils/test_mfu.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Optional +from typing import Optional import pytest import torch @@ -26,7 +26,7 @@ # $(which pytest) path/to/test_mfu.py -def get_model_from_config(model_config_dict: Dict) -> GPT2LLM: +def get_model_from_config(model_config_dict: dict) -> GPT2LLM: """get gpt2 model from config_dict""" class InstantationModel(BaseModel): diff --git a/tutorials/library_usage/README.md b/tutorials/library_usage/README.md index e23abe32..74729a28 100644 --- a/tutorials/library_usage/README.md +++ b/tutorials/library_usage/README.md @@ -28,7 +28,7 @@ class CustomGPT2LLMCollateFn(CollateFnIF): self.target_key = target_key self.custom_attribute = custom_attribute - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) samples = {self.sample_key: sample_tensor[:, :-1]} targets = {self.target_key: sample_tensor[:, 1:]} diff --git a/tutorials/library_usage/main.py b/tutorials/library_usage/main.py index 6e65e4e6..727a4db3 100644 --- a/tutorials/library_usage/main.py +++ b/tutorials/library_usage/main.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List import torch from pydantic import BaseModel @@ -24,7 +23,7 @@ def __init__(self, sample_key: str, target_key: str, custom_attribute: str): self.target_key = target_key self.custom_attribute = custom_attribute - def __call__(self, batch: List[List[int]]) -> DatasetBatch: + def __call__(self, batch: list[list[int]]) -> DatasetBatch: sample_tensor = torch.tensor(batch) samples = {self.sample_key: sample_tensor[:, :-1]} targets = {self.target_key: sample_tensor[:, 1:]}