Skip to content

Commit cdf21ec

Browse files
jeejeeleelulmer
authored andcommitted
[Misc] Qwen2.5 VL support LoRA (vllm-project#13261)
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent d3f0d9c commit cdf21ec

File tree

4 files changed

+130
-63
lines changed

4 files changed

+130
-63
lines changed

docs/source/models/supported_models.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ See [this page](#generative-models) for more information on how to use generativ
854854
* Qwen2.5-VL
855855
* T + I<sup>E+</sup> + V<sup>E+</sup>
856856
* `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
857-
*
857+
* ✅︎
858858
* ✅︎
859859
* ✅︎
860860
- * `UltravoxModel`

tests/lora/conftest.py

+5
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,11 @@ def qwen2vl_lora_files():
237237
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
238238

239239

240+
@pytest.fixture(scope="session")
241+
def qwen25vl_lora_files():
242+
return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon")
243+
244+
240245
@pytest.fixture(scope="session")
241246
def tinyllama_lora_files():
242247
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")

tests/lora/test_qwen2vl.py

+118-58
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,143 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
3-
from typing import List
2+
from dataclasses import dataclass
3+
from typing import Dict, List, Optional
44

55
import pytest
6+
from packaging.version import Version
7+
from transformers import __version__ as TRANSFORMERS_VERSION
68

79
import vllm
810
from vllm.assets.image import ImageAsset
911
from vllm.lora.request import LoRARequest
1012
from vllm.platforms import current_platform
1113

12-
MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
1314

14-
PROMPT_TEMPLATE = (
15-
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
16-
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
17-
"What is in the image?<|im_end|>\n"
18-
"<|im_start|>assistant\n")
15+
@dataclass
16+
class TestConfig:
17+
model_path: str
18+
lora_path: str
19+
max_num_seqs: int = 2
20+
max_loras: int = 2
21+
max_lora_rank: int = 16
22+
max_model_len: int = 4096
23+
mm_processor_kwargs: Optional[Dict[str, int]] = None
24+
25+
def __post_init__(self):
26+
if self.mm_processor_kwargs is None:
27+
self.mm_processor_kwargs = {
28+
"min_pixels": 28 * 28,
29+
"max_pixels": 1280 * 28 * 28,
30+
}
31+
32+
33+
class Qwen2VLTester:
34+
"""Test helper for Qwen2 VL models with LoRA"""
35+
36+
PROMPT_TEMPLATE = (
37+
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>"
38+
"\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
39+
"What is in the image?<|im_end|>\n"
40+
"<|im_start|>assistant\n")
41+
42+
def __init__(self, config: TestConfig):
43+
self.config = config
44+
self.llm = self._initialize_llm()
45+
46+
def _initialize_llm(self) -> vllm.LLM:
47+
"""Initialize the LLM with given configuration"""
48+
return vllm.LLM(
49+
model=self.config.model_path,
50+
max_num_seqs=self.config.max_num_seqs,
51+
enable_lora=True,
52+
max_loras=self.config.max_loras,
53+
max_lora_rank=self.config.max_lora_rank,
54+
trust_remote_code=True,
55+
mm_processor_kwargs=self.config.mm_processor_kwargs,
56+
max_model_len=self.config.max_model_len,
57+
)
58+
59+
def run_test(self,
60+
images: List[ImageAsset],
61+
expected_outputs: List[str],
62+
lora_id: Optional[int] = None,
63+
temperature: float = 0,
64+
max_tokens: int = 5) -> List[str]:
65+
66+
sampling_params = vllm.SamplingParams(
67+
temperature=temperature,
68+
max_tokens=max_tokens,
69+
)
70+
inputs = [{
71+
"prompt": self.PROMPT_TEMPLATE,
72+
"multi_modal_data": {
73+
"image": asset.pil_image
74+
},
75+
} for asset in images]
76+
77+
lora_request = LoRARequest(str(lora_id), lora_id,
78+
self.config.lora_path)
79+
outputs = self.llm.generate(inputs,
80+
sampling_params,
81+
lora_request=lora_request)
82+
generated_texts = [
83+
output.outputs[0].text.strip() for output in outputs
84+
]
1985

20-
IMAGE_ASSETS = [
86+
# Validate outputs
87+
for generated, expected in zip(generated_texts, expected_outputs):
88+
assert expected.startswith(
89+
generated), f"Generated text {generated} doesn't "
90+
f"match expected pattern {expected}"
91+
92+
return generated_texts
93+
94+
95+
TEST_IMAGES = [
2196
ImageAsset("stop_sign"),
2297
ImageAsset("cherry_blossom"),
2398
]
2499

25-
# After fine-tuning with LoRA, all generated content should start begin `A`.
26-
EXPECTED_OUTPUT = [
100+
EXPECTED_OUTPUTS = [
27101
"A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
28102
"A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
29103
]
30104

31-
32-
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
33-
sampling_params = vllm.SamplingParams(
34-
temperature=0,
35-
max_tokens=5,
36-
)
37-
38-
inputs = [{
39-
"prompt": PROMPT_TEMPLATE,
40-
"multi_modal_data": {
41-
"image": asset.pil_image
42-
},
43-
} for asset in IMAGE_ASSETS]
44-
45-
outputs = llm.generate(
46-
inputs,
47-
sampling_params,
48-
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
49-
if lora_id else None,
50-
)
51-
# Print the outputs.
52-
generated_texts: List[str] = []
53-
for output in outputs:
54-
generated_text = output.outputs[0].text.strip()
55-
generated_texts.append(generated_text)
56-
print(f"Generated text: {generated_text!r}")
57-
return generated_texts
105+
QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct"
106+
QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct"
58107

59108

60109
@pytest.mark.xfail(
61110
current_platform.is_rocm(),
62111
reason="Qwen2-VL dependency xformers incompatible with ROCm")
63112
def test_qwen2vl_lora(qwen2vl_lora_files):
64-
llm = vllm.LLM(
65-
MODEL_PATH,
66-
max_num_seqs=2,
67-
enable_lora=True,
68-
max_loras=2,
69-
max_lora_rank=16,
70-
trust_remote_code=True,
71-
mm_processor_kwargs={
72-
"min_pixels": 28 * 28,
73-
"max_pixels": 1280 * 28 * 28,
74-
},
75-
max_model_len=4096,
76-
)
77-
output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1)
78-
for i in range(len(EXPECTED_OUTPUT)):
79-
assert EXPECTED_OUTPUT[i].startswith(output1[i])
80-
81-
output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2)
82-
for i in range(len(EXPECTED_OUTPUT)):
83-
assert EXPECTED_OUTPUT[i].startswith(output2[i])
113+
"""Test Qwen 2.0 VL model with LoRA"""
114+
config = TestConfig(model_path=QWEN2VL_MODEL_PATH,
115+
lora_path=qwen2vl_lora_files)
116+
tester = Qwen2VLTester(config)
117+
118+
# Test with different LoRA IDs
119+
for lora_id in [1, 2]:
120+
tester.run_test(TEST_IMAGES,
121+
expected_outputs=EXPECTED_OUTPUTS,
122+
lora_id=lora_id)
123+
124+
125+
@pytest.mark.xfail(
126+
current_platform.is_rocm(),
127+
reason="Qwen2.5-VL dependency xformers incompatible with ROCm",
128+
)
129+
@pytest.mark.skipif(
130+
Version(TRANSFORMERS_VERSION) < Version("4.49.0"),
131+
reason="Qwen2.5-VL require transformers version no lower than 4.49.0",
132+
)
133+
def test_qwen25vl_lora(qwen25vl_lora_files):
134+
"""Test Qwen 2.5 VL model with LoRA"""
135+
config = TestConfig(model_path=QWEN25VL_MODEL_PATH,
136+
lora_path=qwen25vl_lora_files)
137+
tester = Qwen2VLTester(config)
138+
139+
# Test with different LoRA IDs
140+
for lora_id in [1, 2]:
141+
tester.run_test(TEST_IMAGES,
142+
expected_outputs=EXPECTED_OUTPUTS,
143+
lora_id=lora_id)

vllm/model_executor/models/qwen2_5_vl.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -734,23 +734,25 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
734734
"up_proj",
735735
],
736736
}
737-
# LoRA specific attributes, TODO: double check
737+
# LoRA specific attributes
738738
supported_lora_modules = [
739+
# language model
739740
"qkv_proj",
740741
"o_proj",
741742
"gate_up_proj",
742-
"down_proj",
743-
"gate_proj"
744-
"up_proj",
743+
"down_proj", # Same name with vision encoder
745744
# vision tower
746745
"qkv",
746+
"gate_proj",
747+
"up_proj",
747748
"attn.proj", # Distinguish patch_embed.proj
748749
"fc1",
749750
"fc2",
750751
# projector
751752
"mlp.0",
752753
"mlp.2"
753754
]
755+
754756
embedding_modules = {}
755757
embedding_padding_modules = []
756758

0 commit comments

Comments
 (0)