Skip to content

Commit

Permalink
added pad_time function
Browse files Browse the repository at this point in the history
  • Loading branch information
nvdreidenbach committed Jan 2, 2025
1 parent 00d5451 commit ac26890
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 27 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 87 files
+1 −1 .gitlab/stages/00.pre.yml
+0 −1 examples/gpt3/gpt_config.yaml
+50 −55 examples/inference/README.md
+5 −5 examples/inference/gpt/simple_gpt_batch_inference.py
+3 −3 examples/inference/t5/simple_t5_batch_inference.py
+2 −2 examples/multimodal/README.md
+66 −137 examples/multimodal/dataset_helpers.py
+0 −0 examples/multimodal/evaluate_ai2d.py
+0 −0 examples/multimodal/evaluate_chartqa.py
+0 −0 examples/multimodal/evaluate_coco.py
+0 −0 examples/multimodal/evaluate_mathvista.py
+0 −6 examples/multimodal/evaluate_mmmu.py
+0 −0 examples/multimodal/evaluate_ocrbench.py
+0 −0 examples/multimodal/evaluate_textvqa.py
+0 −0 examples/multimodal/evaluate_vqav2.py
+0 −0 examples/multimodal/evaluation_datasets.py
+2 −9 examples/multimodal/nvlm/README.md
+1 −1 examples/multimodal/nvlm/pretrain_qwen20_72b_internvit_6b.sh
+1 −1 examples/multimodal/nvlm/pretrain_yi_34b_internvit_6b.sh
+1 −1 examples/multimodal/nvlm/run_text_generation_qwen20_72b_internvit_6b.sh
+2 −2 examples/multimodal/nvlm/run_text_generation_yi_34b_internvit_6b.sh
+1 −1 examples/multimodal/nvlm/sft_34b_internvit.sh
+1 −1 examples/multimodal/nvlm/sft_qwen20_72b_internvit_6b.sh
+6 −1 examples/multimodal/pretrain_mistral_clip.sh
+1 −1 examples/multimodal/run_text_generation.py
+6 −1 examples/multimodal/sft_mistral_clip.sh
+13 −4 examples/multimodal/text_generation_mistral_clip.sh
+2 −2 examples/multimodal/train.py
+1 −2 megatron/core/dist_checkpointing/mapping.py
+2 −0 megatron/core/dist_checkpointing/serialization.py
+14 −13 megatron/core/dist_checkpointing/validation.py
+6 −21 megatron/core/distributed/distributed_data_parallel.py
+18 −29 megatron/core/extensions/transformer_engine.py
+29 −4 megatron/core/inference/common_inference_params.py
+8 −15 megatron/core/inference/engines/mcore_engine.py
+2 −2 megatron/core/inference/inference_request.py
+0 −35 megatron/core/inference/sampling_params.py
+3 −3 megatron/core/inference/scheduler.py
+4 −4 megatron/core/inference/text_generation_controllers/encoder_decoder_text_generation_controller.py
+398 −3 megatron/core/inference/text_generation_controllers/simple_text_generation_controller.py
+0 −400 megatron/core/inference/text_generation_controllers/text_generation_controller.py
+24 −48 megatron/core/models/bert/bert_layer_specs.py
+13 −16 megatron/core/models/common/embeddings/rope_utils.py
+2 −5 megatron/core/models/multimodal/llava_model.py
+24 −49 megatron/core/optimizer/__init__.py
+9 −21 megatron/core/optimizer/clip_grads.py
+102 −193 megatron/core/optimizer/distrib_optimizer.py
+86 −143 megatron/core/optimizer/optimizer.py
+0 −65 megatron/core/optimizer/optimizer_config.py
+0 −10 megatron/core/pipeline_parallel/schedules.py
+76 −80 megatron/core/rerun_state_machine.py
+212 −670 megatron/core/transformer/cuda_graphs.py
+1 −4 megatron/core/transformer/moe/README.md
+4 −115 megatron/core/transformer/moe/moe_utils.py
+26 −55 megatron/core/transformer/moe/router.py
+1 −4 megatron/core/transformer/transformer_block.py
+6 −32 megatron/core/transformer/transformer_config.py
+5 −4 megatron/inference/text_generation/forward_step.py
+8 −39 megatron/training/arguments.py
+12 −27 megatron/training/checkpointing.py
+13 −37 megatron/training/training.py
+11 −49 megatron/training/utils.py
+10 −4 pretrain_vlm.py
+0 −1 ...sts/test_cases/gpt/gpt3_345m_nightly_dgx_a100_1N8G_mcore_tp2_pp2_ep2_te_4experts2parallel/model_config.yaml
+1 −0 ...s/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_resume_torch_dist_rope_embeddings_dgx_a100_1N8G/model_config.yaml
+1 −0 tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp1_pp2_rope_embeddings_dgx_a100_1N8G/model_config.yaml
+0 −1 tests/functional_tests/test_cases/gpt/gpt3_nightly_mcore_te_tp2_pp1_modelopt_distill_resume/model_config.yaml
+0 −1 ...imodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_dev.json
+0 −1 ...imodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/golden_values_lts.json
+0 −57 .../multimodal-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dgx_a100_1N8G/model_config.yaml
+0 −1 ...ava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_dev.json
+0 −1 ...ava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/golden_values_lts.json
+0 −58 ...al-llava/multimodal_llava_mr_mcore_te_tp4_pp1_freeze_vit_freeze_lm_dist_opt_dgx_a100_1N8G/model_config.yaml
+0 −1 ...ional_tests/test_cases/t5/t5_220m_nightly_dgx_a100_1N8G_mcore_te_tp1_pp1_vp1_resume_torch/model_config.yaml
+0 −1 tests/functional_tests/test_cases/t5/t5_220m_nightly_dgx_a100_1N8G_mcore_te_tp2_pp1_vp1/model_config.yaml
+0 −1 ..._tests/test_cases/t5/t5_220m_nightly_dgx_a100_1N8G_mcore_te_tp2_pp1_vp1_sequence_parallel/model_config.yaml
+0 −2 tests/test_utils/recipes/multimodal-llava.yaml
+0 −68 tests/unit_tests/dist_checkpointing/test_flattened_resharding.py
+0 −33 tests/unit_tests/dist_checkpointing/test_serialization.py
+8 −6 tests/unit_tests/inference/engines/test_mcore_engine.py
+3 −3 tests/unit_tests/inference/test_common_inference_params.py
+2 −2 tests/unit_tests/inference/test_scheduler.py
+2 −2 tests/unit_tests/inference/text_generation_controllers/test_encoder_decoder_text_generation_controller.py
+13 −13 tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py
+0 −47 tests/unit_tests/test_optimizer.py
+0 −44 tests/unit_tests/transformer/moe/test_aux_loss.py
+0 −59 tests/unit_tests/transformer/moe/test_routers.py
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated 372 files
31 changes: 28 additions & 3 deletions sub-packages/bionemo-moco/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,30 @@ Generate the time schedule as a tensor.
- `nsteps` _Optioanl[int]_ - Number of time steps. If None, uses the value from initialization.
- `device` _Optional[str]_ - Device to place the schedule on (default is "cpu").

<a id="mocoschedulesinference_time_schedulesInferenceSchedulepad_time"></a>

#### pad\_time

```python
def pad_time(n_samples: int,
scalar_time: Float,
device: Optional[Union[str, torch.device]] = None) -> Tensor
```

Creates a tensor of shape (n_samples,) filled with a scalar time value.

**Arguments**:

- n_samples (int): The desired dimension of the output tensor.
- scalar_time (Float): The scalar time value to fill the tensor with.
- device (Optional[Union[str, torch.device]], optional):
The device to place the tensor on. Defaults to None, which uses the default device.


**Returns**:

- Tensor: A tensor of shape (n_samples,) filled with the scalar time value.

<a id="mocoschedulesinference_time_schedulesContinuousInferenceSchedule"></a>

## ContinuousInferenceSchedule Objects
Expand Down Expand Up @@ -2906,7 +2930,7 @@ A Continuous Flow Matching interpolant.
>>> from bionemo.bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule
>>> from bionemo.bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule

flow_matcher = ContinuousFlowMatcher(
time_distribution = UniformTimeDistribution(...),
Expand All @@ -2929,8 +2953,9 @@ for epoch in range(1000):

# Generation
x_pred = flow_matcher.sample_prior(data.shape)
for t in LinearTimeSchedule(...).generate_schedule():
time = torch.full((batch_size,), t)
inference_sched = LinearInferenceSchedule(...)
for t in inference_sched.generate_schedule():
time = inference_sched.pad_time(x_pred.shape[0], t)
u_hat = model(x_pred, time)
x_pred = flow_matcher.step(u_hat, x_pred, time)
return x_pred
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@
"outputs": [],
"source": [
"for dt, t in zip(dts, ts):\n",
" t = torch.full((xt.shape[0],), t).to(DEVICE)\n",
" t = schedule.pad_time(num_samples, t, DEVICE)\n",
" logits = model(xt, t)\n",
" xt = dfm.step(logits, t, xt, dt, stochasticity=0)"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class ContinuousFlowMatcher(Interpolant):
>>> from bionemo.moco.distributions.prior.continuous.gaussian import GaussianPrior
>>> from bionemo.moco.distributions.time.uniform import UniformTimeDistribution
>>> from bionemo.moco.interpolants.continuous_time.continuous.continuous_flow_matching import ContinuousFlowMatcher
>>> from bionemo.moco.schedules.inference_time_schedules import LinearTimeSchedule
>>> from bionemo.moco.schedules.inference_time_schedules import LinearInferenceSchedule
flow_matcher = ContinuousFlowMatcher(
time_distribution = UniformTimeDistribution(...),
Expand All @@ -78,8 +78,9 @@ class ContinuousFlowMatcher(Interpolant):
# Generation
x_pred = flow_matcher.sample_prior(data.shape)
for t in LinearTimeSchedule(...).generate_schedule():
time = torch.full((batch_size,), t)
inference_sched = LinearInferenceSchedule(...)
for t in inference_sched.generate_schedule():
time = inference_sched.pad_time(x_pred.shape[0], t)
u_hat = model(x_pred, time)
x_pred = flow_matcher.step(u_hat, x_pred, time)
return x_pred
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,22 @@ def generate_schedule(
"""
pass

def pad_time(
self, n_samples: int, scalar_time: Float, device: Optional[Union[str, torch.device]] = None
) -> Tensor:
"""Creates a tensor of shape (n_samples,) filled with a scalar time value.
Args:
n_samples (int): The desired dimension of the output tensor.
scalar_time (Float): The scalar time value to fill the tensor with.
device (Optional[Union[str, torch.device]], optional):
The device to place the tensor on. Defaults to None, which uses the default device.
Returns:
Tensor: A tensor of shape (n_samples,) filled with the scalar time value.
"""
return torch.full((n_samples,), fill_value=scalar_time).to(device)


class ContinuousInferenceSchedule(InferenceSchedule):
"""A base class for continuous time inference schedules."""
Expand Down

0 comments on commit ac26890

Please sign in to comment.