diff --git a/config_files/config_example_audio_mem_map_dataset.yaml b/config_files/config_example_audio_mem_map_dataset.yaml new file mode 100644 index 000000000..95f7e226f --- /dev/null +++ b/config_files/config_example_audio_mem_map_dataset.yaml @@ -0,0 +1,16 @@ +features: + - jq_pattern: .transcript + codec: + type_hint: HfTokenizerCodec + config: + add_eos_token: true + tokenizer: + type_hint: GPT2TokenizerFast + config: + tokenizer_file: ./data/tokenizer/tokenizer.json + - jq_pattern: .audio_path + codec: + type_hint: TorchaudioAudioCodec + config: + target_sample_rate: 16_000 + n_mels: 80 diff --git a/config_files/data_config.yaml b/config_files/data_config.yaml new file mode 100644 index 000000000..f68990765 --- /dev/null +++ b/config_files/data_config.yaml @@ -0,0 +1,15 @@ +features: + - jq_pattern: .cls + codec: + type_hint: HfTokenizerCodec + config: + add_eos_token: true + tokenizer: + type_hint: GPT2TokenizerFast + config: + tokenizer_file: ./data/tokenizer/tokenizer.json + - jq_pattern: .img_path + codec: + type_hint: PillowImageCodec + config: + save_format: png diff --git a/pyproject.toml b/pyproject.toml index caa789cd9..a7ac3d9fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,13 @@ dependencies = [ "jq", "xformers", "class_resolver", - "wandb" + "wandb", + "pillow", + "scipy", + "torchaudio", + "pillow", + "ffmpeg", + "soundfile" ] [project.optional-dependencies] diff --git a/src/modalities/__main__.py b/src/modalities/__main__.py index 8f7091302..5ebc4eb27 100644 --- a/src/modalities/__main__.py +++ b/src/modalities/__main__.py @@ -15,7 +15,7 @@ from modalities.batch import EvaluationResultBatch from modalities.checkpointing.checkpointing import Checkpointing, CheckpointingIF from modalities.checkpointing.checkpointing_factory import CheckpointingFactory -from modalities.config.config import AppConfig, ModalitiesSetupConfig, RunMode +from modalities.config.config import AppConfig, ModalitiesSetupConfig, PreparationAppConfig, RunMode from modalities.config.lookup_types import TokenizerTypes from modalities.dataloader.create_index import IndexGenerator from modalities.dataloader.create_packed_data import PackedDataGenerator @@ -104,6 +104,7 @@ def entry_point_create_memmap_index(src_path, index_path): @main.command(name="create_packed_data") @click.argument("src_path", type=Path) +@click.argument("config_file_path", type=Path) @click.option( "--dst_path", type=str, @@ -111,41 +112,23 @@ def entry_point_create_memmap_index(src_path, index_path): help="output path for packed data file. will use parent directory of src_path if none.", ) @click.option( - "--index_path", + "--idx_path", type=Path, default=None, help="input path for index. will search in parent directory of src_path if none.", ) -@click.option( - "--tokenizer_type", - type=TokenizerTypes, - show_default=True, - default=TokenizerTypes.GPT2TokenizerFast, - help="Specify which Tokenizer (inheriting from transformers.PretrainedTokenizers) should get used.", -) -@click.option( - "--tokenizer_file", - type=Path, - show_default=True, - default=Path(__file__).parents[2] / Path("data/tokenizer/tokenizer.json"), - help="path to tokenizer json", -) -@click.option( - "--jq_pattern", - type=str, - show_default=True, - default=".text", - help="jq pattern to extract the data from the json line.", -) -def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern): - # TODO: if we want to use alternative entrypoints together with the ResolverRegistry, - # we can currently not rely on the existing class resolver. - # This is based on its connection to the overall `AppConfig`. - # One would requires an object of it to instantiate the ResolverRegistry. - # This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing - # ResolverRegistry to work dynamically with any type-hinted config object from config.py. - tokenizer = tokenizer_type.value(tokenizer_file=str(tokenizer_file)) - generator = PackedDataGenerator(src_path, index_path=index_path, tokenizer=tokenizer, jq_pattern=jq_pattern) +def entry_point_create_packed_data(src_path, config_file_path, dst_path, idx_path): + config_dict = load_app_config_dict(config_file_path) + config = PreparationAppConfig.model_validate(config_dict) + # build codec components + resolvers = ResolverRegister() + codecs = {f.jq_pattern: resolvers.build_component_by_config(f.codec) for f in config.features} + # generate packed data + generator = PackedDataGenerator( + codecs, + src_path=src_path, + idx_path=idx_path, + ) generator.run(dst_path) diff --git a/src/modalities/batch.py b/src/modalities/batch.py index bc6c62c07..7cf3f34ee 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch): losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) + def __str__(self) -> str: eval_str = ( f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:" ) eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()]) eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()]) - eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]) + eval_str += "\n\nthroughput metrics: " + "\n\t".join( + [f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()] + ) eval_str += "\n===============================================" return eval_str diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 0e166242a..c1b096632 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -11,6 +11,7 @@ BatchSamplerTypes, CheckpointingExectionTypes, CheckpointingStrategyTypes, + CodecTypes, CollatorTypes, DataloaderTypes, DatasetTypes, @@ -51,6 +52,54 @@ class GPT2TokenizerFastConfig(BaseModel): config: GPT2TokenizerFastConfig +class CodecConfig(BaseModel): + class HfTokenizerCodecConfig(BaseModel): + tokenizer: TokenizerConfig + max_length: Optional[int] = None + add_eos_token: bool = True + + class PillowImageCodecConfig(BaseModel): + save_format: str = "png" + + class TorchaudioAudioCodecConfig(BaseModel): + target_sample_rate: int = 16_000 + n_mels: int = 80 + + type_hint: CodecTypes + config: Union[ + HfTokenizerCodecConfig, + PillowImageCodecConfig, + TorchaudioAudioCodecConfig, + ] = Field(union_mode="left_to_right") + + @model_validator(mode="before") + def _resolve_type(cls, data): + if isinstance(data, dict): + # resolve config type from type hint + type_hint = data["type_hint"] + CONFIG_RESOLVER = { + CodecTypes.HfTokenizerCodec.name: cls.HfTokenizerCodecConfig, + CodecTypes.PillowImageCodec.name: cls.PillowImageCodecConfig, + CodecTypes.TorchaudioAudioCodec.name: cls.TorchaudioAudioCodecConfig, + } + # create config object of correct type + config_type = CONFIG_RESOLVER.get(type_hint) + config = config_type(**data["config"]) + # return codec config + return {"type_hint": type_hint, "config": config} + + return data + + +class FeatureConfig(BaseModel): + codec: CodecConfig + jq_pattern: str + + +class PreparationAppConfig(BaseModel): + features: List[FeatureConfig] + + class DatasetConfig(BaseModel): class MemMapDatasetConfig(BaseModel): raw_data_path: FilePath @@ -284,6 +333,7 @@ class RunMode(Enum): FROM_SCRATCH = "FROM_SCRATCH" WARM_START = "WARM_START" + class ModalitiesSetupConfig(BaseModel): class WarmStartSettings(BaseModel): checkpoint_model_path: Path diff --git a/src/modalities/config/lookup_types.py b/src/modalities/config/lookup_types.py index 46147480f..2f94cd8ca 100644 --- a/src/modalities/config/lookup_types.py +++ b/src/modalities/config/lookup_types.py @@ -9,6 +9,7 @@ SaveEveryKStepsCheckpointingStrategy, SaveKMostRecentCheckpointsStrategy, ) +from modalities.dataloader.codecs import HfTokenizerCodec, PillowImageCodec, TorchaudioAudioCodec from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDatasetBuilder @@ -47,6 +48,12 @@ class TokenizerTypes(LookupEnum): GPT2TokenizerFast = GPT2TokenizerFast +class CodecTypes(LookupEnum): + HfTokenizerCodec = HfTokenizerCodec + PillowImageCodec = PillowImageCodec + TorchaudioAudioCodec = TorchaudioAudioCodec + + class DatasetTypes(LookupEnum): MemMapDataset = MemMapDataset PackedMemMapDatasetContinuous = PackedMemMapDatasetContinuous diff --git a/src/modalities/dataloader/codecs.py b/src/modalities/dataloader/codecs.py new file mode 100644 index 000000000..639115061 --- /dev/null +++ b/src/modalities/dataloader/codecs.py @@ -0,0 +1,196 @@ +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Generic, Optional, TypeVar + +import numpy as np +import torch +import torchaudio +from PIL import Image +from transformers import PreTrainedTokenizer + +T = TypeVar("T") + + +class Codec(ABC, Generic[T]): + @abstractmethod + def encode(self, obj: T) -> bytes: + pass + + @staticmethod + @abstractmethod + def decode(serialized_obj: bytes) -> T: + pass + + +class FixSizedCodec(Codec[T]): + """Base class for fix-sized Codecs + + Fix-sized codecs are special in that they encode a sequence of values where + each value is encoded by a fix number of bytes. The length of thegenerated + bytestring is an integer multiple of `num_bytes_per_value`. + """ + + @classmethod + @abstractmethod + def num_bytes_per_value(cls) -> int: + raise NotImplementedError + + +class HfTokenizerCodec(FixSizedCodec[str]): + TOKEN_SIZE_IN_BYTES = 4 + + @classmethod + def num_bytes_per_value(cls) -> int: + return cls.TOKEN_SIZE_IN_BYTES + + def __init__( + self, tokenizer: PreTrainedTokenizer, max_length: Optional[int] = None, add_eos_token: bool = True + ) -> None: + # instantiate + self.tokenizer = tokenizer + self.add_eos_token = add_eos_token + + if add_eos_token: + # get eos token in bytes to append to the end of each sequence + eos_token = self.tokenizer.convert_tokens_to_ids(self.tokenizer.eos_token) + self.eos_token = eos_token.to_bytes(type(self).TOKEN_SIZE_IN_BYTES, byteorder="big") + + self.tokenizer_kwargs = ( + {} if max_length is None else dict(max_length=max_length - int(add_eos_token), truncation=True) + ) + + def encode(self, text: str) -> bytes: + # tokenize text and convert the token ids to bytes + tokens = [ + t.to_bytes(type(self).TOKEN_SIZE_IN_BYTES, byteorder="big") + for t in self.tokenizer(text, **self.tokenizer_kwargs)["input_ids"] + ] + # + if len(tokens) == 0: + raise ValueError("Received empty sample") + # add special eos token + if self.add_eos_token: + tokens.append(self.eos_token) + + # join byte strings + return b"".join(tokens) + + @classmethod + def decode(cls, serialized_tokens: bytes) -> str: + return [ + int.from_bytes(serialized_tokens[i : i + cls.TOKEN_SIZE_IN_BYTES], byteorder="big") + for i in range(0, len(serialized_tokens), cls.TOKEN_SIZE_IN_BYTES) + ] + + +class PillowImageCodec(Codec[str]): + def __init__(self, save_format: str = "png") -> None: + self._format = save_format + + def encode(self, img_file_path: str) -> bytes: + buf = BytesIO() + # write image to buffer + with Image.open(img_file_path) as img: + img.save(buf, format=self._format) + # return buffer content + buf.seek(0) + return buf.read() + + @staticmethod + def decode(serialized_img: bytes) -> str: + return Image.open(BytesIO(serialized_img)) + + +class TorchaudioAudioCodec(Codec[str]): + N_FFT = 400 + HOP_LENGTH = 160 + + def __init__( + self, + target_sample_rate: int = 16_000, + n_mels: int = 80, + ) -> None: + self.target_sample_rate = target_sample_rate + self.extract_mel_spec = torchaudio.transforms.MelSpectrogram( + sample_rate=target_sample_rate, + n_mels=n_mels, + n_fft=type(self).N_FFT, + hop_length=type(self).HOP_LENGTH, + ) + + def load_audio( + self, + audio_file_path: str, + ) -> torch.Tensor: + audio, sample_rate = torchaudio.load( + audio_file_path, + ) + + return ( + audio.mean(dim=0), + sample_rate, + ) + + def extract_log_mel_spectrogram( + self, + audio: torch.Tensor, + ) -> torch.Tensor: + ############################################ + # Feature extraction is quite similar to how it is done + # for Radford, Alec, et al. "Robust speech recognition + # via large-scale weak supervision." 2023 AKA Whisper. + # Their code can be found here: + # https://github.com/openai/whisper/blob/main/whisper/audio.py + # MIT LICENSE: https://github.com/openai/whisper/blob/main/LICENSE + ############################################ + + mel_spec = self.extract_mel_spec(audio) + log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0) + log_mel_spec = (log_mel_spec + 4.0) / 4.0 + return log_mel_spec.transpose(1, 0) + + def resample( + self, + audio: torch.Tensor, + sample_rate: int, + ) -> torch.Tensor: + resampler = torchaudio.transforms.Resample( + sample_rate, + self.target_sample_rate, + dtype=audio.dtype, + ) + return resampler(audio) + + def encode( + self, + audio_file_path: str, + ) -> bytes: + audio, sample_rate = self.load_audio( + audio_file_path, + ) + + audio = ( + self.resample( + audio, + sample_rate, + ) + if sample_rate != self.target_sample_rate + else audio + ) + + log_mel_spec = self.extract_log_mel_spectrogram( + audio, + ).numpy() + + buf = BytesIO() + np.save(buf, log_mel_spec) + buf.seek(0) + + return buf.read() + + @staticmethod + def decode( + serialized_audio: bytes, + ) -> np.ndarray: + return np.load(BytesIO(serialized_audio)) diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 6e8d4d3cf..c79133504 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -1,56 +1,59 @@ import pickle import warnings from pathlib import Path -from typing import IO +from typing import IO, Dict import jq import numpy as np from tqdm import tqdm -from transformers import PreTrainedTokenizer +from modalities.dataloader.codecs import Codec from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader class PackedDataGenerator: - # amount of bytes to represent tokens as integers. - # If the vocabulary exceeds 2^(8*`size_in_bytes`), this requires adaptation. - TOKEN_SIZE_IN_BYTES = 4 + """ + + Format: HEAD DATA CODECS INDEX + HEAD: DATA_HEAD CODECS_HEAD + """ + # amount of bytes to represent number of all tokens in dataset. # If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation. # Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides - HEAD_SIZE_IN_BYTES = 8 - - def __init__( - self, - src_path: Path, - tokenizer: PreTrainedTokenizer, - index_path: Path = None, - jq_pattern: str = ".text", - max_number_of_tokens: int = None, - ): + DATA_HEAD_SIZE_IN_BYTES = 8 + CODECS_HEAD_SIZE_IN_BYTES = 8 + + def __init__(self, codecs: Dict[str, Codec], src_path: Path, idx_path: Path = None, max_num_of_bytes: int = None): """ Reads in a jsonl file and the corresponding index file and packs dataset file for LLM training. + :param codec: Codec object, which is used to encode the objects into bytes :param src_path: Path to a jsonl file, which holds text data :param index_path: Path to an index file, which indicates the start character position and length of samples given in `src_path`. If not defined, an index file next to `src_path` is picked, by replacing its suffix with ".idx". - :param tokenizer: PretrainedTokenizer object, which is used to pre-tokenize the provided data in `src_path`. - Tokenization is necessary to work on final lengths of token sequences. :param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed - :param max_number_of_tokens: Limit the total amount of tokens in the packed dataset. - If not specified, the whole data is packed into the dataset. """ + + jq_patterns, codecs = zip(*codecs.items()) + + self.codecs = codecs + self.jq_filters = [jq.compile(pattern) for pattern in jq_patterns] + self.src_path = src_path - self.tokenizer = tokenizer - self.jq_filter = jq.compile(jq_pattern) - self.max_tokens = max_number_of_tokens + self._reader = LargeFileLinesReader(src_path, index_path=idx_path) + + # keep track of file size + self._total_data_bytes = 0 + self._max_data_bytes = max_num_of_bytes - self._reader = LargeFileLinesReader(src_path, index_path=index_path) - self._total_num_of_tokens = 0 - self._curr_offset = self.HEAD_SIZE_IN_BYTES self._index_list = [] + @property + def _current_offset(self) -> int: + return self._total_data_bytes + type(self).DATA_HEAD_SIZE_IN_BYTES + type(self).CODECS_HEAD_SIZE_IN_BYTES + def _default_destination_path(self, destination_path: Path = None) -> Path: if destination_path is None: default_destination_path = Path(self.src_path.parent, f"{self.src_path.stem}.pbin") @@ -62,23 +65,26 @@ def _default_destination_path(self, destination_path: Path = None) -> Path: return Path(destination_path) def run(self, dst_path: Path = None): - assert self._total_num_of_tokens == 0, f"This {self.__name__} was already used and is exhausted. Use another!" + assert self._total_data_bytes == 0, f"This {self.__name__} was already used and is exhausted. Use another!" dst_path = self._default_destination_path(destination_path=dst_path) if dst_path.exists(): raise ValueError(f"file already exists at destination path '{dst_path}'.") - encoded_eos_token = self.tokenizer(self.tokenizer.eos_token)["input_ids"][0] - encoded_eos_token_as_bytes = encoded_eos_token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") with dst_path.open("wb") as f: - # allocate first self.header_size_in_bytes bytes for header (encodes length of data section) - # not possible to prepend header after determining size of data section - f.write((0).to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big")) + # store the type-hints to the codec types + # TODO: get the type hints from the enum in case they + # don't match the class name exactly + codecs_bytes = pickle.dumps([type(codec).__name__ for codec in self.codecs]) + + # allocate bytes for data header and write codecs header + f.write((0).to_bytes(type(self).DATA_HEAD_SIZE_IN_BYTES, byteorder="big")) + f.write(len(codecs_bytes).to_bytes(type(self).DATA_HEAD_SIZE_IN_BYTES, byteorder="big")) - # write data section (tokens) + # write data section for idx, line in tqdm(enumerate(self._reader)): try: - self._process_line(encoded_eos_token_as_bytes, f, line) + self._process_line(f, line) except ValueError: warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}") except StopIteration: @@ -86,36 +92,41 @@ def run(self, dst_path: Path = None): except Exception as exception: warnings.warn(f"could not process line: {exception=}") - # write index + # write codecs and index section to file + f.write(codecs_bytes) f.write(pickle.dumps(self._index_list)) self._update_data_length_in_pre_allocated_header(dst_path) def _update_data_length_in_pre_allocated_header(self, dst_path: Path): - start_of_index_in_bytes = self._index_list[-1][0] + self._index_list[-1][1] - length_of_byte_encoded_data_section = start_of_index_in_bytes - self.HEAD_SIZE_IN_BYTES - header_content = length_of_byte_encoded_data_section.to_bytes(self.HEAD_SIZE_IN_BYTES, byteorder="big") + header_content = self._total_data_bytes.to_bytes(type(self).DATA_HEAD_SIZE_IN_BYTES, byteorder="big") header_content = np.frombuffer(header_content, dtype="uint8") # write the header content to the packed dataset file - m = np.memmap(dst_path, mode="r+", offset=0, shape=(self.HEAD_SIZE_IN_BYTES,)) + m = np.memmap(dst_path, mode="r+", offset=0, shape=(type(self).DATA_HEAD_SIZE_IN_BYTES,)) m[:] = header_content[:] - def _process_line(self, eos_token_as_bytes: bytes, f: IO, line: str): - jq_retrieved_text = self.jq_filter.input_text(line).first() - tokens = self.tokenizer(jq_retrieved_text)["input_ids"] - if len(tokens) == 0: - raise ValueError("Received empty sample...") - token_idx = 0 - for token in tokens: - token_as_bytes = token.to_bytes(self.TOKEN_SIZE_IN_BYTES, byteorder="big") - f.write(token_as_bytes) - self._total_num_of_tokens += 1 - if self._total_num_of_tokens == self.max_tokens: - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES - self._index_list.append((self._curr_offset, segment_length)) - raise StopIteration - token_idx += 1 - f.write(eos_token_as_bytes) - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes - self._index_list.append((self._curr_offset, segment_length)) - self._curr_offset += segment_length + def _process_line(self, f: IO, line: str): + sizes = [None] * len(self.codecs) + + for i, (codec, jq_filter) in enumerate( + zip(self.codecs, self.jq_filters), + ): + # get object to encode and encode using codec + jq_retrieved_text = jq_filter.input_text(line).first() + bytestring = codec.encode(jq_retrieved_text) + num_bytes = len(bytestring) + + if num_bytes == 0: + raise ValueError("Detected Empty sample") + + # write bytestring to file and update size array + f.write(bytestring) + sizes[i] = num_bytes + + # update index and total number of bytes written + self._index_list.append([self._current_offset] + sizes) + self._total_data_bytes += sum(sizes) + + # exceeds size limit + if (self._max_data_bytes is not None) and (self._total_data_bytes >= self._max_data_bytes): + raise StopIteration diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 8e7a4c3bb..1633df817 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -11,14 +11,28 @@ from tqdm import tqdm from transformers import BatchEncoding, PreTrainedTokenizer -from ..dataloader.large_file_lines_reader import LargeFileLinesReader +from .codecs import FixSizedCodec +from .create_packed_data import PackedDataGenerator +from .large_file_lines_reader import LargeFileLinesReader + + +class SampleKeysMismatchException(Exception): + pass class Dataset(TorchdataSet): - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): + def __init__(self, raw_data_path: Path, sample_keys: list[str]): self.raw_data_path = raw_data_path - self.block_size = block_size - self.sample_key = sample_key + self.sample_keys = sample_keys + # must provide a sample key for each codec + if len(self.sample_keys) != self.num_elements_per_item: + raise SampleKeysMismatchException( + "Expected %i sample keys, got %s" % (self.num_elements_per_item, self.sample_keys) + ) + + @property + def num_elements_per_item(self) -> int: + raise NotImplementedError def _check_if_inbounds(self, idx: int): if not 0 <= idx < len(self): @@ -50,7 +64,8 @@ def __init__( TODO: If this setting should support multi-modal features using separately encoded inputs, this needs to get replaced with a list of sample keys! """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) + super().__init__(raw_data_path=raw_data_path, sample_key=sample_key) + self.block_size = block_size self.reader = LargeFileLinesReader(self.raw_data_path, index_path=index_path) self.jq_filter = jq.compile(jq_pattern) @@ -70,24 +85,48 @@ def __getitem__(self, idx: int) -> BatchEncoding: class PackedMemMapDatasetBase(Dataset): - INT_SIZE_IN_BYTES = 4 - HEADER_SIZE_IN_BYTES = 8 + def _read_bytes(self, offset: int, size: int) -> bytes: + return np.memmap( + self.raw_data_path, + mode="r", + offset=offset, + shape=(size,), + ).tobytes() - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): + @property + def num_elements_per_item(self) -> int: + return len(self._codec_types) + + def __init__(self, raw_data_path: Path, sample_keys: list[str]): """ Base class for packed memmapped datasets. The underlying dataset file has the structure: - | header | data | index | - The header contains information about the length of the subsequent data sequence. The index contains - the tuple information (start, end) in terms of byte positions. + | data_header | codecs_header | data | codecs | index | + + The data and codecs headers contains information about the length of the data and codecs sequences. + + The codecs sequence contains the codec type hints required to decode the bytes to the expected + data type. Specifically it is an encoded list of codec type names: + + (codec_1, codec_2, ...) + + The index stores byte positions of the dataset items in the following format: + + (offset, size_1, size_2, ...) + + The start and end tuple of the j-th value are computed by: + + (offset + sum_{i int: + # read bytes from file + return int.from_bytes(self._read_bytes(offset, size), byteorder="big") + + # read headers + self.data_len = read_header(offset=0, size=PackedDataGenerator.DATA_HEAD_SIZE_IN_BYTES) + self.codecs_len = read_header( + offset=PackedDataGenerator.DATA_HEAD_SIZE_IN_BYTES, size=PackedDataGenerator.CODECS_HEAD_SIZE_IN_BYTES + ) + + # compute offsets to index raw data file + self.data_offset = PackedDataGenerator.DATA_HEAD_SIZE_IN_BYTES + PackedDataGenerator.CODECS_HEAD_SIZE_IN_BYTES + self.codecs_offset = self.data_offset + self.data_len + self.index_offset = self.codecs_offset + self.codecs_len + + # read codecs + self._codec_type_hints = self._read_bytes(offset=self.codecs_offset, size=self.codecs_len) + self._codec_type_hints = pickle.loads(self._codec_type_hints) + # needs to be here to avoid circular import + # TODO: find a better way to avoid the circular import + from ..config.lookup_types import CodecTypes + + # resolve codec types + self._codec_types = [getattr(CodecTypes, codec_type_hint).value for codec_type_hint in self._codec_type_hints] # get index - self.index_base = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + self.data_len, - shape=(self.total_bytes - self.data_len - self.HEADER_SIZE_IN_BYTES,), - ).view(f"S{self.total_bytes-self.data_len-self.HEADER_SIZE_IN_BYTES}") - self.index_base = pickle.loads(self.index_base) + self._index_base = self._read_bytes(offset=self.index_offset, size=self.total_bytes - self.index_offset) + self._index_base = pickle.loads(self._index_base) + assert all(len(idx) == len(self._codec_types) + 1 for idx in self._index_base) + + # initialize after codec types are defined because + # num_elements_per_item depends on it + super().__init__(raw_data_path=raw_data_path, sample_keys=sample_keys) + + +class PackedMemMapDataset(PackedMemMapDatasetBase): + """Packed Memory Map Dataset""" + + def __len__(self) -> int: + return len(self._index_base) + + def __getitem__(self, idx: int) -> BatchEncoding: + # get index values + self._check_if_inbounds(idx) + idx = self._index_base[idx] + + enc = {} + offset = idx[0] + for key, size, codec_type in zip(self.sample_keys, idx[1:], self._codec_types): + # decode item + bytestring = self._read_bytes(offset, size) + enc[key] = codec_type.decode(bytestring) + # update offset + offset += size + + return BatchEncoding(data=enc) class PackedMemMapDatasetContinuous(PackedMemMapDatasetBase): - def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): + def __init__(self, raw_data_path: Path, sample_key: str, block_size: int): """ PackedMemMapDatasetContinuous iterates through the data in block_size sized chunks, irrespective of the samples' start and end position, as defined in the index. @@ -130,28 +208,50 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str): Use `modalities create_packed_data` to create one based on a jsonl-file. :param block_size: alias for max sequence length. The amount of tokens the model can handle. :param sample_key: model-specific parameter to indicate where in the BatchEncoding the input_token_ids are. - TODO: If this setting should support multi-modal features using separately encoded inputs, - this needs to get replaced with a list of sample keys! """ - super().__init__(raw_data_path=raw_data_path, block_size=block_size, sample_key=sample_key) + try: + super().__init__(raw_data_path=raw_data_path, sample_keys=[sample_key]) + except SampleKeysMismatchException as e: + raise ValueError( + "Can only read continuously from packed data files of single-element dataset, i.e." + "datasets with a single item per line. The specified dataset has %i elements per item." + % self.num_elements_per_item + ) from e + + # check if codec is supported + if not issubclass(self.codec_type, FixSizedCodec): + raise ValueError("Can only read continuously from fix-sized codecs, got %s." % self.codec_type) + self.block_size = block_size # get number of total tokens in file - total_tokens = self.data_len // self.INT_SIZE_IN_BYTES - self._num_samples = total_tokens // self.block_size + total_values = self.data_len // self._num_bytes_per_value + self._num_samples = total_values // self.block_size + + @property + def sample_key(self) -> str: + return self.sample_keys[0] + + @property + def codec_type(self) -> FixSizedCodec: + return self._codec_types[0] + + @property + def _num_bytes_per_value(self) -> int: + return self.codec_type.num_bytes_per_value() def __len__(self) -> int: return self._num_samples def __getitem__(self, idx: int) -> BatchEncoding: self._check_if_inbounds(idx) - tokens_as_byte_strings = np.memmap( - self.raw_data_path, - mode="r", - offset=self.HEADER_SIZE_IN_BYTES + idx * self.INT_SIZE_IN_BYTES * self.block_size, - shape=(self.INT_SIZE_IN_BYTES * self.block_size,), - ).view(f"S{self.INT_SIZE_IN_BYTES}") - tokens = [int.from_bytes(token, byteorder="big") for token in tokens_as_byte_strings] - return BatchEncoding(data={self.sample_key: tokens}) + # read block-sized chunk of bytes + byte_string = self._read_bytes( + offset=self.data_offset + idx * self.block_size * self._num_bytes_per_value, + size=self.block_size * self._num_bytes_per_value, + ) + # decode and pack into batch encoding + values = self.codec_type.decode(byte_string) + return BatchEncoding(data={self.sample_key: values}) class PackedMemMapDatasetMegatron(PackedMemMapDatasetBase): diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index c5e5e3a22..07e344d52 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -15,4 +15,4 @@ class RunningEnvError(Exception): class TimeRecorderStateError(Exception): - pass \ No newline at end of file + pass diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index d5f4aec23..7b38e58ff 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from collections import defaultdict +from typing import Dict, List + from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF -from typing import Dict, List class MessageBrokerIF(ABC): """Interface for message broker objects.""" + @abstractmethod def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): raise NotImplementedError @@ -18,6 +20,7 @@ def distribute_message(self, message: Message): class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" + def __init__(self) -> None: self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) diff --git a/src/modalities/logging_broker/publisher.py b/src/modalities/logging_broker/publisher.py index 34ff834ba..28cc27de2 100644 --- a/src/modalities/logging_broker/publisher.py +++ b/src/modalities/logging_broker/publisher.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar -from modalities.logging_broker.message_broker import Message, MessageBroker +from modalities.logging_broker.message_broker import Message, MessageBroker from modalities.logging_broker.messages import MessageTypes T = TypeVar("T") @@ -15,6 +15,7 @@ def publish_message(self, payload: T, message_type: MessageTypes): class MessagePublisher(MessagePublisherIF[T]): """The MessagePublisher sends messages through a message broker.""" + def __init__( self, message_broker: MessageBroker, diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 7e965b75e..6b4e5c2d4 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar + from modalities.logging_broker.messages import Message T = TypeVar("T") @@ -11,4 +12,3 @@ class MessageSubscriberIF(ABC, Generic[T]): @abstractmethod def consume_message(self, message: Message[T]): raise NotImplementedError - diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index b29657258..92fe0fc1b 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -2,14 +2,15 @@ from typing import Optional import rich +import wandb from rich.console import Group from rich.panel import Panel -import wandb from modalities.batch import EvaluationResultBatch +from modalities.config.config import AppConfig, WandbConfig from modalities.logging_broker.messages import Message from modalities.logging_broker.subscriber import MessageSubscriberIF -from modalities.config.config import AppConfig, WandbConfig + class DummyResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def consume_message(self, message: Message[EvaluationResultBatch]): @@ -49,8 +50,15 @@ def consume_message(self, message: Message[EvaluationResultBatch]): class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): """A subscriber object for the WandBEvaluationResult observable.""" - def __init__(self, num_ranks: int, project: str, experiment_id: str, mode: WandbConfig.WandbMode, dir: Path, - experiment_config: Optional[AppConfig] = None) -> None: + def __init__( + self, + num_ranks: int, + project: str, + experiment_id: str, + mode: WandbConfig.WandbMode, + dir: Path, + experiment_config: Optional[AppConfig] = None, + ) -> None: super().__init__() self.num_ranks = num_ranks @@ -82,6 +90,4 @@ def consume_message(self, message: Message[EvaluationResultBatch]): f"{eval_result.dataloader_tag} {metric_key}": metric_values for metric_key, metric_values in eval_result.throughput_metrics.items() } - wandb.log( - data=throughput_metrics, step=eval_result.global_train_sample_id + 1 - ) + wandb.log(data=throughput_metrics, step=eval_result.global_train_sample_id + 1) diff --git a/src/modalities/models/gpt2/preprocess_dataset.py b/src/modalities/models/gpt2/preprocess_dataset.py index 99afb069e..fe5b223eb 100644 --- a/src/modalities/models/gpt2/preprocess_dataset.py +++ b/src/modalities/models/gpt2/preprocess_dataset.py @@ -1,21 +1,25 @@ +import os from itertools import chain -from datasets import load_dataset -from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config + from accelerate import Accelerator -import os +from datasets import load_dataset +from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast def main(): - def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. - # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + # We drop the small remainder, and if the total_length < block_size we exclude + # this batch and return an empty dict. We could add padding if the model + # supported it instead of this drop, you can customize this part to your needs. total_length = (total_length // block_size) * block_size # Split by chunks of max_len. - result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } result["labels"] = result["input_ids"].copy() return result diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index d00a8043d..511419b9b 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,9 +1,11 @@ from abc import abstractmethod from typing import Dict -from modalities.batch import DatasetBatch, InferenceResultBatch + import torch import torch.nn as nn +from modalities.batch import DatasetBatch, InferenceResultBatch + class NNModel(nn.Module): def __init__(self, seed: int = None): diff --git a/src/modalities/resolver_register.py b/src/modalities/resolver_register.py index 9f571efe6..b2417ef46 100644 --- a/src/modalities/resolver_register.py +++ b/src/modalities/resolver_register.py @@ -8,19 +8,22 @@ from transformers import PreTrainedTokenizer from modalities.checkpointing.checkpointing import CheckpointingExecutionIF, CheckpointingStrategyIF -from modalities.config.config import AppConfig, OptimizerTypes, SchedulerTypes +from modalities.config.config import OptimizerTypes, SchedulerTypes from modalities.config.lookup_types import ( BatchSamplerTypes, CheckpointingExectionTypes, CheckpointingStrategyTypes, + CodecTypes, CollatorTypes, DataloaderTypes, DatasetTypes, + LookupEnum, LossTypes, ModelTypes, SamplerTypes, TokenizerTypes, ) +from modalities.dataloader.codecs import Codec from modalities.dataloader.dataloader import LLMDataLoader from modalities.dataloader.dataset import Dataset from modalities.loss_functions import CLMCrossEntropyLoss, Loss @@ -29,119 +32,104 @@ from modalities.running_env.fsdp.fsdp_running_env import FSDPRunningEnv, RunningEnv, RunningEnvTypes +# TODO: this should be a singleton class ResolverRegister: - def __init__(self, config: AppConfig) -> None: - self._resolver_register: Dict[str, ClassResolver] = self._create_resolver_register(config=config) + # TODO: args and kwargs only to be backwards compatible + # older versions required the appconfig as argument + def __init__(self, *args, **kwargs): + self._resolver_register = self._build_resolver_register() - def build_component_by_config(self, config: BaseModel, extra_kwargs: Dict = {}) -> Any: + def build_component_by_key_query(self, register_key: str, type_hint: str, extra_kwargs: Dict = {}) -> Any: + raise NotImplementedError + + def build_component_by_config(self, config: BaseModel, extra_kwargs: Dict[str, Any] = {}) -> Any: assert ( "type_hint" in config.model_fields.keys() ), f"Field 'type_hint' missing but needed for initalisation in {config}" - kwargs = {key: getattr(config.config, key) for key in config.config.model_dump().keys()} - kwargs.update(extra_kwargs) # allow override via extra_kwargs, to add nested objects + assert ( + "config" in config.model_fields.keys() + ), f"Field 'config' missing but needed for initalisation in {config}" + + kwargs = extra_kwargs.copy() + + for key in config.config.model_fields.keys(): + # get the value corresponding to the key + # prefer the extra keyword arguments when both specified + val = getattr(config.config, key) + val = kwargs.get(key, val) + + # handle nested components + if isinstance(val, BaseModel) and "type_hint" in val.model_fields and "config" in val.model_fields: + kwargs[key] = self.build_component_by_config(val) + + else: + kwargs[key] = val + return self._build_component( - register_key=config.type_hint, + register_key=type(config.type_hint), register_query=config.type_hint.name, extra_kwargs=kwargs, ) - def build_component_by_key_query(self, register_key: str, type_hint: str, extra_kwargs: Dict = {}) -> Any: - return self._build_component(register_key=register_key, register_query=type_hint, extra_kwargs=extra_kwargs) - - def _build_component(self, register_key: str, register_query: str, extra_kwargs: Dict = {}): + def _build_component(self, register_key: LookupEnum, register_query: str, extra_kwargs: Dict[str, Any] = {}): + assert register_key in self._resolver_register return self._resolver_register[register_key].make( query=register_query, pos_kwargs=extra_kwargs, ) - def _find_values_with_key_in_nested_structure(self, nested_structure: Dict, key: str) -> List[Any]: - found_values = [] - for k, v in nested_structure.items(): - if k == key: - found_values.append(v) - elif isinstance(v, dict): - found_values.extend(self._find_values_with_key_in_nested_structure(v, key)) - return found_values - - def _create_resolver_register(self, config: AppConfig) -> Dict[str, ClassResolver]: - set(self._find_values_with_key_in_nested_structure(nested_structure=config.model_dump(), key="type_hint")) - resolvers = { - config.running_env.type_hint: ClassResolver( + def _build_resolver_register(self) -> List[LookupEnum]: + return { + RunningEnvTypes: ClassResolver( [t.value for t in RunningEnvTypes], base=RunningEnv, default=FSDPRunningEnv, ), - config.model.type_hint: ClassResolver( + ModelTypes: ClassResolver( [t.value for t in ModelTypes], base=NNModel, default=GPT2LLM, ), - config.optimizer.type_hint: ClassResolver( + OptimizerTypes: ClassResolver( [t.value for t in OptimizerTypes], base=optim.Optimizer, default=optim.AdamW, ), - config.scheduler.type_hint: ClassResolver( + SchedulerTypes: ClassResolver( [t.value for t in SchedulerTypes], base=optim.lr_scheduler.LRScheduler, default=optim.lr_scheduler.StepLR, ), - config.loss.type_hint: ClassResolver( + LossTypes: ClassResolver( [t.value for t in LossTypes], base=Loss, default=CLMCrossEntropyLoss, ), - **{ - sampler_type: ClassResolver( - classes=[t.value for t in SamplerTypes], - base=Sampler, - default=DistributedSampler, - ) - for sampler_type in SamplerTypes - }, - **{ - batch_sampler_type: ClassResolver( - classes=[t.value for t in BatchSamplerTypes], - base=BatchSampler, - default=BatchSampler, - ) - for batch_sampler_type in BatchSamplerTypes - }, - **{ - dataloader_type: ClassResolver( - [t.value for t in DataloaderTypes], - base=DataLoader, - default=LLMDataLoader, - ) - for dataloader_type in DataloaderTypes - }, - **{ - dataset_type: ClassResolver([t.value for t in DatasetTypes], base=Dataset) - for dataset_type in DatasetTypes - }, - **{ - collator_type: ClassResolver([t.value for t in CollatorTypes], base=GPT2LLMCollator) - for collator_type in CollatorTypes - }, - **{ - tokenizer_type: ClassResolver([t.value for t in TokenizerTypes], base=PreTrainedTokenizer) - for tokenizer_type in TokenizerTypes - }, - **{ - checkpointing_strategy_type: ClassResolver( - [t.value for t in CheckpointingStrategyTypes], base=CheckpointingStrategyIF - ) - for checkpointing_strategy_type in CheckpointingStrategyTypes - }, - **{ - checkpointing_execution_type: ClassResolver( - [t.value for t in CheckpointingExectionTypes], base=CheckpointingExecutionIF - ) - for checkpointing_execution_type in CheckpointingExectionTypes - }, + SamplerTypes: ClassResolver( + classes=[t.value for t in SamplerTypes], + base=Sampler, + default=DistributedSampler, + ), + BatchSamplerTypes: ClassResolver( + classes=[t.value for t in BatchSamplerTypes], + base=BatchSampler, + default=BatchSampler, + ), + DataloaderTypes: ClassResolver( + [t.value for t in DataloaderTypes], + base=DataLoader, + default=LLMDataLoader, + ), + DatasetTypes: ClassResolver([t.value for t in DatasetTypes], base=Dataset), + CollatorTypes: ClassResolver([t.value for t in CollatorTypes], base=GPT2LLMCollator), + TokenizerTypes: ClassResolver([t.value for t in TokenizerTypes], base=PreTrainedTokenizer), + CodecTypes: ClassResolver([t.value for t in CodecTypes], base=Codec), + CheckpointingStrategyTypes: ClassResolver( + [t.value for t in CheckpointingStrategyTypes], base=CheckpointingStrategyIF + ), + # TODO: fix type in execution + CheckpointingExectionTypes: ClassResolver( + [t.value for t in CheckpointingExectionTypes], base=CheckpointingExecutionIF + ), } - return resolvers - - def add_resolver(self, resolver_key: str, resolver: ClassResolver): - self._resolver_register[resolver_key] = resolver diff --git a/src/modalities/test.py b/src/modalities/test.py index ea16a0917..f81c3630c 100644 --- a/src/modalities/test.py +++ b/src/modalities/test.py @@ -3,7 +3,6 @@ from rich.progress import Progress with Progress() as progress: - task1 = progress.add_task("[red]Downloading...", total=1000) task2 = progress.add_task("[green]Processing...", total=1000) task3 = progress.add_task("[cyan]Cooking...", total=1000) @@ -12,4 +11,4 @@ progress.update(task1, advance=0.5) progress.update(task2, advance=0.3) progress.update(task3, advance=0.9) - time.sleep(0.02) \ No newline at end of file + time.sleep(0.02) diff --git a/tests/conftest.py b/tests/conftest.py index f94133cea..7a1b263b6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,15 @@ import dataclasses +import json import os import pickle from pathlib import Path from unittest.mock import MagicMock +import numpy as np import pytest import torch +import torchaudio +from PIL import Image from torch.optim import Optimizer from torch.utils.data.sampler import BatchSampler, SequentialSampler from transformers import GPT2TokenizerFast @@ -26,16 +30,31 @@ _ROOT_DIR = Path(__file__).parents[1] +@dataclasses.dataclass +class DataPathCollection: + raw_data_path: Path + index_path: Path + + @pytest.fixture def dummy_packed_data_path(tmpdir) -> Path: data = b"" - header_size_in_bytes = 8 + data_header_size_in_bytes = 8 + codecs_header_size_in_bytes = 8 int_size_in_bytes = 4 + # data and codecs tokens = list(range(20)) - data += (len(tokens) * int_size_in_bytes).to_bytes(header_size_in_bytes, byteorder="big") + codecs_bytes = pickle.dumps(["HfTokenizerCodec"]) + # headers + data += (len(tokens) * int_size_in_bytes).to_bytes(data_header_size_in_bytes, byteorder="big") + data += len(codecs_bytes).to_bytes(codecs_header_size_in_bytes, byteorder="big") + # data and codecs data += b"".join([t.to_bytes(int_size_in_bytes, byteorder="big") for t in tokens]) - index = [(4, 24), (28, 40), (68, 12), (80, 4)] # [(index,len), ...] -> in 4 bytes #lengths: 6,10,3,1 + data += codecs_bytes + # index + index = [(16, 24), (40, 28), (68, 12), (80, 16)] # [(index,len), ...] -> in 4 bytes #lengths: 6,10,3,1 data += pickle.dumps(index) + # write to file dummy_packed_data_path = Path(tmpdir, "dummy.pbin") dummy_packed_data_path.write_bytes(data) return dummy_packed_data_path @@ -52,12 +71,6 @@ def dummy_config(monkeypatch) -> AppConfig: return app_config -@dataclasses.dataclass -class DataPathCollection: - raw_data_path: Path - index_path: Path - - @pytest.fixture def dummy_data_path(tmpdir) -> DataPathCollection: source_raw_dummy_data_path = _ROOT_DIR / Path("./data/lorem_ipsum.jsonl") @@ -68,6 +81,58 @@ def dummy_data_path(tmpdir) -> DataPathCollection: return DataPathCollection(raw_data_path=dummy_data_path, index_path=index_path) +@pytest.fixture +def indexed_multimodal_dummy_data_path(tmpdir) -> DataPathCollection: + base_path = Path(tmpdir, "image_data") + img_base_path = Path(base_path, "images") + audio_base_path = Path(base_path, "audios") + + base_path.mkdir(parents=True, exist_ok=True) + img_base_path.mkdir(parents=True, exist_ok=True) + audio_base_path.mkdir(parents=True, exist_ok=True) + + data_path = Path(base_path, "data.jsonl") + index_path = Path(base_path, "data.idx") + img_paths = [Path(img_base_path, "img_%i.png" % i) for i in range(15)] + audio_paths = [Path(audio_base_path, "audio_%i.wav" % i) for i in range(15)] + + # create random images and save them into the temp directory + for img_path in img_paths: + im = np.random.rand(100, 100, 3) * 255 + im = Image.fromarray(im.astype("uint8")).convert("RGB") + im.save(img_path, "PNG") + + # create random spectrograms and save them into the temp directory + NUM_CHANNELS = 1 + SAMPLING_RATE = 16000 + AUDIO_DUR_SECS = 5 + + for audio_path in audio_paths: + audio = torch.randn(NUM_CHANNELS, SAMPLING_RATE * AUDIO_DUR_SECS) + torchaudio.save(audio_path, audio, SAMPLING_RATE) + + # create the jsonl file + with data_path.open("w+") as f: + for img_path in img_paths: + f.write( + json.dumps( + { + "img_path": img_path.absolute().as_posix(), + "audio_path": audio_path.absolute().as_posix(), + "text": ( + f"This item refers to the image stored at {str(img_path)} and " + f"the spectrogram stored at {str(audio_path)}" + ), + } + ) + + "\n" + ) + # create the index file to the jsonl file + IndexGenerator(data_path).create_index(index_path) + + return DataPathCollection(raw_data_path=data_path, index_path=index_path) + + @pytest.fixture def indexed_dummy_data_path(dummy_data_path) -> DataPathCollection: index_generator = IndexGenerator(dummy_data_path.raw_data_path) diff --git a/tests/dataloader/test_packed_dataset.py b/tests/dataloader/test_packed_dataset.py index 64df0e9c9..22cc94942 100644 --- a/tests/dataloader/test_packed_dataset.py +++ b/tests/dataloader/test_packed_dataset.py @@ -1,15 +1,35 @@ +import json + +import numpy.testing import pytest +from PIL import Image +from modalities.dataloader.codecs import HfTokenizerCodec, PillowImageCodec, TorchaudioAudioCodec from modalities.dataloader.create_packed_data import PackedDataGenerator -from modalities.dataloader.dataset import PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron +from modalities.dataloader.dataset import ( + PackedMemMapDataset, + PackedMemMapDatasetContinuous, + PackedMemMapDatasetMegatron, +) +@pytest.mark.skip(reason="New packed data format not implemented for megatron dataset") @pytest.mark.parametrize("block_size, expected_length", [(1, 4), (2, 3), (3, 3), (10, 2), (6, 2), (20, 1), (25, 0)]) def test_packed_megatron_dataset_loading(dummy_packed_data_path, block_size, expected_length): ds = PackedMemMapDatasetMegatron(dummy_packed_data_path, block_size, sample_key="input_ids") assert len(ds) == expected_length +def test_packed_dataset_loading(dummy_packed_data_path): + ds = PackedMemMapDataset(dummy_packed_data_path, sample_keys=["input_ids"]) + + assert len(ds) == 4 + assert ds[0]["input_ids"] == [0, 1, 2, 3, 4, 5] + assert ds[1]["input_ids"] == [6, 7, 8, 9, 10, 11, 12] + assert ds[2]["input_ids"] == [13, 14, 15] + assert ds[3]["input_ids"] == [16, 17, 18, 19] + + @pytest.mark.parametrize( "block_size, expected_length, expected_output", [ @@ -23,7 +43,7 @@ def test_packed_megatron_dataset_loading(dummy_packed_data_path, block_size, exp ], ) def test_packed_continuous_dataset_loading(dummy_packed_data_path, block_size, expected_length, expected_output): - ds = PackedMemMapDatasetContinuous(dummy_packed_data_path, block_size, sample_key="input_ids") + ds = PackedMemMapDatasetContinuous(dummy_packed_data_path, sample_key="input_ids", block_size=block_size) assert len(ds) == expected_length retrieved_input_ids = [list(packed_samples["input_ids"]) for packed_samples in ds] assert retrieved_input_ids == expected_output @@ -39,13 +59,23 @@ def test_packed_continuous_dataset_missing_file(dummy_packed_data_path): def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_of_tokens, expected_index_size): block_size = 5 packed_generator = PackedDataGenerator( - src_path=indexed_dummy_data_path.raw_data_path, tokenizer=gpt2_tokenizer, max_number_of_tokens=max_num_of_tokens + src_path=indexed_dummy_data_path.raw_data_path, + codecs={ + ".text": HfTokenizerCodec( + tokenizer=gpt2_tokenizer, + ) + }, + max_num_of_bytes=( + (HfTokenizerCodec.TOKEN_SIZE_IN_BYTES * max_num_of_tokens) if max_num_of_tokens is not None else None + ), ) default_packed_dataset_path = packed_generator._default_destination_path() assert not default_packed_dataset_path.is_file() packed_generator.run() packed_dataset = PackedMemMapDatasetContinuous( - default_packed_dataset_path, block_size=block_size, sample_key="input_ids" + default_packed_dataset_path, + sample_key="input_ids", + block_size=block_size, ) start_of_jsonl_content = "0 Lorem ipsum dolor sit amet, consetetur sadipscing elitr, sed diam nonumy eirmod tempor" @@ -53,8 +83,105 @@ def test_create_packed_dataset(indexed_dummy_data_path, gpt2_tokenizer, max_num_ packed_dataset_iterator = iter(packed_dataset) assert tokenized_start_of_jsonl_content[:block_size] == next(packed_dataset_iterator)["input_ids"] assert tokenized_start_of_jsonl_content[block_size : 2 * block_size] == next(packed_dataset_iterator)["input_ids"] - assert len(packed_dataset.index_base) == expected_index_size + assert len(packed_dataset._index_base) == expected_index_size # check validity of index section in packed dataset - for idx, (offset, entry_length) in enumerate(packed_dataset.index_base[:-1]): - assert offset + entry_length == packed_dataset.index_base[idx + 1][0] + for idx, (offset, entry_length) in enumerate(packed_dataset._index_base[:-1]): + assert offset + entry_length == packed_dataset._index_base[idx + 1][0] + + +def test_packed_image_dataset(indexed_multimodal_dummy_data_path): + # create packed data file + packed_generator = PackedDataGenerator( + src_path=indexed_multimodal_dummy_data_path.raw_data_path, + idx_path=indexed_multimodal_dummy_data_path.index_path, + codecs={".img_path": PillowImageCodec()}, + ) + # get destination path + default_packed_dataset_path = packed_generator._default_destination_path() + assert not default_packed_dataset_path.is_file() + # create packed dataset file + packed_generator.run() + + # read dataset + ds = PackedMemMapDataset( + default_packed_dataset_path, + sample_keys=["img"], + ) + # read the jsonl to get the source image paths + with indexed_multimodal_dummy_data_path.raw_data_path.open("r") as f: + src_data = list(map(json.loads, f.read().strip().split("\n"))) + # compare source image with dataset content + for src, item in zip(src_data, ds): + with Image.open(src["img_path"]) as src_img: + numpy.testing.assert_allclose(src_img, item["img"]) + + +def test_packed_audio_dataset(indexed_multimodal_dummy_data_path): + # create packed data file + packed_generator = PackedDataGenerator( + src_path=indexed_multimodal_dummy_data_path.raw_data_path, + idx_path=indexed_multimodal_dummy_data_path.index_path, + codecs={".audio_path": TorchaudioAudioCodec()}, + ) + # get destination path + default_packed_dataset_path = packed_generator._default_destination_path() + assert not default_packed_dataset_path.is_file() + # create packed dataset file + packed_generator.run() + + # read dataset + ds = PackedMemMapDataset( + default_packed_dataset_path, + sample_keys=["feat"], + ) + # read the jsonl to get the source feature paths + with indexed_multimodal_dummy_data_path.raw_data_path.open("r") as f: + src_data = list(map(json.loads, f.read().strip().split("\n"))) + + # compare source features with dataset content + codec = TorchaudioAudioCodec() + for src, item in zip(src_data, ds, strict=True): + audio, sample_rate = codec.load_audio(src["audio_path"]) + audio = codec.resample(audio, sample_rate) + log_mel_spec = codec.extract_log_mel_spectrogram(audio) + numpy.testing.assert_allclose(log_mel_spec, item["feat"]) + + +def test_packed_multimodal_dataset(indexed_multimodal_dummy_data_path, gpt2_tokenizer): + # create packed data file + packed_generator = PackedDataGenerator( + src_path=indexed_multimodal_dummy_data_path.raw_data_path, + idx_path=indexed_multimodal_dummy_data_path.index_path, + codecs={ + ".img_path": PillowImageCodec(), + ".text": HfTokenizerCodec(tokenizer=gpt2_tokenizer, add_eos_token=False), + ".audio_path": TorchaudioAudioCodec(), + }, + ) + # get destination path + default_packed_dataset_path = packed_generator._default_destination_path() + assert not default_packed_dataset_path.is_file() + # create packed dataset file + packed_generator.run() + + # read dataset + ds = PackedMemMapDataset( + default_packed_dataset_path, + sample_keys=["img", "input_ids", "audio_feat"], + ) + audio_codec = TorchaudioAudioCodec() + + # read the jsonl to get the source values + with indexed_multimodal_dummy_data_path.raw_data_path.open("r") as f: + src_data = list(map(json.loads, f.read().strip().split("\n"))) + # compare source with dataset content + for src, item in zip(src_data, ds): + with Image.open(src["img_path"]) as src_img: + numpy.testing.assert_allclose(src_img, item["img"]) + assert gpt2_tokenizer(src["text"])["input_ids"] == item["input_ids"] + + audio, sample_rate = audio_codec.load_audio(src["audio_path"]) + audio = audio_codec.resample(audio, sample_rate) + log_mel_spec = audio_codec.extract_log_mel_spectrogram(audio) + numpy.testing.assert_allclose(log_mel_spec, item["audio_feat"])