Skip to content

Commit

Permalink
feat(components): migrate function_based convert_to_delimited_string …
Browse files Browse the repository at this point in the history
…to rlhf_preprocessor component

PiperOrigin-RevId: 628282787
  • Loading branch information
Googler committed Apr 26, 2024
1 parent 0c26c04 commit efefe34
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
DO NOT EDIT - This file is generated, manual changes will be overridden.
"""

IMAGE_TAG = '20240425_1027_RC00'
IMAGE_TAG = '20240425_1734_RC00'
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google_cloud_pipeline_components._implementation.llm import preprocess_chat_dataset
from google_cloud_pipeline_components._implementation.llm import private_text_comparison_importer
from google_cloud_pipeline_components._implementation.llm import reward_model_trainer
from google_cloud_pipeline_components._implementation.llm import rlhf_preprocessor
from google_cloud_pipeline_components._implementation.llm import upload_tensorboard_metrics
import kfp

Expand All @@ -45,6 +46,7 @@ def pipeline(
accelerator_type: str,
accelerator_count: int,
reward_model_image_uri: str,
comma_separated_candidates_field_names: str,
prompt_sequence_length: int = 512,
target_sequence_length: int = 64,
batch_size: int = 64,
Expand Down Expand Up @@ -72,6 +74,7 @@ def pipeline(
accelerator_type: Specific accelerator type for the custom job.
accelerator_count: The number of accelerator.
reward_model_image_uri: Docker image URI to use for the reward model training job.
comma_separated_candidates_field_names: Comma separated list of fields that contain candidate text, e.g. ``'field_1,field_2,field_3'``.
prompt_sequence_length: Maximum tokenized sequence length for input text. Higher values increase memory overhead. This value should be at most 8192. Default value is 512.
target_sequence_length: Maximum tokenized sequence length for target text. Higher values increase memory overhead. This value should be at most 1024. Default value is 64.
batch_size: Number of examples in each finetuning step. Default is 64.
Expand All @@ -91,7 +94,6 @@ def pipeline(
"""
# fmt: on
prompt_column = 'input_text'
candidate_columns = ['candidate_0', 'candidate_1']
choice_column = 'choice'

processed_preference_dataset = (
Expand All @@ -103,9 +105,6 @@ def pipeline(
).set_display_name('Preprocess Prompt Dataset')
)

comma_separated_candidates_field_names = (
function_based.convert_to_delimited_string(items=candidate_columns)
)
preference_dataset_importer = (
private_text_comparison_importer.private_text_comparison_importer(
project=project,
Expand All @@ -114,7 +113,7 @@ def pipeline(
'processed_dataset_uri'
],
inputs_field_name=prompt_column,
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
comma_separated_candidates_field_names=comma_separated_candidates_field_names,
choice_field_name=choice_column,
split=env.TRAIN_SPLIT,
large_model_reference=reward_model_reference,
Expand All @@ -131,7 +130,7 @@ def pipeline(
location=location,
input_text=eval_dataset,
inputs_field_name=prompt_column,
comma_separated_candidates_field_names=comma_separated_candidates_field_names.output,
comma_separated_candidates_field_names=comma_separated_candidates_field_names,
choice_field_name=choice_column,
split=env.TRAIN_SPLIT,
large_model_reference=reward_model_reference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Component that preprocesses inputs for Reinforcement Learning from Human Feedback (RLHF)."""

import os
from typing import List

from google_cloud_pipeline_components import _placeholders
from google_cloud_pipeline_components import utils as gcpc_utils
Expand All @@ -33,6 +34,7 @@ def rlhf_preprocessor(
gcp_resources: dsl.OutputPath(str), # pytype: disable=invalid-annotation
has_tensorboard_id: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
has_inference_dataset: dsl.OutputPath(bool), # pytype: disable=invalid-annotation
metadata_candidate_columns_string: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_large_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_reference_model_path: dsl.OutputPath(str), # pytype: disable=invalid-annotation
metadata_reward_model_reference: dsl.OutputPath(str), # pytype: disable=invalid-annotation
Expand Down Expand Up @@ -104,6 +106,7 @@ def rlhf_preprocessor(
f'--use_experimental_image={use_experimental_image}',
f'--has_tensorboard_id_path={has_tensorboard_id}',
f'--has_inference_dataset_path={has_inference_dataset}',
f'--metadata_candidate_columns_string_path={metadata_candidate_columns_string}',
f'--metadata_large_model_reference_path={metadata_large_model_reference}',
f'--metadata_reference_model_path_path={metadata_reference_model_path}',
f'--metadata_reward_model_reference_path={metadata_reward_model_reference}',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def rlhf_pipeline(
reward_model_image_uri=preprocess_metadata.outputs[
'metadata_refined_image_uri'
],
comma_separated_candidates_field_names=preprocess_metadata.outputs[
'metadata_candidate_columns_string'
],
prompt_sequence_length=prompt_sequence_length,
target_sequence_length=target_sequence_length,
eval_dataset=validate_pipeline_task.outputs[
Expand Down

0 comments on commit efefe34

Please sign in to comment.