Skip to content

Commit

Permalink
more improve
Browse files Browse the repository at this point in the history
  • Loading branch information
thevasudevgupta committed May 19, 2022
1 parent b492fb1 commit 65e8d88
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 45 deletions.
1 change: 1 addition & 0 deletions docs/readme.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
WIP
77 changes: 33 additions & 44 deletions experiments/train_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import optax
from flax import traverse_util
from flax.training import train_state
from tqdm.auto import tqdm
from transformers import (FlaxWav2Vec2ForCTC, Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor)

Expand All @@ -35,14 +34,14 @@ def weight_decay_mask(params):


model_id = "facebook/wav2vec2-large-lv60"
# model = FlaxWav2Vec2ForCTC.from_pretrained(model_id)
model = FlaxWav2Vec2ForCTC.from_pretrained(model_id)

# state = TrainState.create(
# apply_fn=model.__call__,
# params=model.params,
# tx=create_tx(1e-4, 1e-4),
# loss_fn=optax.ctc_loss,
# )
state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=create_tx(1e-4, 1e-4),
loss_fn=optax.ctc_loss,
)


@partial(jax.pmap, axis_name="batch")
Expand Down Expand Up @@ -87,22 +86,21 @@ def __call__(self, batch: List[Dict[str, Any]]):
text = [sample["text"] for sample in batch]

# TODO: explore other padding options in JAX (special dynamic padding?)
# audio = self.feature_extractor(
# audio,
# padding="max_length",
# max_length=self.audio_max_len,
# truncation=True,
# return_tensors="np",
# )
# text = self.tokenizer(
# text,
# max_length=self.text_max_len,
# truncation=True,
# padding="max_length",
# return_tensors="np",
# )
# return audio, text
return (text,)
audio = self.feature_extractor(
audio,
padding="max_length",
max_length=self.audio_max_len,
truncation=True,
return_tensors="np",
)
text = self.tokenizer(
text,
max_length=self.text_max_len,
truncation=True,
padding="max_length",
return_tensors="np",
)
return audio, text


# TODO (for fine-tuning):
Expand All @@ -117,36 +115,27 @@ def __call__(self, batch: List[Dict[str, Any]]):
)

trainer_config = TrainerConfig(
max_epochs=30,
max_epochs=2,
train_batch_size_per_device=2,
eval_batch_size_per_device=2,
wandb_project_name="speech-JAX",
)

# trainer = Trainer(
# config=trainer_config,
# datacollator=collate_fn,
# training_step=training_step,
# validation_step=validation_step,
# state=state,
# )
trainer = Trainer(
config=trainer_config,
datacollator=collate_fn,
training_step=training_step,
validation_step=validation_step,
state=state,
)


from datasets import interleave_datasets, load_dataset

train_data = [
load_dataset("librispeech_asr", "clean", split="train.100", streaming=True),
# load_dataset("librispeech_asr", "clean", split="train.360", streaming=True),
# load_dataset("librispeech_asr", "other", split="train.500", streaming=True),
load_dataset("librispeech_asr", "clean", split="train.360", streaming=True),
load_dataset("librispeech_asr", "other", split="train.500", streaming=True),
]
train_data = interleave_datasets(train_data)
# val_data = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)


from speech_jax.training import DataLoader

dataloader = DataLoader(train_data, batch_size=4, collate_fn=collate_fn)

i = 0
for batch in tqdm(dataloader):
print(batch)
val_data = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
21 changes: 21 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1 +1,22 @@
Something exciting WIP

### For development locally

```bash
# JAX should be installed by user before running following

git clone https://github.com/vasudevgupta7/speech-jax.git
pip3 install -e .
```

### Running tests

```bash
pytest -sv tests/
```

### Usage

```python
from speech_jax.training import DataLoader, Trainer, TrainerConfig
```
20 changes: 19 additions & 1 deletion src/speech_jax/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
from typing import Callable
from pathlib import Path
from typing import Callable, Union

import jax
import jax.numpy as jnp
Expand All @@ -10,6 +11,8 @@
from flax.training.common_utils import shard
from tqdm.auto import tqdm

PathType = Union[Path, str]


class DataLoader:
def __init__(self, dataset: IterableDataset, batch_size=1, collate_fn=None):
Expand Down Expand Up @@ -93,3 +96,18 @@ def evaluate(self, data: DataLoader, state: train_state.TrainState):
loss = self.validation_step(batch, state)
val_loss += jax_utils.unreplicate(loss)
return val_loss

def save_checkpoint(self, ckpt_dir: PathType) -> Path:
ckpt_dir = Path(ckpt_dir)
ckpt_dir.mkdir(exist_ok=True)
# TODO: add logic here
# directory saving
# flax model in flax_model.msgpack
# optim state in optim_state.msgpack
# model config in config.yaml
# training config in ...

return ckpt_dir

def load_checkpoint(self, ckpt_dir: PathType):
...
5 changes: 5 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import unittest

class DataLoaderTester(unittest.TestCase):
def test_hello(self):
print("hello world!")

0 comments on commit 65e8d88

Please sign in to comment.