Skip to content

[TPU] optimize the all-reduce performance #15903

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged

Conversation

yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Apr 1, 2025

Before the PR, the all-reduce performance is not optimal due to two reasons:

  • on v6e-8, XLA compiler accidentally apply 2D-ring strategy, while 1D-ring is expected
  • The ring-order cannot be automatically adjusted

Copy link

github-actions bot commented Apr 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@yaochengji yaochengji marked this pull request as draft April 1, 2025 20:10
@yaochengji yaochengji requested a review from alexm-redhat April 1, 2025 20:10
@mergify mergify bot added the tpu Related to Google TPUs label Apr 1, 2025
@yaochengji
Copy link
Collaborator Author

I have two questions:

  • Currently I wrap the all-reduce into a pytorch python custom op to make it compatible with dynamo, should I wrap it in vllm/distributed/parallel_state.py instead?
  • I need to another environment variable LIBTPU_INIT_ARGS="--xla_tpu_force_1d_allreduce_at_chunk_count=1" for the performance optimization, should I add it in the code?

@alexm-redhat
Copy link
Collaborator

alexm-redhat commented Apr 1, 2025

@yaochengji thanks for this important PR!

  1. About dynamo, I don't have a strong opinion there.
  2. About libtpu flag, I think you can detect inside init_device() of tpu_worker that it is a v6 and simply add the env var. Similar to the code below (that adds PJRT_DEVICE flag):
def init_device(self):
        os.environ["PJRT_DEVICE"] = "TPU"

@yaochengji yaochengji requested a review from youkaichao April 1, 2025 20:27
@yaochengji
Copy link
Collaborator Author

Thanks @alexm-redhat for your suggestion.

Hi @youkaichao , do you have any suggestion on the dynamo custom op?

@yaochengji
Copy link
Collaborator Author

My local experiments show that the throughput can improve from ~4.2 reqs/s to ~4.9reqs/s for Llama 70B on 8 v6e.

@yaochengji yaochengji marked this pull request as ready for review April 2, 2025 06:10
@mergify mergify bot added the v1 label Apr 2, 2025

if USE_RAY:
from vllm.executor import ray_utils


@torch.library.custom_op("tpu::all_reduce", mutates_args=())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use custom op call in vllm directly? e.g. extending this line

self.use_custom_op_call = current_platform.is_cuda_alike()
to include tpu.

i didn't do it initially, because i remember tpu has some custom dynamo-related logic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for you suggestion!

I removed the custom op in tpu_communicator.py and made use of the custom op for TPU in parallel_state.py

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@alexm-redhat alexm-redhat enabled auto-merge (squash) April 2, 2025 17:06
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 2, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Chengji Yao <chengjiyao@google.com>
auto-merge was automatically disabled April 2, 2025 22:04

Head branch was pushed to by a user without write access

@yaochengji yaochengji force-pushed the chengji/optimize-allreduce branch from 5edfa49 to 91628ee Compare April 2, 2025 22:04
@robertgshaw2-redhat
Copy link
Collaborator

Magical incantations!

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) April 3, 2025 00:15
@robertgshaw2-redhat
Copy link
Collaborator

NOTE: V1 test failing fixed by #15969

@robertgshaw2-redhat robertgshaw2-redhat merged commit 01b6113 into vllm-project:main Apr 3, 2025
38 checks passed
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: Chengji Yao <chengjiyao@google.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants