Skip to content

[Kernel] LoRA - Enable CUDAGraphs for V1 #14626

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 6 commits into from
Mar 14, 2025

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Mar 11, 2025

Enable CUDAGraphs support for V1 LoRA

Related issue: #10617

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

1, 0)
embeddings_indices = torch.narrow(
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ changes are to avoid errors such as,

  raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['input_ids'].size()[0], L['positions'].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['input_ids'].size()[0]) are valid because L['input_ids'].size()[0] was inferred to be a constant (8192).
  - Not all values of RelaxedUnspecConstraint(L['positions'].size()[0]) are valid because L['positions'].size()[0] was inferred to be a constant (8192).

full_output = self.base_layer.forward(
x.add_(indices * added_tokens_mask))
full_output = self.base_layer.forward(x +
(indices * added_tokens_mask))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x here is the input_ids. In V1, we don't zero out the cuda graph pad region.
Avoid the in-place update here to prevent accumulating garbage into the input buffer.

vllm/config.py Outdated
vllm_factors.append(
hashlib.md5(
str(self.scheduler_config.max_num_batched_tokens).encode()
).hexdigest())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During torch.compile, LoRA static buffers like in

self._token_lora_indices = torch.empty(max_num_batched_tokens,
and
token_lora_mapping = torch.empty(max_num_tokens,
get captured along with their sizes and strides (they aren't dynamic)

When max_num_batched_tokens changes, and when the captured graph is executed, we hit assert_size_stride asserts on these tensors. As a solution, we simply recompile when max_num_batched_tokens change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str(self.scheduler_config.max_num_batched_tokens) should be enough. you only need to add a string into factors. no need to hash it here.

@varun-sundar-rabindranath
Copy link
Contributor Author

LoRA TP test times increase from 43m to 51m . I believe this is mostly coming from CUDAGraph capture model for V1.
The test times of other LoRA test are relatively stable.

y = self._apply_bias(self.token_lora_indices, y, output_slices,
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
y.size(0))
y = self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ Changes are to avoid errors such as,

  raise ConstraintViolationError(
torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['input_ids'].size()[0], L['positions'].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['input_ids'].size()[0]) are valid because L['input_ids'].size()[0] was inferred to be a constant (8192).
  - Not all values of RelaxedUnspecConstraint(L['positions'].size()[0]) are valid because L['positions'].size()[0] was inferred to be a constant (8192).

@varun-sundar-rabindranath
Copy link
Contributor Author

@jeejeelee @youkaichao @bnellnm @ProExpertProg Can you please take a look when you get a chance ! Thanks 🙏

@jeejeelee jeejeelee requested a review from youkaichao March 13, 2025 02:06
@jeejeelee
Copy link
Collaborator

Thank you for your outstanding work. The following are my local test results.

  • v1 eager vs cudagraph
    img_v3_02kb_110d4863-44fb-4ceb-84ea-1d6092910f6g
  • V0 cudagraph vs V1 cudagrpah
    img_v3_02kb_3d44d216-33dd-4d0e-a8f0-caaedba67b0g

@varun-sundar-rabindranath
Copy link
Contributor Author

Thanks @jeejeelee for running this 🙌

I believe #14685 should help the V0 case 👍

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is surprising, I didn't expect lora code can be traced by Dynamo correctly. how do you pass the lora metadata, e.g. which lora is used for which request?

@varun-sundar-rabindranath
Copy link
Contributor Author

this is surprising, I didn't expect lora code can be traced by Dynamo correctly. how do you pass the lora metadata, e.g. which lora is used for which request?

Hi @youkaichao - the metadata update is done like before, in punica_gpu.py .
This is the metadata file - https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/v1/v1_kernel_metadata.py
the update happens here -

self._v1_prepare_metadata_tensors(self.token_lora_indices,

I did run into issues with the dynamic shape tracing. I have added comments in such places in the PR. fortunately there were few.
There were cases where the metadata tensors are traced with their shapes and strides. The shapes were all dependent on max_num_batched_tokens. I resorted to recompilation in such cases https://github.com/vllm-project/vllm/pull/14626/files#r1990059792

@@ -3443,12 +3455,6 @@ def __post_init__(self):
" Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION

if self.lora_config is not None and self.compilation_config.level !=\
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is V0 lora compatible with torch.compile ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment. The SGMV ops input some forward pass specific metadata, such as token_nums and max_seq_length as python integers and IIUC, during tracing these are captured as constants but they shouldn't be.

The WIth the lora/layers.py changes in this PR and with #14685 , V0 LoRA should become compatible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the information. then can you keep the assert in v0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. re-introduced the check for V0 👍

@@ -237,16 +237,19 @@ def set_lora(
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)

def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
embeddings_indices = self.punica_wrapper.embeddings_indices
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is the culprit. it is actually a function, and slices a tensor using a python int, which will fail the symbolic shape compilation. changing to torch.narrow with x.size(0) is the correct fix 👍

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the change makes sense to me. leave it to @jeejeelee to verify the correctness.

Copy link
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some general comments:

for a computation graph to be compatible with vLLM's torch.compile integration, all the output/input tensors of all operation seen by the pytorch compiler (to be specific, Dynamo), must share the same dynamic shape, and all the other shapes must be static.

if you want to slice a tensor along a certain dimension to be num_tokens, you cannot use a python int, but should use x.size(0)

@ProExpertProg
Copy link
Contributor

@varun-sundar-rabindranath do you know why V1 CUDA graph TPOT is worse than V1 eager?

@varun-sundar-rabindranath
Copy link
Contributor Author

@varun-sundar-rabindranath do you know why V1 CUDA graph TPOT is worse than V1 eager?

Hi @ProExpertProg , where are you seeing this ?

@ProExpertProg
Copy link
Contributor

ProExpertProg commented Mar 13, 2025

Thank you for your outstanding work. The following are my local test results.

  • v1 eager vs cudagraph
    img_v3_02kb_110d4863-44fb-4ceb-84ea-1d6092910f6g
  • V0 cudagraph vs V1 cudagrpah
    img_v3_02kb_3d44d216-33dd-4d0e-a8f0-caaedba67b0g

Here

Oops, did not realize this was toks/s, not seconds.

Varun Sundar Rabindranath added 5 commits March 13, 2025 10:13
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@@ -52,6 +52,7 @@ def set_active_loras(worker: Union[Worker, V1Worker],
seed=0,
dtype="float16",
revision=None,
enforce_eager=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you planning on keeping this eager?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it testing code that should be removed before this pr is ready?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intend to keep it. The CI test was running out of memory, which I assume is because of the cudagraph capture.

also, that specific test, doesn't actually run the model.

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
@jeejeelee jeejeelee enabled auto-merge (squash) March 14, 2025 01:56
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 14, 2025
@vllm-bot vllm-bot merged commit 0b1cfa6 into vllm-project:main Mar 14, 2025
51 of 53 checks passed
richardsliu pushed a commit to richardsliu/vllm that referenced this pull request Mar 14, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Richard Liu <ricliu@google.com>
@youkaichao youkaichao deleted the varun/v1-lora-cudagraph branch March 19, 2025 15:54
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants