diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 14ddf507d..6d6fd72d3 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -28,6 +28,7 @@ checkpointing: save_period: 10 policy: + training_backend: "hf" # Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA/nemo-rl/issues/227) model_name: "Qwen/Qwen2.5-1.5B" tokenizer: diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml new file mode 100644 index 000000000..c19a68a70 --- /dev/null +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -0,0 +1,132 @@ +# GRPO Algorithm Configuration +defaults: "grpo_math_1B.yaml" + +grpo: + num_prompts_per_step: 32 + num_generations_per_prompt: 16 + max_num_steps: 1000000 + normalize_rewards: true + use_leave_one_out_baseline: true + val_period: 10 + val_at_start: false + max_val_samples: 256 + val_batch_size: 256 + +loss_fn: + reference_policy_kl_penalty: 0.01 + ratio_eps_min: 0.2 + ratio_eps_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false + +checkpointing: + enabled: false + checkpoint_dir: "results/grpo_megatron" + metric_name: "val_reward" + higher_is_better: true + keep_top_k: 3 + save_period: 10 + +policy: + training_backend: "megatron" + model_name: "Qwen/Qwen2.5-1.5B-Instruct" + tokenizer: + name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default + train_global_batch_size: 512 + train_micro_batch_size: 2 + generation_batch_size: 64 # Only used when generating using megatron backend + logprob_batch_size: 4 + max_total_sequence_length: 512 + precision: "bfloat16" + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_dtype: "float32" + context_parallel_size: 1 + refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB + + dtensor_cfg: + enabled: false + + max_grad_norm: 1.0 + # makes the training sequence length divisible by the tensor parallel size + # this is useful for sequence parallel training + make_sequence_length_divisible_by: ${policy.tensor_model_parallel_size} + + optimizer: null # remove default FSDP optimizer + + megatron_cfg: + enabled: true + empty_unused_memory_level: 1 + converter_type: "Qwen2ForCausalLM" + + optimizer: + optimizer: "adam" + lr: 5.0e-6 + min_lr: 5.0e-7 + weight_decay: 0.01 + bf16: false + fp16: false + params_dtype: "float32" + + #adam + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_eps: 1e-8 + + #sgd + sgd_momentum: 0.9 + + #distributed optimizer + use_distributed_optimizer: true + use_precision_aware_optimizer: true + + clip_grad: ${policy.max_grad_norm} + + scheduler: + start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay} + weight_decay_incr_style: "constant" + lr_decay_style: "constant" + lr_decay_iters: null + lr_warmup_iters: 50 + lr_warmup_init: 5.0e-7 + + generation: + backend: "vllm" + max_new_tokens: ${policy.max_total_sequence_length} + temperature: 1.0 + top_p: 1.0 + top_k: null + vllm_cfg: + tensor_parallel_size: 1 + gpu_memory_utilization: 0.6 + max_model_len: ${policy.max_total_sequence_length} + +data: + max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len + prompt_file: "examples/prompts/cot.txt" + system_prompt_file: null + dataset_name: "OpenMathInstruct-2" + +env: + math: + num_workers: 8 + +logger: + log_dir: "logs" # Base directory for all logs + num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal + wandb_enabled: false + tensorboard_enabled: false + monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard + wandb: + project: "grpo-dev" + name: "sj_megatron_1B" + tensorboard: {} + gpu_monitoring: + collection_interval: 10 # How often to collect GPU usage metrics (in seconds) + flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds) + +cluster: + gpus_per_node: 1 + num_nodes: 1 \ No newline at end of file diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 5a007451d..4fe68f69c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -49,6 +49,7 @@ from nemo_rl.models.interfaces import PolicyInterface from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.hf_policy import HfPolicy +from nemo_rl.models.policy.megatron_policy import MegatronPolicy from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import ( Logger, @@ -198,7 +199,7 @@ def setup( # Cluster # ========================== print("\nā–¶ Setting up compute cluster...") - colocated_inference = generation_config["backend"] != "hf" + colocated_inference = generation_config["backend"] not in ["hf", "megatron"] cluster = RayVirtualCluster( name="grpo_policy_cluster", bundle_ct_per_node_list=[cluster_config["gpus_per_node"]] @@ -218,30 +219,42 @@ def setup( backend = generation_config["backend"] generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM - if backend == "hf": + if backend in ["hf", "megatron"]: policy_generation = None - print(f" āœ“ Using HF backend for generation with {policy_config['model_name']}") elif backend == "vllm": policy_generation = VllmGeneration(cluster=cluster, config=generation_config) # Worker groups are not initialized until the first call to run something on workergroups. # vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory). policy_generation.finish_generation() - print( - f" āœ“ Using vLLM backend for generation with {policy_config['model_name']}" + else: + raise ValueError(f"Unknown generation backend: {backend}") + print(f" āœ“ Using {backend} for generation with {policy_config['model_name']}") + + if policy_config["training_backend"] == "hf": + policy = HfPolicy( + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + weights_path=Path(last_checkpoint_path) / "policy" / "weights" + if last_checkpoint_path + else None, + optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" + if last_checkpoint_path + else None, + init_optimizer=True, + ) + elif policy_config["training_backend"] == "megatron": + policy = MegatronPolicy( + cluster=cluster, + config=policy_config, + tokenizer=tokenizer, + init_optimizer=True, + init_reference_model=True, + ) + else: + raise ValueError( + f"Unknown training backend: {policy_config['training_backend']}" ) - - policy = HfPolicy( - cluster=cluster, - config=policy_config, - tokenizer=tokenizer, - weights_path=Path(last_checkpoint_path) / "policy" / "weights" - if last_checkpoint_path - else None, - optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer" - if last_checkpoint_path - else None, - init_optimizer=True, - ) loss_fn = ClippedPGLossFn(loss_config) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 84d751036..ec17ecb82 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple, TypedDict +from typing import Any, Optional, Tuple, TypedDict import torch @@ -21,6 +21,7 @@ masked_mean, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs from nemo_rl.models.dtensor.parallelize import ( get_logprobs_from_vocab_parallel_logits, ) @@ -90,6 +91,8 @@ def __call__( self, next_token_logits: torch.Tensor, data: BatchedDataDict[ClippedPGLossDataDict], + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> Tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" token_mask = data["token_mask"][:, 1:] @@ -109,7 +112,16 @@ def __call__( next_token_logits = next_token_logits.to(torch.float32) - if isinstance(next_token_logits, torch.distributed.tensor.DTensor): + if vocab_parallel_group is not None: + curr_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + data["input_ids"], + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + group=vocab_parallel_group, + inference_only=False, + ) + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): curr_logprobs = get_logprobs_from_vocab_parallel_logits( next_token_logits, data["input_ids"] ) diff --git a/nemo_rl/distributed/named_sharding.py b/nemo_rl/distributed/named_sharding.py new file mode 100644 index 000000000..5eb4bef60 --- /dev/null +++ b/nemo_rl/distributed/named_sharding.py @@ -0,0 +1,169 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Sequence, Union + +import numpy as np + + +class NamedSharding: + """Represents an N-dimensional arrangement of ranks with named axes, facilitating data sharding, replication, and collection based on these axes. + + Example: + layout = [ + [[0, 1, 2, 3], [4, 5, 6, 7]], + ] + names = ["dp", "pp", "tp"] + # This represents DP=1, PP=2, TP=4 + sharding = NamedSharding(layout, names) + print(sharding.shape) # Output: (1, 2, 4) + print(sharding.names) # Output: ['dp', 'pp', 'tp'] + print(sharding.get_ranks(dp=0, pp=1)) # Output: [4, 5, 6, 7] + """ + + def __init__(self, layout: Sequence[Any], names: List[str]): + """Initializes the NamedSharding object. + + Args: + layout: A nested sequence (e.g., list of lists) representing the ND rank layout. + All inner lists must contain integer rank IDs. + names: A list of strings representing the names of the dimensions, + ordered from the outermost to the innermost dimension. + """ + # Convert to numpy array first, inferring dtype + try: + initial_array = np.array(layout) + except ( + ValueError + ) as e: # Catch potential errors during array creation (e.g., ragged arrays) + raise ValueError(f"Could not create NumPy array from layout: {e}") + + # Check if the inferred dtype is integer-like or float representing integers + if not np.issubdtype(initial_array.dtype, np.integer): + # Check if all elements are actually integers (handles floats like 1.0) + if not np.equal(np.mod(initial_array, 1), 0).all(): + raise ValueError("Layout must contain only integer rank IDs.") + # If they are float but represent integers (e.g., 1.0), cast them + self._layout = initial_array.astype(int) + else: + self._layout = initial_array # Already integer type + + self._names = list(names) + + if self._layout.ndim != len(self._names): + raise ValueError( + f"Number of dimensions in layout ({self._layout.ndim}) " + f"must match the number of names ({len(self._names)})." + ) + + # Check for duplicate ranks (on the final integer array) + unique_ranks, counts = np.unique(self._layout, return_counts=True) + duplicates = unique_ranks[counts > 1] + if duplicates.size > 0: + raise ValueError(f"Duplicate ranks found in layout: {duplicates.tolist()}") + + self._name_to_axis = {name: i for i, name in enumerate(self._names)} + + @property + def shape(self) -> Dict[str, int]: + """Returns the shape of the rank layout.""" + return {name: size for name, size in zip(self._names, self._layout.shape)} + + @property + def names(self) -> List[str]: + """Returns the names of the axes.""" + return list(self._names) # Return a copy + + @property + def ndim(self) -> int: + """Returns the number of dimensions.""" + return self._layout.ndim + + @property + def size(self) -> int: + """Returns the total number of ranks.""" + return self._layout.size + + @property + def layout(self) -> np.ndarray: + """Returns the underlying NumPy array representing the layout.""" + return self._layout.copy() # Return a copy + + def get_ranks(self, **kwargs: int) -> Union["NamedSharding", int]: + """Gets the ranks corresponding to specific indices along named axes. + + Args: + **kwargs: Keyword arguments where the key is the axis name (e.g., "dp", "tp") + and the value is the index along that axis. + + Returns: + A new NamedSharding instance representing the subset of ranks. + The shape of the returned sharding corresponds to the axes *not* specified + in the kwargs. If all axes are specified, an int is returned. + + Raises: + ValueError: If an invalid axis name is provided or if an index is out of bounds. + """ + indices: List[Any] = [slice(None)] * self.ndim + specified_axes = set() + + for name, index in kwargs.items(): + if name not in self._name_to_axis: + raise ValueError( + f"Invalid axis name: '{name}'. Valid names are: {self.names}" + ) + if not (0 <= index < self.shape[name]): + raise IndexError( + f"Index {index} is out of bounds for axis '{name}' with size {self.shape[name]}" + ) + + axis_index = self._name_to_axis[name] + indices[axis_index] = index + specified_axes.add(axis_index) + + # Get the subset of ranks + subset_layout = self._layout[tuple(indices)] + + # Create a new list of names for the remaining dimensions + remaining_names = [ + name for i, name in enumerate(self._names) if i not in specified_axes + ] + + # If all dimensions were specified, we need to handle the 0-dimensional case + if not remaining_names: + return subset_layout.item() + + return NamedSharding(subset_layout, remaining_names) + + def get_axis_index(self, name: str) -> int: + """Gets the numerical index of a named axis.""" + if name not in self._name_to_axis: + raise ValueError( + f"Invalid axis name: '{name}'. Valid names are: {self.names}" + ) + return self._name_to_axis[name] + + def get_axis_size(self, name: str) -> int: + """Gets the size of a named axis.""" + return self.shape[name] + + def __repr__(self) -> str: + shape_str = ", ".join([f"{self.shape[name]}" for name in self.names]) + return f"NamedSharding(shape=({shape_str}), names={self.names}, layout={self._layout})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, NamedSharding): + return NotImplemented + return ( + np.array_equal(self._layout, other._layout) and self._names == other._names + ) diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index 76c1be627..0c1e15312 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -15,13 +15,14 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Union import ray from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from nemo_rl.distributed.batched_data_dict import SlicedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.utils.venvs import create_local_venv @@ -32,6 +33,7 @@ class MultiWorkerFuture: futures: List[ray.ObjectRef] used_workers: List[int] + return_from_workers: List[str] = None respect_tied_workers: bool = True def get_results(self, worker_group): @@ -52,6 +54,9 @@ def get_results(self, worker_group): # Basic case: Get all results all_results = ray.get(self.futures) + if self.return_from_workers: + return [all_results[worker_idx] for worker_idx in self.return_from_workers] + # If we don't need to deduplicate by tied workers, return all results if not self.respect_tied_workers: return all_results @@ -190,6 +195,7 @@ def __init__( workers_per_node: Optional[Union[int, List[int]]] = None, name_prefix: str = "", bundle_indices_list: Optional[List[tuple]] = None, + sharding_annotations: Optional[NamedSharding] = None, ): """Initialize a group of distributed Ray workers. @@ -202,6 +208,7 @@ def __init__( bundle_indices_list: Explicit list of (node_idx, [local_bundle_indices]) tuples. Each tuple defines a tied group of workers placed on the same node. If provided, workers_per_node is ignored. + sharding_annotations: NamedSharding object representing mapping of named axes to ranks (i.e. for TP, PP, etc.) """ self._workers = [] self._worker_metadata = [] @@ -212,6 +219,7 @@ def __init__( # For example, if worker with index 3 belongs to tied worker group 1, # then worker_to_tied_group_index[3] = 1 self.worker_to_tied_group_index = {} + self.sharding_annotations = sharding_annotations # If explicit bundle indices are provided, use those if bundle_indices_list is None: @@ -501,6 +509,112 @@ def run_all_workers_single_data( return futures + def run_all_workers_sharded_data( + self, + method_name: str, + data: Iterable[SlicedDataDict], # arbitrary nested iterables of SlicedDataDicts + in_sharded_axes: List[str], + replicate_on_axes: List[str], + output_is_replicated: List[str], + common_kwargs: Optional[Dict[str, Any]] = None, + ): + """Run a method on all workers in parallel with sharded data. + + All axes provided in in_sharded_axes will be replicated on replicate_on_axes. For axes not provided in either, + data will just be sent to index 0 of that axis. + + Args: + method_name: Name of the method to call on each worker + data: Iterable of SlicedDataDicts to pass to workers/groups + in_sharded_axes: List of axes that are sharded + replicate_on_axes: List of axes that are to be replicated + output_is_replicated: List of axes along which the output is replicated (and we should just return the first result) + common_kwargs: Additional keyword arguments to pass to all workers + Returns: + MultiWorkerFuture: Object containing futures and their associated worker information + """ + if self.sharding_annotations is None: + raise ValueError( + "Sharding annotations must be provided to use sharded data distribution" + ) + + if common_kwargs is None: + common_kwargs = {} + + futures = [] + used_workers = [] + + # Validate axes + for axis in in_sharded_axes + replicate_on_axes: + if axis not in self.sharding_annotations.names: + raise ValueError( + f"Axis '{axis}' not found in sharding annotations. Valid axes: {self.sharding_annotations.names}" + ) + + # Check for overlapping axes + overlap = set(in_sharded_axes).intersection(set(replicate_on_axes)) + if overlap: + raise ValueError(f"Axes cannot be both sharded and replicated: {overlap}") + + return_from_workers = [] + # For each worker, determine what data it should receive + for worker_idx, worker in enumerate(self._workers): + # Get the worker's coordinates in the sharding space + worker_coords = {} + for axis in self.sharding_annotations.names: + # For this worker, find its position in each axis + for i in range(self.sharding_annotations.get_axis_size(axis)): + ranks = self.sharding_annotations.get_ranks(**{axis: i}) + if isinstance(ranks, int): + if ranks == worker_idx: + worker_coords[axis] = i + break + elif worker_idx in ranks.layout.flatten(): + worker_coords[axis] = i + break + + # Determine if this worker should receive data + should_receive_data = True + return_from_this_worker = True + for axis in self.sharding_annotations.names: + if axis not in worker_coords: + continue + if ( + axis not in in_sharded_axes + and axis not in replicate_on_axes + and worker_coords[axis] != 0 + ): + # For axes not in either list, only workers at index 0 receive data + should_receive_data = False + break + if axis in output_is_replicated: + if worker_coords[axis] != 0: + return_from_this_worker = False + if return_from_this_worker: + return_from_workers.append(worker_idx) + + if should_receive_data: + # Find the appropriate data slice for this worker + worker_data = data + for axis in in_sharded_axes: + if axis in worker_coords: + # Select the appropriate slice for this axis + worker_data = worker_data[worker_coords[axis]] + + # Call the method on the worker with its data slice + future = getattr(worker, method_name).remote( + worker_data, **common_kwargs + ) + futures.append(future) + else: + # If this worker doesn't need data, just call the method with None + future = getattr(worker, method_name).remote(None, **common_kwargs) + futures.append(future) + + return MultiWorkerFuture( + futures=futures, used_workers=None, return_from_workers=return_from_workers + ) + def get_all_worker_results(self, future_bundle): """Get results from all workers, optionally filtering to get just one result per tied worker group. diff --git a/nemo_rl/models/megatron/common.py b/nemo_rl/models/megatron/common.py index 341a77c5b..2baccbe82 100644 --- a/nemo_rl/models/megatron/common.py +++ b/nemo_rl/models/megatron/common.py @@ -11,3 +11,141 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial +from typing import Iterable + +import torch +import torch.distributed as dist +from megatron.core.models.gpt import GPTModel +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) +from megatron.training.utils import get_ltor_masks_and_position_ids +from nemo.tron.state import GlobalState + +from nemo_rl.algorithms.loss_functions import LossFunction + + +def forward_step_arbitrary_loss( + state: GlobalState, data_iterator: Iterable, model: GPTModel, loss_fn: LossFunction +): + """Forward training step. + + Args: + state (GlobalState): Global state for the run + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + # timers = state.timers + straggler_timer = state.straggler_timer + + # timers("batch-generator", log_level=2).start() + with straggler_timer(bdata=True): + data_dict = next(data_iterator).to("cuda") + input_ids = data_dict["input_ids"] + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) + output_tensor = model(input_ids, position_ids, attention_mask) + loss_data = data_dict + # timers("batch-generator").stop() + + with straggler_timer: + output_tensor = model(input_ids, position_ids, attention_mask) + + return output_tensor, partial( + loss_fn, + data=loss_data, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + ) # lambda x: (torch.sum(x), {'a': x}) # + + +def broadcast_tensor( + tensor: torch.Tensor | None, src_rank: int, group: dist.ProcessGroup +): + """Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. + + Handles the case where the input tensor might be None on non-source ranks. + If the input tensor is provided on non-source ranks, it must have the + correct shape and dtype matching the tensor on the source rank. + + Args: + tensor: The tensor to broadcast on the source rank. Can be None on + non-source ranks (will be created with correct shape/dtype). + If not None on non-source ranks, it's used as the buffer + for the broadcast and must match the source tensor's metadata. + src_rank (int): The global rank of the source process. + group: The process group for communication. + + Returns: + torch.Tensor: The broadcasted tensor. On non-source ranks, this will + be the tensor received from the source. + + Raises: + ValueError: If the tensor is None on the source rank, or if a tensor + provided on a non-source rank has mismatched shape/dtype/device. + TypeError: If broadcasting metadata fails (e.g., due to pickling issues). + """ + rank = dist.get_rank() + # Assume operations happen on the default CUDA device for the rank + # TODO: Consider making device explicit if needed, e.g., derive from tensor on src + device = torch.cuda.current_device() + + # 1. Broadcast metadata (shape and dtype) using broadcast_object_list + if rank == src_rank: + if tensor is None: + raise ValueError(f"Rank {rank} is source ({src_rank}) but tensor is None.") + # Package metadata into a list containing shape and dtype + metadata = [tensor.shape, tensor.dtype] + object_list = [metadata] + else: + # Placeholder for receiving the object on non-source ranks + object_list = [None] + + # Broadcast the list containing the metadata object + # This relies on the underlying distributed backend supporting object serialization (pickle) + try: + dist.broadcast_object_list(object_list, src=src_rank, group=group) + except Exception as e: + # Catch potential issues with pickling or backend support + raise TypeError( + f"Failed to broadcast tensor metadata using broadcast_object_list: {e}" + ) from e + + # All ranks now have the metadata in object_list[0] + received_shape, received_dtype = object_list[0] + + # 2. Prepare tensor buffer on non-source ranks + if rank != src_rank: + if tensor is None: + # Create tensor if it wasn't provided by the caller + tensor = torch.empty(received_shape, dtype=received_dtype, device=device) + else: + # Validate the tensor provided by the caller on the non-source rank + if tensor.shape != received_shape: + raise ValueError( + f"Rank {rank}: Provided tensor has shape {tensor.shape}, " + f"but source rank {src_rank} is broadcasting shape {received_shape}." + ) + if tensor.dtype != received_dtype: + raise ValueError( + f"Rank {rank}: Provided tensor has dtype {tensor.dtype}, " + f"but source rank {src_rank} is broadcasting dtype {received_dtype}." + ) + # Ensure the provided tensor is on the correct device + # Compare torch.device objects directly for accuracy + if tensor.device != torch.device(device): + raise ValueError( + f"Rank {rank}: Provided tensor is on device {tensor.device}, " + f"but expected broadcast device is {device}." + ) + + # 3. Broadcast the actual tensor data + # The tensor object (either original on src, newly created, or validated user-provided on non-src) + # must exist on all ranks before calling broadcast. + # `dist.broadcast` operates in-place on the provided tensor object. + dist.broadcast(tensor, src=src_rank, group=group) + + return tensor diff --git a/nemo_rl/models/megatron/converters/__init__.py b/nemo_rl/models/megatron/converters/__init__.py new file mode 100644 index 000000000..f37ce6f70 --- /dev/null +++ b/nemo_rl/models/megatron/converters/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + +from .common import ( + SafeDict, + get_all_rank_ids_in_group, + get_global_layer_num, + get_local_layer_num, +) +from .llama import mcore_te_to_hf_llama +from .qwen2 import mcore_te_to_hf_qwen2 + + +class ModelType(Enum): + LLAMA = "LlamaForCausalLM" + QWEN2 = "Qwen2ForCausalLM" + + +REGISTRY = { + ModelType.LLAMA: mcore_te_to_hf_llama, + ModelType.QWEN2: mcore_te_to_hf_qwen2, +} +# Allow indexing by string name +for key in list(REGISTRY.keys()): + REGISTRY[key.value] = REGISTRY[key] + +__all__ = [ + "get_all_rank_ids_in_group", + "get_local_layer_num", + "get_global_layer_num", + "REGISTRY", + "SafeDict", +] diff --git a/nemo_rl/models/megatron/converters/common.py b/nemo_rl/models/megatron/converters/common.py new file mode 100644 index 000000000..c0bdcd163 --- /dev/null +++ b/nemo_rl/models/megatron/converters/common.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from megatron.core import parallel_state + +_GROUP_TO_RANKS_CACHE = {} + + +def get_all_rank_ids_in_group(group): + """Get all rank ids in a group.""" + if group in _GROUP_TO_RANKS_CACHE: + return _GROUP_TO_RANKS_CACHE[group] + + curr_global_rank = int(torch.distributed.get_rank()) + group_size = torch.distributed.get_world_size(group=group) + global_rank_tensor = torch.tensor( + [curr_global_rank], dtype=torch.int, device=torch.cuda.current_device() + ) + global_ranks = [ + torch.empty(1, dtype=torch.int, device=torch.cuda.current_device()) + for _ in range(group_size) + ] + torch.distributed.all_gather(global_ranks, global_rank_tensor, group=group) + _GROUP_TO_RANKS_CACHE[group] = [ + int(global_ranks[i].item()) for i in range(group_size) + ] + return _GROUP_TO_RANKS_CACHE[group] + + +def get_local_layer_num(s): + """Assumes layer number is preceeded by 'layers.'.""" + segments = s.split(".") + number = None + for i, segment in enumerate(segments): + if segment == "layers": + if segments[i + 1].isdigit(): + number = int(segments[i + 1]) + break + return number + + +def get_global_layer_num(s, cfg): + """Assumes layer number is preceeded by 'layers.'. + + Assumes pipeline model parallel size is set. + In the state dict, the layer number is the local layer number (PP local). + This function converts the local layer number to the global layer number. + """ + local_layer_num = get_local_layer_num(s) + global_layer_num = ( + parallel_state.get_pipeline_model_parallel_rank() + * cfg.num_layers + // parallel_state.get_pipeline_model_parallel_world_size() + + local_layer_num + ) + return global_layer_num + + +class SafeDict(dict): + def __missing__(self, key): + return "{" + key + "}" diff --git a/nemo_rl/models/megatron/converters/llama.py b/nemo_rl/models/megatron/converters/llama.py new file mode 100644 index 000000000..97fdedcb8 --- /dev/null +++ b/nemo_rl/models/megatron/converters/llama.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import einops +import torch + +_GROUP_TO_RANKS_CACHE = {} + + +def split_qkv_llama(gathered_mcore_qkv_layer, cfg): + hidden_size = cfg.hidden_size + head_num = cfg.num_attention_heads + num_query_groups = ( + cfg.num_query_groups or head_num + ) # different num_query_groups for 70B + + head_size = cfg.kv_channels or ( + hidden_size // head_num + ) # equivalent to hf's head_dim + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_weights = gathered_mcore_qkv_layer + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange( + (heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group + ) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + ## Example of slices + ## 7b: num_query_groups = head_num = 32, + ## q_slice = [0, 3, 6, 9 , ... 90, 93] + ## k_slice = [1, 4, 7, 10, ... 91, 94] + ## v_slice = [2, 5, 8, 11, ... 92, 95] + ## 70b (with GQA): num_query_groups = 8, head_num = 64 + ## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77] + ## k_slice = [8, 18, 28, ... 68, 78] + ## v_slice = [9, 19, 29, ... 69, 79] + + q_name = "model.layers.{gl}.self_attn.q_proj.weight" + k_name = "model.layers.{gl}.self_attn.k_proj.weight" + v_name = "model.layers.{gl}.self_attn.v_proj.weight" + q = qkv_weights[q_slice].reshape(-1, hidden_size) + k = qkv_weights[k_slice].reshape(-1, hidden_size) + v = qkv_weights[v_slice].reshape(-1, hidden_size) + + return {q_name: q, k_name: k, v_name: v} + + +def split_fc1_gate_down_llama(gathered_mcore_fc1, cfg): + # gate proj and up proj are mixed right now, and we need to reshape them + # [ gate_tp0 ] [ gate_tp0 ] + # [ up_tp0 ] --\ [ gate_tp1 ] --\ (split gate) + # [ gate_tp1 ] --/ [ up_tp0 ] --/ (split up) + # [ up_tp1 ] [ up_tp1 ] + tp = cfg.tensor_model_parallel_size + gathered_mcore_fc1 = einops.rearrange( + gathered_mcore_fc1, "(t c d) a1 -> c (t d) a1", c=2, t=tp + ) + mlp_gate_proj_weight = gathered_mcore_fc1[0] + mlp_up_proj_weight = gathered_mcore_fc1[1] + mlp_gate_proj_base_name = "model.layers.{gl}.mlp.gate_proj.weight" + mlp_up_proj_base_name = "model.layers.{gl}.mlp.up_proj.weight" + return { + mlp_up_proj_base_name: mlp_up_proj_weight, + mlp_gate_proj_base_name: mlp_gate_proj_weight, + } + + +mcore_te_to_hf_llama = { + "embedding.word_embeddings.weight": {"tp": 0, "hf": "model.embed_tokens.weight"}, + "decoder.final_layernorm.weight": {"hf": "model.norm.weight"}, + "output_layer.weight": {"tp": 0, "hf": "lm_head.weight"}, + "decoder.layers.{l}.self_attention.linear_proj.weight": { + "tp": 1, + "hf": "model.layers.{gl}.self_attn.o_proj.weight", + }, + "decoder.layers.{l}.self_attention.linear_qkv.weight": { + "tp": 0, + "hf_func": split_qkv_llama, + }, + "decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight": { + "hf": "model.layers.{gl}.input_layernorm.weight" + }, + "decoder.layers.{l}.mlp.linear_fc1.weight": { + "tp": 0, + "hf_func": split_fc1_gate_down_llama, + }, + "decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight": { + "hf": "model.layers.{gl}.post_attention_layernorm.weight" + }, + "decoder.layers.{l}.mlp.linear_fc2.weight": { + "tp": 1, + "hf": "model.layers.{gl}.mlp.down_proj.weight", + }, +} diff --git a/nemo_rl/models/megatron/converters/qwen2.py b/nemo_rl/models/megatron/converters/qwen2.py new file mode 100644 index 000000000..784dd1976 --- /dev/null +++ b/nemo_rl/models/megatron/converters/qwen2.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from nemo_rl.models.megatron.converters.llama import ( + split_fc1_gate_down_llama, + split_qkv_llama, +) + + +def split_qkv_bias_qwen(gathered_mcore_qkv_layer, cfg): + hidden_size = cfg.hidden_size + head_num = cfg.num_attention_heads + num_query_groups = ( + cfg.num_query_groups or head_num + ) # different num_query_groups for GQA + + head_size = cfg.kv_channels or ( + hidden_size // head_num + ) # equivalent to hf's head_dim + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = gathered_mcore_qkv_layer + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + + q_slice = torch.cat( + [ + torch.arange( + (heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group + ) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_name = "model.layers.{gl}.self_attn.q_proj.bias" + k_name = "model.layers.{gl}.self_attn.k_proj.bias" + v_name = "model.layers.{gl}.self_attn.v_proj.bias" + q = qkv_bias[q_slice].reshape(-1) + k = qkv_bias[k_slice].reshape(-1) + v = qkv_bias[v_slice].reshape(-1) + + return {q_name: q, k_name: k, v_name: v} + + +mcore_te_to_hf_qwen2 = { + "embedding.word_embeddings.weight": {"tp": 0, "hf": "model.embed_tokens.weight"}, + "decoder.final_layernorm.weight": {"hf": "model.norm.weight"}, + "output_layer.weight": {"tp": 0, "hf": "lm_head.weight"}, + "decoder.layers.{l}.self_attention.linear_proj.weight": { + "tp": 1, + "hf": "model.layers.{gl}.self_attn.o_proj.weight", + }, + "decoder.layers.{l}.self_attention.linear_qkv.weight": { + "tp": 0, + "hf_func": split_qkv_llama, + }, + "decoder.layers.{l}.self_attention.linear_qkv.bias": { + "tp": 0, + "hf_func": split_qkv_bias_qwen, + }, + "decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight": { + "hf": "model.layers.{gl}.input_layernorm.weight" + }, + "decoder.layers.{l}.mlp.linear_fc1.weight": { + "tp": 0, + "hf_func": split_fc1_gate_down_llama, + }, + "decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight": { + "hf": "model.layers.{gl}.post_attention_layernorm.weight" + }, + "decoder.layers.{l}.mlp.linear_fc2.weight": { + "tp": 1, + "hf": "model.layers.{gl}.mlp.down_proj.weight", + }, +} diff --git a/nemo_rl/models/megatron/refit_utils.py b/nemo_rl/models/megatron/refit_utils.py new file mode 100644 index 000000000..0344a40dd --- /dev/null +++ b/nemo_rl/models/megatron/refit_utils.py @@ -0,0 +1,198 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re +from typing import Dict, List, Tuple + +import torch +from megatron.core import parallel_state +from nemo.collections.llm.gpt.model.base import GPTConfig + +import nemo_rl.models.megatron.converters as model_converters + + +def get_param_conversion_recipe_dict( + name, converter_type: model_converters.ModelType, model_cfg: GPTConfig +): + converter_dict = model_converters.REGISTRY[converter_type] + + local_layer = model_converters.get_local_layer_num(name) + global_layer = ( + model_converters.get_global_layer_num(name, model_cfg) + if local_layer is not None + else None + ) + format_dict = model_converters.SafeDict(l=local_layer, gl=global_layer) + + formatted_mapping = { + k.format_map(format_dict): rec for k, rec in converter_dict.items() + } + return formatted_mapping, format_dict + + +@torch.no_grad() +def get_global_param_key_to_local_key_map( + model, model_cfg: GPTConfig, keys: List[Tuple[str, str]] +) -> Dict[str, Tuple[int, str]]: + """Get a mapping from global parameter keys to local parameter keys. + + Args: + model: The model to get the mapping for. + model_cfg: The model configuration. + keys: The keys to get the mapping for. Tuple of (local_key, global_hf_key) + + Returns: + A dictionary mapping global parameter keys to a tuple of (rank, local parameter key). + """ + # Initialize pipeline parallel group information. + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + pp_global_rank_ids = model_converters.get_all_rank_ids_in_group(pp_group) + + # Build a mapping on each PP rank from a computed global key to the raw state dict key. + # The global key is computed by replacing the local layer number (after "layers.") + # with its corresponding global layer number (if applicable). + local_map = {} + for local_key, _ in keys: + if local_key not in model.state_dict(): + continue + local_layer = model_converters.get_local_layer_num(local_key) + if local_layer is not None: + global_layer = model_converters.get_global_layer_num(local_key, model_cfg) + # Replace the first occurrence of the digits after "layers." with the global layer number. + global_key = re.sub( + r"(?<=layers\.)\d+", str(global_layer), local_key, count=1 + ) + else: + global_key = local_key + local_map[global_key] = local_key + + # Gather the local maps from all PP ranks (only lightweight key info is gathered). + all_maps = [None] * pp_world_size + torch.distributed.all_gather_object(all_maps, local_map, group=pp_group) + + # Build the union over global keys and assign an owner (the rank with the smallest PP rank). + union_global_map = {} + for pp_rank, omap in enumerate(all_maps): + for gk, raw_key in omap.items(): + if ( + gk not in union_global_map + or pp_global_rank_ids[pp_rank] < union_global_map[gk][0] + ): + union_global_map[gk] = (pp_global_rank_ids[pp_rank], raw_key) + else: + print( + f"WARNING: {gk} already in union_global_map when gathering keys", + flush=True, + ) + + return union_global_map + + +@torch.no_grad() +def gather_and_convert_params( + model, + converter_type: model_converters.ModelType, + model_cfg: GPTConfig, + param_name_to_rank_and_key, +): + # Process each parameter (by its unique global key) one at a time. + gathered_params = {} + for gk in sorted(param_name_to_rank_and_key.keys()): + owner_pp_global_rank, owner_raw_key = param_name_to_rank_and_key[gk] + + # Only the owner PP rank has the parameter locally. + if torch.distributed.get_rank() == owner_pp_global_rank: + param = model.state_dict()[owner_raw_key] + + # Use the conversion dict to get the appropriate recipe for this parameter. + recipe_dict, format_dict = get_param_conversion_recipe_dict( + owner_raw_key, converter_type, model_cfg + ) + recipe = recipe_dict.get(owner_raw_key, None) + if recipe is None and "_extra_state" not in owner_raw_key: + print( + f"WARNING: {owner_raw_key} has no recipe mapping for conversion", + flush=True, + ) + hf_mapping = {"None": None} + else: + # If the parameter is TP-sharded, gather its slices on GPU. + if recipe.get("tp", None) is not None: + tp_group = parallel_state.get_tensor_model_parallel_group() + tp_world_size = torch.distributed.get_world_size(tp_group) + gathered_slices = [ + torch.empty_like(param) for _ in range(tp_world_size) + ] + torch.distributed.all_gather(gathered_slices, param, group=tp_group) + full_param = torch.cat(gathered_slices, dim=recipe["tp"]).to( + torch.bfloat16 + ) + else: + full_param = torch.clone(param).to(torch.bfloat16) + + # Convert the parameter using the provided function or mapping. + if recipe.get("hf_func", None) is not None: + hf_mapping = recipe["hf_func"](full_param, model_cfg) + hf_mapping = { + k.format_map(format_dict): v for k, v in hf_mapping.items() + } + elif recipe.get("hf", None) is not None: + hf_mapping = {recipe["hf"].format_map(format_dict): full_param} + else: + raise NotImplementedError( + f"No conversion recipe found for {owner_raw_key}" + ) + else: + hf_mapping = None # Non-owner ranks will receive the converted tensors. + + # Broadcast the list of target HF parameter keys from the owner. + pp_group = parallel_state.get_pipeline_model_parallel_group() + if torch.distributed.get_rank() == owner_pp_global_rank: + target_keys = [list(hf_mapping.keys())] + else: + target_keys = [None] # Placeholder to be filled by broadcast. + + torch.distributed.broadcast_object_list( + target_keys, src=owner_pp_global_rank, group=pp_group + ) + if "None" in target_keys[0]: + continue + + # For each converted tensor (could be more than one per original parameter), broadcast it individually. + for target_key in target_keys[0]: + if torch.distributed.get_rank() == owner_pp_global_rank: + tensor_to_send = hf_mapping[target_key] + else: + tensor_to_send = None + # Broadcast tensor metadata (shape and dtype) to allocate GPU buffer on receiving ranks. + meta = [None] + if torch.distributed.get_rank() == owner_pp_global_rank: + meta[0] = (tensor_to_send.shape, str(tensor_to_send.dtype)) + torch.distributed.broadcast_object_list( + meta, src=owner_pp_global_rank, group=pp_group + ) + shape, dtype_str = meta[0] + dtype = getattr(torch, dtype_str.split(".")[-1]) + if torch.distributed.get_rank() != owner_pp_global_rank: + tensor_to_send = torch.empty( + *shape, dtype=dtype, device=torch.cuda.current_device() + ) + torch.distributed.broadcast( + tensor_to_send, src=owner_pp_global_rank, group=pp_group + ) + gathered_params[target_key] = tensor_to_send + + torch.cuda.empty_cache() + torch.cuda.synchronize() + return gathered_params diff --git a/nemo_rl/models/policy/megatron_policy.py b/nemo_rl/models/policy/megatron_policy.py new file mode 100644 index 000000000..7428a27ad --- /dev/null +++ b/nemo_rl/models/policy/megatron_policy.py @@ -0,0 +1,321 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from collections import defaultdict +from typing import List, Optional, Union + +import numpy as np +import ray +from ray.util.queue import Queue +from transformers import AutoTokenizer + +from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationInterface, + GenerationOutputSpec, +) +from nemo_rl.models.interfaces import PolicyInterface +from nemo_rl.models.policy import PolicyConfig + + +class MegatronPolicy(PolicyInterface, GenerationInterface): + def __init__( + self, + cluster: RayVirtualCluster, + config: PolicyConfig, + tokenizer: AutoTokenizer, + name_prefix: str = "megatron_policy", + workers_per_node: Optional[Union[int, List[int]]] = None, + init_optimizer: bool = True, + weights_path: Optional[str] = None, + optimizer_path: Optional[str] = None, + init_reference_model: bool = True, + ): + from nemo_rl.models.policy.megatron_policy_worker import ( + MegatronPolicyWorker, + ) + + if weights_path: + weights_path = os.path.abspath(weights_path) + if optimizer_path: + optimizer_path = os.path.abspath(optimizer_path) + + self.sharding_annotations = NamedSharding( + layout=np.arange(cluster.world_size()).reshape( + ( + config["pipeline_model_parallel_size"], + -1, + config["tensor_model_parallel_size"], + ) + ), + names=["pipeline_model_parallel", "data_parallel", "tensor_model_parallel"], + ) + + pre_init_queue = ( + Queue() + ) # just for communication before torch distributed is set up + worker_builder = RayWorkerBuilder( + MegatronPolicyWorker, + config, + tokenizer=tokenizer, + checkpoint_dir=None, + worker_sharding_annotations=self.sharding_annotations, + pre_init_communication_queue=pre_init_queue, + init_optimizer=init_optimizer, + init_reference_model=init_reference_model, + ) + + self.worker_group = RayWorkerGroup( + cluster, + worker_builder, + name_prefix=name_prefix, + workers_per_node=workers_per_node, + sharding_annotations=self.sharding_annotations, + ) + self.dp_size = self.sharding_annotations.get_axis_size("data_parallel") + self.cfg = config + + def get_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec] + ) -> BatchedDataDict: + """Get the logprobs of the model for a data dict. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_sharded_data( + "get_logprobs", + sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=["pipeline_model_parallel", "tensor_model_parallel"], + output_is_replicated=["tensor_model_parallel", "pipeline_model_parallel"], + ) + logprobs = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + return logprobs + + def get_reference_policy_logprobs( + self, data: BatchedDataDict[GenerationDatumSpec], micro_batch_size: int = None + ) -> BatchedDataDict: + """Get the logprobs of the reference policy for a data dict. + + If micro_batch_size is provided, it will be used instead of the configured + logprob_batch_size. + Returns: Identical to get_logprobs. + """ + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_sharded_data( + "get_reference_policy_logprobs", + sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=["pipeline_model_parallel", "tensor_model_parallel"], + output_is_replicated=["tensor_model_parallel", "pipeline_model_parallel"], + ) + logprobs = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures) + ) + return logprobs + + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ): + """Train the policy on a batch of data with a given loss function.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + # Shard and replicate the batch + shards = self.dp_size + sharded_data = data.shard_by_batch_size(shards, batch_size=batch_size) + + # Train each shard in parallel + futures = self.worker_group.run_all_workers_sharded_data( + "train", + sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=["pipeline_model_parallel", "tensor_model_parallel"], + output_is_replicated=["tensor_model_parallel", "pipeline_model_parallel"], + common_kwargs={ + "loss_fn": loss_fn, + "eval_mode": eval_mode, + "gbs": gbs, + "mbs": mbs, + }, + ) + results = self.worker_group.get_all_worker_results(futures) + + # Aggregate the results + aggregated_results = {} + aggregated_results["loss"] = results[0]["global_loss"] + + # Aggregate metrics across all workers + all_mb_metrics = defaultdict(list) + for r in results: + for k, v in r["all_mb_metrics"].items(): + all_mb_metrics[k].extend(v) + aggregated_results["all_mb_metrics"] = dict(all_mb_metrics) + + return aggregated_results + + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using the policy.""" + # Verify input data is right-padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + "Missing required input fields" + ) + + sharded_data = data.shard_by_batch_size(self.dp_size, batch_size=None) + futures = self.worker_group.run_all_workers_sharded_data( + "generate", + sharded_data, + in_sharded_axes=["data_parallel"], + replicate_on_axes=["pipeline_model_parallel", "tensor_model_parallel"], + common_kwargs={"greedy": greedy}, + output_is_replicated=["tensor_model_parallel", "pipeline_model_parallel"], + ) + result = BatchedDataDict.from_batches( + self.worker_group.get_all_worker_results(futures), + pad_value_dict={"output_ids": self.cfg["generation"]["pad_token_id"]}, + ) + + # Verify the output has all required fields + required_keys = [ + "output_ids", + "generation_lengths", + "unpadded_sequence_lengths", + "logprobs", + ] + missing_keys = [key for key in required_keys if key not in result] + if missing_keys: + raise ValueError( + f"Missing required keys for GenerationOutputSpec: {missing_keys}" + ) + + return result + + def prepare_for_generation(self, *args, **kwargs): + # We don't need to do anything here + pass + + def prepare_for_training(self, *args, **kwargs): + # onload everything to the GPU + futures = self.worker_group.run_all_workers_single_data( + "prepare_for_training", only_on="all_tied_workers" + ) + ray.get(futures) + + def prepare_for_lp_inference(self, *args, **kwargs): + futures = self.worker_group.run_all_workers_single_data( + "prepare_for_lp_inference", only_on="all_tied_workers" + ) + ray.get(futures) + + def finish_generation(self, *args, **kwargs): + # We don't need to do anything here + pass + + def finish_training(self, *args, **kwargs): + # Placeholder implementation + pass + + def prepare_weights_for_ipc(self): + futures = self.worker_group.run_all_workers_single_data( + "prepare_weights_for_ipc", only_on="all_tied_workers" + ) + return ray.get(futures)[0] + + def get_weights_ipc_handles(self, keys): + """Fetch weight IPC handles from all workers. + + Returns: + dict: A dictionary mapping device UUIDs to parameter IPC handles. + """ + # Collect IPC handles from all workers + worker_handles = ray.get( + [ + worker.get_weights_ipc_handles.remote(keys) + for worker in self.worker_group.workers + ] + ) + + # Combine all worker handles into a single dictionary + all_handles = {} + for handle in worker_handles: + all_handles.update(handle) + + return all_handles + + def offload_before_refit(self): + """Offload the optimizer and buffers to the CPU.""" + futures = self.worker_group.run_all_workers_single_data( + "offload_before_refit", only_on="all_tied_workers" + ) + ray.get(futures) + + def offload_after_refit(self): + """Offload the optimizer and buffers to the CPU.""" + futures = self.worker_group.run_all_workers_single_data( + "offload_after_refit", only_on="all_tied_workers" + ) + ray.get(futures) + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): + """Save a checkpoint of the model.""" + futures = self.worker_group.run_all_workers_single_data( + "save_checkpoint", + weights_path, + optimizer_path, + offload_to_cpu=offload_to_cpu, + only_on="all_tied_workers", + ) + ray.get(futures) + + def shutdown(self) -> bool: + """Shut down all HF workers and clean up resources.""" + try: + # Use the worker group's shutdown method with the worker's cleanup method + return self.worker_group.shutdown(cleanup_method="shutdown") + except Exception as e: + print(f"Error during policy shutdown: {e}") + return False + + def __del__(self): + """Shuts down the worker groups when the object is deleted or is garbage collected. + + This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to + the object is lost due to leaving a function scope. It's always recommended that the + user calls worker_group.shutdown(). + """ + self.worker_group.shutdown() diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py new file mode 100644 index 000000000..3a0ceb2e7 --- /dev/null +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -0,0 +1,1071 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import os +import time +import warnings +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from functools import partial +from typing import Any, Dict, Iterable, Optional + +import ray +import torch +from megatron.core import parallel_state +from megatron.core.inference.engines import ( + StaticInferenceEngine, +) +from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( + InferenceWrapperConfig, +) +from megatron.core.inference.text_generation_controllers.text_generation_controller import ( + TextGenerationController, +) +from megatron.core.models.gpt import GPTModel +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + get_pipeline_model_parallel_last_rank, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, + is_pipeline_last_stage, +) +from megatron.core.pipeline_parallel import get_forward_backward_func +from megatron.core.rerun_state_machine import get_rerun_state_machine +from megatron.inference.text_generation.mcore_engine_server import ( + run_mcore_engine, +) +from megatron.training.utils import get_ltor_masks_and_position_ids +from nemo.tron import fault_tolerance +from nemo.tron.checkpointing import checkpoint_exists, load_checkpoint +from nemo.tron.config import ( + CheckpointConfig, + ConfigContainer, + DistributedDataParallelConfig, + LoggerConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainingConfig, +) +from nemo.tron.init import initialize_megatron, set_jit_fusion_options +from nemo.tron.model import get_model_from_config +from nemo.tron.optim import setup_optimizer +from nemo.tron.setup import ( + HAVE_FSDP2, + _init_checkpointing_context, + _update_model_config_funcs, +) +from nemo.tron.state import GlobalState +from nemo.tron.tokenizers.tokenizer import build_tokenizer +from nemo.tron.utils.common_utils import get_rank_safe +from nemo.tron.utils.train_utils import ( + logical_and_across_model_parallel_group, + reduce_max_stat_across_model_parallel_group, +) +from ray.util.queue import Queue +from transformers import AutoTokenizer + +from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs +from nemo_rl.distributed.named_sharding import NamedSharding +from nemo_rl.distributed.virtual_cluster import ( + PY_EXECUTABLES, +) +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationOutputSpec, + verify_right_padding, +) +from nemo_rl.models.megatron.common import ( + broadcast_tensor, + forward_step_arbitrary_loss, +) +from nemo_rl.models.megatron.refit_utils import ( + gather_and_convert_params, + get_global_param_key_to_local_key_map, + get_param_conversion_recipe_dict, +) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.utils import get_gpu_info + + +def setup_megatron_model( + cfg: ConfigContainer, + load_optimizer: bool = True, + get_embedding_ranks=None, # TODO @sahilj: What is this? + get_position_embedding_ranks=None, +): + state = GlobalState() + state.cfg = cfg + # TODO: Freeze state.cfg + + initialize_megatron( + cfg=cfg, + get_embedding_ranks=get_embedding_ranks, + get_position_embedding_ranks=get_position_embedding_ranks, + gpu_visibility_externally_set=True, + ) + + if cfg.ft_config and cfg.ft_config.enable_ft_package: + fault_tolerance.setup(cfg, state) + fault_tolerance.maybe_setup_simulated_fault(cfg.ft_config) + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options(cfg.model_config, cfg.train_config.micro_batch_size) + + # Adjust the startup time so it reflects the largest value. + # This will be closer to what scheduler will see (outside of + # image ... launches. + start_time_tensor = torch.tensor( + [state.start_time], dtype=torch.double, device="cuda" + ) + torch.distributed.all_reduce(start_time_tensor, op=torch.distributed.ReduceOp.MIN) + state.start_time = start_time_tensor.item() + + print( + "time to initialize megatron (seconds): {:.3f}".format( + time.time() - state.start_time + ) + ) + torch.distributed.barrier() + + # Context used for persisting some state between checkpoint saves. + checkpointing_context = _init_checkpointing_context(cfg.checkpoint_config) + + # Tokenizer + tokenizer = build_tokenizer( + cfg.tokenizer_config, + make_vocab_size_divisible_by=cfg.model_config.make_vocab_size_divisible_by, + tensor_model_parallel_size=cfg.model_config.tensor_model_parallel_size, + ) + if not cfg.model_config.vocab_size: + cfg.model_config.vocab_size = tokenizer.vocab_size + + torch.distributed.barrier() + + # Model, optimizer, and learning rate. + model = get_model_from_config( + cfg.model_config, + cfg.ddp_config, + use_torch_fsdp2=cfg.dist_config.use_torch_fsdp2, + overlap_param_gather_with_optimizer_step=cfg.optimizer_config.overlap_param_gather_with_optimizer_step, + data_parallel_random_init=cfg.rng_config.data_parallel_random_init, + ) + if load_optimizer: + optimizer, scheduler = setup_optimizer( + optimizer_config=cfg.optimizer_config, + scheduler_config=cfg.scheduler_config, + model=model, + use_gloo_process_groups=cfg.dist_config.use_gloo_process_groups, + ) + else: + optimizer = None + scheduler = None + + _update_model_config_funcs( + model, + cfg.model_config, + cfg.ddp_config, + optimizer, + align_grad_reduce=cfg.dist_config.align_grad_reduce, + ) + print("Model, optimizer, and learning rate scheduler built") + torch.distributed.barrier() + + # Load checkpoint if applicable + if ( + cfg.checkpoint_config.load is not None + or cfg.checkpoint_config.pretrained_checkpoint is not None + ) and ( + checkpoint_exists(cfg.checkpoint_config.load) + or checkpoint_exists(cfg.checkpoint_config.pretrained_checkpoint) + ): + load_checkpoint( + state, + model, + optimizer, + scheduler, + checkpointing_context=checkpointing_context, + skip_load_to_model_and_opt=HAVE_FSDP2 and cfg.dist_config.use_torch_fsdp2, + ) + print("Checkpoint loaded") + torch.distributed.barrier() + + return state, model, optimizer, scheduler, checkpointing_context + + +@ray.remote +class MegatronPolicyWorker: + DEFAULT_PY_EXECUTABLE = PY_EXECUTABLES.SYSTEM + + def __repr__(self): + """Customizes the actor's prefix in the Ray logs. + + This makes it easier to identify which worker is producing specific log messages. + """ + if torch.distributed.is_initialized(): + return f"{self.__class__.__qualname__}[rank={torch.distributed.get_rank()}]" + else: + return f"{self.__class__.__qualname__}" + + def __init__( + self, + config: PolicyConfig, + tokenizer: AutoTokenizer, + checkpoint_dir: str, + worker_sharding_annotations: NamedSharding, + pre_init_communication_queue: Queue, + init_reference_model: bool = True, + init_optimizer: bool = True, + ): + self.cfg = config + self.checkpoint_dir = checkpoint_dir + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + } + self.dtype = dtype_map[self.cfg["precision"]] + + hf_model_name = self.cfg["model_name"] + self.tokenizer = tokenizer + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Check if the checkpoint already exists + output_path = f"/opt/checkpoints/tron/{hf_model_name}" + pt_checkpoint_exists = os.path.exists(output_path) and os.path.exists( + os.path.join(output_path, "iter_0000000") + ) + + if get_rank_safe() == 0: + if pt_checkpoint_exists: + print(f"Checkpoint already exists at {output_path}. Skipping import.") + else: + if "llama" in hf_model_name.lower(): + from nemo.tron.converter.llama import HFLlamaImporter + + print(f"Importing model {hf_model_name} to {output_path}...") + importer = HFLlamaImporter( + hf_model_name, + output_path=f"/opt/checkpoints/tron/{hf_model_name}", + ) + elif "qwen" in hf_model_name.lower(): + from nemo.tron.converter.qwen import HFQwen2Importer + + print(f"Importing model {hf_model_name} to {output_path}...") + importer = HFQwen2Importer( + hf_model_name, + output_path=f"/opt/checkpoints/tron/{hf_model_name}", + ) + else: + raise ValueError(f"Unknown model: {hf_model_name}") + importer.apply() + import megatron.core.rerun_state_machine + + megatron.core.rerun_state_machine.destroy_rerun_state_machine() + pre_init_communication_queue.put(True) + else: + pre_init_communication_queue.get() + pre_init_communication_queue.put(True) + + pretrained_ckpt = f"/opt/checkpoints/tron/{hf_model_name}" + pretrained_run_config = os.path.join( + pretrained_ckpt, "iter_0000000/run_config.yaml" + ) + cfg_from_pretrained = ConfigContainer.from_yaml(pretrained_run_config) + model_cfg = cfg_from_pretrained.model_config + cfg_from_pretrained.logger_config = LoggerConfig() + cfg_from_pretrained.checkpoint_config = CheckpointConfig( + save_interval=100, + save="/nemo_run/checkpoints", + load="/nemo_run/checkpoints", + pretrained_checkpoint=pretrained_ckpt, # This is the path to the pretrained ckpt for the SFT case + async_save=True, + fully_parallel_save=True, + ) + + model_cfg.tensor_model_parallel_size = self.cfg["tensor_model_parallel_size"] + model_cfg.pipeline_model_parallel_size = self.cfg[ + "pipeline_model_parallel_size" + ] + model_cfg.context_parallel_size = self.cfg[ + "context_parallel_size" + ] # not supported right now + model_cfg.bf16 = self.dtype == torch.bfloat16 + model_cfg.fp16 = self.dtype == torch.float16 + model_cfg.params_dtype = self.dtype # amp + model_cfg.pipeline_dtype = self.dtype # dtype_map[self.cfg["pipeline_dtype"]] + model_cfg.parallel_output = True + + checkpoint_config = CheckpointConfig( + save_interval=100, + save="/nemo_run/checkpoints", + load="/nemo_run/checkpoints", + pretrained_checkpoint=pretrained_ckpt, # This is the path to the pretrained ckpt for the SFT case + async_save=True, + fully_parallel_save=True, + fully_parallel_load=True, # Enable fully parallel load + ) + ref_checkpoint_config = CheckpointConfig( + pretrained_checkpoint=pretrained_ckpt, # This is the path to the pretrained ckpt for the SFT case + fully_parallel_load=True, # Enable fully parallel load + ) + self.megatron_cfg = ConfigContainer( + model_config=model_cfg, + checkpoint_config=checkpoint_config, + logger_config=LoggerConfig(logging_level=0), + train_config=TrainingConfig( + micro_batch_size=self.cfg["train_micro_batch_size"], # ignored + global_batch_size=self.cfg["train_global_batch_size"], # ignored + train_iters=1000, # Default value for inference + ), + optimizer_config=OptimizerConfig( + **self.cfg["megatron_cfg"]["optimizer"], + ), + ddp_config=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + use_distributed_optimizer=True, + ), + scheduler_config=SchedulerConfig( + **self.cfg["megatron_cfg"]["scheduler"], + ), + dataset_config=None, + tokenizer_config=TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name, + ), + ) + self.megatron_cfg.validate() + + print(f"cfg: {self.megatron_cfg}") + ( + self.mcore_state, + self.model, + self.optimizer, + self.scheduler, + self.checkpointing_context, + ) = setup_megatron_model(self.megatron_cfg, load_optimizer=init_optimizer) + self.model = self.model[0] # Get the first model from the list + for name, item in self.model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device="cpu", non_blocking=True, copy=True) + self.model.state_dict()[name] = item + + if init_reference_model: + ref_ckpt_context = _init_checkpointing_context(ref_checkpoint_config) + reference_model = get_model_from_config( + self.megatron_cfg.model_config, + self.megatron_cfg.ddp_config, + use_torch_fsdp2=self.megatron_cfg.dist_config.use_torch_fsdp2, + overlap_param_gather_with_optimizer_step=self.megatron_cfg.optimizer_config.overlap_param_gather_with_optimizer_step, + data_parallel_random_init=self.megatron_cfg.rng_config.data_parallel_random_init, + ) + if ( + ref_checkpoint_config.pretrained_checkpoint is not None + and checkpoint_exists(ref_checkpoint_config.pretrained_checkpoint) + ): + load_checkpoint( + self.mcore_state, + reference_model, + None, # no optimizer + None, # no scheduler + checkpointing_context=ref_ckpt_context, + skip_load_to_model_and_opt=HAVE_FSDP2 + and self.megatron_cfg.dist_config.use_torch_fsdp2, + ) + reference_model = reference_model[0] + reference_model.eval() + self.reference_state_dict = {} + for name, item in reference_model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to( + device="cpu", non_blocking=True, copy=True + ) + self.reference_state_dict[name] = item + print("Reference model loaded") + else: + print("Reference model not loaded") + + for name, item in self.model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device="cuda", non_blocking=True, copy=True) + self.model.state_dict()[name] = item + + from nemo.tron.tokenizers.tokenizer import build_tokenizer + + tokenizer_config = TokenizerConfig( + tokenizer_type="HuggingFaceTokenizer", + tokenizer_model=hf_model_name, + ) + + self.megatron_tokenizer = build_tokenizer( + tokenizer_config, + make_vocab_size_divisible_by=128, + tensor_model_parallel_size=self.cfg["tensor_model_parallel_size"], + ) + self.final_padded_vocab_size = tokenizer_config.padded_vocab_size + self.dp_size = worker_sharding_annotations.get_axis_size("data_parallel") + self.converter_type = self.cfg["megatron_cfg"]["converter_type"] + self._held_gather_buffer = None + + def is_alive(self): + return True + + def get_gpu_info(self): + """Return information about the GPU being used by this worker.""" + return get_gpu_info(self.model) + + def train( + self, + data: BatchedDataDict, + loss_fn: LossFunction, + eval_mode: bool = False, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> Dict[str, Any]: + """Train the policy on a batch of data with a given loss function.""" + if gbs is None: + gbs = self.cfg["train_global_batch_size"] + if mbs is None: + mbs = self.cfg["train_micro_batch_size"] + local_gbs = gbs // self.dp_size + dataset_size = data.size + + if eval_mode: + ctx = torch.no_grad() + self.model.eval() + else: + ctx = nullcontext() + # Ensure model is in training mode + self.model.train() + + with ctx: + forward_step = partial(forward_step_arbitrary_loss, loss_fn=loss_fn) + all_mb_metrics = [] + for gb_start in range(0, dataset_size, local_gbs): + num_microbatches = local_gbs // mbs + data_iterator = data.slice( + gb_start, gb_start + local_gbs + ).make_microbatch_iterator(mbs) + + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + # Set grad to zero. + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + forward_step_func=partial(forward_step, self.mcore_state), + data_iterator=data_iterator, + model=self.model, + num_microbatches=num_microbatches, + seq_length=self.cfg[ + "max_total_sequence_length" + ], # model_config.seq_length, + micro_batch_size=self.cfg["train_micro_batch_size"], + decoder_seq_length=self.cfg[ + "max_total_sequence_length" + ], # model_config.seq_length, + forward_only=False, + ) + + # Empty unused memory. + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + torch.cuda.empty_cache() + + # Update parameters. + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() + + # when freezing sub-models we may have a mixture of successful and unsucessful ranks, + # so we must gather across mp ranks + update_successful = logical_and_across_model_parallel_group( + update_successful + ) + # grad_norm and num_zeros_in_grad will be None on ranks without trainable params, + # so we must gather across mp ranks + grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm) + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group( + num_zeros_in_grad + ) + + # Update learning rate. + if update_successful: + increment = ( + num_microbatches + * self.cfg["train_micro_batch_size"] + * self.dp_size + ) + self.scheduler.step(increment=increment) + skipped_iter = 0 + curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) + curr_wd = self.scheduler.get_wd() + else: + skipped_iter = 1 + + # Empty unused memory. + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 2: + torch.cuda.empty_cache() + + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_metrics = {} + for key in losses_reduced[0].keys(): + numerator = 0 + denominator = 0 + for x in losses_reduced: + val = x[key] + # there is one dict per microbatch. in new reporting, we average + # over the total number of tokens across the global batch. + if isinstance(val, tuple) or isinstance(val, list): + numerator += val[0] + denominator += val[1] + else: + # legacy behavior. we average over the number of microbatches, + # and so the denominator is 1. + numerator += val + denominator += 1 + loss_metrics[key] = numerator / denominator + + loss_metrics["lr"] = curr_lr + loss_metrics["wd"] = curr_wd + torch.distributed.broadcast_object_list( + [loss_metrics], + src=get_pipeline_model_parallel_last_rank(), + group=get_pipeline_model_parallel_group(), + ) + else: + loss_metrics = [None] + torch.distributed.broadcast_object_list( + loss_metrics, + src=get_pipeline_model_parallel_last_rank(), + group=get_pipeline_model_parallel_group(), + ) + loss_metrics = loss_metrics[0] + + all_mb_metrics.append(loss_metrics) + + # Aggregate metrics across all microbatches + mb_metrics = defaultdict(list) + for m in all_mb_metrics: + for k, v in m.items(): + mb_metrics[k].append(v) + + with torch.no_grad(): + loss = torch.tensor(loss_metrics["loss"], device="cuda") + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) + + metrics = { + "global_loss": loss.cpu(), + "rank": torch.distributed.get_rank(), + "all_mb_metrics": dict(mb_metrics), + } + return metrics + + @torch.no_grad() + def get_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: + """Get the logprobs of the model for a batch of data. + + Uses the configured logprob_batch_size to do microbatching. + Input data is assumed to be right-padded. The method internally converts to + left-padded format for computation, and returns outputs in right-padded format. + If micro_batch_size is provided, it will be used instead of the configured + logprob_batch_size. + + Returns: + a BatchedDataDict with key "logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + logprob_batch_size = ( + micro_batch_size + if micro_batch_size is not None + else self.cfg["logprob_batch_size"] + ) + all_log_probs = [] + self.model.eval() + + pp_rank = get_pipeline_model_parallel_rank() + pp_grp = get_pipeline_model_parallel_group() + pp_size = get_pipeline_model_parallel_world_size() + + def forward_step_fn(data_iterator: Iterable, model: GPTModel): + data_dict = next(data_iterator) + input_ids = data_dict["input_ids"].cuda() + attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + input_ids, 0, False, False, False + ) + output_tensor = model(input_ids, position_ids, attention_mask) + + def collection_fn(output_tensor): + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + token_logprobs = from_parallel_logits_to_logprobs( + output_tensor.to(torch.float32), + target=input_ids, + vocab_start_index=tp_rank * output_tensor.shape[-1], + vocab_end_index=(tp_rank + 1) * output_tensor.shape[-1], + group=tp_grp, + inference_only=True, + ) + + # Prepend 0 logprob for first token to maintain same sequence length as input + token_logprobs = torch.cat( + [torch.zeros_like(token_logprobs[:, :1]), token_logprobs], dim=1 + ) + + return torch.tensor(0.0), {"logprobs": token_logprobs} + + return output_tensor, collection_fn + + forward_backward_func = get_forward_backward_func() + list_of_logprobs = forward_backward_func( + forward_step_func=forward_step_fn, + data_iterator=data.make_microbatch_iterator(logprob_batch_size), + model=self.model, + num_microbatches=max(1, data.size // logprob_batch_size), + seq_length=self.cfg["max_total_sequence_length"], + micro_batch_size=logprob_batch_size, + decoder_seq_length=self.cfg["max_total_sequence_length"], + forward_only=True, + ) + if is_pipeline_last_stage(ignore_virtual=True): + all_logprobs = [l["logprobs"] for l in list_of_logprobs] + logprobs = torch.cat(all_logprobs, dim=0) + # broadcast logprobs to first pp rank + broadcast_tensor(logprobs, torch.distributed.get_rank(), pp_grp) + else: + logprobs = broadcast_tensor( + None, get_pipeline_model_parallel_last_rank(), pp_grp + ) + return BatchedDataDict(logprobs=logprobs).to("cpu") + + @contextmanager + def use_reference_model(self): + """Context manager that temporarily swaps the reference model and active model. + + On entry: Moves model to CPU, moves reference_model to CUDA. Swaps the references + On exit: Restores original references and re-flips cuda/cpu + """ + with torch.no_grad(): + try: + # Save original references + model_state_dict = {} + for name, item in self.model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to( + device="cpu", non_blocking=True, copy=True + ) + model_state_dict[name] = item + + self.model.load_state_dict(self.reference_state_dict, strict=True) + # for name, item in self.reference_state_dict.items(): + # if isinstance(item, torch.Tensor): + # self.model.state_dict()[name] = item.detach().to(device="cuda", non_blocking=True, copy=True) + + gc.collect() + torch.cuda.empty_cache() + + # - self.model is the original reference_model, now on CUDA + # - self.reference_model is the original model, now on CPU + yield + + finally: + # Restore original references and device placement + self.model.load_state_dict(model_state_dict, strict=True) + # for name, item in model_state_dict.items(): + # if isinstance(item, torch.Tensor): + # item = item.detach().to(device="cuda", non_blocking=True, copy=True) + # self.model.state_dict()[name] = item + + gc.collect() + torch.cuda.empty_cache() + + def get_reference_policy_logprobs( + self, data: BatchedDataDict, micro_batch_size: int = None + ) -> BatchedDataDict: + """Get the logprobs from the reference policy for a batch of data. + + If micro_batch_size is provided, it will be used instead of the configured + logprob_batch_size. + + Returns: + a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length]. + We use the convention that the logprob of the first token is 0 so that the sequence length is maintained. + The logprob of input token i is specified at position i in the output logprobs tensor. + """ + with self.use_reference_model(): + reference_logprobs = self.get_logprobs(data, micro_batch_size) + + return_data = BatchedDataDict() + return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu() + return return_data + + @torch.no_grad() + def generate( + self, data: BatchedDataDict[GenerationDatumSpec], greedy: bool = False + ) -> BatchedDataDict[GenerationOutputSpec]: + """Generate a batch of data using huggingface framework generation. + + Args: + data: BatchedDataDict containing input_ids and input_lengths tensors + Returns: + BatchedDataDict conforming to GenerationOutputSpec: + - output_ids: input + generated token IDs + - logprobs: Log probabilities for each token + - generation_lengths: Lengths of each response + """ + self.model.config.flash_decode = True + # Verify input is right padded + assert isinstance(data, BatchedDataDict), ( + f"data must be a BatchedDataDict, got type: {type(data)}" + ) + assert "input_ids" in data and "input_lengths" in data, ( + f"input_ids and input_lengths must be present in the BatchedDataDict, got keys: {data.keys()}" + ) + is_right_padded, error_msg = verify_right_padding( + data, pad_value=self.tokenizer.pad_token_id + ) + if not is_right_padded: + warnings.warn( + f"Input to Megatron Generation worker is not properly right-padded: {error_msg}" + ) + + model_cfg = self.megatron_cfg.model_config + inference_wrapper_config = InferenceWrapperConfig( + hidden_size=model_cfg.hidden_size, + inference_batch_times_seqlen_threshold=1000000, + fp32_residual_connection=model_cfg.fp32_residual_connection, + params_dtype=model_cfg.params_dtype, + padded_vocab_size=self.final_padded_vocab_size, # Use the potentially updated value + inference_max_seq_length=self.cfg["generation"]["max_new_tokens"], + inference_max_requests=self.cfg["generation_batch_size"], + ) + + from megatron.core.inference.contexts import StaticInferenceContext + from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( + GPTInferenceWrapper, + ) + + inference_context = StaticInferenceContext.from_config(inference_wrapper_config) + + inference_wrapped_model = GPTInferenceWrapper( + self.model, inference_wrapper_config, inference_context + ) + text_generation_controller = TextGenerationController( + inference_wrapped_model=inference_wrapped_model, + tokenizer=self.megatron_tokenizer, + ) + inference_engine = StaticInferenceEngine( + text_generation_controller=text_generation_controller, + max_batch_size=self.cfg["generation_batch_size"], + ) + + # detokenize the prompts + # detokenized_prompts = [ + # self.tokenizer.decode(prompt) + # for prompt in data.get("input_ids") + # ] + # apply chat template + out = run_mcore_engine( + engine=inference_engine, + # prompts = detokenized_prompts, + prompt_tokens_tensor=data.get("input_ids"), + prompt_lengths_tensor=data.get("input_lengths"), + tokens_to_generate=self.cfg["generation"]["max_new_tokens"] + - data.get("input_ids").size(1), + ) + # print(out) + + input_lengths = data.get("input_lengths") + # pad the out "tokens" and "logprobs" and make them into tensors from lists + batch_size = data.get("input_ids").size(0) + max_seq_len = max([len(tokens) for tokens in out["tokens"]]) + + # Create padded tensors for tokens and logprobs + output_ids_padded = torch.full( + (batch_size, max_seq_len), + self.tokenizer.pad_token_id, + dtype=torch.long, + device=data.get("input_ids").device, + ) + + logprobs_padded = torch.zeros( + (batch_size, max_seq_len), + dtype=torch.float, + device=data.get("input_ids").device, + ) + + # Fill in the padded tensors with actual values + for i in range(batch_size): + seq_len = len(out["tokens"][i]) + output_ids_padded[i, :seq_len] = torch.tensor( + out["tokens"][i], dtype=torch.long, device=data.get("input_ids").device + ) + + logprob_len = len(out["logprobs"][i]) + logprobs_padded[i, 1 : logprob_len + 1] = torch.tensor( + out["logprobs"][i], + dtype=torch.float, + device=data.get("input_ids").device, + ) + + out_dict = { + "output_ids": output_ids_padded, + "logprobs": logprobs_padded, + "generation_lengths": torch.tensor( + [len(o) - input_lengths[i] for i, o in enumerate(out["logprobs"])] + ), + "unpadded_sequence_lengths": torch.tensor( + [len(o) for o in out["logprobs"]] + ), + } + + self.model.config.flash_decode = False + return BatchedDataDict.from_batches([out_dict]).to("cpu") + + def zero_out_weights(self): + """Zero out the weights of the model.""" + pass + + def report_device_id(self) -> str: + """Report the UUID of the current CUDA device using NVML. + + Returns: + str: UUID of the device in the format "GPU-xxxxx" + """ + from nemo_rl.utils.nvml import get_device_uuid + + # Get current device index from torch + device_idx = torch.cuda.current_device() + # Get device UUID using NVML + return get_device_uuid(device_idx) + + @torch.no_grad() + def prepare_weights_for_ipc(self): + """Prepare Megatron model weights for IPC transfer to vLLM. + + Collects information about weight tensors (names and sizes). + Returns a list of (parameter_name, size_in_bytes) tuples. + """ + # Ensure model is in evaluation mode + self.model.eval() + + # Get tensor parallel info + tp_world_size = parallel_state.get_tensor_model_parallel_world_size() + + # Collect parameter info + param_info = [] + + # Process each parameter in the model + for name, param in self.model.state_dict().items(): + # Skip _extra_state entries (these are metadata, not actual weights) + if "_extra_state" in name: + continue + + # Use the conversion dict to get the appropriate recipe for this parameter. + recipe_dict, _ = get_param_conversion_recipe_dict( + name, self.converter_type, self.megatron_cfg.model_config + ) + tp = 1 + if name in recipe_dict: + recipe = recipe_dict[name] + if "tp" in recipe and recipe["tp"] is not None: + tp = tp_world_size + + # Calculate size for this parameter + prec_to_bytes = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + } + scale = prec_to_bytes[self.dtype] / prec_to_bytes[param.dtype] + size_in_bytes = param.element_size() * param.numel() * tp * scale + param_info.append(((name, recipe.get("hf", None)), size_in_bytes)) + + # Include buffers (non-parameter tensors) + for name, buffer in self.model.named_buffers(): + if "_extra_state" in name: + continue + + prec_to_bytes = { + torch.bfloat16: 2, + torch.float16: 2, + torch.float32: 4, + } + scale = prec_to_bytes[self.dtype] / prec_to_bytes[buffer.dtype] + size_in_bytes = buffer.element_size() * buffer.numel() * scale + param_info.append((name, size_in_bytes)) + + # Gather parameter info from all pipeline parallel ranks to ensure complete coverage + pp_group = parallel_state.get_pipeline_model_parallel_group() + pp_world_size = torch.distributed.get_world_size(pp_group) + + # Gather all parameter info from all PP ranks + all_param_infos = [None] * pp_world_size + torch.distributed.all_gather_object(all_param_infos, param_info, group=pp_group) + + # Merge all parameter infos, keeping only unique parameter names + merged_param_info = [] + seen_params = set() + + for rank_param_info in all_param_infos: + for name, size in rank_param_info: + if name not in seen_params: + merged_param_info.append((name, size)) + seen_params.add(name) + + # Update param_info with the merged information + param_info = merged_param_info + + print(f"Prepared {len(param_info)} tensors for IPC transfer") + return param_info + + @torch.no_grad() + def get_weights_ipc_handles(self, keys): + """Get IPC handles for the requested Megatron model weights. + + Args: + keys: List of parameter names to get handles for + Returns: + Dict mapping device UUID to list of (mapped_key, handle) tuples + """ + param_name_to_rank_and_key = get_global_param_key_to_local_key_map( + self.model, self.megatron_cfg.model_config, keys + ) + gathered_params = gather_and_convert_params( + self.model, + self.converter_type, + self.megatron_cfg.model_config, + param_name_to_rank_and_key, + ) + gc.collect() + torch.cuda.empty_cache() + + # Get device UUID for IPC handles + device_uuid = self.report_device_id() + from torch.multiprocessing.reductions import reduce_tensor + + # Create IPC handles for each parameter + all_handles = [] + for key, tensor in gathered_params.items(): + handle = reduce_tensor(tensor.detach()) + all_handles.append((key, handle)) + + # Store references to avoid premature garbage collection + self._held_gather_buffer = gathered_params + shapes = {} + for key, tensor in gathered_params.items(): + shapes[key] = tensor.shape + + return {device_uuid: all_handles} + + def prepare_for_lp_inference(self): + self.model.to("cuda") + self.model.eval() + self.offload_before_refit() + + def prepare_for_training(self, *args, **kwargs): + # onload models and optimizer state to cuda + self.model.to("cuda") + self.model.train() + + # Move optimizer state to CUDA if it exists + if hasattr(self, "optimizer") and self.optimizer is not None: + # for state in self.optimizer.state.values(): + for state in self.optimizer._get_state().values(): + for k, v in state.items(): + if torch.is_tensor(v) and not v.is_cuda: + state[k] = v.to("cuda") + + torch.cuda.empty_cache() + + @torch.no_grad() + def offload_before_refit(self): + """Offload the optimizer and buffers to the CPU.""" + torch.randn(1).cuda() # wake up torch allocator + if hasattr(self, "optimizer") and self.optimizer is not None: + # Iterate through the state dictionaries for each parameter group + for state in self.optimizer._get_state().values(): + # Iterate through the state items (e.g., momentum, variance) for a parameter + for k, v in state.items(): + # Check if the item is a tensor and on the GPU + if torch.is_tensor(v) and v.is_cuda: + # Move the tensor to CPU and update the state dictionary + state[k] = v.to("cpu") + + gc.collect() + torch.cuda.empty_cache() + + # Print memory stats after offloading + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after optimizer offload: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + @torch.no_grad() + def offload_after_refit(self): + # Offload as much as possible on the CPU + self.model = self.move_model(self.model, "cpu") + self.model.eval() + torch.randn(1).cuda() # wake up torch allocator + self.offload_before_refit() # rerun the old offload function + + if self._held_gather_buffer is not None: + del self._held_gather_buffer + self._held_gather_buffer = None + + gc.collect() + torch.cuda.empty_cache() + + allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB + reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB + print( + f"GPU Memory after refit complete: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved" + ) + + def move_model(self, model, device): + for name, item in model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device=device, non_blocking=True, copy=True) + model.state_dict()[name] = item + return model + + def save_checkpoint( + self, + weights_path: str, + optimizer_path: Optional[str] = None, + offload_to_cpu: bool = True, + ): + pass + + def load_checkpoint(self, weights_path: str, optimizer_path: Optional[str] = None): + pass + + def shutdown(self): + """Shutdown the policy.""" + # + pass diff --git a/nemo_rl/models/policy/utils.py b/nemo_rl/models/policy/utils.py index af2e84245..c81566b58 100644 --- a/nemo_rl/models/policy/utils.py +++ b/nemo_rl/models/policy/utils.py @@ -14,6 +14,10 @@ import importlib import os +from contextlib import contextmanager +from copy import deepcopy + +import torch def import_class_from_path(name): @@ -94,3 +98,98 @@ def get_gpu_info(model): if k.startswith("CUDA") or k in ["LOCAL_RANK", "RANK", "WORLD_SIZE"] }, } + + +def convert_to_amp_o2_format(state_dict): + """When amp_o2 is enabled, the model gets wrapped in a Float16Module which changes the keys and how it loads need to add module onto it.""" + new_state_dict = {} + + for key, item in state_dict.items(): + if "model.module." not in key: + key = key.replace("model.", "model.module.", 1) + new_state_dict[key] = item + + return new_state_dict + + +def retrieve_model_state_dict_in_cpu(model, megatron_amp_O2=True): + """Get a copy of the model states in CPU.""" + cpu_dict = {} + + for name, item in model.state_dict().items(): + if isinstance(item, torch.Tensor): + item = item.detach().to(device="cpu", non_blocking=True, copy=True) + + cpu_dict[name] = item + + if megatron_amp_O2: + cpu_dict = convert_to_amp_o2_format(cpu_dict) + + torch.cuda.synchronize() + return cpu_dict + + +@torch.no_grad() +def swap_dict(resident_model, cpu_weights, offload_onto_cpu=True, megatron_amp_O2=True): + """Swap the state dict with a specified state dict, and offload the current state dict onto CPU if needed.""" + offloaded_weights = {} + + if offload_onto_cpu: + offloaded_weights = retrieve_model_state_dict_in_cpu( + resident_model, megatron_amp_O2=megatron_amp_O2 + ) + + resident_model.load_state_dict(cpu_weights) + return offloaded_weights + + +@contextmanager +def cpu_weight_swap(resident_model, cpu_weights, megatron_amp_O2=True): + """Swap the weights into GPU, and then swap it out once return.""" + cpu_dict = swap_dict(resident_model, cpu_weights, megatron_amp_O2=megatron_amp_O2) + try: + yield + + finally: + swap_dict( + resident_model, + cpu_dict, + offload_onto_cpu=False, + megatron_amp_O2=megatron_amp_O2, + ) + + +@torch.no_grad() +def copy_model_states_to_cpu( + model, cpu_dict=None, megatron_amp_O2=True, sync=True, alias_non_tensor=False +): + """Mutates the cpu_dict object to throw the model states into preallocated tensors(if they exist). + + For non tensors it will do a deepcopy, unless alias_non_tensor is True. + """ + if cpu_dict is None: + cpu_dict = {} + + for name, item in model.state_dict().items(): + if isinstance(item, torch.Tensor): + if name not in cpu_dict: + cpu_dict[name] = torch.empty( + item.size(), + dtype=item.dtype, + layout=item.layout, + device="cpu", + pin_memory=True, + ) + cpu_dict[name].copy_(item, non_blocking=sync) + elif alias_non_tensor: + cpu_dict[name] = item + else: + cpu_dict[name] = deepcopy(item) + + if megatron_amp_O2: + cpu_dict = convert_to_amp_o2_format(cpu_dict) + + if sync: + torch.cuda.synchronize() + + return cpu_dict diff --git a/nemo_rl/utils/nvml.py b/nemo_rl/utils/nvml.py index 137374e00..c2c519397 100644 --- a/nemo_rl/utils/nvml.py +++ b/nemo_rl/utils/nvml.py @@ -60,7 +60,11 @@ def get_device_uuid(device_idx: int) -> str: with nvml_context(): try: handle = pynvml.nvmlDeviceGetHandleByIndex(global_device_idx) - return pynvml.nvmlDeviceGetUUID(handle) + uuid = pynvml.nvmlDeviceGetUUID(handle) + # Ensure the UUID is returned as a string, not bytes + if isinstance(uuid, bytes): + uuid = uuid.decode("utf-8") + return uuid except pynvml.NVMLError as e: raise RuntimeError( f"Failed to get device UUID for device {device_idx} (global index: {global_device_idx}): {e}" diff --git a/tests/unit/distributed/test_named_sharding.py b/tests/unit/distributed/test_named_sharding.py new file mode 100644 index 000000000..bb9d327c8 --- /dev/null +++ b/tests/unit/distributed/test_named_sharding.py @@ -0,0 +1,151 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from nemo_rl.distributed.named_sharding import NamedSharding + + +@pytest.fixture +def sample_sharding(): + """Provides a standard NamedSharding instance for testing.""" + layout = [[[0, 1, 2, 3], [4, 5, 6, 7]]] # dp=1, pp=2, tp=4 + names = ["dp", "pp", "tp"] + return NamedSharding(layout, names) + + +def test_initialization_success(sample_sharding): + assert sample_sharding.shape == {"dp": 1, "pp": 2, "tp": 4} + assert sample_sharding.names == ["dp", "pp", "tp"] + assert sample_sharding.ndim == 3 + assert sample_sharding.size == 8 + np.testing.assert_array_equal( + sample_sharding.layout, np.array([[[0, 1, 2, 3], [4, 5, 6, 7]]]) + ) + + +def test_initialization_dim_mismatch(): + layout = [[0, 1], [2, 3]] + names = ["dp", "pp", "tp"] + with pytest.raises(ValueError, match="Number of dimensions.*must match"): + NamedSharding(layout, names) + + +def test_initialization_non_integer(): + layout = [[0, 1.5], [2, 3]] + names = ["dp", "pp"] + with pytest.raises(ValueError, match="Layout must contain only integer rank IDs"): + NamedSharding(layout, names) + + +def test_initialization_duplicate_ranks(): + layout = [[0, 1], [2, 0]] + names = ["dp", "pp"] + with pytest.raises(ValueError, match="Duplicate ranks found"): + NamedSharding(layout, names) + + +def test_get_ranks_full_slice(sample_sharding): + # Get all ranks for dp=0 + ranks = sample_sharding.get_ranks(dp=0) + correct_out = NamedSharding(np.array([[0, 1, 2, 3], [4, 5, 6, 7]]), ["pp", "tp"]) + assert ranks == correct_out + + +def test_get_ranks_partial_slice(sample_sharding): + # Get ranks for dp=0, pp=1 + ranks = sample_sharding.get_ranks(dp=0, pp=1) + correct_out = NamedSharding(np.array([4, 5, 6, 7]), ["tp"]) + assert ranks == correct_out + + +def test_get_ranks_partial_slice_2(sample_sharding): + ranks = sample_sharding.get_ranks(dp=0, tp=2) + correct_out = NamedSharding(np.array([2, 6]), ["pp"]) + assert ranks == correct_out + + +def test_get_ranks_single_rank(sample_sharding): + # Get rank for dp=0, pp=0, tp=2 + ranks = sample_sharding.get_ranks(dp=0, pp=0, tp=2) + correct_out = 2 + assert ranks == correct_out + + +def test_get_ranks_no_args(sample_sharding): + # Get all ranks flattened + ranks = sample_sharding.get_ranks() + assert ranks == sample_sharding + + +def test_get_ranks_invalid_name(sample_sharding): + with pytest.raises(ValueError, match="Invalid axis name: 'xx'"): + sample_sharding.get_ranks(xx=0) + + +def test_get_ranks_index_out_of_bounds(sample_sharding): + with pytest.raises(IndexError, match="Index 2 is out of bounds for axis 'pp'"): + sample_sharding.get_ranks(pp=2) + with pytest.raises(IndexError, match="Index 4 is out of bounds for axis 'tp'"): + sample_sharding.get_ranks(tp=4) + + +def test_get_axis_index(sample_sharding): + assert sample_sharding.get_axis_index("dp") == 0 + assert sample_sharding.get_axis_index("pp") == 1 + assert sample_sharding.get_axis_index("tp") == 2 + + +def test_get_axis_index_invalid_name(sample_sharding): + with pytest.raises(ValueError, match="Invalid axis name: 'xx'"): + sample_sharding.get_axis_index("xx") + + +def test_get_axis_size(sample_sharding): + assert sample_sharding.get_axis_size("dp") == 1 + assert sample_sharding.get_axis_size("pp") == 2 + assert sample_sharding.get_axis_size("tp") == 4 + + +def test_equality(): + layout1 = [[[0, 1], [2, 3]]] + names1 = ["a", "b", "c"] + sharding1 = NamedSharding(layout1, names1) + + layout2 = [[[0, 1], [2, 3]]] + names2 = ["a", "b", "c"] + sharding2 = NamedSharding(layout2, names2) + + layout3 = [[[0, 1], [2, 4]]] # Different layout + names3 = ["a", "b", "c"] + sharding3 = NamedSharding(layout3, names3) + + layout4 = [[[0, 1], [2, 3]]] + names4 = ["x", "y", "z"] # Different names + sharding4 = NamedSharding(layout4, names4) + + assert sharding1 == sharding2 + assert sharding1 != sharding3 + assert sharding1 != sharding4 + assert sharding1 != "not a sharding object" + + +def test_repr(sample_sharding): + representation = repr(sample_sharding) + assert "NamedSharding" in representation + assert "shape=(1, 2, 4)" in representation + assert "names=['dp', 'pp', 'tp']" in representation + assert "layout=" in representation + assert "[[[0 1 2 3]" in representation # Check layout content part + assert "[4 5 6 7]]]" in representation