Skip to content

[V1][PP] Do not block engine core when no requests to schedule #14585

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 2 commits into from
Mar 11, 2025

Conversation

comaniac
Copy link
Collaborator

@comaniac comaniac commented Mar 11, 2025

The current engine step with batch queue blocks the busy loop for POLLING_TIMEOUT_S when the batch queue is empty. This results in unstable and possibly longer TTFT when the first request comes in. This PR solves the issue.

  • Serving command on L4 GPUs (with and without VLLM_USE_V1=1
VLLM_USE_V1=1 vllm serve unsloth/Llama-3.1-8B-Instruct \
--no-enable-prefix-caching \
--distributed-executor-backend="ray" \
--max-model-len=8192 \
--pipeline-parallel-size=2
  • Benchmark command
python benchmarks/benchmark_serving.py \
    --backend vllm \
    --model unsloth/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json \
    --request-rate 2 \
    --num-prompts 200

V0

============ Serving Benchmark Result ============
Successful requests:                     200
Benchmark duration (s):                  143.01
Total input tokens:                      42659
Total generated tokens:                  43128
Request throughput (req/s):              1.40
Output token throughput (tok/s):         301.57
Total Token throughput (tok/s):          599.86
---------------Time to First Token----------------
Mean TTFT (ms):                          157.58
Median TTFT (ms):                        150.50
P99 TTFT (ms):                           289.79
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          94.65
Median TPOT (ms):                        96.47
P99 TPOT (ms):                           115.44
---------------Inter-token Latency----------------
Mean ITL (ms):                           93.56
Median ITL (ms):                         86.70
P99 ITL (ms):                            277.66
==================================================

V1 (main)

============ Serving Benchmark Result ============
Successful requests:                     200
Benchmark duration (s):                  138.58
Total input tokens:                      42659
Total generated tokens:                  43325
Request throughput (req/s):              1.44
Output token throughput (tok/s):         312.63
Total Token throughput (tok/s):          620.45
---------------Time to First Token----------------
Mean TTFT (ms):                          201.24
Median TTFT (ms):                        151.40
P99 TTFT (ms):                           1611.02
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          80.32
Median TPOT (ms):                        81.41
P99 TPOT (ms):                           92.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.94
Median ITL (ms):                         76.87
P99 ITL (ms):                            202.60
==================================================

V1 (this PR)

============ Serving Benchmark Result ============
Successful requests:                     200
Benchmark duration (s):                  138.82
Total input tokens:                      42659
Total generated tokens:                  43463
Request throughput (req/s):              1.44
Output token throughput (tok/s):         313.10
Total Token throughput (tok/s):          620.40
---------------Time to First Token----------------
Mean TTFT (ms):                          159.05
Median TTFT (ms):                        149.20
P99 TTFT (ms):                           288.27
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          80.24
Median TPOT (ms):                        81.46
P99 TPOT (ms):                           93.69
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.99
Median ITL (ms):                         76.96
P99 ITL (ms):                            205.92
==================================================

cc @ruisearch42

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
@comaniac comaniac added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 11, 2025
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 11, 2025
Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Copy link
Collaborator

@ruisearch42 ruisearch42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix

scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)

# If no more requests can be scheduled and the job queue is not empty,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: If no more requests can be scheduled now, there may be new ones added in the near future. So we have two options, one is to block here waiting for the oldest batch to finish, one is to move forward to the next loop iteration. We preferred blocking here. Would this always be the best strategy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We block here only when the batch queue is full (so all stages are busy). In this case if we schedule one more batch then the scheduling overhead of this batch can be hidden so it might be better. However, this batch cannot include any on the fly requests, so I'm not sure if this is the optimal.

Comment on lines +215 to +216
# Blocking until the first result is available.
model_output = future.result()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need any timeout here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because in this case we have nothing else to do but can just wait for the first batch to finish (similar to PP=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but in this case, the engine will not add any requests to the scheduler in the meantime, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but doesn't matter because all GPUs are busy. On the other hand if we no wait here and add new requests to the scheduler, we will suffer from two issues:

  1. The new added requests cannot be scheduled anyways due to no resources.
  2. If the oldest batch is finished during the time we are adding new requests, then that batch will be delayed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Thanks for the explanation!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh btw what about the FSM compilation? Can it be done in parallel with the execution?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess so, although we won't be able to schedule the request before its FSM is compiled, we should be able to kick off the compilation first.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice optimization!

@WoosukKwon WoosukKwon merged commit 4290b70 into vllm-project:main Mar 11, 2025
29 of 30 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…project#14585)

Signed-off-by: Cody Yu <hao.yu.cody@gmail.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants