Skip to content

[Model] Add LoRA support for TransformersModel #13770

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ steps:
source_file_dependencies:
- vllm/lora
- tests/lora
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py --ignore=lora/test_transfomers_model.py
parallelism: 4

- label: PyTorch Fullgraph Smoke Test # 9min
Expand Down Expand Up @@ -589,6 +589,7 @@ steps:
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_minicpmv_tp.py
- pytest -v -s -x lora/test_transfomers_model.py


- label: Weight Loading Multiple GPU Test # 33min
Expand Down
15 changes: 1 addition & 14 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,7 @@ Transformers fallback has supported most of available quantization in vLLM (exce

##### LoRA

LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!

Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.

Hints as to how this would look like:

```python
class TransformersModel(nn.Module, SupportsLoRA):
def __init__(*):
...
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
```

Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue.

##### Remote code

Expand Down
5 changes: 5 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def baichuan_regex_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")


@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")


@pytest.fixture(scope="session")
def minicpmv_lora_files():
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
Expand Down
120 changes: 120 additions & 0 deletions tests/lora/test_transfomers_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List

import pytest

import vllm
from tests.utils import fork_new_process_for_each_test
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "ArthurZ/ilama-3.2-1B"

PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501

EXPECTED_LORA_OUTPUT = [
"SELECT count(*) FROM singer",
"SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501
"SELECT DISTINCT Country FROM singer WHERE Age > 20",
]


def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
prompts = [
PROMPT_TEMPLATE.format(query="How many singers do we have?"),
PROMPT_TEMPLATE.format(
query=
"What is the average, minimum, and maximum age of all singers from France?" # noqa: E501
),
PROMPT_TEMPLATE.format(
query=
"What are all distinct countries where singers above age 20 are from?" # noqa: E501
),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
if lora_id else None)
# Print the outputs.
generated_texts: List[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


@pytest.fixture(autouse=True)
def v1(run_with_both_engines_lora):
# Simple autouse wrapper to run both engines for each test
# This can be promoted up to conftest.py to run for every
# test in a package
pass


@pytest.mark.skip_v1
@fork_new_process_for_each_test
def test_ilama_lora(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=16,
tensor_parallel_size=1,
trust_remote_code=True,
enable_chunked_prefill=True)

output1 = do_sample(llm, ilama_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_ilama_lora_tp4(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=False,
enable_chunked_prefill=True)

output1 = do_sample(llm, ilama_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]


@pytest.mark.skip_v1
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_ilama_lora_tp4_fully_sharded_loras(ilama_lora_files):
llm = vllm.LLM(MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=16,
tensor_parallel_size=4,
trust_remote_code=True,
fully_sharded_loras=True,
enable_chunked_prefill=True)
output1 = do_sample(llm, ilama_lora_files, lora_id=1)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output1[i] == EXPECTED_LORA_OUTPUT[i]
output2 = do_sample(llm, ilama_lora_files, lora_id=2)
for i in range(len(EXPECTED_LORA_OUTPUT)):
assert output2[i] == EXPECTED_LORA_OUTPUT[i]
25 changes: 18 additions & 7 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,11 @@ def apply(self,
self.output_slices)
return output

@classmethod
def get_source_layer(cls, source_layer: nn.Module) -> type:
# Check parent_cls in case source_layer is a HFCompatibleLinear.
return getattr(source_layer, "parent_cls", type(source_layer))


class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

Expand Down Expand Up @@ -443,7 +448,8 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ReplicatedLinear
source_layer = cls.get_source_layer(source_layer)
return source_layer is ReplicatedLinear


class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
Expand Down Expand Up @@ -539,8 +545,9 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
source_layer = cls.get_source_layer(source_layer)
return source_layer is ColumnParallelLinear or (
source_layer is MergedColumnParallelLinear
and len(packed_modules_list) == 1)


Expand Down Expand Up @@ -682,7 +689,8 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is MergedColumnParallelLinear
source_layer = cls.get_source_layer(source_layer)
return (source_layer is MergedColumnParallelLinear
and len(packed_modules_list) == 2)


Expand Down Expand Up @@ -750,7 +758,8 @@ def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool:
return type(source_layer) is QKVParallelLinear and len(
source_layer = cls.get_source_layer(source_layer)
return source_layer is QKVParallelLinear and len(
packed_modules_list) == 1


Expand Down Expand Up @@ -811,7 +820,8 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return (type(source_layer) is QKVParallelLinear
source_layer = cls.get_source_layer(source_layer)
return (source_layer is QKVParallelLinear
and len(packed_modules_list) == 3)


Expand Down Expand Up @@ -896,7 +906,8 @@ def can_replace_layer(
packed_modules_list: List,
model_config: Optional[PretrainedConfig],
) -> bool:
return type(source_layer) is RowParallelLinear
source_layer = cls.get_source_layer(source_layer)
return source_layer is RowParallelLinear


class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
Expand Down
25 changes: 14 additions & 11 deletions vllm/lora/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,20 @@ def from_layer(layer: nn.Module,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

# The Case for HFCompatibleLinear
if (hasattr(layer, "get_lora_class")
and layer.__class__.__name__ == "HFCompatibleLinear"):
lora_cls = layer.get_lora_class(lora_config.fully_sharded_loras)
ret = lora_cls(layer)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
instance_layer = lora_cls(layer)
if layer.__class__.__name__ == "HFCompatibleLinear":
# HACK: Make the forward method compatible with the original
# forward method of the instance_layer.
original_forward = instance_layer.forward

def new_forward(input):
input = input.squeeze(0)
return original_forward(input)[0] # noqa: B023

instance_layer.forward = new_forward
instance_layer.create_lora_weights(max_loras, lora_config,
model_config)
return instance_layer
return layer


Expand Down
43 changes: 6 additions & 37 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA, RowParallelLinearWithShardedLoRA)
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
Expand All @@ -43,7 +38,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsQuant
from .interfaces import SupportsLoRA, SupportsQuant
from .utils import maybe_prefix

logger = init_logger(__name__)
Expand Down Expand Up @@ -102,44 +97,18 @@ def replace_linear_class(
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)

lora_linear_cls = {
ColumnParallelLinear: {
True: ColumnParallelLinearWithShardedLoRA, # fully sharded
False: ColumnParallelLinearWithLoRA # not fully sharded
},
RowParallelLinear: {
True: RowParallelLinearWithShardedLoRA,
False: RowParallelLinearWithLoRA
},
# ReplicatedLinear doesn't support fully sharded LoRA yet,
# so we use the same class for both cases.
ReplicatedLinear: {
True: ReplicatedLinearWithLoRA,
False: ReplicatedLinearWithLoRA
}
}

class HFCompatibleLinear(vllm_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
"""
# NOTE: The LoRA layer needs to use `parent_cls`.
@property
def parent_cls(self) -> type:
return vllm_linear_cls

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]

@classmethod
def get_lora_class(cls, fully_sharded: bool = False):
"""
Get the LoRA class corresponding to the current transformer
linear class.

Args:
fully_sharded (bool): If True, select the LoRA class variant
that supports fully sharded LoRA. Defaults to False.

"""
return lora_linear_cls[vllm_linear_cls][fully_sharded]

return HFCompatibleLinear(
input_size=linear.in_features,
output_size=linear.out_features,
Expand All @@ -148,7 +117,7 @@ def get_lora_class(cls, fully_sharded: bool = False):
)


class TransformersModel(nn.Module, SupportsQuant):
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
Expand Down