Skip to content

Introduce VLLM_CUDART_SO_PATH to allow users specify the .so path #12998

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion vllm/distributed/device_communicators/cuda_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
"""

import ctypes
import glob
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

# this line makes it possible to directly load `libcudart.so` using `ctypes`
import torch # noqa

import vllm.envs as envs
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -60,6 +62,29 @@ def find_loaded_library(lib_name) -> Optional[str]:
return path


def get_cudart_lib_path_from_env() -> Optional[str]:
"""
In some system, find_loaded_library() may not work. So we allow users to
specify the path through environment variable VLLM_CUDART_SO_PATH.
"""
cudart_so_env = envs.VLLM_CUDART_SO_PATH
if cudart_so_env is not None:
cudart_paths = [
cudart_so_env,
]
for path in cudart_paths:
file_paths = glob.glob(path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is it a glob ? I think it should just be a path that can be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, cleaned up in #13203.

We were thinking specify some partial path. But using the single path is cleaner. Thanks for pointing it out.

if len(file_paths) > 0:
logger.info(
"Found cudart library at %s through env var"
"VLLM_CUDART_SO_PATH=%s",
file_paths[0],
cudart_so_env,
)
return file_paths[0]
return None


class CudaRTLibrary:
exported_functions = [
# ​cudaError_t cudaSetDevice ( int device )
Expand Down Expand Up @@ -105,8 +130,13 @@ class CudaRTLibrary:
def __init__(self, so_file: Optional[str] = None):
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = get_cudart_lib_path_from_env()
assert so_file is not None, \
"libcudart is not loaded in the current process"
(
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
)
if so_file not in CudaRTLibrary.path_to_library_cache:
lib = ctypes.CDLL(so_file)
CudaRTLibrary.path_to_library_cache[so_file] = lib
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_CUDART_SO_PATH: Optional[str] = None


def get_default_cache_root():
Expand Down Expand Up @@ -572,6 +573,11 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),

# In some system, find_loaded_library() may not work. So we allow users to
# specify the path through environment variable VLLM_CUDART_SO_PATH.
"VLLM_CUDART_SO_PATH":
lambda: os.getenv("VLLM_CUDART_SO_PATH", None),
}

# end-env-vars-definition
Expand Down