From 8dc922d4ddc92d99265a251e30dea8f6ddb0f53c Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Wed, 5 Jul 2023 11:23:07 -0700 Subject: [PATCH] [python] Use lmi-dist api to get batch class (#901) --- .../rolling_batch/lmi_dist_rolling_batch.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index f54901d2f..77ec54eba 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -14,8 +14,6 @@ from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception from transformers import AutoConfig from lmi_dist.models import get_model -from lmi_dist.models.flash_causal_lm import FlashCausalLMBatch -from lmi_dist.models.seq2seq_lm import Seq2SeqLMBatch from lmi_dist.utils.parameters import ( NextTokenChooserParameters, StoppingCriteriaParameters, @@ -25,19 +23,6 @@ import torch -ARCHITECTURE_2_BATCH_CLS = { - "RWForCausalLM": FlashCausalLMBatch, - "GPTNeoXForCausalLM": FlashCausalLMBatch, - "T5ForConditionalGeneration": Seq2SeqLMBatch, - "LlamaForCausalLM": FlashCausalLMBatch -} - - -def get_batch_cls_from_architecture(architecture): - if architecture in ARCHITECTURE_2_BATCH_CLS: - return ARCHITECTURE_2_BATCH_CLS[architecture] - raise ValueError("Invalid architecture, not supported by lmi-dist") - class LmiDistRollingBatch(RollingBatch): @@ -60,8 +45,6 @@ def __init__(self, model_id_or_path, device, properties, **kwargs): def _init_model(self, kwargs, model_id_or_path): self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs) - self.batch_cls = get_batch_cls_from_architecture( - self.config.architectures[0]) sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1 self.model = get_model( model_id_or_path, @@ -69,6 +52,7 @@ def _init_model(self, kwargs, model_id_or_path): sharded=sharded, quantize=None, trust_remote_code=kwargs.get("trust_remote_code")) + self.batch_cls = self.model.batch_type @stop_on_any_exception def inference(self, input_data, parameters):