Skip to content

[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many experts #13236

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch



@pytest.mark.skipif(
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
Expand Down Expand Up @@ -134,7 +135,12 @@ def get_quant_method(self, layer: torch.nn.Module,
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return None

@classmethod
Expand Down
35 changes: 16 additions & 19 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
Expand All @@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}

def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is full_config?
Can we add a comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh just the config dict, I see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is just the original config saved from from_config so we can forward to MoeWNA16Config

if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
Expand Down Expand Up @@ -90,6 +83,7 @@ def __init__(
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config

if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
Expand Down Expand Up @@ -132,7 +126,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic)
lm_head_quantized, dynamic, config)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -155,12 +149,15 @@ def override_quantization_method(cls, hf_quant_cfg,
" faster inference")
return None

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)

Expand Down
20 changes: 15 additions & 5 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -37,6 +32,12 @@ def __init__(self, linear_quant_method: str, weight_bits: int,
self.linear_quant_method = linear_quant_method
self.full_config = full_config
self.use_marlin = False
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
full_config)
Expand Down Expand Up @@ -115,6 +116,8 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
awq_min_capability = AWQConfig.get_min_capability()

gptq_compatible = quant_method == "gptq" and \
Expand All @@ -129,6 +132,13 @@ def get_quant_method(self, layer: torch.nn.Module,
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinConfig.from_config(
Expand Down