Skip to content

[TPU] optimize the all-reduce performance #15903

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)

if USE_RAY:
from vllm.executor import ray_utils
Expand Down Expand Up @@ -79,9 +81,12 @@ def __init__(self,

pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()

def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
return xm.all_reduce(xm.REDUCE_SUM, input_)
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)

def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
Expand Down
5 changes: 4 additions & 1 deletion vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:


if supports_custom_op():
from vllm.platforms import current_platform
direct_register_custom_op(
op_name="all_reduce",
op_func=all_reduce,
mutates_args=[],
fake_impl=all_reduce_fake,
dispatch_key=current_platform.dispatch_key,
)


Expand Down Expand Up @@ -219,7 +221,8 @@ def __init__(
self.cpu_group, 1 << 22, 6)

from vllm.platforms import current_platform
self.use_custom_op_call = current_platform.is_cuda_alike()
self.use_custom_op_call = (current_platform.is_cuda_alike()
or current_platform.is_tpu())

@property
def first_rank(self):
Expand Down
6 changes: 6 additions & 0 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ def __init__(

def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
# Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D
# ring, the xla tpu compiler flag
# `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to
# fix this. It will be removed after the bug in XLA compiler is fixed.
os.environ["LIBTPU_INIT_ARGS"] = (
"--xla_tpu_force_1d_allreduce_at_chunk_count=1")
torch.set_grad_enabled(False)
torch.set_default_dtype(self.model_config.dtype)

Expand Down