Skip to content

Commit

Permalink
[python] Reformat python code (#917)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Jul 7, 2023
1 parent c737451 commit ea2623b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 28 deletions.
27 changes: 14 additions & 13 deletions engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ def __init__(self, lm_block: LMBlock, default_search_algorithm: str,
self.results: Dict[int, List[int]] = defaultdict(list)

self.seq_batchers: Dict[
Type[SeqBatcher]:List[SeqBatcher]] = defaultdict(
list) # {key: List[SeqBatcher]}
Type[SeqBatcher]:List[SeqBatcher]] = defaultdict(list)

self.lru_kv_cache = OrderedDict()
self.lru_max_size = 10
Expand All @@ -55,7 +54,8 @@ def add_request(self,
search_algorithm: str = None,
search_configs: List[SearchConfig] = None,
kv_cache: Union[Tuple, None] = None,
kv_cache_prompt_ids: Union[Dict[int, torch.tensor], None] = None):
kv_cache_prompt_ids: Union[Dict[int, torch.tensor],
None] = None):
"""
Args: kv_cache_prompt_ids = {request_uid -> List[token_ids]}
"""
Expand All @@ -66,20 +66,23 @@ def add_request(self,
if search_configs:
for idx, search_config in enumerate(search_configs):
if search_config.use_lru_kv_cache:
prompt_ids_tensor = kv_cache_prompt_ids[request_uids[idx].item()]
prompt_ids_tensor = kv_cache_prompt_ids[
request_uids[idx].item()]
key = tuple(prompt_ids_tensor.flatten().tolist())
if not key:
raise Exception(f"request_uids = {request_uids[idx]}: search_config says use_kv_cache_prompt, "
f"but the prompt_ids is not provided.")
raise Exception(
f"request_uids = {request_uids[idx]}: search_config says use_kv_cache_prompt, "
f"but the prompt_ids is not provided.")
else:
# lru operations
if key not in self.lru_kv_cache:
if len(self.lru_kv_cache) + 1 > self.lru_max_size:
# If cache size exceeds the maximum, remove by FIFO order
self.lru_kv_cache.popitem(last=False)
kv_cache_tuple = compute_kv_cache(input_ids=prompt_ids_tensor,
lm_block=self.lm_block,
search_configs=[search_config])
kv_cache_tuple = compute_kv_cache(
input_ids=prompt_ids_tensor,
lm_block=self.lm_block,
search_configs=[search_config])
kv_cache_new = []
for k, v in kv_cache_tuple:
k_new = k.cpu()
Expand Down Expand Up @@ -111,10 +114,8 @@ def add_request(self,

index_not_use_prompt = torch.tensor(index_not_use_prompt)
self._add_request(input_ids[index_not_use_prompt],
request_uids[index_not_use_prompt],
search_algorithm,
search_configs_not_use_prompt,
kv_cache)
request_uids[index_not_use_prompt], search_algorithm,
search_configs_not_use_prompt, kv_cache)

def _add_request(self,
input_ids: torch.Tensor,
Expand Down
45 changes: 30 additions & 15 deletions engines/python/setup/djl_python/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def test_lm_block(self):
def test_greedy_scheduler(self):
model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id, padding_side='left')
tokenizer = GPT2Tokenizer.from_pretrained(model_id,
padding_side='left')
tokenizer.pad_token = "[PAD]"
lm_block = HuggingfaceBlock(model)

Expand All @@ -67,9 +68,12 @@ def test_greedy_scheduler(self):
# Test add request
scheduler.add_request(input_ids_0, request_ids)

input_ids = tokenizer([r"When your legs don't work like they used to before And I can't sweep you off",
r"There's a time that I remember, when I did not know"],
return_tensors='pt', padding=True).input_ids
input_ids = tokenizer([
r"When your legs don't work like they used to before And I can't sweep you off",
r"There's a time that I remember, when I did not know"
],
return_tensors='pt',
padding=True).input_ids

# Test merging longer sequences
request_ids = torch.tensor([[1], [2]])
Expand All @@ -87,9 +91,10 @@ def test_greedy_scheduler(self):
"remember the last time I saw a girl in a dress. I can't remember the last time"

# Load a kv_cache from file and test merging a shorter sequence
input_ids = tokenizer([r"When your legs don't work",
r"'t remember",
r""], return_tensors='pt', padding=True).input_ids
input_ids = tokenizer(
[r"When your legs don't work", r"'t remember", r""],
return_tensors='pt',
padding=True).input_ids
request_ids = torch.tensor([[3], [4], [5]])

# Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated
Expand Down Expand Up @@ -434,28 +439,38 @@ def test_utils(self):
def test_lru_kv_cache(self):
model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id)
tokenizer = GPT2Tokenizer.from_pretrained(model_id, padding_side='left')
tokenizer = GPT2Tokenizer.from_pretrained(model_id,
padding_side='left')
tokenizer.pad_token = "[PAD]"
lm_block = HuggingfaceBlock(model)

search_config = SearchConfig()
search_config.max_new_seqlen = 30
scheduler = SeqBatchScheduler(lm_block, "greedy", search_config)

prompt_ids = tokenizer(
'Memories follow me left and right. I can', return_tensors='pt', padding=True).input_ids
prompt_ids = tokenizer('Memories follow me left and right. I can',
return_tensors='pt',
padding=True).input_ids
prompt_ids = prompt_ids.view(1, -1)
prompt_ids_dict = {1: prompt_ids, 2: prompt_ids}

# Load a kv_cache from file and test merging a shorter sequence
input_ids = tokenizer([r"When your legs don't work",
r"'t remember",
r""], return_tensors='pt', padding=True).input_ids
input_ids = tokenizer(
[r"When your legs don't work", r"'t remember", r""],
return_tensors='pt',
padding=True).input_ids
request_ids = torch.tensor([[0], [1], [2]])
search_configs = [SearchConfig(), SearchConfig(use_lru_kv_cache=True), SearchConfig(use_lru_kv_cache=True)]
search_configs = [
SearchConfig(),
SearchConfig(use_lru_kv_cache=True),
SearchConfig(use_lru_kv_cache=True)
]

# Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated
scheduler.add_request(input_ids, request_ids, search_configs=search_configs, kv_cache_prompt_ids=prompt_ids_dict)
scheduler.add_request(input_ids,
request_ids,
search_configs=search_configs,
kv_cache_prompt_ids=prompt_ids_dict)

# Test trim_and_collect
for idx, _ in enumerate(scheduler.increment_forward(100)):
Expand Down

0 comments on commit ea2623b

Please sign in to comment.