-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[TPU][V1] Fix Sampler recompilation #15309
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
Conversation
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Hi @NickLucche , one thing I noticed is that the host overhead becomes larger. Previously the average is 3.5ms, now the average is about 5ms. You can set the environemnt variable |
Hi @NickLucche I found sampler test is not added to TPU CI. Also it's running in |
Thanks for looking into it. I had to move the sampling pre-processing (slicing) from on device to host. I'll look into optimizing it, but on GPU we just keep everything on device and do the slicing there, which may cause recompilation depending on the number of reqs being scheduled. TL;DR It's a trade-off in integrating existing logic.
Yes but I'd rather change the logic then, probably merge the two tests into one. |
… code Signed-off-by: NickLucche <nlucches@redhat.com>
@lsy323 there's still a very small graph being re-compiled at runtime that I haven't been able to track down. I have to hold on that test until it is figured out. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just needs some cleanup. Have you validated a performance win smoke test?
indices = torch.zeros( | ||
num_reqs_to_sample, | ||
dtype=torch.int32, | ||
device=device, | ||
) | ||
xm.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth leaving a comment why this isn't at the end of the loop
@NickLucche , we can use |
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Wes Medford <wryanmedford@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Fix XLA recompilations introduced .
Namely it factors out the on-device slicing that is happening inside
InputBatch._make_sampling_metadata
as well as an issue with XLA not pre-compilingsample_from_hidden
when its output isn't moved to cpu.Update:
input_batch
to reduce wasteTPUSupportedSamplingMetadata
is now a simpler wrapper around tensors managed in input_batch.