Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Classifier-Free Guidance (CFG) #651

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added aphrodite/cfg/__init__.py
Empty file.
164 changes: 164 additions & 0 deletions aphrodite/cfg/cfg_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import List, Optional, Union

import torch

from aphrodite.common.sequence import IntermediateTensors, SamplerOutput
from aphrodite.distributed import get_pp_group
from aphrodite.multimodal import MultiModalInputs
from aphrodite.task_handler.model_runner import (
FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata,
ModelRunner)


class CFGModelRunner(ModelRunner):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@torch.inference_mode()
def model_execute(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> torch.Tensor:
if num_steps > 1:
raise ValueError("num_steps > 1 is not supported in ModelRunner")

if self.lora_config:
assert model_input.lora_requests is not None
assert model_input.lora_mapping is not None
self.set_active_loras(model_input.lora_requests,
model_input.lora_mapping)

if self.prompt_adapter_config:
assert model_input.prompt_adapter_requests is not None
assert model_input.prompt_adapter_mapping is not None
self.set_active_prompt_adapters(
model_input.prompt_adapter_requests,
model_input.prompt_adapter_mapping)

if self.attn_backend.get_name() == "flashinfer":
assert model_input.attn_metadata is not None
assert model_input.input_tokens is not None
if self.flashinfer_decode_workspace_buffer is None:
self.flashinfer_decode_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_decode_wrapper = \
BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_decode_workspace_buffer, "NHD")
self.flashinfer_prefill_workspace_buffer = torch.empty(
FLASHINFER_WORKSPACE_BUFFER_SIZE,
dtype=torch.uint8,
device=self.device)
self.flashinfer_prefill_wrapper = \
BatchPrefillWithPagedKVCacheWrapper(
self.flashinfer_prefill_workspace_buffer, "NHD")

model_input.attn_metadata.prefill_wrapper = \
self.flashinfer_prefill_wrapper
if model_input.attn_metadata.use_cuda_graph:
batch_size = model_input.input_tokens.shape[0]
model_input.attn_metadata.decode_wrapper = self.graph_runners[
model_input.
virtual_engine][batch_size].flashinfer_decode_wrapper
else:
model_input.attn_metadata.decode_wrapper = \
self.flashinfer_decode_wrapper
model_input.attn_metadata.begin_forward()

# Currently cuda graph is only supported by the decode phase.
assert model_input.attn_metadata is not None
prefill_meta = model_input.attn_metadata.prefill_metadata
decode_meta = model_input.attn_metadata.decode_metadata
# TODO: We can remove this once all
# virtual engines share the same kv cache.
virtual_engine = model_input.virtual_engine
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
else:
model_executable = self.model

multi_modal_kwargs = model_input.multi_modal_kwargs or {}
seqlen_agnostic_kwargs = {
"finished_requests_ids": model_input.finished_requests_ids,
"request_ids_to_seq_ids": model_input.request_ids_to_seq_ids,
} if self.has_seqlen_agnostic else {}
if (self.observability_config is not None
and self.observability_config.collect_model_forward_time):
raise NotImplementedError("")

hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)

return hidden_or_intermediate_states

@torch.inference_mode()
def get_logits(
self,
hidden_or_intermediate_states: torch.Tensor,
model_input: ModelInputForGPUWithSamplingMetadata,
) -> torch.Tensor:
return self.model._get_logits(hidden_or_intermediate_states,
model_input.sampling_metadata)

@torch.inference_mode()
def compute_logits(
self,
logits: torch.Tensor,
model_input: ModelInputForGPUWithSamplingMetadata,
) -> torch.Tensor:
return self.model.compute_logits(logits,
model_input.sampling_metadata)

@torch.inference_mode()
def do_sample(
self,
logits: torch.Tensor,
model_input: ModelInputForGPUWithSamplingMetadata,
):
if not self.is_driver_worker:
return []

# Sample the next token.
output: SamplerOutput = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)

if self.return_hidden_states:
raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner")

return [output]

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForGPUWithSamplingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:

hidden_or_intermediate_states = self.model_execute(model_input, kv_caches, intermediate_tensors, num_steps)

if not get_pp_group().is_last_rank:
return hidden_or_intermediate_states

hidden_or_intermediate_states = self.get_logits(hidden_or_intermediate_states, model_input)
logits = self.compute_logits(hidden_or_intermediate_states, model_input)

return self.do_sample(logits, model_input)
162 changes: 162 additions & 0 deletions aphrodite/cfg/cfg_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import copy
from typing import Dict, List, Optional, Tuple

import torch

from aphrodite.cfg.cfg_model_runner import CFGModelRunner
from aphrodite.cfg.separated_worker import SeparatedWorker
from aphrodite.common.config import (ClassifierFreeGuidanceConfig,
ParallelConfig)
from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceData, SequenceGroupMetadata)
from aphrodite.distributed import get_pp_group, get_tp_group
from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase,
WorkerBase)


def create_cfg_worker(*args, **kwargs) -> "CFGWorker":

assert "classifier_free_guidance_config" in kwargs
classifier_free_guidance_config: ClassifierFreeGuidanceConfig = kwargs.get("classifier_free_guidance_config")
assert classifier_free_guidance_config is not None
kwargs.pop("classifier_free_guidance_config")

kwargs["model_runner_cls"] = CFGModelRunner
root_worker = SeparatedWorker(*args, **kwargs)

guidance_model_config = classifier_free_guidance_config.guidance_model_config
guidance_parallel_config = classifier_free_guidance_config.guidance_parallel_config
kwargs.update(
model_config=guidance_model_config,
parallel_config=guidance_parallel_config,
)
guidance_worker = SeparatedWorker(*args, **kwargs)

return CFGWorker(
root_worker=root_worker,
guidance_worker=guidance_worker,
is_driver_worker=kwargs["is_driver_worker"],
parallel_config=kwargs["parallel_config"],
)


class CFGWorker(LoraNotSupportedWorkerBase):
def __init__(
self,
root_worker: WorkerBase,
guidance_worker: WorkerBase,
is_driver_worker: bool,
parallel_config: ParallelConfig,
):
self.root_worker = root_worker
self.guidance_worker = guidance_worker
self.is_driver_worker = is_driver_worker
self.parallel_config = parallel_config
assert self.parallel_config.pipeline_parallel_size == 1

def init_device(self):
self.root_worker.init_device()
self.guidance_worker.init_device()

def load_model(self):
self.root_worker.load_model()
self.guidance_worker.share_model(self.root_worker)

def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks, num_cpu_blocks = (
self.root_worker.determine_num_available_blocks())

root_cache_block_size_bytes = (
self.root_worker.get_cache_block_size_bytes())
guidance_cache_block_size_bytes = (
self.guidance_worker.get_cache_block_size_bytes())

new_num_gpu_blocks = int(
num_gpu_blocks * root_cache_block_size_bytes /
(guidance_cache_block_size_bytes + root_cache_block_size_bytes))
return new_num_gpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.root_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
self.guidance_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)

@property
def do_metadata_broadcast(self) -> bool:
return self.parallel_config.tensor_parallel_size > 1

@torch.inference_mode()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:

# prepare negative request with shallow copy
if execute_model_req is not None:
negative_seq_group_metadata_list: List[SequenceGroupMetadata] = []
negative_excute_model_req = execute_model_req.clone(negative_seq_group_metadata_list)
for seq_group_metadata in execute_model_req.seq_group_metadata_list:
negative_seq_group_metadata = copy.copy(seq_group_metadata)
negative_seq_data: Dict[int, SequenceData] = {}
negative_block_tables: Dict[int, List[int]] = {}
assert len(seq_group_metadata.seq_data) == 1
for seq_id in seq_group_metadata.seq_data.keys():
negative_seq_data[seq_id] = seq_group_metadata.negative_seq_data
negative_block_tables[seq_id] = seq_group_metadata.negative_block_table

if negative_seq_group_metadata.is_prompt:
negative_seq_group_metadata.token_chunk_size = list(negative_seq_data.values())[0].get_len()

negative_seq_group_metadata.seq_data = negative_seq_data
negative_seq_group_metadata.block_tables = negative_block_tables
negative_seq_group_metadata.negative_seq_data = None
negative_seq_group_metadata.negative_block_table = None
negative_seq_group_metadata_list.append(negative_seq_group_metadata)
negative_excute_model_req.seq_group_metadata_list = negative_seq_group_metadata_list
else:
negative_excute_model_req = None

inputs = self.root_worker.prepare_input(execute_model_req)
negative_inputs = self.guidance_worker.prepare_input(negative_excute_model_req)
if inputs is None:
assert negative_inputs is None
return None

# get root models's logits
condition_logits = self.root_worker.execute_model_part(inputs)
# get unconditional logits
unconditional_logits = self.guidance_worker.execute_model_part(negative_inputs)

# do classifier free guidance logist process
model_input, _ = inputs
if condition_logits is not None:
for seq_group in model_input.sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
guidance_scale = seq_group.sampling_params.guidance_scale
if guidance_scale == 1.0:
break
for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices):
logits_row = torch.nn.functional.log_softmax(condition_logits[logits_row_idx], dim=-1)
unconditional_logits_row = torch.nn.functional.log_softmax(unconditional_logits[logits_row_idx], dim=-1)
condition_logits[logits_row_idx] = guidance_scale * (logits_row - unconditional_logits_row) + unconditional_logits_row

# do logist_processor
scores = self.root_worker.compute_logits(condition_logits, model_input)
if not self.is_driver_worker:
return []

# do sample
output = self.root_worker.do_sample(scores, model_input)

if not get_pp_group().is_last_rank:
# output is IntermediateTensors
get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group())
return [None]

# output is List[SamplerOutput]
return output

def get_cache_block_size_bytes(self):
raise NotImplementedError
Loading
Loading