Skip to content

Commit

Permalink
Changed scheduler to use deques instead of lists (vllm-project#2290)
Browse files Browse the repository at this point in the history
Co-authored-by: Woosuk Kwon <[email protected]>
  • Loading branch information
NadavShmayo and WoosukKwon authored Jan 7, 2024
1 parent d0215a5 commit 05921a9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 24 deletions.
18 changes: 10 additions & 8 deletions vllm/core/policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from collections import deque
from typing import Deque

from vllm.sequence import SequenceGroup

Expand All @@ -15,13 +16,14 @@ def get_priority(
def sort_by_priority(
self,
now: float,
seq_groups: List[SequenceGroup],
) -> List[SequenceGroup]:
return sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
)
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))


class FCFS(Policy):
Expand Down
28 changes: 14 additions & 14 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections import deque
import enum
import time
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
Expand Down Expand Up @@ -29,7 +30,7 @@ class SchedulerOutputs:

def __init__(
self,
scheduled_seq_groups: List[SequenceGroup],
scheduled_seq_groups: Iterable[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
Expand Down Expand Up @@ -75,13 +76,12 @@ def __init__(
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window)

# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
self.swapped: Deque[SequenceGroup] = deque()

def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
Expand Down Expand Up @@ -152,7 +152,7 @@ def _schedule(self) -> SchedulerOutputs:
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
self.waiting.popleft()
continue

# If the sequence group cannot be allocated, stop.
Expand All @@ -166,7 +166,7 @@ def _schedule(self) -> SchedulerOutputs:
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
self.waiting.popleft()
continue

# If the number of batched tokens exceeds the limit, stop.
Expand All @@ -188,7 +188,7 @@ def _schedule(self) -> SchedulerOutputs:
break
seq_lens = new_seq_lens

seq_group = self.waiting.pop(0)
seq_group = self.waiting.popleft()
self._allocate(seq_group)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
Expand All @@ -214,14 +214,14 @@ def _schedule(self) -> SchedulerOutputs:
self.running = self.policy.sort_by_priority(now, self.running)

# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
victim_seq_group = self.running.pop()
self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group)
else:
Expand Down Expand Up @@ -255,7 +255,7 @@ def _schedule(self) -> SchedulerOutputs:
self.scheduler_config.max_num_seqs):
break

seq_group = self.swapped.pop(0)
seq_group = self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
Expand Down Expand Up @@ -376,7 +376,7 @@ def _preempt_by_recompute(
self.block_manager.free(seq)
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self.waiting.insert(0, seq_group)
self.waiting.appendleft(seq_group)

def _preempt_by_swap(
self,
Expand Down
6 changes: 4 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,10 @@ def _process_model_outputs(

# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
for seq_group in scheduled_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)

Expand Down

0 comments on commit 05921a9

Please sign in to comment.