File tree 2 files changed +20
-6
lines changed
2 files changed +20
-6
lines changed Original file line number Diff line number Diff line change @@ -509,8 +509,9 @@ def _set_accelerators(
509
509
510
510
# Canonicalize the accelerator names.
511
511
accelerators = {
512
- accelerator_registry .canonicalize_accelerator_name (acc ):
513
- acc_count for acc , acc_count in accelerators .items ()
512
+ accelerator_registry .canonicalize_accelerator_name (
513
+ acc , self ._cloud ): acc_count
514
+ for acc , acc_count in accelerators .items ()
514
515
}
515
516
516
517
acc , _ = list (accelerators .items ())[0 ]
@@ -1311,8 +1312,9 @@ def __setstate__(self, state):
1311
1312
accelerators = state .pop ('_accelerators' , None )
1312
1313
if accelerators is not None :
1313
1314
accelerators = {
1314
- accelerator_registry .canonicalize_accelerator_name (acc ):
1315
- acc_count for acc , acc_count in accelerators .items ()
1315
+ accelerator_registry .canonicalize_accelerator_name (
1316
+ acc , cloud = None ): acc_count
1317
+ for acc , acc_count in accelerators .items ()
1316
1318
}
1317
1319
state ['_accelerators' ] = accelerators
1318
1320
Original file line number Diff line number Diff line change 1
1
"""Accelerator registry."""
2
+ import typing
3
+ from typing import Optional
4
+
2
5
from sky .clouds import service_catalog
3
6
from sky .utils import ux_utils
4
7
8
+ if typing .TYPE_CHECKING :
9
+ from sky import clouds
10
+
5
11
# Canonicalized names of all accelerators (except TPUs) supported by SkyPilot.
6
12
# NOTE: Must include accelerators supported for local clusters.
7
13
#
@@ -67,8 +73,13 @@ def is_schedulable_non_gpu_accelerator(accelerator_name: str) -> bool:
67
73
return False
68
74
69
75
70
- def canonicalize_accelerator_name (accelerator : str ) -> str :
76
+ def canonicalize_accelerator_name (accelerator : str ,
77
+ cloud : Optional ['clouds.Cloud' ]) -> str :
71
78
"""Returns the canonical accelerator name."""
79
+ cloud_str = None
80
+ if cloud is not None :
81
+ cloud_str = str (cloud ).lower ()
82
+
72
83
# TPU names are always lowercase.
73
84
if accelerator .lower ().startswith ('tpu-' ):
74
85
return accelerator .lower ()
@@ -84,7 +95,8 @@ def canonicalize_accelerator_name(accelerator: str) -> str:
84
95
# To cover such cases, we should search the accelerator name
85
96
# in the service catalog.
86
97
searched = service_catalog .list_accelerators (name_filter = accelerator ,
87
- case_sensitive = False )
98
+ case_sensitive = False ,
99
+ clouds = cloud_str )
88
100
names = list (searched .keys ())
89
101
90
102
# Exact match.
You can’t perform that action at this time.
0 commit comments