Skip to content

Update to torch==2.6.0 #12721

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 22 commits into from
Mar 14, 2025
Merged

Update to torch==2.6.0 #12721

merged 22 commits into from
Mar 14, 2025

Conversation

mgoin
Copy link
Member

@mgoin mgoin commented Feb 4, 2025

Only updates for CUDA. Successfully built locally on H100 CUDA 12.5 system and tested with vllm serve meta-llama/Llama-3.1-8B-Instruct

We should upgrade other hardware backends separately. For instance, CPU is blocked by IPEX in the Dockerfile.cpu

FIX #12719

Signed-off-by: mgoin <michael@neuralmagic.com>
@mgoin mgoin requested a review from tlrmchlsmth as a code owner February 4, 2025 01:44
Copy link

github-actions bot commented Feb 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 4, 2025
Signed-off-by: mgoin <michael@neuralmagic.com>
@mgoin mgoin changed the title Update to torch==2.6.0 [WIP] Update to torch==2.6.0 Feb 4, 2025
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 4, 2025
@mgoin mgoin changed the title [WIP] Update to torch==2.6.0 Update to torch==2.6.0 Feb 4, 2025
Signed-off-by: mgoin <michael@neuralmagic.com>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, CI looks green

@houseroad
Copy link
Collaborator

Shall we merge #12393 first? cc: @youkaichao

Copy link
Contributor

@fialhocoelho fialhocoelho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I built vLLM by merging this PR, and it worked perfectly 🚀

@mgoin
Copy link
Member Author

mgoin commented Feb 4, 2025

Confirmed that this update will break V1 at the current state, we should wait for #12393 at least

VLLM_USE_V1=1 vllm serve meta-llama/Llama-3.1-8B-Instruct
...
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/code/vllm/vllm/compilation/backends.py", line 616, in __call__
ERROR 02-04 15:27:21 core.py:210]     PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/code/vllm/vllm/compilation/backends.py", line 424, in run
ERROR 02-04 15:27:21 core.py:210]     return super().run(*fake_args)
ERROR 02-04 15:27:21 core.py:210]            ^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/fx/interpreter.py", line 167, in run
ERROR 02-04 15:27:21 core.py:210]     self.env[node] = self.run_node(node)
ERROR 02-04 15:27:21 core.py:210]                      ^^^^^^^^^^^^^^^^^^^
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/fx/interpreter.py", line 230, in run_node
ERROR 02-04 15:27:21 core.py:210]     return getattr(self, n.op)(n.target, args, kwargs)
ERROR 02-04 15:27:21 core.py:210]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/code/vllm/vllm/compilation/backends.py", line 439, in call_module
ERROR 02-04 15:27:21 core.py:210]     compiled_graph_for_general_shape = wrap_inductor(
ERROR 02-04 15:27:21 core.py:210]                                        ^^^^^^^^^^^^^^
ERROR 02-04 15:27:21 core.py:210]   File "/home/mgoin/code/vllm/vllm/compilation/backends.py", line 254, in wrap_inductor
ERROR 02-04 15:27:21 core.py:210]     original_load = FxGraphCache.load
ERROR 02-04 15:27:21 core.py:210]                     ^^^^^^^^^^^^^^^^^
ERROR 02-04 15:27:21 core.py:210] torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x71985bc685c0>' raised:
ERROR 02-04 15:27:21 core.py:210] AttributeError: type object 'FxGraphCache' has no attribute 'load'
ERROR 02-04 15:27:21 core.py:210] 
ERROR 02-04 15:27:21 core.py:210] While executing %submod_0 : [num_users=5] = call_module[target=submod_0](args = (%l_input_ids_, %s0, %l_self_modules_embed_tokens_parameters_weight_, %l_self_modules_layers_modules_0_modules_input_layernorm_parameters_weight_, %l_self_modules_layers_modules_0_modules_self_attn_modules_qkv_proj_parameters_weight_, %l_positions_, %l_self_modules_layers_modules_0_modules_self_attn_modules_rotary_emb_buffers_cos_sin_cache_), kwargs = {})

@youkaichao
Copy link
Member

@mgoin can you help review and stamp that PR?

@zhouyuan
Copy link
Contributor

zhouyuan commented Feb 7, 2025

@mgoin Thanks a lot for the update. IPEX CPU w/ PT 2.6 will be released next week. Will update on this as soon as the binary is out.

Cc: @Guobing-Chen @bigPYJ1151

Thanks, -yuan

Copy link

mergify bot commented Feb 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @mgoin.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 10, 2025
@jiangshaoping
Copy link

I wanna when this PR will be merged?

@mergify mergify bot removed the needs-rebase label Feb 10, 2025
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
@mgoin
Copy link
Member Author

mgoin commented Mar 12, 2025

I added the enable_auto_functionalized_v2 guard and merged with main, so let's see where the CI ends up

@mgoin
Copy link
Member Author

mgoin commented Mar 14, 2025

I kicked off a manual build to get past the timeout for docker build. https://buildkite.com/vllm/ci/builds/15327/table

Unfortunately there are a few errors left that seem related to the upgrade. I can't look into this right now so happy for others to contribute

@DarkLight1337
Copy link
Member

LoRA and multi-modal tests should be fixed on main, let's see what errors are left.

@DarkLight1337
Copy link
Member

There seems to be an import error in bitsandbytes: https://buildkite.com/vllm/ci/builds/15327#01958c9b-93f5-4350-aa81-c3dcf079bfe7

@ProExpertProg
Copy link
Contributor

Yeah looking into it, it seems that triton==3.2 does not have triton.ops whereas 3.1 does.

Signed-off-by: luka <luka@neuralmagic.com>
@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Mar 14, 2025

Possibly good to go now?? 🤞 🤞

edit: of course not -- I'll fix the pre-commit

tlrmchlsmth and others added 3 commits March 14, 2025 16:05
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
@tlrmchlsmth tlrmchlsmth merged commit 14f301b into main Mar 14, 2025
58 checks passed
@tlrmchlsmth tlrmchlsmth deleted the update-torch-2.6.0 branch March 14, 2025 20:58
@xihuai18
Copy link

hi, How can I build vllm using torch 2.5.1 after this PR? Was there anyone succeeded?

@ProExpertProg
Copy link
Contributor

hi, How can I build vllm using torch 2.5.1 after this PR? Was there anyone succeeded?

Can you try pip install -e . --no-build-isolation in an environment with torch==2.5.1 already installed?

@xihuai18
Copy link

hi, How can I build vllm using torch 2.5.1 after this PR? Was there anyone succeeded?

Can you try pip install -e . --no-build-isolation in an environment with torch==2.5.1 already installed?

I am trying:

git clone https://github.com/vllm-project/vllm.git
cd vllm
python use_existing_torch.py
pip install -r requirements/build.txt
pip install -e . --no-build-isolation
``

@ProExpertProg
Copy link
Contributor

Are you getting an error? You might need to downgrade other dependencies as well, that would be my only other guess.

@xihuai18
Copy link

Are you getting an error? You might need to downgrade other dependencies as well, that would be my only other guess.

I am building wheels for torch 2.5.1, but I meet many errors. I hope vllm could officially provide wheels for torch2.5.1 since torch2.6.0 would lead to many dependence problems when using vllm with some integrations such as verl or ms-swift.

@ProExpertProg
Copy link
Contributor

Could you create a new issue and post the errors? I don't think providing official 2.5.1 wheels is on the roadmap for v0.8.0+. But you're welcome to use an earlier version or cherry pick the commits you need.

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
nishith-fujitsu pushed a commit to nishith-fujitsu/vllm that referenced this pull request Apr 9, 2025
Signed-off-by: mgoin <michael@neuralmagic.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: luka <luka@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
No Sign up for free to join this conversation on GitHub. Already have an account? No Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Installation]: Supporting PyTorch 2.6?