Skip to content

[Kernel] [V1] Improved performance for V1 Triton (ROCm) backend #14152

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 13 commits into from
Mar 6, 2025
135 changes: 76 additions & 59 deletions tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import math
import random
import time
from collections.abc import Callable

import pytest
import torch
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask

from vllm.attention.backends.xformers import _make_alibi_bias
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
Expand All @@ -24,6 +27,8 @@
SLIDING_WINDOW = [0, 16, 64, 128, 256, 512, 2048]
KV_CACHE_DTYPES = ["auto", "fp8", "fp8_e5m2"]

OPS = [chunked_prefill_paged_decode, context_attention_fwd]


@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
Expand All @@ -32,6 +37,7 @@
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
Expand All @@ -41,6 +47,7 @@ def test_contexted_kv_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
Expand All @@ -65,6 +72,9 @@ def test_contexted_kv_attention(
block_size = 32
max_block_per_request = 64
query_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)]
# ensure one sequence in batch is a decode
query_lens[-1] = 1

ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)]
seq_lens = [a + b for a, b in zip(query_lens, ctx_lens)]
num_kv_heads = num_heads // num_queries_per_kv
Expand Down Expand Up @@ -144,36 +154,36 @@ def test_contexted_kv_attention(

# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
Expand Down Expand Up @@ -228,7 +238,7 @@ def test_contexted_kv_attention(
end_time = time.time()
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
output_ref = output_ref.reshape(output.shape)
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-4
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)


Expand All @@ -238,6 +248,7 @@ def test_contexted_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi(
num_heads: int,
Expand All @@ -246,6 +257,7 @@ def test_contexted_kv_attention_alibi(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:

if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
Expand Down Expand Up @@ -375,36 +387,36 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:

# Warm up the Triton kernel by calling it once before actually measuring
# generation time
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
start_time = time.time()
context_attention_fwd(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
op(query,
k,
v,
output,
kv_cache_dtype,
k_cache,
v_cache,
block_table,
b_start_loc,
b_seq_len,
max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes)
torch.cuda.synchronize()
end_time = time.time()
print(f"triton Time: {(end_time - start_time)*1000:.2f} ms")
Expand Down Expand Up @@ -503,6 +515,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_f32(
num_heads: int,
Expand All @@ -512,9 +525,11 @@ def test_contexted_kv_attention_f32(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
sliding_window, dtype, kv_cache_dtype, device)
sliding_window, dtype, kv_cache_dtype, device,
op)


@pytest.mark.optional
Expand All @@ -524,6 +539,7 @@ def test_contexted_kv_attention_f32(
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("op", OPS)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
Expand All @@ -532,6 +548,7 @@ def test_contexted_kv_attention_alibi_f32(
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
op: Callable,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device)
dtype, kv_cache_dtype, device, op)
Loading