File tree 2 files changed +7
-4
lines changed
vllm/model_executor/models
2 files changed +7
-4
lines changed Original file line number Diff line number Diff line change @@ -88,12 +88,13 @@ def __init__(self,
88
88
tp_size = tp_size )
89
89
90
90
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
91
- num_tokens , hidden_size = hidden_states .shape
91
+ # NOTE: hidden_states can have either 1D or 2D shape.
92
+ orig_shape = hidden_states .shape
92
93
hidden_states = hidden_states .view (- 1 , self .hidden_size )
93
94
# router_logits: (num_tokens, n_experts)
94
95
router_logits , _ = self .gate (hidden_states )
95
96
final_hidden_states = self .experts (hidden_states , router_logits )
96
- return final_hidden_states .view (num_tokens , hidden_size )
97
+ return final_hidden_states .view (orig_shape )
97
98
98
99
99
100
class MixtralAttention (nn .Module ):
Original file line number Diff line number Diff line change @@ -126,7 +126,9 @@ def __init__(
126
126
bias = False )
127
127
128
128
def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
129
- num_tokens , hidden_dim = hidden_states .shape
129
+ # NOTE: hidden_states can have either 1D or 2D shape.
130
+ orig_shape = hidden_states .shape
131
+ hidden_dim = hidden_states .shape [- 1 ]
130
132
hidden_states = hidden_states .view (- 1 , hidden_dim )
131
133
shared_output = None
132
134
if self .shared_expert is not None :
@@ -145,7 +147,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
147
final_hidden_states = tensor_model_parallel_all_reduce (
146
148
final_hidden_states )
147
149
148
- return final_hidden_states .view (num_tokens , hidden_dim )
150
+ return final_hidden_states .view (orig_shape )
149
151
150
152
151
153
class Qwen2MoeAttention (nn .Module ):
You can’t perform that action at this time.
0 commit comments