Skip to content

[FEAT] [ROCm] Add AITER int8 scaled gemm kernel #15433

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 11 commits into from
Mar 29, 2025

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Mar 25, 2025

Description

This is a PR to integrate int8 scaled gemm kernel and focus on the model generated in compressed tensor format using llm-compressor.
To use this feature, set VLLM_ROCM_USE_AITER=1. (Default value of VLLM_ROCM_USE_AITER_LINEAR is 1 [enabled by default when AITER is used])

Performance Gain

Experiment setup:
GPU: 4 * MI300X
Model: neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w8a8

Input Token Length Output Token Length Kernel Used Throughput Perf Gain in % over Triton
128 128 AITER 20.2%
128 2048 AITER 22.3%
2048 128 AITER 13.6 %
2048 2048 AITER 10.9%

Accuracy Comparison: Triton Scaled MM vs AITER Scaled GEMM Kernel

Kernel Filter Exact Match Stderr
Triton flexible-extract 0.9477 0.0061
Triton strict-match 0.9439 0.0063
AIter flexible-extract 0.9477 0.0061
AIter strict-match 0.9431 0.0064

Model: neuralmagic/Meta-Llama-3.1-405B-Instruct-quantized.w8a8 (5-shot, TP=4)

Unit tests

  • tests/quantization/test_compressed_tensors.py

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa marked this pull request as ready for review March 25, 2025 13:27
@ProExpertProg
Copy link
Contributor

I will do a deeper review later but could you please use the ScaledMMKernel abstraction for this?

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@tjtanaa tjtanaa requested a review from tlrmchlsmth as a code owner March 26, 2025 06:16
@tjtanaa
Copy link
Contributor Author

tjtanaa commented Mar 26, 2025

@ProExpertProg
I have implemented a AiterScaledMMLinearKernel subclass.
It is now ready for review.

Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tjtanaa.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@ProExpertProg
Copy link
Contributor

Please resolve the merge conflicts, thanks!

Comment on lines 33 to 40
# try import aiter
try:
pass
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed supported on ROCm.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you forgot to import aiter here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed this is missing the import

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. It seems ruff remove the import aiter. I have annotated this line. Ruff will not changed it into pass.

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
@mergify mergify bot removed the needs-rebase label Mar 26, 2025
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few initial comments!

Comment on lines 310 to 312
assert qkv_proj.weight.dtype is (torch.float8_e4m3fnuz
if current_platform.is_rocm()
else torch.float8_e4m3fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use current_platform.fp8_dtype()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 148 to 151
# assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
# [M, 1])
# assert scale_b.shape == torch.Size([1, 1]) or scale_b.shape == torch.Size(
# [N, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clean up comments

"triton_scaled_mm")
triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm
return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
if is_rocm_aiter_gemm_w8a8_scaled_mm_enabled():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't live inside cutlass_scaled_mm, has nothing to do with cutlass. This code should just live inside AiterScaledMMLinearKernel.apply.

I know the Triton kernel is here but it shouldn't be either, I'm currently refactoring that.

Comment on lines 23 to 28
if current_platform.is_cpu():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"currently supported on CPU.")
if not current_platform.is_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be a single check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 33 to 40
# try import aiter
try:
pass
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed supported on ROCm.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed this is missing the import

@@ -57,6 +71,11 @@ def use_v0_only(monkeypatch):
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a bit confusing. What models are and aren't supported by aiter vs Triton?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER only supports per-channel-per-channel INT8 gemm and per-tensor-per-tensor INT8 GEMM. It does not support mix precision MM and mix quantization scheme.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add that as a comment?

@pytest.mark.parametrize("num_logprobs", [10])
@pytest.mark.skipif(not current_platform.is_rocm(),
reason="This tests is skipped on non-ROCm platform.")
def test_compressed_tensors_w8a8_logprobs_rocm_aiter(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be folkded into the existing tests, by adding a boolean use_aiter parameter in the tests? And we can do [False] if <platform ...> else [False, True]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

tjtanaa added 2 commits March 27, 2025 16:07
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link
Contributor

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments, thanks for working with me on this!

@@ -20,25 +27,20 @@ def get_min_capability(cls) -> int:
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if current_platform.is_cpu():
if current_platform.is_cpu() or not current_platform.is_rocm():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can't this be simpler?

Suggested change
if current_platform.is_cpu() or not current_platform.is_rocm():
if not current_platform.is_rocm():

Copy link
Contributor Author

@tjtanaa tjtanaa Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I have removed the check for CPU.

@@ -57,6 +71,11 @@ def use_v0_only(monkeypatch):
)
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add that as a comment?

from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig


def is_rocm_aiter_gemm_w8a8_scaled_mm_enabled() -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method should just be inlined to the sole callsite (unless I'm missing another use)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved.

Comment on lines 97 to 98
per_channel_tensor_scale_a = (x_s.numel() == m)
per_channel_tensor_scale_b = (w_s.numel() == n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this more accurate here?

Suggested change
per_channel_tensor_scale_a = (x_s.numel() == m)
per_channel_tensor_scale_b = (w_s.numel() == n)
per_token_scale_a = (x_s.numel() == m)
per_channel_scale_b = (w_s.numel() == n)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you are right. I have made the amendments. Thank you so much.

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
"installed supported on ROCm.")
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (
current_platform.is_rocm() \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already checked this above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. removed current_platform.is_rocm()

" ATIER block scaled GEMM yet.")

from aiter import gemm_a8w8_CK
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious for future work: does this kernel support fp8?

Also, can you add a comment why w_q needs to be transposed here? I assume because it's using the Cutlass prepare weights which are transposed so here we restore it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm/aiter does not support FP8 at this moment.
I have added the comment.

tjtanaa added 2 commits March 28, 2025 17:49
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Comment on lines 103 to 109
# per-channel-per-channel a8w8 scacled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-channel-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" ATIER block scaled GEMM yet.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typos/inaccuracies:

Suggested change
# per-channel-per-channel a8w8 scacled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-channel-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" ATIER block scaled GEMM yet.")
# per-token-per-channel a8w8 scaled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `cutlass_scaled_mm` does not support" +
" AITER block scaled GEMM yet.")

And then what does this mean: cutlass_scaled_mm does not support AITER block scaled GEMM yet."?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. This is the modified version

        # @TODO:
        # Maybe broadcast the per-tensor-scale into per-channel-scale
        # if one of the scale is a per-channel-scale.
        # For now, it only supports:
        # - per-tensor-per-tensor a8w8 scaled GEMM, and
        # - per-token-per-channel a8w8 scaled GEMM
        assert ((per_tensor_scale_a and per_tensor_scale_b)
                or (per_token_scale_a and per_channel_scale_b)), (
                    "Currently only support per-tensor-per-tensor GEMM " +
                    " and per-token-per-channel GEMM through AITER"
                    " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
                    "does not support AITER block scaled GEMM.")

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice work!

@mgoin mgoin added rocm Related to AMD ROCm quantization ready ONLY add when PR is ready to merge/full CI is needed labels Mar 28, 2025
@mgoin mgoin enabled auto-merge (squash) March 28, 2025 18:51
@vllm-bot vllm-bot merged commit 4965ec4 into vllm-project:main Mar 29, 2025
42 of 44 checks passed
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Apr 2, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
Alex4210987 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Apr 5, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
quantization ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants