Skip to content

Commit 70a7435

Browse files
authored
[Minor] Improve completion of accelerator name when cloud is specified (#3014)
1 parent 16cdc7f commit 70a7435

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

sky/resources.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,9 @@ def _set_accelerators(
509509

510510
# Canonicalize the accelerator names.
511511
accelerators = {
512-
accelerator_registry.canonicalize_accelerator_name(acc):
513-
acc_count for acc, acc_count in accelerators.items()
512+
accelerator_registry.canonicalize_accelerator_name(
513+
acc, self._cloud): acc_count
514+
for acc, acc_count in accelerators.items()
514515
}
515516

516517
acc, _ = list(accelerators.items())[0]
@@ -1311,8 +1312,9 @@ def __setstate__(self, state):
13111312
accelerators = state.pop('_accelerators', None)
13121313
if accelerators is not None:
13131314
accelerators = {
1314-
accelerator_registry.canonicalize_accelerator_name(acc):
1315-
acc_count for acc, acc_count in accelerators.items()
1315+
accelerator_registry.canonicalize_accelerator_name(
1316+
acc, cloud=None): acc_count
1317+
for acc, acc_count in accelerators.items()
13161318
}
13171319
state['_accelerators'] = accelerators
13181320

sky/utils/accelerator_registry.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
"""Accelerator registry."""
2+
import typing
3+
from typing import Optional
4+
25
from sky.clouds import service_catalog
36
from sky.utils import ux_utils
47

8+
if typing.TYPE_CHECKING:
9+
from sky import clouds
10+
511
# Canonicalized names of all accelerators (except TPUs) supported by SkyPilot.
612
# NOTE: Must include accelerators supported for local clusters.
713
#
@@ -67,8 +73,13 @@ def is_schedulable_non_gpu_accelerator(accelerator_name: str) -> bool:
6773
return False
6874

6975

70-
def canonicalize_accelerator_name(accelerator: str) -> str:
76+
def canonicalize_accelerator_name(accelerator: str,
77+
cloud: Optional['clouds.Cloud']) -> str:
7178
"""Returns the canonical accelerator name."""
79+
cloud_str = None
80+
if cloud is not None:
81+
cloud_str = str(cloud).lower()
82+
7283
# TPU names are always lowercase.
7384
if accelerator.lower().startswith('tpu-'):
7485
return accelerator.lower()
@@ -84,7 +95,8 @@ def canonicalize_accelerator_name(accelerator: str) -> str:
8495
# To cover such cases, we should search the accelerator name
8596
# in the service catalog.
8697
searched = service_catalog.list_accelerators(name_filter=accelerator,
87-
case_sensitive=False)
98+
case_sensitive=False,
99+
clouds=cloud_str)
88100
names = list(searched.keys())
89101

90102
# Exact match.

0 commit comments

Comments
 (0)