Skip to content

Memap dataset for multimodal data #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config_files/config_example_audio_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
features:
- jq_pattern: .transcript
codec:
type_hint: HfTokenizerCodec
config:
add_eos_token: true
tokenizer:
type_hint: GPT2TokenizerFast
config:
tokenizer_file: ./data/tokenizer/tokenizer.json
- jq_pattern: .audio_path
codec:
type_hint: TorchaudioAudioCodec
config:
target_sample_rate: 16_000
n_mels: 80
15 changes: 15 additions & 0 deletions config_files/data_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
features:
- jq_pattern: .cls
codec:
type_hint: HfTokenizerCodec
config:
add_eos_token: true
tokenizer:
type_hint: GPT2TokenizerFast
config:
tokenizer_file: ./data/tokenizer/tokenizer.json
- jq_pattern: .img_path
codec:
type_hint: PillowImageCodec
config:
save_format: png
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ dependencies = [
"jq",
"xformers",
"class_resolver",
"wandb"
"wandb",
"pillow",
"scipy",
"torchaudio",
"pillow",
"ffmpeg",
"soundfile"
]

[project.optional-dependencies]
Expand Down
47 changes: 15 additions & 32 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from modalities.batch import EvaluationResultBatch
from modalities.checkpointing.checkpointing import Checkpointing, CheckpointingIF
from modalities.checkpointing.checkpointing_factory import CheckpointingFactory
from modalities.config.config import AppConfig, ModalitiesSetupConfig, RunMode
from modalities.config.config import AppConfig, ModalitiesSetupConfig, PreparationAppConfig, RunMode
from modalities.config.lookup_types import TokenizerTypes
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import PackedDataGenerator
Expand Down Expand Up @@ -104,48 +104,31 @@ def entry_point_create_memmap_index(src_path, index_path):

@main.command(name="create_packed_data")
@click.argument("src_path", type=Path)
@click.argument("config_file_path", type=Path)
@click.option(
"--dst_path",
type=str,
default=None,
help="output path for packed data file. will use parent directory of src_path if none.",
)
@click.option(
"--index_path",
"--idx_path",
type=Path,
default=None,
help="input path for index. will search in parent directory of src_path if none.",
)
@click.option(
"--tokenizer_type",
type=TokenizerTypes,
show_default=True,
default=TokenizerTypes.GPT2TokenizerFast,
help="Specify which Tokenizer (inheriting from transformers.PretrainedTokenizers) should get used.",
)
@click.option(
"--tokenizer_file",
type=Path,
show_default=True,
default=Path(__file__).parents[2] / Path("data/tokenizer/tokenizer.json"),
help="path to tokenizer json",
)
@click.option(
"--jq_pattern",
type=str,
show_default=True,
default=".text",
help="jq pattern to extract the data from the json line.",
)
def entry_point_create_packed_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern):
# TODO: if we want to use alternative entrypoints together with the ResolverRegistry,
# we can currently not rely on the existing class resolver.
# This is based on its connection to the overall `AppConfig`.
# One would requires an object of it to instantiate the ResolverRegistry.
# This could get resolved by implementing on own ResolverRegistry for each entrypoint or adapting the existing
# ResolverRegistry to work dynamically with any type-hinted config object from config.py.
tokenizer = tokenizer_type.value(tokenizer_file=str(tokenizer_file))
generator = PackedDataGenerator(src_path, index_path=index_path, tokenizer=tokenizer, jq_pattern=jq_pattern)
def entry_point_create_packed_data(src_path, config_file_path, dst_path, idx_path):
config_dict = load_app_config_dict(config_file_path)
config = PreparationAppConfig.model_validate(config_dict)
# build codec components
resolvers = ResolverRegister()
codecs = {f.jq_pattern: resolvers.build_component_by_config(f.codec) for f in config.features}
# generate packed data
generator = PackedDataGenerator(
codecs,
src_path=src_path,
idx_path=idx_path,
)
generator.run(dst_path)


Expand Down
5 changes: 4 additions & 1 deletion src/modalities/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch):
losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())
metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())
throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())

def __str__(self) -> str:
eval_str = (
f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:"
)
eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()])
eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()])
eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()])
eval_str += "\n\nthroughput metrics: " + "\n\t".join(
[f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]
)
eval_str += "\n==============================================="
return eval_str
50 changes: 50 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BatchSamplerTypes,
CheckpointingExectionTypes,
CheckpointingStrategyTypes,
CodecTypes,
CollatorTypes,
DataloaderTypes,
DatasetTypes,
Expand Down Expand Up @@ -51,6 +52,54 @@ class GPT2TokenizerFastConfig(BaseModel):
config: GPT2TokenizerFastConfig


class CodecConfig(BaseModel):
class HfTokenizerCodecConfig(BaseModel):
tokenizer: TokenizerConfig
max_length: Optional[int] = None
add_eos_token: bool = True

class PillowImageCodecConfig(BaseModel):
save_format: str = "png"

class TorchaudioAudioCodecConfig(BaseModel):
target_sample_rate: int = 16_000
n_mels: int = 80

type_hint: CodecTypes
config: Union[
HfTokenizerCodecConfig,
PillowImageCodecConfig,
TorchaudioAudioCodecConfig,
] = Field(union_mode="left_to_right")

@model_validator(mode="before")
def _resolve_type(cls, data):
if isinstance(data, dict):
# resolve config type from type hint
type_hint = data["type_hint"]
CONFIG_RESOLVER = {
CodecTypes.HfTokenizerCodec.name: cls.HfTokenizerCodecConfig,
CodecTypes.PillowImageCodec.name: cls.PillowImageCodecConfig,
CodecTypes.TorchaudioAudioCodec.name: cls.TorchaudioAudioCodecConfig,
}
# create config object of correct type
config_type = CONFIG_RESOLVER.get(type_hint)
config = config_type(**data["config"])
# return codec config
return {"type_hint": type_hint, "config": config}

return data


class FeatureConfig(BaseModel):
codec: CodecConfig
jq_pattern: str


class PreparationAppConfig(BaseModel):
features: List[FeatureConfig]


class DatasetConfig(BaseModel):
class MemMapDatasetConfig(BaseModel):
raw_data_path: FilePath
Expand Down Expand Up @@ -284,6 +333,7 @@ class RunMode(Enum):
FROM_SCRATCH = "FROM_SCRATCH"
WARM_START = "WARM_START"


class ModalitiesSetupConfig(BaseModel):
class WarmStartSettings(BaseModel):
checkpoint_model_path: Path
Expand Down
7 changes: 7 additions & 0 deletions src/modalities/config/lookup_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SaveEveryKStepsCheckpointingStrategy,
SaveKMostRecentCheckpointsStrategy,
)
from modalities.dataloader.codecs import HfTokenizerCodec, PillowImageCodec, TorchaudioAudioCodec
from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader
from modalities.dataloader.dataset import MemMapDataset, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron
from modalities.dataloader.open_gptx_dataset.mmap_dataset import MMapIndexedDatasetBuilder
Expand Down Expand Up @@ -47,6 +48,12 @@ class TokenizerTypes(LookupEnum):
GPT2TokenizerFast = GPT2TokenizerFast


class CodecTypes(LookupEnum):
HfTokenizerCodec = HfTokenizerCodec
PillowImageCodec = PillowImageCodec
TorchaudioAudioCodec = TorchaudioAudioCodec


class DatasetTypes(LookupEnum):
MemMapDataset = MemMapDataset
PackedMemMapDatasetContinuous = PackedMemMapDatasetContinuous
Expand Down
Loading