Skip to content

Commit

Permalink
async
Browse files Browse the repository at this point in the history
  • Loading branch information
youngkent committed Jan 13, 2025
1 parent 66e5dea commit 050b084
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 235 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SamplerOutput:

# [num_reqs]
sampled_token_ids: List[int]
sampled_token_ids: torch.Tensor

# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
Expand Down
31 changes: 18 additions & 13 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""A layer that samples the next tokens from the model's outputs."""

from typing import Tuple

import torch
import torch.nn as nn

from vllm.v1.outputs import SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
from vllm.v1.sample.ops.penalties import apply_all_penalties, apply_min_token_penalties

Check failure on line 10 in vllm/v1/sample/sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/sample/sampler.py:10:81: E501 Line too long (87 > 80)
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler

_SAMPLING_EPS = 1e-5
Expand All @@ -34,7 +34,8 @@ def forward(
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs, topk_indices = self.get_topk_logprobs(
logits, sampling_metadata)
logits, sampling_metadata
)
else:
topk_logprobs = None
topk_indices = None
Expand All @@ -50,9 +51,8 @@ def forward(
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)

# NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(),
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
Expand All @@ -79,8 +79,7 @@ def sample(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
assert not (sampling_metadata.all_greedy
and sampling_metadata.all_random)
assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)

Check failure on line 82 in vllm/v1/sample/sampler.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/sample/sampler.py:82:81: E501 Line too long (82 > 80)
if sampling_metadata.all_greedy:
return self.greedy_sample(logits)

Expand Down Expand Up @@ -112,7 +111,8 @@ def get_topk_logprobs(
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
logprobs, sampling_metadata.max_num_logprobs, dim=-1
)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
Expand All @@ -122,15 +122,20 @@ def apply_penalties(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
apply_min_token_penalties(logits, sampling_metadata.output_token_ids,
sampling_metadata.stop_token_ids,
sampling_metadata.min_tokens)
apply_min_token_penalties(
logits,
sampling_metadata.output_token_ids,
sampling_metadata.stop_token_ids,
sampling_metadata.min_tokens,
)
if not sampling_metadata.no_penalties:
assert sampling_metadata.prompt_token_ids is not None
logits = apply_all_penalties(
logits, sampling_metadata.prompt_token_ids,
logits,
sampling_metadata.prompt_token_ids,
sampling_metadata.presence_penalties,
sampling_metadata.frequency_penalties,
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
sampling_metadata.output_token_ids,
)
return logits
60 changes: 37 additions & 23 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import asyncio
import multiprocessing
import os
import weakref
from collections.abc import Sequence
from typing import (Any, Callable, Dict, Generic, List, Optional, TypeVar,
Union, overload)
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Optional,
overload,
TypeVar,
Union,
)

import torch

from vllm.logger import init_logger
from vllm.utils import get_mp_context, kill_process_tree
Expand All @@ -13,6 +25,12 @@
T = TypeVar("T")


async def cuda_stream_sync() -> None:
await asyncio.get_running_loop().run_in_executor(
executor=None, func=torch.cuda.current_stream().synchronize
)


class ConstantList(Generic[T], Sequence):

def __init__(self, x: List[T]) -> None:
Expand All @@ -36,31 +54,23 @@ def remove(self, item):
def clear(self):
raise Exception("Cannot clear a constant list")

def index(self,
item: T,
start: int = 0,
stop: Optional[int] = None) -> int:
return self._x.index(item, start,
stop if stop is not None else len(self._x))
def index(self, item: T, start: int = 0, stop: Optional[int] = None) -> int:
return self._x.index(item, start, stop if stop is not None else len(self._x))

Check failure on line 58 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/utils.py:58:81: E501 Line too long (85 > 80)

@overload
def __getitem__(self, item: int) -> T:
...
def __getitem__(self, item: int) -> T: ...

@overload
def __getitem__(self, s: slice, /) -> List[T]:
...
def __getitem__(self, s: slice, /) -> List[T]: ...

def __getitem__(self, item: Union[int, slice]) -> Union[T, List[T]]:
return self._x[item]

@overload
def __setitem__(self, item: int, value: T):
...
def __setitem__(self, item: int, value: T): ...

@overload
def __setitem__(self, s: slice, value: T, /):
...
def __setitem__(self, s: slice, value: T, /): ...

def __setitem__(self, item: Union[int, slice], value: Union[T, List[T]]):
raise Exception("Cannot set item in a constant list")
Expand Down Expand Up @@ -95,23 +105,27 @@ def __init__(
context = get_mp_context()
reader, writer = context.Pipe(duplex=False)

assert ("ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs)
assert (
"ready_pipe" not in process_kwargs
and "input_path" not in process_kwargs
and "output_path" not in process_kwargs
)
process_kwargs["ready_pipe"] = writer
process_kwargs["input_path"] = input_path
process_kwargs["output_path"] = output_path

# Run busy loop in background process.
self.proc = context.Process(target=target_fn, kwargs=process_kwargs)
self._finalizer = weakref.finalize(self, shutdown, self.proc,
input_path, output_path)
self._finalizer = weakref.finalize(
self, shutdown, self.proc, input_path, output_path
)
self.proc.start()

# Wait for startup.
if reader.recv()["status"] != "READY":
raise RuntimeError(f"{process_name} initialization failed. "
"See root cause above.")
raise RuntimeError(
f"{process_name} initialization failed. " "See root cause above."

Check failure on line 127 in vllm/v1/utils.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/v1/utils.py:127:81: E501 Line too long (81 > 80)
)

def shutdown(self):
self._finalizer()
Expand Down
Loading

0 comments on commit 050b084

Please sign in to comment.