Skip to content

[Bugfix] Fix size calculation of processing cache #15114

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 3 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@

import numpy as np
import pytest
import torch
from transformers import ProcessorMixin

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
ProcessingCache, PromptIndexTargets,
PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
Expand Down Expand Up @@ -890,6 +895,45 @@ def test_find_mm_placeholders(
assert result == expected


def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)


def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])


def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])


# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = ProcessingCache.get_lru_cache(2048, type(item))
cache[""] = item

assert cache.currsize == expected_size


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
("limit", "num_supported", "is_valid"),
Expand Down
60 changes: 46 additions & 14 deletions vllm/multimodal/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange)
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)

Expand Down Expand Up @@ -853,33 +853,62 @@ class ProcessingCache:

@staticmethod
def get_lru_cache(
capacity_gb: int,
capacity_gb: float,
value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]:

def get_size(leaf: object) -> int:
def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)

# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)

# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors
return leaf.nbytes

return sys.getsizeof(leaf)

return LRUCache[str, _V](
GiB_bytes * capacity_gb,
getsizeof=lambda x: json_reduce_leaves(
def get_item_size(
value: Union[MultiModalKwargs, MultiModalKwargsItem,
Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b,
json_map_leaves(get_size, x),
),
)
json_map_leaves(get_leaf_size, value),
)

if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)

def __init__(self, capacity_gb: int) -> None:
return size

return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)

def __init__(
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__()

# DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
self.debug_cache_hits = 0
self.debug_cache_total = 0

self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)

def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps
Expand All @@ -890,6 +919,9 @@ def _maybe_log_cache_stats(self) -> None:
if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)

def get(
self,
Expand Down