Skip to content

Commit

Permalink
OOM error fixing (#58)
Browse files Browse the repository at this point in the history
* First try at fixing OOM's

* Second round of fixes for OOM errors

* Fix tests

* Rework to add to huggingface directly

* Revert initial implementation.
Respond to reviewer comments.

* Add fixme

* Use OutOfMemoryError to backoff of batch size

* Fix small error.
Change to real error - and suppress warnings
change name to oom_batch_size_backoff_mult

---------

Co-authored-by: Damien Daspit <[email protected]>
  • Loading branch information
johnml1135 and ddaspit authored Nov 22, 2023
1 parent d324ec7 commit 4faa596
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 18 deletions.
10 changes: 9 additions & 1 deletion machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Any, cast

Expand All @@ -15,6 +16,8 @@
from ..nmt_model_factory import NmtModelFactory
from ..shared_file_service import SharedFileService

logger = logging.getLogger(__name__)


class HuggingFaceNmtModelFactory(NmtModelFactory):
def __init__(self, config: Any, shared_file_service: SharedFileService) -> None:
Expand Down Expand Up @@ -67,7 +70,11 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens,
)

def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
if half_previous_batch_size:
self._config.huggingface.generate_params.batch_size = max(
self._config.huggingface.generate_params.batch_size // 2, 1
)
return HuggingFaceNmtEngine(
self._model,
src_lang=self._config.src_lang,
Expand All @@ -76,6 +83,7 @@ def create_engine(self) -> TranslationEngine:
num_beams=self._config.huggingface.generate_params.num_beams,
batch_size=self._config.huggingface.generate_params.batch_size,
truncation=TruncationStrategy.LONGEST_FIRST,
oom_batch_size_backoff_mult=self._config.huggingface.generate_params.oom_batch_size_backoff_mult,
)

def save_model(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion machine/jobs/nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer:
...

@abstractmethod
def create_engine(self) -> TranslationEngine:
def create_engine(self, half_previous_batch_size=False) -> TranslationEngine:
...

@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ default:
device: 0
num_beams: 2
batch_size: 16
oom_batch_size_backoff_mult: 0.5
tokenizer:
add_unk_src_tokens: true
add_unk_trg_tokens: true
Expand All @@ -34,4 +35,4 @@ staging:
huggingface:
parent_model_name: facebook/nllb-200-distilled-600M
generate_params:
num_beams: 1
num_beams: 1
63 changes: 51 additions & 12 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import gc
import logging
from math import exp, prod
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast

Expand All @@ -17,29 +18,36 @@
from ..translation_sources import TranslationSources
from ..word_alignment_matrix import WordAlignmentMatrix

logger = logging.getLogger(__name__)


class HuggingFaceNmtEngine(TranslationEngine):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
oom_batch_size_backoff_mult: float = 1.0,
**pipeline_kwargs,
) -> None:
if isinstance(model, PreTrainedModel):
model.eval()
self._model = model
self._pipeline_kwargs = pipeline_kwargs
if isinstance(self._model, PreTrainedModel):
self._model.eval()
else:
model_config = AutoConfig.from_pretrained(str(model), label2id={}, id2label={}, num_labels=0)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(model), config=model_config))
self._tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)
model_config = AutoConfig.from_pretrained(str(self._model), label2id={}, id2label={}, num_labels=0)
self._model = cast(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)

src_lang = pipeline_kwargs.get("src_lang")
tgt_lang = pipeline_kwargs.get("tgt_lang")
src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
if (
src_lang is not None
and tgt_lang is not None
and "prefix" not in pipeline_kwargs
and (model.name_or_path.startswith("t5-") or model.name_or_path.startswith("google/mt5-"))
and "prefix" not in self._pipeline_kwargs
and (self._model.name_or_path.startswith("t5-") or self._model.name_or_path.startswith("google/mt5-"))
):
pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if (
Expand All @@ -56,10 +64,15 @@ def __init__(
):
raise ValueError(f"The specified model does not support the language code '{tgt_lang}'")

self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1))

self._oom_batch_size_backoff_mult = oom_batch_size_backoff_mult

self._pipeline = _TranslationPipeline(
model=model,
model=self._model,
tokenizer=self._tokenizer,
**pipeline_kwargs,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
Expand All @@ -73,6 +86,32 @@ def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequ

def translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
while True:
if type(segments) is str:
segments = [segments]
else:
segments = [segment for segment in segments]
outer_batch_size = len(segments)
all_results: List[Sequence[TranslationResult]] = []
try:
for step in range(0, outer_batch_size, self._batch_size):
all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size]))
return all_results
except torch.cuda.OutOfMemoryError: # type: ignore[reportGeneralTypeIssues]
if self._oom_batch_size_backoff_mult >= 0.9999 or self._batch_size <= 1:
raise
self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_mult)), 1)
logger.warning(f"Out of memory error caught. Reducing batch size to {self._batch_size} and retrying.")
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)

def _try_translate_n_batch(
self, n: int, segments: Sequence[Union[str, Sequence[str]]]
) -> Sequence[Sequence[TranslationResult]]:
all_results: List[List[TranslationResult]] = []
i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def train(
num_labels=0,
)
model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config))

logger.info("Initializing tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True)

src_lang = self._src_lang
Expand Down Expand Up @@ -194,6 +196,7 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True)

if self._add_unk_src_tokens or self._add_unk_trg_tokens:
logger.info("Checking for missing tokens")
if not isinstance(tokenizer, PreTrainedTokenizerFast):
logger.warning(
f"Tokenizer can not be updated from default configuration: \
Expand Down Expand Up @@ -234,6 +237,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
tokenizer.id_to_lang_token[lang_id] = lang_code

if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
Expand Down Expand Up @@ -309,6 +313,7 @@ def preprocess_function(examples):
model_inputs["labels"] = labels["input_ids"]
return model_inputs

logger.info("Run tokenizer")
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -339,17 +344,21 @@ def preprocess_function(examples):
],
)

logger.info("Train NMT model")
ckpt = None
if self._training_args.resume_from_checkpoint is not None:
ckpt = self._training_args.resume_from_checkpoint
elif last_checkpoint is not None:
ckpt = last_checkpoint
train_result = self._trainer.train(resume_from_checkpoint=ckpt)
train_result = self._trainer.train(
resume_from_checkpoint=ckpt,
)

self._metrics = train_result.metrics
self._metrics["train_samples"] = len(train_dataset)

self._trainer.log_metrics("train", self._metrics)
logger.info("Model training finished")

def save(self) -> None:
if self._trainer is None:
Expand Down
4 changes: 2 additions & 2 deletions machine/translation/translation_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from abc import abstractmethod
from types import TracebackType
from typing import ContextManager, Optional, Sequence, Type, Union
from typing import Optional, Sequence, Type, Union

from .translation_result import TranslationResult


class TranslationEngine(ContextManager["TranslationEngine"]):
class TranslationEngine:
@abstractmethod
def translate(self, segment: Union[str, Sequence[str]]) -> TranslationResult:
...
Expand Down

0 comments on commit 4faa596

Please sign in to comment.