Skip to content

[Model] RowParallelLinear: pass bias to quant_method.apply #6327

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 7 commits into from
Jul 19, 2024

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Jul 11, 2024

I think for TP=1 case we can pass the bias to the apply function, which should lead to more use of the fused CUTLASS kernels.

cc @tlrmchlsmth @cyang49

Update: this is now implemented for TP>1 too

tdoublep added 2 commits July 11, 2024 04:23
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep
Copy link
Member Author

CI failures are a bit weird (see #6332)

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. I see that for the TP > 1 case we can't naively fuse the bias, since each rank will add the bias to its output and then the biases will be redundantly summed during the AllReduce.

Another thing to try is to pick one of the ranks to apply the bias. The overhead of applying a bias during a GEMM is low enough that we probably don't need to worry load imbalance, and this will let every rank skip a load and store of the activation tensor.

which should lead to more use of the fused CUTLASS kernels.

FYI this won't lead to more use of the CUTLASS kernels. Currently they're used whenever there is no bias. The issue

@@ -753,18 +753,23 @@ def forward(self, input_):

# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
bias_ = None if (self.tp_size > 1 or self.skip_bias_add) else self.bias
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining what happens when tp_size > 1? I think it should say something like:

# Fuse bias into the GEMM in the TP == 1 case.
# The TP > 1 case is problematic as the bias will be redundantly summed during the AllReduce if fused

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment (although it reads a bit differently since the TP>1 case is now handled).

@tlrmchlsmth
Copy link
Collaborator

CI failures are a bit weird (see #6332)

If the bias is fused into the GEMM, it will likely be added to the output at a higher precision. For example in our CUTLASS fp8 kernels, the bias add happens in fp32. On main the output will be converted to the output dtype before adding the bias. That could explain the test failure you're seeing.

@tdoublep
Copy link
Member Author

For example in our CUTLASS fp8 kernels, the bias add happens in fp32. On main the output will be converted to the output dtype before adding the bias. That could explain the test failure you're seeing.

Thanks, yes that explains why this change causes the tests to fail (since the same tests also fail in fp32).

Now need to figure out why tests don't work in fp32, it is somehow related to chunked prefill tests only.

@tdoublep
Copy link
Member Author

tdoublep commented Jul 12, 2024

OK, so the failing tests on this branch are now passing in fp32 but only after this unrelated bug is fixed: #6373

@tdoublep
Copy link
Member Author

The offending tests actually also fail on main using float16 on an H100. I think probably the test in question (tests/samplers/test_logprobs.py::test_get_prompt_logprobs) should be changed to use float32 in order to get a reliable comparison against the logprobs from HF.

@tdoublep
Copy link
Member Author

Another thing to try is to pick one of the ranks to apply the bias. The overhead of applying a bias during a GEMM is low enough that we probably don't need to worry load imbalance, and this will let every rank skip a load and store of the activation tensor.

I like that idea - I implemented it and it makes the code even simpler.

FYI this won't lead to more use of the CUTLASS kernels. Currently they're used whenever there is no bias.

Right, but they will be after we merge #6270

The offending tests actually also fail on main using float16 on an H100.

The CI tests should pass after we merge #6409

@tdoublep
Copy link
Member Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 15, 2024
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, although the spec decoding tests look suspicious, so I'll wait until those are resolved before accepting

Have you done any e2e benchmarking of e.g. Granite or Qwen?

@tdoublep
Copy link
Member Author

tdoublep commented Jul 15, 2024

This looks good to me, although the spec decoding tests look suspicious

I agree - will look into it.

Have you done any e2e benchmarking of e.g. Granite or Qwen?

Not yet but tomorrow hopefully, will share once I have some data.

@tdoublep
Copy link
Member Author

@tlrmchlsmth I fixed the CI issue by changing the failing TP=2 test to use float32 precision (which we are already using for the equivalent TP=1 test).

@tdoublep tdoublep changed the title [Model] RowParallelLinear: pass bias to quant_method.apply for tp=1 case [Model] RowParallelLinear: pass bias to quant_method.apply Jul 16, 2024
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@tlrmchlsmth
Copy link
Collaborator

@tdoublep could you share perf results when you have them? Curious about what the impact will be. Thanks!

@tdoublep
Copy link
Member Author

tdoublep commented Jul 19, 2024

@tlrmchlsmth got sidetracked with a few things, but had some time to run benchmarks today. I ran the benchmarking using granite-20b using an FP8 checkpoint I created myself using AutoFP8. The model is deployed on an H100 GPU using all default vLLM settings. I'm using our fmperf tool and I simulate different number of concurrent users of the server (indicated as the point labels on the plots below), each user is sending requests with 128 input tokens, 128 output tokens. I've tested 3 variants:

  1. main branch but with the changes from [Kernel] Use CUTLASS kernels for the FP8 layers with Bias #6270 reverted (e.g., so without any use of CUTLAS for the layers with bias)
  2. main branch unchanged
  3. rplinear_bias branch, which should use CUTLASS for every linear layer with bias.

Here are the results:

image

The results look as expected imo

cc @cyang49

@tlrmchlsmth
Copy link
Collaborator

The results look great!

main branch but with the changes from #6270 reverted (e.g., so without fusion of bias adds)

Just one detail: If you revert #6270, the effect isn't that the bias adds aren't fused. Instead we were falling back to torch._scaled_mm

Thanks for linking fmperf, I'll check that out

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice simplicity with the final version

@mgoin mgoin merged commit a5314e8 into vllm-project:main Jul 19, 2024
73 checks passed
@tdoublep
Copy link
Member Author

Just one detail: If you revert #6270, the effect isn't that the bias adds aren't fused. Instead we were falling back to torch._scaled_mm

You are right - have updated the comment + figure accordingly.

@tdoublep tdoublep deleted the rplinear_bias branch July 19, 2024 13:19
xjpang pushed a commit to xjpang/vllm that referenced this pull request Jul 24, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…ect#6327)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…ect#6327)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants