Skip to content

Commit a65ab89

Browse files
WoosukKwonadityagoel14
authored andcommitted
[Bugfix] Support 2D input shape in MoE layer (vllm-project#6287)
(cherry picked from commit e72ae80)
1 parent 8473d70 commit a65ab89

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

vllm/model_executor/models/mixtral.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,13 @@ def __init__(self,
8888
tp_size=tp_size)
8989

9090
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
91-
num_tokens, hidden_size = hidden_states.shape
91+
# NOTE: hidden_states can have either 1D or 2D shape.
92+
orig_shape = hidden_states.shape
9293
hidden_states = hidden_states.view(-1, self.hidden_size)
9394
# router_logits: (num_tokens, n_experts)
9495
router_logits, _ = self.gate(hidden_states)
9596
final_hidden_states = self.experts(hidden_states, router_logits)
96-
return final_hidden_states.view(num_tokens, hidden_size)
97+
return final_hidden_states.view(orig_shape)
9798

9899

99100
class MixtralAttention(nn.Module):

vllm/model_executor/models/qwen2_moe.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def __init__(
126126
bias=False)
127127

128128
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
129-
num_tokens, hidden_dim = hidden_states.shape
129+
# NOTE: hidden_states can have either 1D or 2D shape.
130+
orig_shape = hidden_states.shape
131+
hidden_dim = hidden_states.shape[-1]
130132
hidden_states = hidden_states.view(-1, hidden_dim)
131133
shared_output = None
132134
if self.shared_expert is not None:
@@ -145,7 +147,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145147
final_hidden_states = tensor_model_parallel_all_reduce(
146148
final_hidden_states)
147149

148-
return final_hidden_states.view(num_tokens, hidden_dim)
150+
return final_hidden_states.view(orig_shape)
149151

150152

151153
class Qwen2MoeAttention(nn.Module):

0 commit comments

Comments
 (0)