Skip to content

Commit 2092a6f

Browse files
AoyuQCyoukaichao
andauthored
[V1][Core] Add worker_base for v1 worker (vllm-project#12816)
Signed-off-by: Aoyu <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: Aoyu <[email protected]> Co-authored-by: youkaichao <[email protected]>
1 parent c9d3ecf commit 2092a6f

File tree

4 files changed

+153
-52
lines changed

4 files changed

+153
-52
lines changed

vllm/utils.py

+43
Original file line numberDiff line numberDiff line change
@@ -2220,3 +2220,46 @@ def import_pynvml():
22202220
"""
22212221
import vllm.third_party.pynvml as pynvml
22222222
return pynvml
2223+
2224+
2225+
def warn_for_unimplemented_methods(cls: Type[T]) -> Type[T]:
2226+
"""
2227+
A replacement for `abc.ABC`.
2228+
When we use `abc.ABC`, subclasses will fail to instantiate
2229+
if they do not implement all abstract methods.
2230+
Here, we only require `raise NotImplementedError` in the
2231+
base class, and log a warning if the method is not implemented
2232+
in the subclass.
2233+
"""
2234+
2235+
original_init = cls.__init__
2236+
2237+
def find_unimplemented_methods(self: object):
2238+
unimplemented_methods = []
2239+
for attr_name in dir(self):
2240+
# bypass inner method
2241+
if attr_name.startswith('_'):
2242+
continue
2243+
2244+
try:
2245+
attr = getattr(self, attr_name)
2246+
# get the func of callable method
2247+
if callable(attr):
2248+
attr_func = attr.__func__
2249+
except AttributeError:
2250+
continue
2251+
src = inspect.getsource(attr_func)
2252+
if "NotImplementedError" in src:
2253+
unimplemented_methods.append(attr_name)
2254+
if unimplemented_methods:
2255+
method_names = ','.join(unimplemented_methods)
2256+
msg = (f"Methods {method_names} not implemented in {self}")
2257+
logger.warning(msg)
2258+
2259+
@wraps(original_init)
2260+
def wrapped_init(self, *args, **kwargs) -> None:
2261+
original_init(self, *args, **kwargs)
2262+
find_unimplemented_methods(self)
2263+
2264+
type.__setattr__(cls, '__init__', wrapped_init)
2265+
return cls

vllm/v1/worker/gpu_worker.py

+9-19
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@
2121
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
2222
from vllm.v1.outputs import ModelRunnerOutput
2323
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
24+
from vllm.v1.worker.worker_base import WorkerBase
2425

2526
logger = init_logger(__name__)
2627

2728
if TYPE_CHECKING:
2829
from vllm.v1.core.scheduler_output import SchedulerOutput
2930

3031

31-
class Worker:
32+
class Worker(WorkerBase):
3233

3334
def __init__(
3435
self,
@@ -39,23 +40,11 @@ def __init__(
3940
is_driver_worker: bool = False,
4041
):
4142

42-
# TODO: use WorkerBase.__init__(self, vllm_config=vllm_config)
43-
self.vllm_config = vllm_config
44-
self.model_config = vllm_config.model_config
45-
self.cache_config = vllm_config.cache_config
46-
self.lora_config = vllm_config.lora_config
47-
self.load_config = vllm_config.load_config
48-
self.parallel_config = vllm_config.parallel_config
49-
self.scheduler_config = vllm_config.scheduler_config
50-
self.device_config = vllm_config.device_config
51-
self.speculative_config = vllm_config.speculative_config
52-
self.prompt_adapter_config = vllm_config.prompt_adapter_config
53-
self.observability_config = vllm_config.observability_config
54-
55-
self.parallel_config.rank = rank
56-
self.local_rank = local_rank
57-
self.rank = rank
58-
self.distributed_init_method = distributed_init_method
43+
super().__init__(vllm_config=vllm_config,
44+
local_rank=local_rank,
45+
rank=rank,
46+
distributed_init_method=distributed_init_method,
47+
is_driver_worker=is_driver_worker)
5948

6049
if self.model_config.trust_remote_code:
6150
# note: lazy import to avoid importing torch before initializing
@@ -126,7 +115,8 @@ def init_device(self):
126115
set_random_seed(self.model_config.seed)
127116

128117
# Construct the model runner
129-
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
118+
self.model_runner: GPUModelRunner = GPUModelRunner(
119+
self.vllm_config, self.device)
130120

131121
def load_model(self) -> None:
132122
if self.vllm_config.model_config.enable_sleep_mode:

vllm/v1/worker/worker_base.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from typing import Optional
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
from vllm.config import VllmConfig
9+
from vllm.logger import init_logger
10+
from vllm.v1.kv_cache_interface import KVCacheSpec
11+
from vllm.worker.worker_base import WorkerBase as WorkerBaseV0
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class WorkerBase(WorkerBaseV0):
17+
"""
18+
Abstract class for v1 worker, mainly define some methods for v1.
19+
For methods shared by v0 and v1, define them in v0 WorkerBase
20+
"""
21+
22+
def __init__(
23+
self,
24+
vllm_config: VllmConfig,
25+
local_rank: int,
26+
rank: int,
27+
distributed_init_method: str,
28+
is_driver_worker: bool = False,
29+
):
30+
"""
31+
Initialize common worker components.
32+
33+
Args:
34+
vllm_config: Complete vLLM configuration
35+
local_rank: Local device index
36+
rank: Global rank in distributed setup
37+
distributed_init_method: Distributed initialization method
38+
is_driver_worker: Whether this worker handles driver
39+
responsibilities
40+
"""
41+
# Configuration storage
42+
super().__init__(vllm_config=vllm_config)
43+
44+
self.local_rank = local_rank
45+
self.rank = rank
46+
self.distributed_init_method = distributed_init_method
47+
self.is_driver_worker = is_driver_worker
48+
49+
# Device and model state
50+
self.device: Optional[torch.device] = None
51+
self.model_runner: Optional[nn.Module] = None
52+
53+
def get_kv_cache_spec(self) -> KVCacheSpec:
54+
"""Get specifications for KV cache implementation."""
55+
raise NotImplementedError
56+
57+
def compile_or_warm_up_model(self) -> None:
58+
"""Prepare model for execution through compilation/warmup."""
59+
raise NotImplementedError
60+
61+
def check_health(self) -> None:
62+
"""Basic health check (override for device-specific checks)."""
63+
return

vllm/worker/worker_base.py

+38-33
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import dataclasses
44
import os
55
import time
6-
from abc import ABC, abstractmethod
6+
from abc import abstractmethod
77
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
88

99
import cloudpickle
@@ -19,15 +19,17 @@
1919
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
2020
from vllm.utils import (enable_trace_function_call_for_thread,
2121
resolve_obj_by_qualname, run_method,
22-
update_environment_variables)
22+
update_environment_variables,
23+
warn_for_unimplemented_methods)
2324
from vllm.worker.model_runner_base import (BroadcastableModelInput,
2425
ModelRunnerBase,
2526
ModelRunnerInputBase)
2627

2728
logger = init_logger(__name__)
2829

2930

30-
class WorkerBase(ABC):
31+
@warn_for_unimplemented_methods
32+
class WorkerBase:
3133
"""Worker interface that allows vLLM to cleanly separate implementations for
3234
different hardware. Also abstracts control plane communication, e.g., to
3335
communicate request metadata to other workers.
@@ -53,35 +55,31 @@ def __init__(
5355
from vllm.platforms import current_platform
5456
self.current_platform = current_platform
5557

56-
@abstractmethod
5758
def init_device(self) -> None:
5859
"""Initialize device state, such as loading the model or other on-device
5960
memory allocations.
6061
"""
6162
raise NotImplementedError
6263

63-
@abstractmethod
64-
def determine_num_available_blocks(self) -> Tuple[int, int]:
65-
"""Determine the number of available blocks for the GPU KV cache and
66-
swappable CPU KV cache.
67-
68-
The implementation may run profiling or other heuristics to determine
69-
the size of caches.
70-
71-
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
72-
are blocks that are "active" on the device and can be appended to.
73-
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
74-
appended to.
75-
"""
76-
raise NotImplementedError
77-
78-
@abstractmethod
7964
def initialize_cache(self, num_gpu_blocks: int,
8065
num_cpu_blocks: int) -> None:
8166
"""Initialize the KV cache with the given size in blocks.
8267
"""
8368
raise NotImplementedError
8469

70+
def get_model(self) -> nn.Module:
71+
raise NotImplementedError
72+
73+
def load_model(self) -> None:
74+
"""Load model onto target device."""
75+
raise NotImplementedError
76+
77+
def execute_model(
78+
self,
79+
execute_model_req: Optional[ExecuteModelRequest] = None
80+
) -> Optional[List[SamplerOutput]]:
81+
raise NotImplementedError
82+
8583
def start_worker_execution_loop(self) -> None:
8684
"""Execute model loop in parallel worker.
8785
@@ -94,40 +92,43 @@ def start_worker_execution_loop(self) -> None:
9492
if output is None:
9593
return None
9694

97-
@abstractmethod
98-
def get_model(self) -> nn.Module:
99-
raise NotImplementedError
95+
def determine_num_available_blocks(self) -> Tuple[int, int]:
96+
"""Determine the number of available blocks for the GPU KV cache and
97+
swappable CPU KV cache.
10098
101-
@abstractmethod
102-
def execute_model(
103-
self,
104-
execute_model_req: Optional[ExecuteModelRequest] = None
105-
) -> Optional[List[SamplerOutput]]:
99+
The implementation may run profiling or other heuristics to determine
100+
the size of caches.
101+
102+
Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks
103+
are blocks that are "active" on the device and can be appended to.
104+
num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be
105+
appended to.
106+
"""
106107
raise NotImplementedError
107108

108-
@abstractmethod
109109
def get_cache_block_size_bytes(self) -> int:
110110
"""Return the size of a single cache block, in bytes. Used in
111111
speculative decoding.
112112
"""
113113
raise NotImplementedError
114114

115-
@abstractmethod
116115
def add_lora(self, lora_request: LoRARequest) -> bool:
117116
raise NotImplementedError
118117

119-
@abstractmethod
120118
def remove_lora(self, lora_id: int) -> bool:
121119
raise NotImplementedError
122120

123-
@abstractmethod
124121
def pin_lora(self, lora_id: int) -> bool:
125122
raise NotImplementedError
126123

127-
@abstractmethod
128124
def list_loras(self) -> Set[int]:
129125
raise NotImplementedError
130126

127+
@property
128+
def vocab_size(self) -> int:
129+
"""Get vocabulary size from model configuration."""
130+
return self.model_config.get_vocab_size()
131+
131132

132133
class DelegateWorkerBase(WorkerBase):
133134
"""
@@ -156,6 +157,10 @@ def initialize_cache(self, num_gpu_blocks: int,
156157
num_cpu_blocks: int) -> None:
157158
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
158159

160+
def load_model(self) -> None:
161+
"""Load model onto target device."""
162+
self.worker.load_model()
163+
159164
def get_model(self) -> nn.Module:
160165
return self.worker.get_model()
161166

0 commit comments

Comments
 (0)