Skip to content

Commit

Permalink
Format python
Browse files Browse the repository at this point in the history
  • Loading branch information
maaquib committed Jun 30, 2023
1 parent 9c502aa commit 58a5eaa
Showing 1 changed file with 45 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
StoppingCriteriaParameters,
)
import lmi_dist
from lmi_dist.utils.types import (
Batch,
Request,
Generation
)
from lmi_dist.utils.types import (Batch, Request, Generation)
import logging
import torch
import os

ARCHITECTURE_2_BATCH_CLS = {
"RWForCausalLM": FlashCausalLMBatch,
Expand All @@ -36,6 +33,8 @@
"LlamaForCausalLM": FlashCausalLMBatch
}

QUANTIZATION_SUPPORT_ALGO = ["bitsanadbytes"]


def get_batch_cls_from_architecture(architecture):
if architecture in ARCHITECTURE_2_BATCH_CLS:
Expand All @@ -54,7 +53,7 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
:param properties: other properties of the model, such as decoder strategy
:param kwargs passed while loading the model
"""

os.environ["BITSANDBYTES_NOWELCOME"] = "1"
super().__init__(device)
self.properties = properties
self.batch_cls = None
Expand All @@ -63,20 +62,25 @@ def __init__(self, model_id_or_path, device, properties, **kwargs):
self.cache: Batch = None

def _init_model(self, kwargs, model_id_or_path):
self.config = AutoConfig.from_pretrained(model_id_or_path,
**kwargs)
self.batch_cls = get_batch_cls_from_architecture(self.config.architectures[0])
self.config = AutoConfig.from_pretrained(model_id_or_path, **kwargs)
self.batch_cls = get_batch_cls_from_architecture(
self.config.architectures[0])
sharded = int(self.properties.get("tensor_parallel_degree", "-1")) > 1
quantize = self.properties.get("quantize", None)
if quantize and quantize != "bitsanadbytes":
logging.info(f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported. "
f"Setting quantization to None")
dtype = self.properties.get("dtype", None)
if quantize is not None and dtype is not None:
logging.error("Can't set both dtype and quantize")
if quantize and quantize not in QUANTIZATION_SUPPORT_ALGO:
logging.error(
f"Invalid value for quantize: {quantize}. Only `bitsandbytes` quantization is supported."
)
quantize = None
self.model = get_model(model_id_or_path,
revision=None,
sharded=sharded,
quantize=quantize,
trust_remote_code=kwargs.get("trust_remote_code"))
self.model = get_model(
model_id_or_path,
revision=None,
sharded=sharded,
quantize=quantize,
trust_remote_code=kwargs.get("trust_remote_code"))

def inference(self, input_data, parameters):
"""
Expand All @@ -96,15 +100,18 @@ def inference(self, input_data, parameters):
def _prefill_and_decode(self, new_batch):
# prefill step
if new_batch:
generations, prefill_next_batch = self.model.generate_token(new_batch)
generations, prefill_next_batch = self.model.generate_token(
new_batch)

if self.cache:
decode_generations, decode_next_batch = self.model.generate_token(self.cache)
decode_generations, decode_next_batch = self.model.generate_token(
self.cache)
generations.extend(decode_generations)

# concatenate with the existing batch of the model
if decode_next_batch:
self.cache = self.model.batch_type.concatenate([prefill_next_batch, decode_next_batch])
self.cache = self.model.batch_type.concatenate(
[prefill_next_batch, decode_next_batch])
else:
self.cache = prefill_next_batch
else:
Expand All @@ -113,7 +120,10 @@ def _prefill_and_decode(self, new_batch):
generations, next_batch = self.model.generate_token(self.cache)
self.cache = next_batch

generation_dict = {generation.request_id: generation for generation in generations}
generation_dict = {
generation.request_id: generation
for generation in generations
}

req_ids = []
for r in self.pending_requests:
Expand All @@ -132,22 +142,25 @@ def preprocess_requests(self, requests, **kwargs):
for r in requests:
param = r.parameters
parameters = NextTokenChooserParameters(
temperature=param.get("temperature", 0.5), # TODO: Find a better place to put default values
temperature=param.get(
"temperature",
0.5), # TODO: Find a better place to put default values
repetition_penalty=param.get("repetition_penalty", 1.0),
top_k=param.get("top_k", 4),
top_p=param.get("top_p", 1.0),
typical_p=param.get("typical_p", 1.0),
do_sample=param.get("do_sample", False),
)
stop_parameters = StoppingCriteriaParameters(stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))
stop_parameters = StoppingCriteriaParameters(
stop_sequences=param.get("stop_sequences", []),
max_new_tokens=param.get("max_new_tokens", 30))

preprocessed_requests.append(lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters
))
preprocessed_requests.append(
lmi_dist.utils.types.Request(
id=r.id,
inputs=r.input_text,
parameters=parameters,
stopping_parameters=stop_parameters))

if preprocessed_requests:
batch = Batch(id=self.batch_id_counter,
Expand All @@ -156,10 +169,7 @@ def preprocess_requests(self, requests, **kwargs):
self.batch_id_counter += 1

return self.batch_cls.get_batch(
batch,
self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16),
self.device
)
batch, self.model.tokenizer,
kwargs.get("torch_dtype", torch.float16), self.device)
else:
return None

0 comments on commit 58a5eaa

Please sign in to comment.