Skip to content

[TPU][V1] Fix Sampler recompilation #15309

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 3 commits into from
Mar 25, 2025
Merged

Conversation

NickLucche
Copy link
Contributor

@NickLucche NickLucche commented Mar 21, 2025

Fix XLA recompilations introduced .
Namely it factors out the on-device slicing that is happening inside InputBatch._make_sampling_metadata as well as an issue with XLA not pre-compiling sample_from_hidden when its output isn't moved to cpu.

Update:

  • Persistent sampling metadata tensors, re-using the ones in input_batch to reduce waste
  • (hopefully) substantially simplified code. TPUSupportedSamplingMetadata is now a simpler wrapper around tensors managed in input_batch.

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 21, 2025
@yaochengji
Copy link
Collaborator

Hi @NickLucche , one thing I noticed is that the host overhead becomes larger.

Previously the average is 3.5ms, now the average is about 5ms.

You can set the environemnt variable VLLM_TORCH_PROFILER_DIR and add --profile in benchmark_script.py to enable profiling.

@lsy323
Copy link
Collaborator

lsy323 commented Mar 22, 2025

Hi @NickLucche I found sampler test is not added to TPU CI. Also it's running in enforce_eager, could we turn off enforce_eager so that recompilation can be checked in our tests?

@NickLucche
Copy link
Contributor Author

NickLucche commented Mar 24, 2025

Hi @NickLucche , one thing I noticed is that the host overhead becomes larger.

Thanks for looking into it. I had to move the sampling pre-processing (slicing) from on device to host. I'll look into optimizing it, but on GPU we just keep everything on device and do the slicing there, which may cause recompilation depending on the number of reqs being scheduled. TL;DR It's a trade-off in integrating existing logic.

Also it's running in enforce_eager, could we turn off enforce_eager so that recompilation can be checked in our tests?

Yes but I'd rather change the logic then, probably merge the two tests into one.

… code

Signed-off-by: NickLucche <nlucches@redhat.com>
@NickLucche NickLucche marked this pull request as ready for review March 25, 2025 15:54
@NickLucche
Copy link
Contributor Author

@lsy323 there's still a very small graph being re-compiled at runtime that I haven't been able to track down. I have to hold on that test until it is figured out.

@mgoin mgoin added tpu Related to Google TPUs ready ONLY add when PR is ready to merge/full CI is needed bug Something isn't working labels Mar 25, 2025
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just needs some cleanup. Have you validated a performance win smoke test?

indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
xm.mark_step()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth leaving a comment why this isn't at the end of the loop

@yaochengji
Copy link
Collaborator

@lsy323 there's still a very small graph being re-compiled at runtime that I haven't been able to track down. I have to hold on that test until it is figured out.

@NickLucche , we can use PT_XLA_DEBUG_LEVEL=2 to get where the recompilation is triggered. And please remember to clear the xla compilation cache before the execution.

@mgoin mgoin merged commit a0dd7dc into vllm-project:main Mar 25, 2025
40 checks passed
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Wes Medford <wryanmedford@gmail.com>
lengrongfu pushed a commit to lengrongfu/vllm that referenced this pull request Apr 2, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: NickLucche <nlucches@redhat.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants