-
-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many experts #13236
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,20 +10,18 @@ | |
from vllm.model_executor.layers.fused_moe.layer import ( | ||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) | ||
from vllm.model_executor.layers.linear import (LinearMethodBase, | ||
UnquantizedLinearMethod, | ||
set_weight_attrs) | ||
from vllm.model_executor.layers.quantization.base_config import ( | ||
QuantizationConfig) | ||
QuantizationConfig, QuantizeMethodBase) | ||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( | ||
MPLinearLayerConfig, choose_mp_linear_kernel) | ||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config | ||
from vllm.model_executor.layers.quantization.utils import replace_parameter | ||
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( | ||
get_linear_quant_method) | ||
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( | ||
check_marlin_supported, marlin_moe_permute_scales, | ||
marlin_repeat_scales_on_all_ranks, verify_marlin_supported) | ||
from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
UnquantizedEmbeddingMethod) | ||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, | ||
GroupQuantScaleParameter, | ||
PackedColumnParameter, | ||
|
@@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig): | |
(8, True): scalar_types.uint8b128, | ||
} | ||
|
||
def __init__( | ||
self, | ||
weight_bits: int, | ||
group_size: int, | ||
desc_act: bool, | ||
is_sym: bool, | ||
lm_head_quantized: bool, | ||
dynamic: Dict[str, Dict[str, Union[int, bool]]], | ||
) -> None: | ||
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, | ||
is_sym: bool, lm_head_quantized: bool, | ||
dynamic: Dict[str, Dict[str, Union[int, bool]]], | ||
full_config: Dict[str, Any]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is full_config? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh just the config dict, I see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is just the original config saved from |
||
if desc_act and group_size == -1: | ||
# In this case, act_order == True is the same as act_order == False | ||
# (since we have only one group per output channel) | ||
|
@@ -90,6 +83,7 @@ def __init__( | |
self.group_size = group_size | ||
self.desc_act = desc_act | ||
self.lm_head_quantized = lm_head_quantized | ||
self.full_config = full_config | ||
|
||
if (weight_bits, is_sym) not in self.TYPE_MAP: | ||
raise ValueError("Unsupported quantization config: " | ||
|
@@ -132,7 +126,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": | |
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], | ||
default=False) | ||
return cls(weight_bits, group_size, desc_act, is_sym, | ||
lm_head_quantized, dynamic) | ||
lm_head_quantized, dynamic, config) | ||
|
||
@classmethod | ||
def override_quantization_method(cls, hf_quant_cfg, | ||
|
@@ -155,12 +149,15 @@ def override_quantization_method(cls, hf_quant_cfg, | |
" faster inference") | ||
return None | ||
|
||
def get_quant_method( | ||
self, layer: torch.nn.Module, prefix: str | ||
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod", | ||
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]: | ||
def get_quant_method(self, layer: torch.nn.Module, | ||
prefix: str) -> Optional["QuantizeMethodBase"]: | ||
if isinstance(layer, FusedMoE): | ||
return GPTQMarlinMoEMethod(self) | ||
if layer.num_experts > 32: | ||
# For MoEs with many experts the moe_wna16 kernel is faster | ||
return MoeWNA16Config.from_config( | ||
self.full_config).get_quant_method(layer, prefix) | ||
else: | ||
return GPTQMarlinMoEMethod(self) | ||
return get_linear_quant_method(self, layer, prefix, | ||
GPTQMarlinLinearMethod) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah good catch