Skip to content

[V1][PP] Do not block engine core when no requests to schedule #14585

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,23 +205,18 @@ def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore

# If all requests are scheduled or the job queue is full,
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)

# If no more requests can be scheduled and the job queue is not empty,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: If no more requests can be scheduled now, there may be new ones added in the near future. So we have two options, one is to block here waiting for the oldest batch to finish, one is to move forward to the next loop iteration. We preferred blocking here. Would this always be the best strategy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We block here only when the batch queue is full (so all stages are busy). In this case if we schedule one more batch then the scheduling overhead of this batch can be hidden so it might be better. However, this batch cannot include any on the fly requests, so I'm not sure if this is the optimal.

# block until the first batch in the job queue is finished.
if (scheduler_output is None
or scheduler_output.total_num_scheduled_tokens == 0):
try:
future, scheduler_output = self.batch_queue.get(
timeout=POLLING_TIMEOUT_S)
# Blocking until the first result is available.
model_output = future.result()
self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
except queue.Empty:
# If the queue is empty (timeout at .get), return
# an empty EngineCoreOutputs for logging.
engine_core_outputs = EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = future.result()
Comment on lines +215 to +216
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need any timeout here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because in this case we have nothing else to do but can just wait for the first batch to finish (similar to PP=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in this case, the engine will not add any requests to the scheduler in the meantime, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but doesn't matter because all GPUs are busy. On the other hand if we no wait here and add new requests to the scheduler, we will suffer from two issues:

  1. The new added requests cannot be scheduled anyways due to no resources.
  2. If the oldest batch is finished during the time we are adding new requests, then that batch will be delayed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the explanation!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh btw what about the FSM compilation? Can it be done in parallel with the execution?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, although we won't be able to schedule the request before its FSM is compiled, we should be able to kick off the compilation first.

self.batch_queue.task_done()
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)

return engine_core_outputs

Expand Down