From d0a982c107f7a7455c97b159d962533bb5bead80 Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Thu, 6 Jul 2023 17:32:39 -0700 Subject: [PATCH] [python] Reformat python code --- .../scheduler/seq_batch_scheduler.py | 27 +++++------ .../setup/djl_python/tests/test_scheduler.py | 45 ++++++++++++------- 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py b/engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py index 3519af69d..ab9fdf0d0 100644 --- a/engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py +++ b/engines/python/setup/djl_python/scheduler/seq_batch_scheduler.py @@ -43,8 +43,7 @@ def __init__(self, lm_block: LMBlock, default_search_algorithm: str, self.results: Dict[int, List[int]] = defaultdict(list) self.seq_batchers: Dict[ - Type[SeqBatcher]:List[SeqBatcher]] = defaultdict( - list) # {key: List[SeqBatcher]} + Type[SeqBatcher]:List[SeqBatcher]] = defaultdict(list) self.lru_kv_cache = OrderedDict() self.lru_max_size = 10 @@ -55,7 +54,8 @@ def add_request(self, search_algorithm: str = None, search_configs: List[SearchConfig] = None, kv_cache: Union[Tuple, None] = None, - kv_cache_prompt_ids: Union[Dict[int, torch.tensor], None] = None): + kv_cache_prompt_ids: Union[Dict[int, torch.tensor], + None] = None): """ Args: kv_cache_prompt_ids = {request_uid -> List[token_ids]} """ @@ -66,20 +66,23 @@ def add_request(self, if search_configs: for idx, search_config in enumerate(search_configs): if search_config.use_lru_kv_cache: - prompt_ids_tensor = kv_cache_prompt_ids[request_uids[idx].item()] + prompt_ids_tensor = kv_cache_prompt_ids[ + request_uids[idx].item()] key = tuple(prompt_ids_tensor.flatten().tolist()) if not key: - raise Exception(f"request_uids = {request_uids[idx]}: search_config says use_kv_cache_prompt, " - f"but the prompt_ids is not provided.") + raise Exception( + f"request_uids = {request_uids[idx]}: search_config says use_kv_cache_prompt, " + f"but the prompt_ids is not provided.") else: # lru operations if key not in self.lru_kv_cache: if len(self.lru_kv_cache) + 1 > self.lru_max_size: # If cache size exceeds the maximum, remove by FIFO order self.lru_kv_cache.popitem(last=False) - kv_cache_tuple = compute_kv_cache(input_ids=prompt_ids_tensor, - lm_block=self.lm_block, - search_configs=[search_config]) + kv_cache_tuple = compute_kv_cache( + input_ids=prompt_ids_tensor, + lm_block=self.lm_block, + search_configs=[search_config]) kv_cache_new = [] for k, v in kv_cache_tuple: k_new = k.cpu() @@ -111,10 +114,8 @@ def add_request(self, index_not_use_prompt = torch.tensor(index_not_use_prompt) self._add_request(input_ids[index_not_use_prompt], - request_uids[index_not_use_prompt], - search_algorithm, - search_configs_not_use_prompt, - kv_cache) + request_uids[index_not_use_prompt], search_algorithm, + search_configs_not_use_prompt, kv_cache) def _add_request(self, input_ids: torch.Tensor, diff --git a/engines/python/setup/djl_python/tests/test_scheduler.py b/engines/python/setup/djl_python/tests/test_scheduler.py index 7e9809e10..fe757d0d5 100644 --- a/engines/python/setup/djl_python/tests/test_scheduler.py +++ b/engines/python/setup/djl_python/tests/test_scheduler.py @@ -46,7 +46,8 @@ def test_lm_block(self): def test_greedy_scheduler(self): model_id = "gpt2" model = GPT2LMHeadModel.from_pretrained(model_id) - tokenizer = GPT2Tokenizer.from_pretrained(model_id, padding_side='left') + tokenizer = GPT2Tokenizer.from_pretrained(model_id, + padding_side='left') tokenizer.pad_token = "[PAD]" lm_block = HuggingfaceBlock(model) @@ -67,9 +68,12 @@ def test_greedy_scheduler(self): # Test add request scheduler.add_request(input_ids_0, request_ids) - input_ids = tokenizer([r"When your legs don't work like they used to before And I can't sweep you off", - r"There's a time that I remember, when I did not know"], - return_tensors='pt', padding=True).input_ids + input_ids = tokenizer([ + r"When your legs don't work like they used to before And I can't sweep you off", + r"There's a time that I remember, when I did not know" + ], + return_tensors='pt', + padding=True).input_ids # Test merging longer sequences request_ids = torch.tensor([[1], [2]]) @@ -87,9 +91,10 @@ def test_greedy_scheduler(self): "remember the last time I saw a girl in a dress. I can't remember the last time" # Load a kv_cache from file and test merging a shorter sequence - input_ids = tokenizer([r"When your legs don't work", - r"'t remember", - r""], return_tensors='pt', padding=True).input_ids + input_ids = tokenizer( + [r"When your legs don't work", r"'t remember", r""], + return_tensors='pt', + padding=True).input_ids request_ids = torch.tensor([[3], [4], [5]]) # Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated @@ -434,7 +439,8 @@ def test_utils(self): def test_lru_kv_cache(self): model_id = "gpt2" model = GPT2LMHeadModel.from_pretrained(model_id) - tokenizer = GPT2Tokenizer.from_pretrained(model_id, padding_side='left') + tokenizer = GPT2Tokenizer.from_pretrained(model_id, + padding_side='left') tokenizer.pad_token = "[PAD]" lm_block = HuggingfaceBlock(model) @@ -442,20 +448,29 @@ def test_lru_kv_cache(self): search_config.max_new_seqlen = 30 scheduler = SeqBatchScheduler(lm_block, "greedy", search_config) - prompt_ids = tokenizer( - 'Memories follow me left and right. I can', return_tensors='pt', padding=True).input_ids + prompt_ids = tokenizer('Memories follow me left and right. I can', + return_tensors='pt', + padding=True).input_ids prompt_ids = prompt_ids.view(1, -1) prompt_ids_dict = {1: prompt_ids, 2: prompt_ids} # Load a kv_cache from file and test merging a shorter sequence - input_ids = tokenizer([r"When your legs don't work", - r"'t remember", - r""], return_tensors='pt', padding=True).input_ids + input_ids = tokenizer( + [r"When your legs don't work", r"'t remember", r""], + return_tensors='pt', + padding=True).input_ids request_ids = torch.tensor([[0], [1], [2]]) - search_configs = [SearchConfig(), SearchConfig(use_lru_kv_cache=True), SearchConfig(use_lru_kv_cache=True)] + search_configs = [ + SearchConfig(), + SearchConfig(use_lru_kv_cache=True), + SearchConfig(use_lru_kv_cache=True) + ] # Load a kv_cache file to simulate a fixed reusable prefix which is pre-calculated - scheduler.add_request(input_ids, request_ids, search_configs=search_configs, kv_cache_prompt_ids=prompt_ids_dict) + scheduler.add_request(input_ids, + request_ids, + search_configs=search_configs, + kv_cache_prompt_ids=prompt_ids_dict) # Test trim_and_collect for idx, _ in enumerate(scheduler.increment_forward(100)):