Skip to content

Commit b6f99a6

Browse files
authored
[Core] Refactor executor classes for easier inheritance (vllm-project#7673)
[Core] Refactor executor classes to make it easier to inherit GPUExecutor (vllm-project#7673)
1 parent ad28a74 commit b6f99a6

File tree

2 files changed

+27
-21
lines changed

2 files changed

+27
-21
lines changed

vllm/executor/gpu_executor.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ def _get_worker_kwargs(
6262
observability_config=self.observability_config,
6363
)
6464

65+
def _get_worker_module_and_class(self) -> Tuple[str, str]:
66+
if self.scheduler_config.is_multi_step:
67+
worker_module_name = "vllm.worker.multi_step_worker"
68+
worker_class_name = "MultiStepWorker"
69+
elif self.speculative_config:
70+
worker_module_name = "vllm.spec_decode.spec_decode_worker"
71+
worker_class_name = "create_spec_worker"
72+
else:
73+
worker_module_name = "vllm.worker.worker"
74+
worker_class_name = "Worker"
75+
return (worker_module_name, worker_class_name)
76+
6577
def _get_create_worker_kwargs(
6678
self,
6779
local_rank: int = 0,
@@ -70,17 +82,10 @@ def _get_create_worker_kwargs(
7082
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
7183
distributed_init_method)
7284

73-
if self.scheduler_config.is_multi_step:
74-
worker_kwargs.update(
75-
worker_module_name="vllm.worker.multi_step_worker",
76-
worker_class_name="MultiStepWorker")
77-
elif self.speculative_config:
78-
worker_kwargs.update(
79-
worker_module_name="vllm.spec_decode.spec_decode_worker",
80-
worker_class_name="create_spec_worker")
81-
else:
82-
worker_kwargs.update(worker_module_name="vllm.worker.worker",
83-
worker_class_name="Worker")
85+
(worker_module_name,
86+
worker_class_name) = self._get_worker_module_and_class()
87+
worker_kwargs.update(worker_module_name=worker_module_name,
88+
worker_class_name=worker_class_name)
8489

8590
return worker_kwargs
8691

vllm/executor/ray_gpu_executor.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,19 @@ def _configure_ray_workers_use_nsight(self,
9191
return ray_remote_kwargs
9292

9393
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
94-
if self.speculative_config is not None:
95-
worker_module_name = "vllm.spec_decode.spec_decode_worker"
96-
worker_class_name = "create_spec_worker"
97-
elif self.scheduler_config.is_multi_step:
98-
worker_module_name = "vllm.worker.multi_step_worker"
99-
worker_class_name = "MultiStepWorker"
100-
else:
101-
worker_module_name = "vllm.worker.worker"
102-
worker_class_name = "Worker"
94+
(worker_module_name,
95+
worker_class_name) = self._get_worker_module_and_class()
10396

10497
return dict(
10598
worker_module_name=worker_module_name,
10699
worker_class_name=worker_class_name,
107100
trust_remote_code=self.model_config.trust_remote_code,
108101
)
109102

103+
# child class could overwrite this to return actual env vars.
104+
def _get_env_vars_to_be_updated(self):
105+
return self._env_vars_for_all_workers
106+
110107
def _init_workers_ray(self, placement_group: "PlacementGroup",
111108
**ray_remote_kwargs):
112109
if (self.parallel_config.tensor_parallel_size == 1
@@ -231,8 +228,12 @@ def sort_by_driver_then_worker_ip(worker):
231228
"VLLM_TRACE_FUNCTION":
232229
str(envs.VLLM_TRACE_FUNCTION),
233230
}, ) for (node_id, _) in worker_node_and_gpu_ids]
231+
232+
self._env_vars_for_all_workers = (
233+
all_args_to_update_environment_variables)
234+
234235
self._run_workers("update_environment_variables",
235-
all_args=all_args_to_update_environment_variables)
236+
all_args=self._get_env_vars_to_be_updated())
236237

237238
if len(node_gpus) == 1:
238239
# in single node case, we don't need to get the IP address.

0 commit comments

Comments
 (0)