-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Quantization][FP8] Adding support for fp8 gemm layer input in fp8 #14578
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
[Quantization][FP8] Adding support for fp8 gemm layer input in fp8 #14578
Conversation
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@@ -172,6 +173,9 @@ def apply( | |||
if use_per_token_if_dynamic is None: | |||
use_per_token_if_dynamic = self.use_per_token_if_dynamic | |||
|
|||
if out_dtype is None: | |||
out_dtype = input.dtype | |||
|
|||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A | |||
if self.cutlass_fp8_supported: | |||
qinput, x_scale = ops.scaled_fp8_quant( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either add an assert
here or use the same logic as the non-cutlass case
@@ -116,6 +116,21 @@ def get_quant_method(self, layer: torch.nn.Module, | |||
return Fp8KVCacheMethod(self) | |||
return None | |||
|
|||
def get_cache_scale(self, name: str) -> Optional[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is this method used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For models with additional scales such as amd/Llama-3.1-8B-Instruct-FP8-QKV-Proj
Been broken since this de0526f#diff-48d2ca5476d5b776f6401436fcf015c5ce4dc1a23d2b78a09e08fb85acc3697cL399 change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a test for this? Or does it already exist in CI
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
@@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): | |||
|
|||
def __init__(self, strategy: str, is_static_input_scheme: bool): | |||
self.strategy = strategy | |||
self.out_dtype = torch.get_default_dtype() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ProExpertProg - are you use this is okay?
I know that throughput the models, we pipe the dtype through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't know we did that - I thought it the default dtype was used for unquantized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the unquantized dtype by design here, to do fp8 x fp8 -> half instead of half x fp8 -> half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked and we use the default dtype
in many places (attention, RMSNorm, etc.). So I think this is fine @robertgshaw2-redhat
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
73be373
to
28a1958
Compare
Thanks @gshtras - nice work. |
…llm-project#14578) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…llm-project#14578) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
…llm-project#14578) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xinyuxiao <xinyuxiao2024@gmail.com>
…llm-project#14578) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
…llm-project#14578) Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Adding support for the case where both inputs to the FP8 GEMM are in FP8 datatype and not only weights (in preparation for attention with fused FP8 conversion)
Functionality ported over from ROCm/vllm