Skip to content

Adding 'labels' input to model with 'include_loss_args' fails hf examples #1119

Open
@alexlan137

Description

@alexlan137

Hi,

I'm trying to use PiPPy with a custom model that takes both 'input_ids' and 'labels' as inputs. To check for this functionality, I modified the basic pippy_gpt2.py example by first changing the model_class and model_name to GPT2LMHeadModel and then setting setting include_loss_args to True in the function call used to generate example_inputs:
example_inputs = generate_inputs_for_model(model_class, gpt2, model_name, args.batch_size, args.device, include_loss_args=True)

However, this fails with the following traceback:

[rank0]: TypeError: forward() got an unexpected keyword argument 'labels'
RuntimeError: 
[rank0]:             [Stage 0] failed to run forward:
[rank0]:             args: ()
[rank0]:             kwargs: {'input_ids': 'Tensor(torch.Size([1, 1024]), grad=False)', 'labels': 'Tensor(torch.Size([1, 1024]), grad=False)'}

This occurs because PiPPy splits the graph module (split_gm) such that the labels input is sent to the last (4th) submodule, so the first submodule is not expecting an input 'labels'.

I also tried to modify pippy_gpt2.py to insert the labels at the last submodule in schedule.step as follows (although this is not optimal as a long-term solution):

input_values = torch.randint(0, 50257, (args.batch_size, 1024), device=args.device)
example_inputs_0 = {"input_ids": input_values}
example_inputs_3 = {"labels": input_values}

# Run
if args.rank == 0:
    schedule.step(**example_inputs_0)
elif args.rank == 3:
    schedule.step(**example_inputs_3)
else:
    out = schedule.step()

This throws the following error, probably because internal submodules expect RecvInfo and tensors from previous layers rather than new values from input placeholders?
[rank3]: AssertionError: Expected RecvInfo but got <class 'torch.distributed.pipelining._PipelineStage.RootArgPlaceholder'>
I could try to debug further, but is there a better solution or does anyone have any ideas for how to implement this? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions