Skip to content

[V1] Implement sliding window attention in kv_cache_manager #14097

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 41 commits into from
Apr 1, 2025

Conversation

heheda12345
Copy link
Collaborator

@heheda12345 heheda12345 commented Mar 2, 2025

Build on top of #14079, should be merged after it.

This pr supports “real” sliding window in v1:

  1. Support dropping blocks outside sliding window
  2. For prefix caching, only requires the last sliding_window tokens to be cached to achieve a prefix cache hit. e.g., for request ABCDE with sliding window size 2 & block_size 1, if DE are cached while ABC are not, we can still regard ABCDE as the cached prefix.

For models with global attention + sliding window attention, still regard as global-attention-only model in kv cache manager.

Some questions in #13296
It isn’t compatible with cascade attention yet but should be correct due to

if use_alibi or use_sliding_window:

Q: How does it work with chunked prefill?
A: It will allocate blocks for tokens that will be computed in the current step, and free the blocks that outside sliding window in the next step.
Assume window size 1k, chunk size 2k, prompt length 4k, block_size=1

  1. chunk prefill of 2k tokens: block table: [0:2k]
  2. chunk prefill of 2k tokens:
    1. Free the first 1k blocks as they won’t be used after step 1. block_table becomes [null_block*1000] + [1k:2k]
    2. Allocate blocks for the [2k:4k] tokens that will be computed. block_table becomes [null_block*1000] + [1k:4k]
  3. First decode step
    1. Free the [1k:3k] blocks as they won’t be used after step 2. block_table becomes [null_block*3000] + [3k:4k]
    2. Allocate a new slot for decoding, block_table becomes [null_block*3000] + [3k:4001]

Q: What's the shape of the block table for SWA? Is it append-only?
A: It is with the same length as global attention, but changes the blocks outside the sliding window to a special null_block. This replacement only happens in the kv_cache_manager side. As model_runner’s block_table is append-only, we do not replace existing blocks to null blocks in model runner. The result is correct because model runner won’t access the blocks outside sliding window.

This pr is part of the hybrid allocator #11382

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

github-actions bot commented Mar 2, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 2, 2025
Copy link

mergify bot commented Mar 5, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 5, 2025
@WoosukKwon
Copy link
Collaborator

@zhuohan123 Did you have a chance to take a look?

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WIP partial review. Will add more tmrw

Comment on lines 659 to 664
# Verify that the virtual layers of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for virtual_layer1, virtual_layer2 in zip(
kv_cache_configs[0].virtual_layers,
kv_cache_config.virtual_layers):
assert virtual_layer1.kv_cache_spec == virtual_layer2.kv_cache_spec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Will pipeline parallelism fail this assert for some models?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail when different stages have different type of layers. For hybrid models, just throw an error as the first step. For non-hybrid models, the assert won't fail.

This function is introduced in https://github.com/vllm-project/vllm/pull/14079/files, should we discuss it there?

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Collaborator Author

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review! Replied to some questions, will update the code after #14079

Comment on lines 659 to 664
# Verify that the virtual layers of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for virtual_layer1, virtual_layer2 in zip(
kv_cache_configs[0].virtual_layers,
kv_cache_config.virtual_layers):
assert virtual_layer1.kv_cache_spec == virtual_layer2.kv_cache_spec
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will fail when different stages have different type of layers. For hybrid models, just throw an error as the first step. For non-hybrid models, the assert won't fail.

This function is introduced in https://github.com/vllm-project/vllm/pull/14079/files, should we discuss it there?

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link

mergify bot commented Mar 29, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @heheda12345.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 29, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@mergify mergify bot removed the needs-rebase label Mar 29, 2025
@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 30, 2025

As we discussed offline, I think we need a clear separation of the two APIs of SpecializedManager:

  1. The first API describing how to free the KV cache for a request. This API is invoked whenever the request calls allocate_slots.
  2. The second API describing how to check the prefix cache hits. This API is invoked in get_computed_blocks.

Two particular things I want to fix in this PR are that

  1. The first API is also used in get_computed_blocks, blurring the separation.
  2. The second API does not provide "efficient intersection" or "early stopping". I think this can be addressed by adding an extended API (e.g., get longest prefix).

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Collaborator Author

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've change the prefix caching API in SpecializedManager to find_longest_cache_hit. Seems that most complexity can be fixed after that. This find_longest_cache_hit also works for hybrid allocator as long as we can avoid the recomputation between the multiple calls of it for the same request.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great simplification. I love it!

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the awesome work! I’m really happy with how the final APIs and simplifications turned out 🔥🔥

Also, huge thanks for your incredible patience throughout all the back-and-forth edits and long discussions. Really appreciate it!

@WoosukKwon WoosukKwon added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 31, 2025
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
@WoosukKwon WoosukKwon merged commit 3a5f0af into vllm-project:main Apr 1, 2025
33 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Apr 2, 2025
…ject#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
…ject#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…ject#14097)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants