Skip to content

Commit

Permalink
remove train with descriptors from default
Browse files Browse the repository at this point in the history
  • Loading branch information
maclandrol committed Jul 28, 2024
1 parent 0d51d9a commit 3a2e40d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
22 changes: 10 additions & 12 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies:
- transformers
- datasets
- tokenizers
- accelerate >=0.28.0 # for accelerator_config update
- accelerate >=0.33 # for accelerator_config update
- evaluate
- wandb
- huggingface_hub
Expand All @@ -29,19 +29,17 @@ dependencies:
- black >=24
- ruff
- pytest >=6.0
- nbconvert
- jupyterlab
- nbconvert
- ipywidgets

- pip:
- mkdocs <1.6.0
- mkdocs-material >=7.1.1
- mkdocs-material-extensions
- mkdocstrings
- mkdocstrings-python
- mkdocs-jupyter
- markdown-include
- markdown-include
- mdx_truly_sane_lists
- mike >=1.0.0
- mkdocs <1.6.0
- mkdocs-material >=7.1.1
- mkdocs-material-extensions
- mkdocstrings
- mkdocstrings-python
- mkdocs-jupyter
- markdown-include
- mdx_truly_sane_lists
- mike >=1.0.0
5 changes: 3 additions & 2 deletions safe/trainer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ModelArguments:
default=None, metadata={"help": "Optional number of labels for the descriptors"}
)
include_descriptors: Optional[bool] = field(
default=True,
default=False,
metadata={"help": "Whether to train with descriptors if they are available or Not"},
)
prop_loss_coeff: Optional[float] = field(
Expand Down Expand Up @@ -331,7 +331,8 @@ def compute_metrics(eval_preds):

if model_args.include_descriptors:
training_args.label_names = ["labels", "mc_labels"]

else:
training_args.label_names = ["labels"]
# update dispatch_batches in accelerator
training_args.accelerator_config.dispatch_batches = data_args.streaming is not True

Expand Down
4 changes: 2 additions & 2 deletions safe/trainer/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def __call__(self, samples: List[Union[List[int], Any, Dict[str, Any]]]):

# If special token mask has been preprocessed, pop it from the dict.
batch.pop("special_tokens_mask", None)
labels = batch.get("labels", batch["input_ids"].clone())
labels = batch.get(self.label_key, batch["input_ids"].clone())
if tokenizer.pad_token_id is not None:
labels[labels == tokenizer.pad_token_id] = -100
batch["labels"] = labels
batch[self.label_key] = labels

if mc_labels is not None and self.include_descriptors:
batch.update(
Expand Down
10 changes: 8 additions & 2 deletions safe/trainer/trainer_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from transformers import Trainer
from transformers.modeling_utils import unwrap_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.trainer import _is_peft_model


class SAFETrainer(Trainer):
Expand All @@ -22,15 +23,19 @@ def compute_loss(self, model, inputs, return_outputs=False):
labels = (
inputs.pop("labels") if self.label_smoother is not None and "labels" in inputs else None
)

outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]

if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
unwrapped_model = self.accelerator.unwrap_model(model)
if _is_peft_model(unwrapped_model):
model_name = unwrapped_model.base_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True)
else:
loss = self.label_smoother(outputs, labels)
Expand All @@ -42,6 +47,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

mc_loss = outputs.get("mc_loss", None) if isinstance(outputs, dict) else outputs[1]
if mc_loss is not None:
loss = loss + self.prop_loss_coeff * mc_loss
Expand Down

0 comments on commit 3a2e40d

Please sign in to comment.