Skip to content

[Misc] GPTQ Activation Ordering #8135

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 11 commits into from
Sep 9, 2024

Conversation

kylesayrs
Copy link
Contributor

@kylesayrs kylesayrs commented Sep 3, 2024

Activation Ordering

activation-ordering-diagram

Changes

  • Update QuantizationArgs to support actorder argument
  • Add weight_g_idx parameter loader which defaults to all -1s
    • If weight_g_idx is loaded with valid values, then the parameter is passed to the kernel
    • If weight_g_idx is not loaded, then no column reordering is performed

Testing

Inference Script

infer_actorder.py
from vllm import LLM
llm = LLM("nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group")
llm.generate("The future of AI is")

Actorder=Group Evaluation

Meta-Llama-3.1-8B-Instruct-quantized.w4a16 Ours Stderr
78.57 78.87 ± 00.41
50.46 49.24 ± 01.52
76.48 76.00 ± 01.20

These results have discrepancies not related to activation ordering, more precise results will be posted at a later date.

Actorder=Weight Evaluation

Accuracy

Accuracy evaluations were performed using compressed Meta-Llama-3-8B-Instruct as a base model. For reference, Meta-Llama-3-8B-Instruct-quantized.w4a16 was compressed using AutoGPTQ desc_act=True and achieved 72.25% on GSM-8K (5-shot, strict-match).

The following models were quantized using llm-compressor with the same quantization configuration but 156 calibration samples as opposed to 256.

Group Activation Ordering

vllm (pretrained=/home/ksayers/llm-compressor/llama3_actorder_group,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto                           
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|                                                                                                            
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|                                                                                                            
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.7240|?  |0.0123|                                                                                                            
|     |       |strict-match    |     5|exact_match|?  |0.7263|?  |0.0123|

Weight Activation Ordering

vllm (pretrained=/home/ksayers/llm-compressor/llama3_actorder_weight,add_bos_token=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto                          
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|                                                                                                            
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|                                                                                                            
|gsm8k|      3|flexible-extract|     5|exact_match|?  |0.7331|?  |0.0122|                                                                                                            
|     |       |strict-match    |     5|exact_match|?  |0.7354|?  |0.0122|

Latency

Group Activation Ordering

Avg latency: 2.0740114107728003 seconds
10% percentile latency: 2.0585451871156693 seconds
25% percentile latency: 2.0614405693486333 seconds
50% percentile latency: 2.0715408828109503 seconds
75% percentile latency: 2.0796915404498577 seconds
90% percentile latency: 2.0825428303331135 seconds
99% percentile latency: 2.153371973782778 seconds

Weight Activation Ordering

Avg latency: 2.0243701239426932 seconds
10% percentile latency: 2.0120500519871714 seconds
25% percentile latency: 2.01621649088338 seconds
50% percentile latency: 2.020678504370153 seconds
75% percentile latency: 2.026012707967311 seconds
90% percentile latency: 2.0294162426143885 seconds
99% percentile latency: 2.1007918303459885 seconds

Copy link

github-actions bot commented Sep 3, 2024

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@kylesayrs kylesayrs changed the title GPTQ Activation Ordering [Misc] GPTQ Activation Ordering Sep 3, 2024
@kylesayrs kylesayrs marked this pull request as draft September 5, 2024 15:47
Copy link
Contributor

@dsikka dsikka left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Two open questions + can you confirm the act order models worked fine for 2 and 4 gpus?

@kylesayrs kylesayrs marked this pull request as ready for review September 6, 2024 13:56
@robertgshaw2-redhat
Copy link
Collaborator

/ready

@robertgshaw2-redhat robertgshaw2-redhat added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 6, 2024
@kylesayrs
Copy link
Contributor Author

Confirmed that group activation ordering works with tp=2,4

@LucasWilkinson
Copy link
Collaborator

I might be mistaken, but actorder="weight" feels like purely a change in how (the order) llm-compressor quantizes weights, in which case why can't vLLM be oblivious to this fact? It appears from vLLM's perspective actorder="weight" and actorder=None/False is identical, in the sense that it's not reordering activations in either case.

So naming feels a bit confusing sinceactorder="weight" doesn't sound like "activation rendering" at all, maybe we can simplify this by adding a quantorder key in the checkpoint (for record keeping) and just set actorder=None/False that way vLLM can just ignore quantorder completely and actorder is just a boolean indicating if activations should be reordered or not (we could leave it as a str/enum for future proofing if we want but would be a boolean at this point in time)?

I think we should try to create better separation concerns, i.e. better separation on how the model was quantized and how the model needs to be run for inference otherwise its just more things/complexity kernel authors need familiarize themselves with only to realize (in this case) it has no impact.

@kylesayrs
Copy link
Contributor Author

I agree that there are benefits to separating "actorder" and "quantorder" as two orthogonal arguments, that being that vllm would only have to check for the "actorder=True" case.

I think one large downside to separating "actorder" from "quant order" is that we're essentially redefining "actorder" as "activation ordering groups". This means that an llm-compressor user might turn on "actorder", expecting it to also do quantization ordering like it does in GPTQ, then get a model that has additional latency but no performance gain. This is confusing for users, not only because they have to redefine the concept of actorder, but because there's a potential pitfall case now added.

There probably exist better names to better explain the ideas of "non-sequential grouping" and "quantization ordering", but folding these cases into the "actorder" argument allows us to leverage users' existing understanding of activation ordering.

@kylesayrs
Copy link
Contributor Author

In a followup PR, we could keep actorder, but add an additional argument continguous-groups. In the pydantic model we can validate that the continugous-groups argument is True iff actorder="group". This way a compression user can use the nice actorder arugment, but vllm only needs to pay attention to contiguous-groups

@LucasWilkinson
Copy link
Collaborator

I think the goal here would be make it so vLLM doesnt have to maintain an enum of actorders so that users can experiment with orderings in llm-compressor without having to submit PRs to vLLM

In offline conversations with @kylesayrs , its seems like a good alternative solution would be to have a non-contiguous-groups flag which when present and true indicates if a g_idx is present (since g_idx doesnt actually contain the activation order instead but instead just a mapping of in-channels to groups)

This way vLLM can completely ignore actorder (opening room for experimentation with things like the weight order (and potential variations of that), or folding activation permutations into the MLP layers [1, 2] without having to update vLLM. In this case actorder would only exist in the checkpoint as "record keeping" so people can remember what was used to create the checkpoint.

ideally though we wouldn't have compressed-tensor checkpoints in the wild with actorder="group" but no non-contiguous-groups flag, since then we'd have to have extra code/logic in vLLM for backwards compatibility (which would be a shame for just a couple models) @mgoin @kylesayrs do we know if there's any models like this in the wild?

@kylesayrs
Copy link
Contributor Author

@LucasWilkinson A decision has been made that we're not going to support the more generalized checkpoint config in this PR, and will instead continue with using list which specifies which activations orderings come with non-contiguous groupings.

I would characterizing the problem as being a difference in requirements between recipe configs and checkpoint configs, as well as a problem of coupling between llm-compressor and vllm-CT. While it is a problem that will likely come up again, it's considered out of scope for this PR.

There are some options to make the checkpoint config more extensible/ decoupled from llm-compressor in the future. The two that I have proposed are (1) separating the recipe/compression config from the checkpoint config or (2) adding an ActivationOrdering.NON_CONTIGUOUS option to the enum, which would help to generalize the config and allow compressors to be more vague while still giving enough information for correct inference.

Option (1) would likely involve some legacy support for backwards compatibility. Option (2) would be backwards compatible. Option (3) is to maintain a coupling between llm-compressor and vllm-CT. Hopefully these options are palatable enough to defer for a separate PR.

@kylesayrs kylesayrs requested review from dsikka and mgoin September 9, 2024 17:17
@LucasWilkinson
Copy link
Collaborator

LucasWilkinson commented Sep 9, 2024

Cool, ya was just wanting making my voice heard that I think we should be striving separation of concerns in our software design, especially across repos to avoid versioning headaches in the future (and having to coordinate PRs across many repos).

Im not a huge fan option 2 as it is just a continuation of convoluting the difference between activation ordering and group assignment (i.e. mixing something that is fairly GPTQ specific with something that could be viewed as a more general group quantization thing).

Option 1 seems like what we should be striving for as it will give us more flexibly going forward even outside of this specific actorder case. Im not sure what you mean by "checkpoint config" but I assume you mean the fields validated by compressed tensors. I would imagine for option 1 you would still want to save recipe information for accuracy debugging and general recorded keeping, but stored in an separate structure or using keys that vLLM and compressed tensors can just ignore / pass-through.

@@ -232,7 +232,8 @@ def _get_scheme_from_parts(
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just add the condition here?
actorder=weight_quant.actorder == ActivationOrdering.GROUP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but it feels more logical to me to keep argument processing within the CompressedTensorsWNA16.__init__ function. This separates responsibilities and makes clear that the job of _get_scheme_from_parts is to decide which compression scheme applies, not to process the arguments once the scheme is decided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd also have to rename the actorder argument of CompressedTensorsWNA16.__init__, otherwise it would be a misnomer

@kylesayrs
Copy link
Contributor Author

@LucasWilkinson A recipe config means fields that are important to compression algorithms, a checkpoint config means fields relevant to vllm and serving. I agree that checkpoint configs should have some info about how they were compressed, but those fields wouldn't be a requirement.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the careful work and discussion! This is good to land for now

@mgoin mgoin merged commit c7cb5c3 into vllm-project:main Sep 9, 2024
50 checks passed
@kylesayrs kylesayrs deleted the kylesayrs/activation-ordering branch September 12, 2024 03:36
dtrifiro pushed a commit to opendatahub-io/vllm that referenced this pull request Sep 12, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Signed-off-by: Alvant <alvasian@yandex.ru>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants