Skip to content

[Bugfix] V1 Memory Profiling: V0 Sampler Integration without Rejection Sampler #13594

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 11 commits into from
Feb 22, 2025
28 changes: 26 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
Expand Down Expand Up @@ -1303,11 +1304,34 @@ def profile_run(self) -> None:
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
dummy_tensors = lambda v: torch.full(
(num_reqs, ), v, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=dummy_tensors(0.5),
all_greedy=False,
all_random=False,
spec_token_ids=None,
top_p=dummy_tensors(0.9),
top_k=dummy_tensors(logits.size(1) - 1),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=torch.ones_like(logits, dtype=torch.int64),
frequency_penalties=dummy_tensors(0.1),
presence_penalties=dummy_tensors(0.1),
repetition_penalties=dummy_tensors(0.1),
output_token_ids=[[] for _ in range(num_reqs)],
min_tokens={},
logit_bias=[None for _ in range(num_reqs)])
sampler_output = self.model.sample(
logits=logits, sampling_metadata=dummy_metadata)
else:
logits = None
sampler_output = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits
del hidden_states, logits, sampler_output, dummy_metadata
self.encoder_cache.clear()
gc.collect()

Expand Down