From ede262ad02026467c9987e1459517d5a8427d318 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sun, 8 Sep 2024 16:12:10 +0000 Subject: [PATCH] add base cfg worker and model runner Co-authored-by: zhaoyinglia --- aphrodite/cfg/__init__.py | 0 aphrodite/cfg/cfg_model_runner.py | 164 ++++++++++++++++++++++++++++++ aphrodite/cfg/cfg_worker.py | 162 +++++++++++++++++++++++++++++ aphrodite/cfg/separated_worker.py | 75 ++++++++++++++ 4 files changed, 401 insertions(+) create mode 100644 aphrodite/cfg/__init__.py create mode 100644 aphrodite/cfg/cfg_model_runner.py create mode 100644 aphrodite/cfg/cfg_worker.py create mode 100644 aphrodite/cfg/separated_worker.py diff --git a/aphrodite/cfg/__init__.py b/aphrodite/cfg/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/aphrodite/cfg/cfg_model_runner.py b/aphrodite/cfg/cfg_model_runner.py new file mode 100644 index 0000000000..a6bdc9c012 --- /dev/null +++ b/aphrodite/cfg/cfg_model_runner.py @@ -0,0 +1,164 @@ +from typing import List, Optional, Union + +import torch + +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.distributed import get_pp_group +from aphrodite.multimodal import MultiModalInputs +from aphrodite.task_handler.model_runner import ( + FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata, + ModelRunner) + + +class CFGModelRunner(ModelRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def model_execute( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> torch.Tensor: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO: We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + raise NotImplementedError("") + + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + return hidden_or_intermediate_states + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model._get_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model.compute_logits(logits, + model_input.sampling_metadata) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ): + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner") + + return [output] + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + + hidden_or_intermediate_states = self.model_execute(model_input, kv_caches, intermediate_tensors, num_steps) + + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + hidden_or_intermediate_states = self.get_logits(hidden_or_intermediate_states, model_input) + logits = self.compute_logits(hidden_or_intermediate_states, model_input) + + return self.do_sample(logits, model_input) diff --git a/aphrodite/cfg/cfg_worker.py b/aphrodite/cfg/cfg_worker.py new file mode 100644 index 0000000000..fa14f05f07 --- /dev/null +++ b/aphrodite/cfg/cfg_worker.py @@ -0,0 +1,162 @@ +import copy +from typing import Dict, List, Optional, Tuple + +import torch + +from aphrodite.cfg.cfg_model_runner import CFGModelRunner +from aphrodite.cfg.separated_worker import SeparatedWorker +from aphrodite.common.config import (ClassifierFreeGuidanceConfig, + ParallelConfig) +from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceData, SequenceGroupMetadata) +from aphrodite.distributed import get_pp_group, get_tp_group +from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase, + WorkerBase) + + +def create_cfg_worker(*args, **kwargs) -> "CFGWorker": + + assert "classifier_free_guidance_config" in kwargs + classifier_free_guidance_config: ClassifierFreeGuidanceConfig = kwargs.get("classifier_free_guidance_config") + assert classifier_free_guidance_config is not None + kwargs.pop("classifier_free_guidance_config") + + kwargs["model_runner_cls"] = CFGModelRunner + root_worker = SeparatedWorker(*args, **kwargs) + + guidance_model_config = classifier_free_guidance_config.guidance_model_config + guidance_parallel_config = classifier_free_guidance_config.guidance_parallel_config + kwargs.update( + model_config=guidance_model_config, + parallel_config=guidance_parallel_config, + ) + guidance_worker = SeparatedWorker(*args, **kwargs) + + return CFGWorker( + root_worker=root_worker, + guidance_worker=guidance_worker, + is_driver_worker=kwargs["is_driver_worker"], + parallel_config=kwargs["parallel_config"], + ) + + +class CFGWorker(LoraNotSupportedWorkerBase): + def __init__( + self, + root_worker: WorkerBase, + guidance_worker: WorkerBase, + is_driver_worker: bool, + parallel_config: ParallelConfig, + ): + self.root_worker = root_worker + self.guidance_worker = guidance_worker + self.is_driver_worker = is_driver_worker + self.parallel_config = parallel_config + assert self.parallel_config.pipeline_parallel_size == 1 + + def init_device(self): + self.root_worker.init_device() + self.guidance_worker.init_device() + + def load_model(self): + self.root_worker.load_model() + self.guidance_worker.share_model(self.root_worker) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_gpu_blocks, num_cpu_blocks = ( + self.root_worker.determine_num_available_blocks()) + + root_cache_block_size_bytes = ( + self.root_worker.get_cache_block_size_bytes()) + guidance_cache_block_size_bytes = ( + self.guidance_worker.get_cache_block_size_bytes()) + + new_num_gpu_blocks = int( + num_gpu_blocks * root_cache_block_size_bytes / + (guidance_cache_block_size_bytes + root_cache_block_size_bytes)) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + self.root_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + + # prepare negative request with shallow copy + if execute_model_req is not None: + negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + negative_excute_model_req = execute_model_req.clone(negative_seq_group_metadata_list) + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + negative_seq_group_metadata = copy.copy(seq_group_metadata) + negative_seq_data: Dict[int, SequenceData] = {} + negative_block_tables: Dict[int, List[int]] = {} + assert len(seq_group_metadata.seq_data) == 1 + for seq_id in seq_group_metadata.seq_data.keys(): + negative_seq_data[seq_id] = seq_group_metadata.negative_seq_data + negative_block_tables[seq_id] = seq_group_metadata.negative_block_table + + if negative_seq_group_metadata.is_prompt: + negative_seq_group_metadata.token_chunk_size = list(negative_seq_data.values())[0].get_len() + + negative_seq_group_metadata.seq_data = negative_seq_data + negative_seq_group_metadata.block_tables = negative_block_tables + negative_seq_group_metadata.negative_seq_data = None + negative_seq_group_metadata.negative_block_table = None + negative_seq_group_metadata_list.append(negative_seq_group_metadata) + negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list + else: + negative_excute_model_req = None + + inputs = self.root_worker.prepare_input(execute_model_req) + negative_inputs = self.guidance_worker.prepare_input(negative_excute_model_req) + if inputs is None: + assert negative_inputs is None + return None + + # get root models's logits + condition_logits = self.root_worker.execute_model_part(inputs) + # get unconditional logits + unconditional_logits = self.guidance_worker.execute_model_part(negative_inputs) + + # do classifier free guidance logist process + model_input, _ = inputs + if condition_logits is not None: + for seq_group in model_input.sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + guidance_scale = seq_group.sampling_params.guidance_scale + if guidance_scale == 1.0: + break + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + logits_row = torch.nn.functional.log_softmax(condition_logits[logits_row_idx], dim=-1) + unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1) + condition_logits[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row + + # do logist_processor + scores = self.root_worker.compute_logits(condition_logits, model_input) + if not self.is_driver_worker: + return [] + + # do sample + output = self.root_worker.do_sample(scores, model_input) + + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) + return [None] + + # output is List[SamplerOutput] + return output + + def get_cache_block_size_bytes(self): + raise NotImplementedError diff --git a/aphrodite/cfg/separated_worker.py b/aphrodite/cfg/separated_worker.py new file mode 100644 index 0000000000..816741a5fa --- /dev/null +++ b/aphrodite/cfg/separated_worker.py @@ -0,0 +1,75 @@ +from typing import List, Optional, Tuple + +import torch + +from aphrodite.distributed import get_pp_group, get_tp_group +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.task_handler.worker import Worker +from aphrodite.task_handler.worker_base import WorkerInput +from aphrodite.task_handler.model_runner_base import BroadcastableModelInput # to be implemented soon +from aphrodite.task_handler.model_runner import ModelInputForGPUWithSamplingMetadata + + +class SeparatedWorker(Worker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model_runner.get_logits(hidden_or_intermediate_states, model_input) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model_runner.compute_logits(logits, model_input) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> List[SamplerOutput]: + return self.model_runner.do_sample(logits, model_input) + + @torch.inference_mode() + def execute_model_part( + self, + inputs: Tuple[BroadcastableModelInput, WorkerInput], + ) -> Optional[List[SamplerOutput]]: + + model_input, worker_input = inputs + num_steps = worker_input.num_steps + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group())) + + hidden_or_intermediate_states = self.model_runner.model_execute( + model_input, + self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors, + num_steps + ) + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.get_logits(hidden_or_intermediate_states, model_input) + + return logits