|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 |
| - |
3 |
| -from typing import List |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Dict, List, Optional |
4 | 4 |
|
5 | 5 | import pytest
|
| 6 | +from packaging.version import Version |
| 7 | +from transformers import __version__ as TRANSFORMERS_VERSION |
6 | 8 |
|
7 | 9 | import vllm
|
8 | 10 | from vllm.assets.image import ImageAsset
|
9 | 11 | from vllm.lora.request import LoRARequest
|
10 | 12 | from vllm.platforms import current_platform
|
11 | 13 |
|
12 |
| -MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct" |
13 | 14 |
|
14 |
| -PROMPT_TEMPLATE = ( |
15 |
| - "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" |
16 |
| - "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" |
17 |
| - "What is in the image?<|im_end|>\n" |
18 |
| - "<|im_start|>assistant\n") |
| 15 | +@dataclass |
| 16 | +class TestConfig: |
| 17 | + model_path: str |
| 18 | + lora_path: str |
| 19 | + max_num_seqs: int = 2 |
| 20 | + max_loras: int = 2 |
| 21 | + max_lora_rank: int = 16 |
| 22 | + max_model_len: int = 4096 |
| 23 | + mm_processor_kwargs: Optional[Dict[str, int]] = None |
| 24 | + |
| 25 | + def __post_init__(self): |
| 26 | + if self.mm_processor_kwargs is None: |
| 27 | + self.mm_processor_kwargs = { |
| 28 | + "min_pixels": 28 * 28, |
| 29 | + "max_pixels": 1280 * 28 * 28, |
| 30 | + } |
| 31 | + |
| 32 | + |
| 33 | +class Qwen2VLTester: |
| 34 | + """Test helper for Qwen2 VL models with LoRA""" |
| 35 | + |
| 36 | + PROMPT_TEMPLATE = ( |
| 37 | + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>" |
| 38 | + "\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" |
| 39 | + "What is in the image?<|im_end|>\n" |
| 40 | + "<|im_start|>assistant\n") |
| 41 | + |
| 42 | + def __init__(self, config: TestConfig): |
| 43 | + self.config = config |
| 44 | + self.llm = self._initialize_llm() |
| 45 | + |
| 46 | + def _initialize_llm(self) -> vllm.LLM: |
| 47 | + """Initialize the LLM with given configuration""" |
| 48 | + return vllm.LLM( |
| 49 | + model=self.config.model_path, |
| 50 | + max_num_seqs=self.config.max_num_seqs, |
| 51 | + enable_lora=True, |
| 52 | + max_loras=self.config.max_loras, |
| 53 | + max_lora_rank=self.config.max_lora_rank, |
| 54 | + trust_remote_code=True, |
| 55 | + mm_processor_kwargs=self.config.mm_processor_kwargs, |
| 56 | + max_model_len=self.config.max_model_len, |
| 57 | + ) |
| 58 | + |
| 59 | + def run_test(self, |
| 60 | + images: List[ImageAsset], |
| 61 | + expected_outputs: List[str], |
| 62 | + lora_id: Optional[int] = None, |
| 63 | + temperature: float = 0, |
| 64 | + max_tokens: int = 5) -> List[str]: |
| 65 | + |
| 66 | + sampling_params = vllm.SamplingParams( |
| 67 | + temperature=temperature, |
| 68 | + max_tokens=max_tokens, |
| 69 | + ) |
| 70 | + inputs = [{ |
| 71 | + "prompt": self.PROMPT_TEMPLATE, |
| 72 | + "multi_modal_data": { |
| 73 | + "image": asset.pil_image |
| 74 | + }, |
| 75 | + } for asset in images] |
| 76 | + |
| 77 | + lora_request = LoRARequest(str(lora_id), lora_id, |
| 78 | + self.config.lora_path) |
| 79 | + outputs = self.llm.generate(inputs, |
| 80 | + sampling_params, |
| 81 | + lora_request=lora_request) |
| 82 | + generated_texts = [ |
| 83 | + output.outputs[0].text.strip() for output in outputs |
| 84 | + ] |
19 | 85 |
|
20 |
| -IMAGE_ASSETS = [ |
| 86 | + # Validate outputs |
| 87 | + for generated, expected in zip(generated_texts, expected_outputs): |
| 88 | + assert expected.startswith( |
| 89 | + generated), f"Generated text {generated} doesn't " |
| 90 | + f"match expected pattern {expected}" |
| 91 | + |
| 92 | + return generated_texts |
| 93 | + |
| 94 | + |
| 95 | +TEST_IMAGES = [ |
21 | 96 | ImageAsset("stop_sign"),
|
22 | 97 | ImageAsset("cherry_blossom"),
|
23 | 98 | ]
|
24 | 99 |
|
25 |
| -# After fine-tuning with LoRA, all generated content should start begin `A`. |
26 |
| -EXPECTED_OUTPUT = [ |
| 100 | +EXPECTED_OUTPUTS = [ |
27 | 101 | "A red stop sign stands prominently in the foreground, with a traditional Chinese gate and a black SUV in the background, illustrating a blend of modern and cultural elements.", # noqa: E501
|
28 | 102 | "A majestic skyscraper stands tall, partially obscured by a vibrant canopy of cherry blossoms, against a clear blue sky.", # noqa: E501
|
29 | 103 | ]
|
30 | 104 |
|
31 |
| - |
32 |
| -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: |
33 |
| - sampling_params = vllm.SamplingParams( |
34 |
| - temperature=0, |
35 |
| - max_tokens=5, |
36 |
| - ) |
37 |
| - |
38 |
| - inputs = [{ |
39 |
| - "prompt": PROMPT_TEMPLATE, |
40 |
| - "multi_modal_data": { |
41 |
| - "image": asset.pil_image |
42 |
| - }, |
43 |
| - } for asset in IMAGE_ASSETS] |
44 |
| - |
45 |
| - outputs = llm.generate( |
46 |
| - inputs, |
47 |
| - sampling_params, |
48 |
| - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) |
49 |
| - if lora_id else None, |
50 |
| - ) |
51 |
| - # Print the outputs. |
52 |
| - generated_texts: List[str] = [] |
53 |
| - for output in outputs: |
54 |
| - generated_text = output.outputs[0].text.strip() |
55 |
| - generated_texts.append(generated_text) |
56 |
| - print(f"Generated text: {generated_text!r}") |
57 |
| - return generated_texts |
| 105 | +QWEN2VL_MODEL_PATH = "Qwen/Qwen2-VL-2B-Instruct" |
| 106 | +QWEN25VL_MODEL_PATH = "Qwen/Qwen2.5-VL-3B-Instruct" |
58 | 107 |
|
59 | 108 |
|
60 | 109 | @pytest.mark.xfail(
|
61 | 110 | current_platform.is_rocm(),
|
62 | 111 | reason="Qwen2-VL dependency xformers incompatible with ROCm")
|
63 | 112 | def test_qwen2vl_lora(qwen2vl_lora_files):
|
64 |
| - llm = vllm.LLM( |
65 |
| - MODEL_PATH, |
66 |
| - max_num_seqs=2, |
67 |
| - enable_lora=True, |
68 |
| - max_loras=2, |
69 |
| - max_lora_rank=16, |
70 |
| - trust_remote_code=True, |
71 |
| - mm_processor_kwargs={ |
72 |
| - "min_pixels": 28 * 28, |
73 |
| - "max_pixels": 1280 * 28 * 28, |
74 |
| - }, |
75 |
| - max_model_len=4096, |
76 |
| - ) |
77 |
| - output1 = do_sample(llm, qwen2vl_lora_files, lora_id=1) |
78 |
| - for i in range(len(EXPECTED_OUTPUT)): |
79 |
| - assert EXPECTED_OUTPUT[i].startswith(output1[i]) |
80 |
| - |
81 |
| - output2 = do_sample(llm, qwen2vl_lora_files, lora_id=2) |
82 |
| - for i in range(len(EXPECTED_OUTPUT)): |
83 |
| - assert EXPECTED_OUTPUT[i].startswith(output2[i]) |
| 113 | + """Test Qwen 2.0 VL model with LoRA""" |
| 114 | + config = TestConfig(model_path=QWEN2VL_MODEL_PATH, |
| 115 | + lora_path=qwen2vl_lora_files) |
| 116 | + tester = Qwen2VLTester(config) |
| 117 | + |
| 118 | + # Test with different LoRA IDs |
| 119 | + for lora_id in [1, 2]: |
| 120 | + tester.run_test(TEST_IMAGES, |
| 121 | + expected_outputs=EXPECTED_OUTPUTS, |
| 122 | + lora_id=lora_id) |
| 123 | + |
| 124 | + |
| 125 | +@pytest.mark.xfail( |
| 126 | + current_platform.is_rocm(), |
| 127 | + reason="Qwen2.5-VL dependency xformers incompatible with ROCm", |
| 128 | +) |
| 129 | +@pytest.mark.skipif( |
| 130 | + Version(TRANSFORMERS_VERSION) < Version("4.49.0"), |
| 131 | + reason="Qwen2.5-VL require transformers version no lower than 4.49.0", |
| 132 | +) |
| 133 | +def test_qwen25vl_lora(qwen25vl_lora_files): |
| 134 | + """Test Qwen 2.5 VL model with LoRA""" |
| 135 | + config = TestConfig(model_path=QWEN25VL_MODEL_PATH, |
| 136 | + lora_path=qwen25vl_lora_files) |
| 137 | + tester = Qwen2VLTester(config) |
| 138 | + |
| 139 | + # Test with different LoRA IDs |
| 140 | + for lora_id in [1, 2]: |
| 141 | + tester.run_test(TEST_IMAGES, |
| 142 | + expected_outputs=EXPECTED_OUTPUTS, |
| 143 | + lora_id=lora_id) |
0 commit comments