diff --git a/trl/trainer/ppo_config.py b/trl/trainer/ppo_config.py index 8f23084ce8..c00f1c1d02 100644 --- a/trl/trainer/ppo_config.py +++ b/trl/trainer/ppo_config.py @@ -58,6 +58,8 @@ class PPOConfig(OnPolicyConfig): Discount factor. lam (`float`, *optional*, defaults to `0.95`): Lambda value for GAE. + save_value_model (`bool`, *optional*, defaults to `False`): + Whether the value model (also known as the critic model) should be saved when the policy model is saved. If `False`, the folder will contain the files for the policy only. If `True`, the folder will contain sub-folders for the policy and value model. You can import them by specifying the subfolder using a keyword argument: `from_pretrained(repo_id, subfolder=subfolder)` ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM @@ -121,6 +123,12 @@ class PPOConfig(OnPolicyConfig): default=0.95, metadata={"help": "Lambda value for GAE."}, ) + save_value_model: bool = field( + default=False, + metadata={ + "help": "Whether the value model (also known as the critic model) should be saved when the policy model is saved. If `False`, the folder will contain the files for the policy only. If `True`, the folder will contain sub-folders for the policy and value model. You can import them by specifying the subfolder using a keyword argument: `from_pretrained(repo_id, subfolder=subfolder)`" + }, + ) ds3_gather_for_generation: bool = field( default=True, metadata={ diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 54b62394a8..39e063edc7 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -328,6 +328,14 @@ def null_ref_context(self): self.model.policy.set_adapter(self.model_adapter_name or "default") def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + # Handle the None case here so that we can have subfolders for policy and value + if output_dir is None: + output_dir = self.args.output_dir + if output_dir is None: + raise ValueError("No output directory specified for saving the model") + # I am unsure whether this early return is legal. Line 4814 in Trainer.py says that save_model has to be executed on all processes for TPU training. Previously, save_model would be called in parallel while one process had already set self.model to self.model.policy, resulting in errors. Including this line gets rid of those errors and the model still gets uploaded. + if not hasattr(self.model, "policy"): + return backup_model = self.model self.model = self.model.policy # save only the policy @@ -335,13 +343,28 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa backup_deepspeed = self.deepspeed self.deepspeed = self.model - super().save_model(output_dir, _internal_call) + policy_output_dir = output_dir if not self.args.save_value_model else os.path.join(output_dir, "policy_model") + super().save_model(policy_output_dir, _internal_call) self.model = backup_model if self.is_deepspeed_enabled: self.deepspeed = backup_deepspeed + if self.args.save_value_model: + backup_model = self.model + self.model = self.model.value_model + + if self.is_deepspeed_enabled: + backup_deepspeed = self.deepspeed + self.deepspeed = self.model + value_output_dir = os.path.join(output_dir, "value_model") + super().save_model(value_output_dir, _internal_call) + self.model = backup_model + + if self.is_deepspeed_enabled: + self.deepspeed = backup_deepspeed + def train(self): args = self.args accelerator = self.accelerator