-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[NVIDIA] Support nvfp4 cutlass gemm #13571
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: kaixih <kaixih@nvidia.com>
c2e2e58
to
50ac4fc
Compare
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); | ||
}; | ||
|
||
struct Fp4GemmSm100Bfloat16 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like theres alot of commonality between Fp4GemmSm100Float
, Fp4GemmSm100Half
and Fp4GemmSm100Bfloat16
, could we just template this out? i.e. Fp4GemmSm100<float>
, Fp4GemmSm100<half_t>
and Fp4GemmSm100<bfloat16_t>
, this will likely help duplication when start tuning tile sizes for perf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently there's an issue with CUTLASS 3.8 and the gcc version we use for compiling the templates. Hence, we have a dumb fix for an initial version of the fp4 gemm. Will add a to-do track this and go back to templating once that issue is fixed.
@tlrmchlsmth thoughts on renaming the |
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Again seems the failed tests are not related to this PR. PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry thanks for the updates, overall looks ok to me (left some comments for future work) but wanted to follow up on (sorry just saw this now!):
This PR requires Cutlass 3.8 (which hasn't been officially released yet) to fully function. However, it should still build using placeholder functions.
What's the motivation for landing this before 3.8 is released? to allow users to preview it using VLLM_CUTLASS_SRC_DIR
? if so we should probably add a more verbose comment in:
template <typename T>
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
at::Tensor const& A_sf, at::Tensor const& B_sf,
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
cudaStream_t stream) {
TORCH_CHECK(false, "Unsupported cutlass version");
}
on how to preview it
} | ||
|
||
template <typename T> | ||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
future work: we should see if can unify this with cutlass_gemm_caller
in csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
}; | ||
|
||
template <typename T> | ||
struct Fp4GemmSm100 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
future work: I think we should try to unify this in the future with cutlass_3x_gemm
in csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
#if defined ENABLE_NVFP4 && ENABLE_NVFP4 | ||
return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); | ||
#endif | ||
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please elaborate on this a bit, like say something like:
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should be compiled using CUDA 12.8 and target compute capability 100 or above.");
Totally makes sense |
Signed-off-by: kaixih <kaixih@nvidia.com>
Thanks for the review! I’ve addressed the comments, except for the future work related to code structure changes. @LucasWilkinson PTAL. |
@kaixih Thanks for the hard work! apologies for the long back and forth but looks like CUTLASS 3.8 just got released!! Do you think you can upgrade to that in this PR? Given that its required (and doesnt make a ton of sense to rush and land this if its not going to be functional without it) Also while we are still making changes, lets try to adopt this: #13571 (comment) |
Signed-off-by: kaixih <kaixih@nvidia.com>
@LucasWilkinson Can we focus this PR on NVFP4 support and address the code structure changes in a separate PR? |
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. | ||
# Please keep this in sync with FetchContent_Declare line below. | ||
set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use") | ||
set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed this PR didn't actually update CUTLASS to 3.8 (see line 250 below)
…t#13709) Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…t#13709) Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
Forked from #12519 (Will be closed soon), we decide to separate the fp4 quantization and fp4 gemm as two PRs. (1) fp4 quantization (PR merged); (2) fp4 gemm (This PR).
This PR requires Cutlass 3.8 (which hasn't been officially released yet) to fully function. However, it should still build using placeholder functions.
cc. @pavanimajety @kushanam