Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some QOL config/saving improvements #134

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 5 additions & 17 deletions eole/config/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,6 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
model_config = get_config_dict()
model_config["arbitrary_types_allowed"] = True # to allow torch.dtype

# TODO: clarify models vs model (model config retrieved from checkpoint)
model_path: str | List[str] = Field(
description="Path to model .pt file(s). "
"Multiple models can be specified for ensemble decoding."
) # some specific (mapping to "models") in legacy code, need to investigate
src: str = Field(description="Source file to decode (one line per sequence).")
tgt: str | None = Field(
default=None,
description="True target sequences, useful for scoring or prefix decoding.",
)
tgt_file_prefix: bool = Field(
default=False, description="Generate predictions using provided tgt as prefix."
)
output: str = Field(
default="pred.txt",
description="Path to output the predictions (each line will be the decoded sequence).",
)
report_align: bool = Field(
default=False, description="Report alignment for each translation."
)
Expand All @@ -148,6 +131,11 @@ class InferenceConfig(RunningConfig, DecodingConfig, LoRaConfig, QuantizeConfig)
data_type: str | None = (
"text" # deprecated? hopefully will change with input streams logic
)
chat_template: str | None = None
optional_eos: List[str] | None = Field(
default=[],
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
)

def get_model_path(self):
return self.model_path[0]
Expand Down
20 changes: 16 additions & 4 deletions eole/config/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class TrainConfig(
) # not sure this still works
model: ModelConfig | None = None # TypeAdapter handling discrimination directly
training: TrainingConfig | None = Field(default_factory=TrainingConfig)
inference: InferenceConfig | None = Field(default=None)

def get_model_path(self):
return self.training.get_model_path()
Expand Down Expand Up @@ -100,10 +101,21 @@ class PredictConfig(
None # patch for CT2 inference engine (to improve later)
)
model: ModelConfig | None = None
chat_template: str | None = None
optional_eos: List[str] | None = Field(
default=[],
description="Optional EOS tokens that would stop generation, e.g. <|eot_id|> for Llama3",
model_path: str | List[str] = Field(
description="Path to model .pt file(s). "
"Multiple models can be specified for ensemble decoding."
) # some specific (mapping to "models") in legacy code, need to investigate
src: str = Field(description="Source file to decode (one line per sequence).")
tgt: str | None = Field(
default=None,
description="True target sequences, useful for scoring or prefix decoding.",
)
tgt_file_prefix: bool = Field(
default=False, description="Generate predictions using provided tgt as prefix."
)
output: str = Field(
default="pred.txt",
description="Path to output the predictions (each line will be the decoded sequence).",
)

@model_validator(mode="after")
Expand Down
15 changes: 10 additions & 5 deletions eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def load_checkpoint(model_path):
config_dict = json.loads(os.path.expandvars(f.read()))
# drop data to prevent validation issues
config_dict["data"] = {}
# drop inference to prevent validation issues
if "inference" in config_dict.keys():
config_dict.pop("inference")
if "training" in config_dict.keys():
config_dict["training"]["dummy_load"] = True
else:
Expand Down Expand Up @@ -290,13 +287,18 @@ def _save_config(self):

def _save_transforms_artifacts(self):
if self.transforms is not None:
checkpoint_path = os.path.join(self.model_path, self.step_dir)
for transform_name, transform in self.transforms.items():
transform_save_config = transform._save_artifacts(self.model_path)
transform_save_config, artifacts = transform._save_artifacts(
checkpoint_path
)
setattr(
self.config.transforms_configs,
transform_name,
transform_save_config,
)
for artifact in artifacts:
self._make_symlink(artifact)
# we probably do not need to save transforms artifacts for each checkpoint
# transform._save_artifacts(os.path.join(self.model_path, self.step_dir))

Expand All @@ -323,7 +325,10 @@ def _save(self, step):
)
self._save_optimizer()
self._save_weights(model_state_dict)
logger.info(f"Saving transforms artifacts, if any, to {self.model_path}")
logger.info(
"Saving transforms artifacts, if any, "
f"to {os.path.join(self.model_path, self.step_dir)}"
)
self._save_transforms_artifacts()
logger.info(f"Saving config and vocab to {self.model_path}")
self._save_vocab()
Expand Down
15 changes: 9 additions & 6 deletions eole/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def warm_up(self, vocabs=None):

def _save_artifacts(self, model_path):
save_config = copy.deepcopy(self.config)
artifacts = []
for artifact in self.artifacts:
maybe_artifact = getattr(self, artifact, None)
if maybe_artifact is not None and os.path.exists(maybe_artifact):
Expand All @@ -66,12 +67,14 @@ def _save_artifacts(self, model_path):
shutil.copy(maybe_artifact, model_path)
except shutil.SameFileError:
pass
setattr(
save_config,
artifact,
os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)),
)
return save_config
finally:
artifacts.append(os.path.basename(maybe_artifact))
setattr(
save_config,
artifact,
os.path.join("${MODEL_PATH}", os.path.basename(maybe_artifact)),
)
return save_config, artifacts

@classmethod
def add_options(cls, parser):
Expand Down
Loading