Skip to content

[NVIDIA] Support nvfp4 cutlass gemm #13571

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 10 commits into from
Feb 22, 2025

Conversation

kaixih
Copy link
Contributor

@kaixih kaixih commented Feb 19, 2025

Forked from #12519 (Will be closed soon), we decide to separate the fp4 quantization and fp4 gemm as two PRs. (1) fp4 quantization (PR merged); (2) fp4 gemm (This PR).

This PR requires Cutlass 3.8 (which hasn't been officially released yet) to fully function. However, it should still build using placeholder functions.

cc. @pavanimajety @kushanam

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: kaixih <kaixih@nvidia.com>
@mergify mergify bot added the ci/build label Feb 19, 2025
@kaixih kaixih force-pushed the kaixih/nvfp4_scaled_mm branch from c2e2e58 to 50ac4fc Compare February 19, 2025 23:20
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
};

struct Fp4GemmSm100Bfloat16 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like theres alot of commonality between Fp4GemmSm100Float, Fp4GemmSm100Half and Fp4GemmSm100Bfloat16, could we just template this out? i.e. Fp4GemmSm100<float>, Fp4GemmSm100<half_t> and Fp4GemmSm100<bfloat16_t>, this will likely help duplication when start tuning tile sizes for perf

Copy link
Contributor

@pavanimajety pavanimajety Feb 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there's an issue with CUTLASS 3.8 and the gcc version we use for compiling the templates. Hence, we have a dumb fix for an initial version of the fp4 gemm. Will add a to-do track this and go back to templating once that issue is fixed.

@LucasWilkinson
Copy link
Collaborator

@tlrmchlsmth thoughts on renaming the cutlass_w8a8 folder to cutlass_wXaX or cutlass_scaled_mm and moving these files there?

Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
@kaixih
Copy link
Contributor Author

kaixih commented Feb 20, 2025

Again seems the failed tests are not related to this PR. PTAL.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry thanks for the updates, overall looks ok to me (left some comments for future work) but wanted to follow up on (sorry just saw this now!):

This PR requires Cutlass 3.8 (which hasn't been officially released yet) to fully function. However, it should still build using placeholder functions.

What's the motivation for landing this before 3.8 is released? to allow users to preview it using VLLM_CUTLASS_SRC_DIR? if so we should probably add a more verbose comment in:

 template <typename T>
 void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
              at::Tensor const& A_sf, at::Tensor const& B_sf,
              at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
              cudaStream_t stream) {
   TORCH_CHECK(false, "Unsupported cutlass version");
 }

on how to preview it

}

template <typename T>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future work: we should see if can unify this with cutlass_gemm_caller in csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh

};

template <typename T>
struct Fp4GemmSm100 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future work: I think we should try to unify this in the future with cutlass_3x_gemm in csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh

#if defined ENABLE_NVFP4 && ENABLE_NVFP4
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please elaborate on this a bit, like say something like:

TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should be compiled using CUDA 12.8 and target compute capability 100 or above.");

@tlrmchlsmth
Copy link
Collaborator

@tlrmchlsmth thoughts on renaming the cutlass_w8a8 folder to cutlass_wXaX or cutlass_scaled_mm and moving these files there?

Totally makes sense

Signed-off-by: kaixih <kaixih@nvidia.com>
@kaixih
Copy link
Contributor Author

kaixih commented Feb 21, 2025

Thanks for the review! I’ve addressed the comments, except for the future work related to code structure changes.

@LucasWilkinson PTAL.

@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Feb 21, 2025

@kaixih Thanks for the hard work! apologies for the long back and forth but looks like CUTLASS 3.8 just got released!! Do you think you can upgrade to that in this PR? Given that its required (and doesnt make a ton of sense to rush and land this if its not going to be functional without it)

Also while we are still making changes, lets try to adopt this: #13571 (comment)

Signed-off-by: kaixih <kaixih@nvidia.com>
@kaixih
Copy link
Contributor Author

kaixih commented Feb 21, 2025

@LucasWilkinson Can we focus this PR on NVFP4 support and address the code structure changes in a separate PR?

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) February 21, 2025 23:16
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 21, 2025
@simon-mo simon-mo merged commit e109e59 into vllm-project:main Feb 22, 2025
76 of 83 checks passed
WoosukKwon pushed a commit that referenced this pull request Feb 23, 2025
Signed-off-by: Roger Wang <ywang@roblox.com>
Comment on lines 230 to +232
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use")
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed this PR didn't actually update CUTLASS to 3.8 (see line 250 below)

Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…t#13709)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants