Skip to content

Commit 3a1c5c4

Browse files
authored
add typechecking for boto clients/resources (#5319)
* add typechecking for boto clients/resources * import module instead of class
1 parent 6f0054e commit 3a1c5c4

File tree

8 files changed

+122
-35
lines changed

8 files changed

+122
-35
lines changed

requirements-dev.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ aiohttp==3.9.3
4040

4141
# For mocking AWS
4242
moto==5.1.2
43+
44+
# boto3 type checking
45+
boto3-stubs[s3,ec2,sts,iam,service-quotas]

sky/adaptors/aws.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,31 @@
3131
import logging
3232
import threading
3333
import time
34-
from typing import Any, Callable
34+
import typing
35+
from typing import Callable, Literal, Optional, TypeVar
3536

3637
from sky.adaptors import common
3738
from sky.utils import annotations
3839
from sky.utils import common_utils
3940

41+
if typing.TYPE_CHECKING:
42+
import boto3
43+
_ = boto3 # Supress pylint use before assignment error
44+
import mypy_boto3_ec2
45+
import mypy_boto3_iam
46+
import mypy_boto3_s3
47+
import mypy_boto3_service_quotas
48+
import mypy_boto3_sts
49+
4050
_IMPORT_ERROR_MESSAGE = ('Failed to import dependencies for AWS. '
4151
'Try pip install "skypilot[aws]"')
4252
boto3 = common.LazyImport('boto3', import_error_message=_IMPORT_ERROR_MESSAGE)
4353
botocore = common.LazyImport('botocore',
4454
import_error_message=_IMPORT_ERROR_MESSAGE)
4555
_LAZY_MODULES = (boto3, botocore)
4656

57+
T = TypeVar('T')
58+
4759
logger = logging.getLogger(__name__)
4860
_session_creation_lock = threading.RLock()
4961

@@ -73,8 +85,8 @@ def _assert_kwargs_builtin_type(kwargs):
7385
f'kwargs should not contain none built-in types: {kwargs}')
7486

7587

76-
def _create_aws_object(creation_fn_or_cls: Callable[[], Any],
77-
object_name: str) -> Any:
88+
def _create_aws_object(creation_fn_or_cls: Callable[[], T],
89+
object_name: str) -> T:
7890
"""Create an AWS object.
7991
8092
Args:
@@ -123,6 +135,25 @@ def session(check_credentials: bool = True):
123135
return s
124136

125137

138+
# New typing overloads can be added as needed.
139+
@typing.overload
140+
def resource(service_name: Literal['ec2'],
141+
**kwargs) -> 'mypy_boto3_ec2.ServiceResource':
142+
...
143+
144+
145+
@typing.overload
146+
def resource(service_name: Literal['s3'],
147+
**kwargs) -> 'mypy_boto3_s3.ServiceResource':
148+
...
149+
150+
151+
@typing.overload
152+
def resource(service_name: Literal['iam'],
153+
**kwargs) -> 'mypy_boto3_iam.ServiceResource':
154+
...
155+
156+
126157
# Avoid caching the resource/client objects. If we are using the assumed role,
127158
# the credentials will be automatically rotated, but the cached resource/client
128159
# object will only refresh the credentials with a fixed 15 minutes interval,
@@ -142,7 +173,7 @@ def resource(service_name: str, **kwargs):
142173
"""
143174
_assert_kwargs_builtin_type(kwargs)
144175

145-
max_attempts = kwargs.pop('max_attempts', None)
176+
max_attempts: Optional[int] = kwargs.pop('max_attempts', None)
146177
if max_attempts is not None:
147178
config = botocore_config().Config(
148179
retries={'max_attempts': max_attempts})
@@ -158,6 +189,28 @@ def resource(service_name: str, **kwargs):
158189
service_name, **kwargs), 'resource')
159190

160191

192+
# New typing overloads can be added as needed.
193+
@typing.overload
194+
def client(service_name: Literal['s3'], **kwargs) -> 'mypy_boto3_s3.Client':
195+
pass
196+
197+
198+
@typing.overload
199+
def client(service_name: Literal['ec2'], **kwargs) -> 'mypy_boto3_ec2.Client':
200+
pass
201+
202+
203+
@typing.overload
204+
def client(service_name: Literal['sts'], **kwargs) -> 'mypy_boto3_sts.Client':
205+
pass
206+
207+
208+
@typing.overload
209+
def client(service_name: Literal['service-quotas'],
210+
**kwargs) -> 'mypy_boto3_service_quotas.Client':
211+
pass
212+
213+
161214
def client(service_name: str, **kwargs):
162215
"""Create an AWS client of a certain service.
163216

sky/adaptors/common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import functools
33
import importlib
44
import threading
5+
import types
56
from typing import Any, Callable, Optional, Tuple
67

78

8-
class LazyImport:
9+
class LazyImport(types.ModuleType):
910
"""Lazy importer for modules.
1011
1112
This is mainly used in two cases:

sky/clouds/aws.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,12 @@ def get_image_size(cls, image_id: str, region: Optional[str]) -> float:
312312
'Example: ami-0729d913a335efca7')
313313
try:
314314
client = aws.client('ec2', region_name=region)
315-
image_info = client.describe_images(ImageIds=[image_id])
316-
image_info = image_info.get('Images', [])
315+
image_info = client.describe_images(ImageIds=[image_id]).get(
316+
'Images', [])
317317
if not image_info:
318318
with ux_utils.print_exception_no_traceback():
319319
raise ValueError(image_not_found_message)
320-
image_info = image_info[0]
321-
image_size = image_info['BlockDeviceMappings'][0]['Ebs'][
320+
image_size = image_info[0]['BlockDeviceMappings'][0]['Ebs'][
322321
'VolumeSize']
323322
except (aws.botocore_exceptions().NoCredentialsError,
324323
aws.botocore_exceptions().ProfileNotFound):

sky/clouds/service_catalog/data_fetchers/fetch_aws.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import textwrap
1414
import traceback
1515
import typing
16-
from typing import Dict, List, Optional, Set, Tuple, Union
16+
from typing import List, Optional, Set, Tuple, Union
1717

1818
import numpy as np
1919

@@ -24,6 +24,7 @@
2424
from sky.utils import ux_utils
2525

2626
if typing.TYPE_CHECKING:
27+
from mypy_boto3_ec2 import type_defs as ec2_type_defs
2728
import pandas as pd
2829
else:
2930
pd = adaptors_common.LazyImport('pandas')
@@ -192,7 +193,7 @@ def _get_spot_pricing_table(region: str) -> 'pd.DataFrame':
192193
paginator = client.get_paginator('describe_spot_price_history')
193194
response_iterator = paginator.paginate(ProductDescriptions=['Linux/UNIX'],
194195
StartTime=datetime.datetime.utcnow())
195-
ret: List[Dict[str, str]] = []
196+
ret: List['ec2_type_defs.SpotPriceTypeDef'] = []
196197
for response in response_iterator:
197198
# response['SpotPriceHistory'] is a list of dicts, each dict is like:
198199
# {

sky/data/storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
from sky.utils import ux_utils
3939

4040
if typing.TYPE_CHECKING:
41-
import boto3 # type: ignore
4241
from google.cloud import storage # type: ignore
42+
import mypy_boto3_s3
4343

4444
logger = sky_logging.init_logger(__name__)
4545

@@ -1363,7 +1363,7 @@ def __init__(self,
13631363
is_sky_managed: Optional[bool] = None,
13641364
sync_on_reconstruction: bool = True,
13651365
_bucket_sub_path: Optional[str] = None):
1366-
self.client: 'boto3.client.Client'
1366+
self.client: 'mypy_boto3_s3.Client'
13671367
self.bucket: 'StorageHandle'
13681368
# TODO(romilb): This is purely a stopgap fix for
13691369
# https://github.com/skypilot-org/skypilot/issues/3405
@@ -3295,7 +3295,7 @@ def __init__(self,
32953295
is_sky_managed: Optional[bool] = None,
32963296
sync_on_reconstruction: Optional[bool] = True,
32973297
_bucket_sub_path: Optional[str] = None):
3298-
self.client: 'boto3.client.Client'
3298+
self.client: 'mypy_boto3_s3.Client'
32993299
self.bucket: 'StorageHandle'
33003300
super().__init__(name, source, region, is_sky_managed,
33013301
sync_on_reconstruction, _bucket_sub_path)
@@ -4700,7 +4700,7 @@ def __init__(self,
47004700
is_sky_managed: Optional[bool] = None,
47014701
sync_on_reconstruction: bool = True,
47024702
_bucket_sub_path: Optional[str] = None):
4703-
self.client: 'boto3.client.Client'
4703+
self.client: 'mypy_boto3_s3.Client'
47044704
self.bucket: 'StorageHandle'
47054705
super().__init__(name, source, region, is_sky_managed,
47064706
sync_on_reconstruction, _bucket_sub_path)

sky/provision/aws/config.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import json
1212
import logging
1313
import time
14+
import typing
1415
from typing import Any, Dict, List, Optional, Set, Tuple
1516

1617
import colorama
@@ -23,6 +24,10 @@
2324
from sky.utils import annotations
2425
from sky.utils import common_utils
2526

27+
if typing.TYPE_CHECKING:
28+
import mypy_boto3_ec2
29+
from mypy_boto3_ec2 import type_defs as ec2_type_defs
30+
2631
logger = sky_logging.init_logger(__name__)
2732

2833
RAY = 'ray-autoscaler'
@@ -223,7 +228,8 @@ def _get_role(role_name: str):
223228

224229

225230
@annotations.lru_cache(scope='request', maxsize=128) # Keep bounded.
226-
def _get_route_tables(ec2, vpc_id: Optional[str], region: str,
231+
def _get_route_tables(ec2: 'mypy_boto3_ec2.ServiceResource',
232+
vpc_id: Optional[str], region: str,
227233
main: bool) -> List[Any]:
228234
"""Get route tables associated with a VPC and region
229235
@@ -248,7 +254,8 @@ def _get_route_tables(ec2, vpc_id: Optional[str], region: str,
248254
'RouteTables', [])
249255

250256

251-
def _is_subnet_public(ec2, subnet_id, vpc_id: Optional[str]) -> bool:
257+
def _is_subnet_public(ec2: 'mypy_boto3_ec2.ServiceResource', subnet_id,
258+
vpc_id: Optional[str]) -> bool:
252259
"""Checks if a subnet is public by existence of a route to an IGW.
253260
254261
Conventionally, public subnets connect to a IGW, and private subnets to a
@@ -441,10 +448,14 @@ def _get_pruned_subnets(current_subnets: List[Any]) -> Set[str]:
441448
return subnets, first_subnet_vpc_id
442449

443450

444-
def _vpc_id_from_security_group_ids(ec2, sg_ids: List[str]) -> Any:
451+
def _vpc_id_from_security_group_ids(ec2: 'mypy_boto3_ec2.ServiceResource',
452+
sg_ids: List[str]) -> Any:
445453
# sort security group IDs to support deterministic unit test stubbing
446454
sg_ids = sorted(set(sg_ids))
447-
filters = [{'Name': 'group-id', 'Values': sg_ids}]
455+
filters: List['ec2_type_defs.FilterTypeDef'] = [{
456+
'Name': 'group-id',
457+
'Values': sg_ids
458+
}]
448459
security_groups = ec2.security_groups.filter(Filters=filters)
449460
vpc_ids = [sg.vpc_id for sg in security_groups]
450461
vpc_ids = list(set(vpc_ids))
@@ -462,15 +473,19 @@ def _vpc_id_from_security_group_ids(ec2, sg_ids: List[str]) -> Any:
462473
return vpc_ids[0]
463474

464475

465-
def _get_vpc_id_by_name(ec2, vpc_name: str, region: str) -> str:
476+
def _get_vpc_id_by_name(ec2: 'mypy_boto3_ec2.ServiceResource', vpc_name: str,
477+
region: str) -> str:
466478
"""Returns the VPC ID of the unique VPC with a given name.
467479
468480
Exits with code 1 if:
469481
- No VPC with the given name is found in the current region.
470482
- More than 1 VPC with the given name are found in the current region.
471483
"""
472484
# Look in the 'Name' tag (shown as Name column in console).
473-
filters = [{'Name': 'tag:Name', 'Values': [vpc_name]}]
485+
filters: List['ec2_type_defs.FilterTypeDef'] = [{
486+
'Name': 'tag:Name',
487+
'Values': [vpc_name]
488+
}]
474489
vpcs = list(ec2.vpcs.filter(Filters=filters))
475490
if not vpcs:
476491
_skypilot_log_error_and_exit_for_failover(
@@ -486,8 +501,9 @@ def _get_vpc_id_by_name(ec2, vpc_name: str, region: str) -> str:
486501
return vpcs[0].id
487502

488503

489-
def _get_subnet_and_vpc_id(ec2, security_group_ids: Optional[List[str]],
490-
region: str, availability_zone: Optional[str],
504+
def _get_subnet_and_vpc_id(ec2: 'mypy_boto3_ec2.ServiceResource',
505+
security_group_ids: Optional[List[str]], region: str,
506+
availability_zone: Optional[str],
491507
use_internal_ips: bool,
492508
vpc_name: Optional[str]) -> Tuple[Any, str]:
493509
if vpc_name is not None:
@@ -514,7 +530,8 @@ def _get_subnet_and_vpc_id(ec2, security_group_ids: Optional[List[str]],
514530
return subnets, vpc_id
515531

516532

517-
def _configure_security_group(ec2, vpc_id: str, expected_sg_name: str,
533+
def _configure_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
534+
vpc_id: str, expected_sg_name: str,
518535
extended_ip_rules: List) -> List[str]:
519536
security_group = _get_or_create_vpc_security_group(ec2, vpc_id,
520537
expected_sg_name)
@@ -551,7 +568,8 @@ def _configure_security_group(ec2, vpc_id: str, expected_sg_name: str,
551568
return sg_ids
552569

553570

554-
def _get_or_create_vpc_security_group(ec2, vpc_id: str,
571+
def _get_or_create_vpc_security_group(ec2: 'mypy_boto3_ec2.ServiceResource',
572+
vpc_id: str,
555573
expected_sg_name: str) -> Any:
556574
"""Find or create a security group in the specified VPC.
557575
@@ -612,7 +630,8 @@ def _get_or_create_vpc_security_group(ec2, vpc_id: str,
612630
return security_group
613631

614632

615-
def _get_security_group_from_vpc_id(ec2, vpc_id: str,
633+
def _get_security_group_from_vpc_id(ec2: 'mypy_boto3_ec2.ServiceResource',
634+
vpc_id: str,
616635
group_name: str) -> Optional[Any]:
617636
"""Get security group by VPC ID and group name."""
618637
existing_groups = list(

0 commit comments

Comments
 (0)