Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for pipeline decoder model #729

Merged
merged 19 commits into from
Sep 19, 2024
Merged

Conversation

baijumeswani
Copy link
Contributor

@baijumeswani baijumeswani commented Jul 29, 2024

This pull request adds support for scenarios where the decoder only model is split into multiple smaller onnx models.

onnxruntime-genai will execute these models in a pipeline fashion. The pipeline is expected to be defined in the genai config.

Update to the GenAI config

The outputs of the previous model in the pipeline can be fed into inputs of the next model in the pipeline. Configurations exposed to the user:

  • filename: file name of the onnx model in the pipeline. Required field.
  • session_options: Each PipelineModel can define its own SessionOptions. In essence, that pipeline model session will use this session options to execute. By not provided, the default session options are used.
  • run_on_first_token_generation: Whether that pipeline model should be run on the first token generation. Default: true
  • run_on_nth_token_generation: Whether that pipeline model should be run on nth token generation (n != 1). Default: true
  • output_names_forwarder: In case output names from the previous model do not align with input names of the following model, this mapping can be defined in the config.
  • inputs: input names of the pipeline model. Required field.
  • outputs: output names of the pipeline model. Required field.

Example pipeline

Here we have split the decoder model in 3 parts:

  1. Embeddings
  2. Transformer
  3. Model head
"pipeline": [
    {
        "embedding": {
            "filename": "phi-3-embedding.onnx",
            "inputs": [
                "input_ids"
            ],
            "outputs": [
                "inputs_embeds"
            ]
        },
        "transformer_model": {
            "filename": "phi-3-transformer.onnx",
            "inputs": [
                "inputs_embeds",
                "past_keys_0",
                "past_values_0",
                "..."
            ],
            "outputs": [
                "transformer_output",
                "present_keys_0",
                "present_values_0",
                "..."
            ],
        },
        "transformer_head": {
            "filename": "phi-3-transformer-head.onnx",
            "inputs": [
                "transformer_output"
            ],
            "outputs": [
                "logits"
            ]
        }
    }
]

In the above example, the outputs of the embedding pipeline model are fed into the inputs of the transformer model. Similarly, the outputs of the transformer model are fed into the inputs of the transformer model head pipeline model.

Assumptions and limitations

  • Final model inputs and outputs are expected to be the same as is currently supported in decoder only models. No other inputs/outputs are managed by the pipeline model. Inputs/outputs managed by the pipeline model:
    • input_ids (input)
    • kv cache (input)
    • attention_mask (input)
    • logits (output)
    • kv cache (output)
  • The managed inputs and outputs (listed above) must be allocated on the device where the search is expected to take place. i.e. the pipeline does not move/copy data from one device to another after a session runs.
  • The intermediate (unmanaged inputs and outputs must reside on CPU) and the ort session is responsible to make copies to and from the device (-host) the session options is registered for.

Where can this feature be used

  • In scenarios where the user would like to run different parts of the model using different session options.
  • In scenarios where it is hard to combine multiple smaller models into one big model (due to limitations of the model or device being used to run the model).

Co-authors: @edgchen1 @ajindal1

@wangyems
Copy link
Contributor

qq: is this for pipeline parallelism?

@baijumeswani
Copy link
Contributor Author

qq: is this for pipeline parallelism?

No, the work here is not intended for pipeline parallelism. However, it could potentially be useful in pipeline parallelism. Sorry for the late response.

@yufenglee
Copy link
Member

could you please add an unit test to cover the pipeline?

src/config.cpp Show resolved Hide resolved
src/config.cpp Show resolved Hide resolved
src/config.cpp Outdated Show resolved Hide resolved
src/models/decoder_only_pipeline.cpp Outdated Show resolved Hide resolved
src/models/model.h Show resolved Hide resolved
src/config.cpp Outdated Show resolved Hide resolved
src/models/decoder_only_pipeline.cpp Show resolved Hide resolved
src/models/decoder_only_pipeline.cpp Outdated Show resolved Hide resolved
src/models/decoder_only_pipeline.cpp Outdated Show resolved Hide resolved
src/models/model.cpp Outdated Show resolved Hide resolved
src/config.h Outdated Show resolved Hide resolved
src/config.h Outdated Show resolved Hide resolved
Copy link
Member

@yufenglee yufenglee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:shipit:

src/models/kv_cache.cpp Outdated Show resolved Hide resolved
@baijumeswani baijumeswani merged commit f81b6eb into main Sep 19, 2024
13 checks passed
@baijumeswani baijumeswani deleted the baijumeswani/phi3-pipeline branch September 19, 2024 00:04
@baijumeswani
Copy link
Contributor Author

Thank you all for the review. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants