diff --git a/training/DeepSpeed-Domino/domino/arguments.py b/training/DeepSpeed-Domino/domino/arguments.py index 8bc59223a..cda214897 100644 --- a/training/DeepSpeed-Domino/domino/arguments.py +++ b/training/DeepSpeed-Domino/domino/arguments.py @@ -206,6 +206,8 @@ def parse_args(): help='Report loss and timing interval.') parser.add_argument('--save-interval', type=int, default=None, help='Number of iterations between checkpoint saves.') + parser.add_argument('--input-split-dim', type=str, default='batch', + help='Dimension for input split.') args = parser.parse_args() @@ -355,6 +357,8 @@ class TransformerConfig(): no_sync_func: Callable = None # grad_sync_func: Callable = None # param_sync_func: Callable = None + + input_split_dim: str = 'batch' def __post_init__(self): """ Python dataclass method that is used to modify attributes after initialization. @@ -396,5 +400,6 @@ def core_transformer_config_from_args(args): kw_args['init_method'] = args.init_method kw_args['output_layer_init_method'] = args.init_method kw_args['params_dtype'] = args.params_dtype + kw_args['input_split_dim'] = args.input_split_dim return TransformerConfig(**kw_args) diff --git a/training/DeepSpeed-Domino/domino/language_model.py b/training/DeepSpeed-Domino/domino/language_model.py index 2cfb2f9fd..a92933fc7 100644 --- a/training/DeepSpeed-Domino/domino/language_model.py +++ b/training/DeepSpeed-Domino/domino/language_model.py @@ -127,6 +127,7 @@ def __init__(self, self.init_method = config.init_method self.encoder_attn_mask_type = encoder_attn_mask_type self.encoder_hidden_state = None + self.input_split_dim = config.input_split_dim if self.pre_process: self.embedding = Embedding(self.hidden_size, @@ -177,17 +178,30 @@ def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, encoder_out_size = encoder_input.shape p_batch_size = encoder_out_size[1] // 2 + p_seq_size = encoder_out_size[0] // 2 dtype = encoder_input.dtype encoder_output_t = torch.empty(encoder_out_size, dtype=dtype, device=torch.cuda.current_device()) intra_partitions = 2 - encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1) + if self.input_split_dim == 'batch': + encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=1) + elif self.input_split_dim == 'seq': + encoder_inputs = torch.tensor_split(encoder_input, intra_partitions, dim=0) + else: + raise NotImplementedError encoder_outputs = self.encoder( encoder_inputs, enc_attn_mask, rotary_pos_emb=rotary_pos_emb) - encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0] - encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1] + + if self.input_split_dim == 'batch': + encoder_output_t[:, 0:p_batch_size, :] = encoder_outputs[0] + encoder_output_t[:, p_batch_size:2*p_batch_size, :] = encoder_outputs[1] + elif self.input_split_dim == 'seq': + encoder_output_t[0:p_seq_size, :, :] = encoder_outputs[0] + encoder_output_t[p_seq_size:2*p_seq_size, :, :] = encoder_outputs[1] + else: + raise NotImplementedError + encoder_output = encoder_output_t - return encoder_output \ No newline at end of file