Skip to content

Commit

Permalink
add comments
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jul 25, 2024
1 parent 7c4f530 commit b5d4812
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 40 deletions.
50 changes: 33 additions & 17 deletions utilization/model/model_utils/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,53 @@ def info_dataset_group(

class AutoBatchSizeSampler(Sampler[List[int]]):

def __init__(self, data, batch_size: int, auto_batch_size: bool, start_from: int = 0):
def __init__(self, data, batch_size: int, auto_batch_size: bool, index_offset: int = 0):
"""Sampler that automatically adjusts the batch size based on the maximum length of the data.
Args:
data: The data to sample from.
batch_size: The maximum batch size.
auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data.
index_offset: The offset of indices to yield.
"""
self.data = [src.to_model_prompt() if hasattr(src, "to_model_prompt") else src for src in data]
total = len(self.data)
self.batch_size = batch_size
self.auto_batch_size = auto_batch_size
self.first_max_len = None
self.index_offset = index_offset
self.data_order = [[]]
self.start_from = start_from
"""The data indices to yield (batches of indices). In convenience of the `__iter__` method, the indices are offset-based: `range(index_offset, index_offset + total)`."""

if not self.auto_batch_size:
for i in range(0, total, self.batch_size):
st = i + self.start_from
ed = min(i + self.batch_size, total) + self.start_from
st = i + self.index_offset
ed = min(i + self.batch_size, total) + self.index_offset
self.data_order[-1].extend(range(st, ed))
if len(self.data_order[-1]) == self.batch_size:
self.data_order.append([])
else:
for i in range(total):
self.data_order[-1].append(i + self.start_from)
self.data_order[-1].append(i + self.index_offset)
if self.check_new_batch(self.data_order[-1], i + 1):
self.data_order.append([])

# remove the last empty batches
while self.data_order[-1] == []:
self.data_order.pop()
logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.start_from}")
logger.debug(f"AutoBatchSizeSampler: {len(self.data_order)} batches starting from {self.index_offset}")

def check_new_batch(self, queries: List[int], next_data: int) -> bool:
def check_new_batch(self, offset_query_indices: List[int], next_data: int) -> bool:
"""Check the condition to start a new batch."""

current_batch = len(queries)
current_batch = len(offset_query_indices)
if not self.auto_batch_size:
return current_batch > self.batch_size
max_len = max(len(self.data[q - self.start_from]) for q in queries)

# data: 0-based
# offset_query_indices: offset-based
# next_data: 0-based
max_len = max(len(self.data[q - self.index_offset]) for q in offset_query_indices)
if next_data < len(self.data):
max_len = max(len(self.data[next_data]), max_len)

Expand All @@ -85,7 +100,6 @@ def check_new_batch(self, queries: List[int], next_data: int) -> bool:

batch_size = available_space // max_len
batch_size = round_down(batch_size)
# print("!!!", queries, current_batch, batch_size, available_space, max_len, self.first_max_len)
return current_batch >= batch_size

def __iter__(self) -> Iterator[List[int]]:
Expand Down Expand Up @@ -162,17 +176,19 @@ def wrapper():

def __iter__(self) -> Iterator[List[int]]:
model = self.dataset_collection._datasets[0].model
accumulative = 0
for total, init_model, self._forward_call in zip(*self._splitted):
accumulative_offset = 0

# iterate over the dataset groups
for group_total, init_model, self._forward_call in zip(*self._splitted):
iterator, total_prefix_num = init_model()
if total_prefix_num > 1 and model.support_cache:
sampler = CachePrefixSampler(
data=iterator,
total=total,
total=group_total,
total_prefix_num=total_prefix_num,
batch_size=self.batch_size,
auto_batch_size=self.auto_batch_size,
start_from=accumulative,
index_offset=accumulative_offset,
)
model.set_cacher(sampler)
yield from sampler
Expand All @@ -182,11 +198,11 @@ def __iter__(self) -> Iterator[List[int]]:
# dynamic batch size for vLLM
yield from AutoBatchSizeSampler(
iterator,
self.batch_size if not self.vllm else total,
self.batch_size if not self.vllm else group_total,
self.auto_batch_size and not self.vllm,
start_from=accumulative
index_offset=accumulative_offset
)
accumulative += total
accumulative_offset += group_total

def call_model(self, *args, **kwargs) -> List[Any]:
"""Route the model to call the corresponding `model_evaluation_method`"""
Expand Down
34 changes: 16 additions & 18 deletions utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment

import copy

from ...chat_templates import DEFAULT_CHAT_CONFIGS, DEFAULT_CHAT_TEMPLATE, add_space, smart_space

# legacy types
Expand Down Expand Up @@ -54,20 +52,26 @@ def __init__(
chat_template: str,
special_tokens_map: Optional[Dict[str, Union[str, List[str]]]] = None,
):
self.default_stop = chat_config.pop("default_stop", [])
self.auto_leading_space = chat_config.pop("auto_leading_space", True)
self.final_lstrip = chat_config.pop("final_lstrip", True)
self.final_rstrip = chat_config.pop("final_rstrip", True)
self.merge_system_to_user = chat_config.pop("merge_system_to_user", False)
self.system_user_sep = chat_config.pop("system_user_sep", "\n")
sequences = deepcopy(chat_config)
self.default_stop = sequences.pop("default_stop", [])
self.auto_leading_space = sequences.pop("auto_leading_space", True)
self.final_lstrip = sequences.pop("final_lstrip", True)
self.final_rstrip = sequences.pop("final_rstrip", True)
self.merge_system_to_user = sequences.pop("merge_system_to_user", False)
self.system_user_sep = sequences.pop("system_user_sep", "\n")

# api model does not need bos_token
if "bos_token" not in chat_config:
chat_config["bos_token"] = ""
if "bos_token" not in sequences:
sequences["bos_token"] = ""

if special_tokens_map is not None:
for key, value in special_tokens_map.items():
if key not in sequences:
sequences[key] = value

self.sequences = chat_config
self.sequences = sequences
self.chat_template = chat_template
self.special_tokens_map = special_tokens_map or {}
self.special_tokens_map = deepcopy(special_tokens_map or {})

@classmethod
def from_chat_template(
Expand All @@ -79,7 +83,6 @@ def from_chat_template(
chat_template = "base"

if chat_template in DEFAULT_CHAT_CONFIGS:
chat_config = copy.deepcopy(DEFAULT_CHAT_CONFIGS[chat_template])
chat_config = DEFAULT_CHAT_CONFIGS[chat_template]
chat_template = DEFAULT_CHAT_TEMPLATE
else:
Expand All @@ -89,11 +92,6 @@ def from_chat_template(
chat_config = {}
chat_template = chat_config

if special_tokens_map is not None:
for key, value in special_tokens_map.items():
if key not in chat_config:
chat_config[key] = value

return cls(chat_config=chat_config, chat_template=chat_template, special_tokens_map=special_tokens_map)

@staticmethod
Expand Down
17 changes: 13 additions & 4 deletions utilization/model/model_utils/prefix_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,16 @@ class CachePrefixSampler(Sampler[List[int]], Cacher):
Consider a batch of data indexed from 0 to 7 with cache level 2. Assume data
0~3 have the same prefix and 4~7 have the same prefix. We need to yield the
data 0 and 4 to cache the prefix, and then yield 0~7 to generate with the cache.
Notes that the data 0 and 4 will be yielded twice in total."""
Notes that the data 0 and 4 will be yielded twice in total.
Args:
data: The data to sample from.
total: The total length of data.
total_prefix_num: The number of prefixes to cache.
batch_size: The maximum batch size.
auto_batch_size: Whether to automatically adjust the batch size based on the maximum length of the data.
index_offset: The offset of indices to yield.
"""

def __init__(
self,
Expand All @@ -305,14 +314,14 @@ def __init__(
total_prefix_num: int,
batch_size: int,
auto_batch_size: bool = False,
start_from: int = 0,
index_offset: int = 0,
):

# split data into (src,) and (src, tgt)
self.total_prefix_num = total_prefix_num
self.joined_data = [[] for _ in range(self.total_prefix_num)]
self.cache_levels = [0] * total
self.start_from = start_from
self.index_offset = index_offset

# the batch_size for the kvcache is smaller than the batch_size to avoid OOM
cache_batch_size = (batch_size + 1) // 2
Expand Down Expand Up @@ -394,7 +403,7 @@ def _get_data_order(self, total):
order_idx_by_cache[i] = -1

for o in data_order_with_cache:
o = [i + self.start_from for i in o]
o = [i + self.index_offset for i in o]

return data_order_with_cache

Expand Down
1 change: 1 addition & 0 deletions utilization/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ class DatasetArguments:
)

continue_from: ClassVar[int] = 0
"""The number of instances (lines) in .json file to resume from. This is set in `PredictionWriter.write_metainfo`."""

# set in `set_logging` with format "{evaluation_results_dir}/{log_filename}.json"
evaluation_results_path: ClassVar[Optional[str]] = None
Expand Down
3 changes: 2 additions & 1 deletion utilization/utils/log_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,12 @@ def write_metainfo(
self.continue_from_path = None
return

# load num instances
self.continue_from_path = continue_from or self.evaluation_args.continue_from
if self.continue_from_path:
self.continue_from_instance = self.check_continue()

# load num instances
# set num instances in dataset_args
if self.continue_from_instance is not None and continue_from is None:
self.dataset_args.continue_from = self.continue_from_instance

Expand Down

0 comments on commit b5d4812

Please sign in to comment.