-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[FEAT] [ROCm] Add AITER int8 scaled gemm kernel #15433
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
Conversation
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
I will do a deeper review later but could you please use the |
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@ProExpertProg |
This pull request has merge conflicts that must be resolved before it can be |
Please resolve the merge conflicts, thanks! |
# try import aiter | ||
try: | ||
pass | ||
except Exception: | ||
return ( | ||
False, | ||
"AiterScaledMMLinearKernel requires `aiter` which is not " + | ||
"installed supported on ROCm.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems you forgot to import aiter here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed this is missing the import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. It seems ruff remove the import aiter
. I have annotated this line. Ruff will not changed it into pass
.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few initial comments!
assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz | ||
if current_platform.is_rocm() | ||
else torch.float8_e4m3fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use current_platform.fp8_dtype()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
# assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size( | ||
# [M, 1]) | ||
# assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size( | ||
# [N, 1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please clean up comments
vllm/_custom_ops.py
Outdated
"triton_scaled_mm") | ||
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm | ||
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) | ||
if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't live inside cutlass_scaled_mm
, has nothing to do with cutlass. This code should just live inside AiterScaledMMLinearKernel.apply
.
I know the Triton kernel is here but it shouldn't be either, I'm currently refactoring that.
if current_platform.is_cpu(): | ||
return ( | ||
False, | ||
"AiterScaledMMLinearKernel requires `aiter` which is not " + | ||
"currently supported on CPU.") | ||
if not current_platform.is_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be a single check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# try import aiter | ||
try: | ||
pass | ||
except Exception: | ||
return ( | ||
False, | ||
"AiterScaledMMLinearKernel requires `aiter` which is not " + | ||
"installed supported on ROCm.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah agreed this is missing the import
@@ -57,6 +71,11 @@ def use_v0_only(monkeypatch): | |||
) | |||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): | |||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is a bit confusing. What models are and aren't supported by aiter vs Triton?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AITER only supports per-channel-per-channel INT8 gemm and per-tensor-per-tensor INT8 GEMM. It does not support mix precision MM and mix quantization scheme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add that as a comment?
@pytest.mark.parametrize("num_logprobs", [10]) | ||
@pytest.mark.skipif(not current_platform.is_rocm(), | ||
reason="This tests is skipped on non-ROCm platform.") | ||
def test_compressed_tensors_w8a8_logprobs_rocm_aiter( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be folkded into the existing tests, by adding a boolean use_aiter
parameter in the tests? And we can do [False] if <platform ...> else [False, True]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments, thanks for working with me on this!
@@ -20,25 +27,20 @@ def get_min_capability(cls) -> int: | |||
@classmethod | |||
def can_implement( | |||
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: | |||
if current_platform.is_cpu(): | |||
if current_platform.is_cpu() or not current_platform.is_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't this be simpler?
if current_platform.is_cpu() or not current_platform.is_rocm(): | |
if not current_platform.is_rocm(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I have removed the check for CPU.
@@ -57,6 +71,11 @@ def use_v0_only(monkeypatch): | |||
) | |||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): | |||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add that as a comment?
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig | ||
|
||
|
||
def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method should just be inlined to the sole callsite (unless I'm missing another use)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Resolved.
per_channel_tensor_scale_a = (x_s.numel() == m) | ||
per_channel_tensor_scale_b = (w_s.numel() == n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this more accurate here?
per_channel_tensor_scale_a = (x_s.numel() == m) | |
per_channel_tensor_scale_b = (w_s.numel() == n) | |
per_token_scale_a = (x_s.numel() == m) | |
per_channel_scale_b = (w_s.numel() == n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes you are right. I have made the amendments. Thank you so much.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
"installed supported on ROCm.") | ||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled | ||
if not ( | ||
current_platform.is_rocm() \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already checked this above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. removed current_platform.is_rocm()
" ATIER block scaled GEMM yet.") | ||
|
||
from aiter import gemm_a8w8_CK | ||
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious for future work: does this kernel support fp8?
Also, can you add a comment why w_q needs to be transposed here? I assume because it's using the Cutlass prepare weights which are transposed so here we restore it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ROCm/aiter
does not support FP8 at this moment.
I have added the comment.
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
# per-channel-per-channel a8w8 scacled GEMM | ||
assert ((per_tensor_scale_a and per_tensor_scale_b) | ||
or (per_token_scale_a and per_channel_scale_b)), ( | ||
"Currently only support per-tensor-per-tensor GEMM " + | ||
" and per-channel-per-channel GEMM through AITER" | ||
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + | ||
" ATIER block scaled GEMM yet.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typos/inaccuracies:
# per-channel-per-channel a8w8 scacled GEMM | |
assert ((per_tensor_scale_a and per_tensor_scale_b) | |
or (per_token_scale_a and per_channel_scale_b)), ( | |
"Currently only support per-tensor-per-tensor GEMM " + | |
" and per-channel-per-channel GEMM through AITER" | |
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + | |
" ATIER block scaled GEMM yet.") | |
# per-token-per-channel a8w8 scaled GEMM | |
assert ((per_tensor_scale_a and per_tensor_scale_b) | |
or (per_token_scale_a and per_channel_scale_b)), ( | |
"Currently only support per-tensor-per-tensor GEMM " + | |
" and per-token-per-channel GEMM through AITER" | |
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" + | |
" AITER block scaled GEMM yet.") |
And then what does this mean: cutlass_scaled_mm does not support AITER block scaled GEMM yet."
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. This is the modified version
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
"does not support AITER block scaled GEMM.")
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice work!
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Description
This is a PR to integrate int8 scaled gemm kernel and focus on the model generated in compressed tensor format using llm-compressor.
To use this feature, set
VLLM_ROCM_USE_AITER=1
. (Default value ofVLLM_ROCM_USE_AITER_LINEAR
is1
[enabled by default when AITER is used])Performance Gain
Experiment setup:
GPU: 4 * MI300X
Model:
neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w8a8
Accuracy Comparison: Triton Scaled MM vs AITER Scaled GEMM Kernel
Model: neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w8a8 (5-shot, TP=4)
Unit tests
tests/quantization/test_compressed_tensors.py