Skip to content

Commit

Permalink
[python] Use lmi-dist api to get batch class (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored and frankfliu committed Jul 5, 2023
1 parent 8b7a072 commit 8dc922d
Showing 1 changed file with 1 addition and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception
from transformers import AutoConfig
from lmi_dist.models import get_model
from lmi_dist.models.flash_causal_lm import FlashCausalLMBatch
from lmi_dist.models.seq2seq_lm import Seq2SeqLMBatch
from lmi_dist.utils.parameters import (
NextTokenChooserParameters,
StoppingCriteriaParameters,
Expand All @@ -25,19 +23,6 @@

import torch

ARCHITECTURE_2_BATCH_CLS = {
"RWForCausalLM": FlashCausalLMBatch,
"GPTNeoXForCausalLM": FlashCausalLMBatch,
"T5ForConditionalGeneration": Seq2SeqLMBatch,
"LlamaForCausalLM": FlashCausalLMBatch
}


def get_batch_cls_from_architecture(architecture):
if architecture in ARCHITECTURE_2_BATCH_CLS:
return ARCHITECTURE_2_BATCH_CLS[architecture]
raise ValueError("Invalid architecture, not supported by lmi-dist")


class LmiDistRollingBatch(RollingBatch):

Expand All @@ -60,15 +45,14 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):

def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
self.batch_cls = get_batch_cls_from_architecture(
self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
self.model = get_model(
model_id_or_path,
revision=None,
sharded=sharded,
quantize=None,
trust_remote_code=kwargs.get("trust_remote_code"))
self.batch_cls = self.model.batch_type

@stop_on_any_exception
def inference(self, input_data, parameters):
Expand Down

0 comments on commit 8dc922d

Please sign in to comment.