@@ -91,22 +91,19 @@ def _configure_ray_workers_use_nsight(self,
91
91
return ray_remote_kwargs
92
92
93
93
def _get_worker_wrapper_args (self ) -> Dict [str , Any ]:
94
- if self .speculative_config is not None :
95
- worker_module_name = "vllm.spec_decode.spec_decode_worker"
96
- worker_class_name = "create_spec_worker"
97
- elif self .scheduler_config .is_multi_step :
98
- worker_module_name = "vllm.worker.multi_step_worker"
99
- worker_class_name = "MultiStepWorker"
100
- else :
101
- worker_module_name = "vllm.worker.worker"
102
- worker_class_name = "Worker"
94
+ (worker_module_name ,
95
+ worker_class_name ) = self ._get_worker_module_and_class ()
103
96
104
97
return dict (
105
98
worker_module_name = worker_module_name ,
106
99
worker_class_name = worker_class_name ,
107
100
trust_remote_code = self .model_config .trust_remote_code ,
108
101
)
109
102
103
+ # child class could overwrite this to return actual env vars.
104
+ def _get_env_vars_to_be_updated (self ):
105
+ return self ._env_vars_for_all_workers
106
+
110
107
def _init_workers_ray (self , placement_group : "PlacementGroup" ,
111
108
** ray_remote_kwargs ):
112
109
if (self .parallel_config .tensor_parallel_size == 1
@@ -231,8 +228,12 @@ def sort_by_driver_then_worker_ip(worker):
231
228
"VLLM_TRACE_FUNCTION" :
232
229
str (envs .VLLM_TRACE_FUNCTION ),
233
230
}, ) for (node_id , _ ) in worker_node_and_gpu_ids ]
231
+
232
+ self ._env_vars_for_all_workers = (
233
+ all_args_to_update_environment_variables )
234
+
234
235
self ._run_workers ("update_environment_variables" ,
235
- all_args = all_args_to_update_environment_variables )
236
+ all_args = self . _get_env_vars_to_be_updated () )
236
237
237
238
if len (node_gpus ) == 1 :
238
239
# in single node case, we don't need to get the IP address.
0 commit comments