Skip to content

Commit 1deaebf

Browse files
rebase code generation wrappers refactoring
1 parent cdc2c0c commit 1deaebf

File tree

5 files changed

+71
-11
lines changed

5 files changed

+71
-11
lines changed

tools/scripts/codegen/legacy_c2j_cpp_gen.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def generate(self, executor: ProcessPoolExecutor, max_workers: int, args: dict)
140140
return -1
141141

142142
print(f"Code generation done, (re)generated {len(done)} packages.") # Including defaults and partitions
143+
return 0
143144

144145
def _init_common_java_cli(self,
145146
service_name: str,
@@ -172,7 +173,7 @@ def _init_common_java_cli(self,
172173
run_command += ["--service", service_name]
173174
run_command += ["--outputfile", output_filename]
174175

175-
if service_name in SMITHY_SUPPORTED_CLIENTS:
176+
if service_name in SMITHY_SUPPORTED_CLIENTS or model_files.use_smithy:
176177
run_command += ["--use-smithy-client"]
177178

178179
for key, val in kwargs.items():

tools/scripts/codegen/model_utils.py

+53-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
A set of utils to go through c2j models and their corresponding endpoint rules
88
"""
99
import datetime
10+
import json
1011
import os
1112
import re
1213

@@ -18,6 +19,29 @@
1819
"transcribe-streaming": "transcribestreaming",
1920
"streams.dynamodb": "dynamodbstreams"}
2021

22+
SMITHY_EXCLUSION_CLIENTS = {
23+
# multi auth
24+
"eventbridge"
25+
, "cloudfront-keyvaluestore"
26+
, "cognito-identity"
27+
, "cognito-idp"
28+
# customization
29+
, "machinelearning"
30+
, "apigatewayv2"
31+
, "apigateway"
32+
, "eventbridge"
33+
, "glacier"
34+
, "lambda"
35+
, "polly"
36+
, "sqs"
37+
# bearer token
38+
# ,"codecatalyst"
39+
# bidirectional streaming
40+
, "lexv2-runtime"
41+
, "qbusiness"
42+
, "transcribestreaming"
43+
}
44+
2145
# Regexp to parse C2J model filename to extract service name and date version
2246
SERVICE_MODEL_FILENAME_PATTERN = re.compile(
2347
"^"
@@ -29,12 +53,13 @@
2953

3054
class ServiceModel(object):
3155
# A helper class to store C2j model info and metadata (endpoint rules and tests)
32-
def __init__(self, service_id, c2j_model, endpoint_rule_set, endpoint_tests):
56+
def __init__(self, service_id: str, c2j_model: str, endpoint_rule_set: str, endpoint_tests: str, use_smithy: bool):
3357
self.service_id = service_id # For debugging purposes, not used atm
3458
# only filenames, no filesystem path
3559
self.c2j_model = c2j_model
3660
self.endpoint_rule_set = endpoint_rule_set
3761
self.endpoint_tests = endpoint_tests
62+
self.use_smithy = use_smithy
3863

3964

4065
class ModelUtils(object):
@@ -113,7 +138,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:
113138

114139
# fetch endpoint-rules filename which is based on ServiceId in c2j models:
115140
try:
116-
service_name_to_model_filename[key] = ModelUtils._build_service_model(endpoint_rules_dir,
141+
service_name_to_model_filename[key] = ModelUtils._build_service_model(models_dir,
142+
endpoint_rules_dir,
117143
model_file_date[0])
118144

119145
if key == "s3":
@@ -125,7 +151,8 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:
125151
service_name_to_model_filename[key] = ServiceModel(service_id=key,
126152
c2j_model=model_file_date[0],
127153
endpoint_rule_set=None,
128-
endpoint_tests=None)
154+
endpoint_tests=None,
155+
use_smithy=False)
129156
if missing:
130157
# TODO: re-enable with endpoints introduction
131158
# print(f"Missing endpoints for services: {missing}")
@@ -137,7 +164,25 @@ def _collect_available_models(models_dir: str, endpoint_rules_dir: str) -> dict:
137164
return service_name_to_model_filename
138165

139166
@staticmethod
140-
def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel:
167+
def is_smithy_enabled(service_id, models_dir, c2j_model_filename):
168+
"""Return true if given service id and c2j model file should enable smithy client generation path
169+
170+
:param service_id:
171+
:param models_dir:
172+
:param c2j_model_filename:
173+
:return:
174+
"""
175+
use_smithy = False
176+
if service_id not in SMITHY_EXCLUSION_CLIENTS:
177+
with open(models_dir + "/" + c2j_model_filename, 'r') as json_file:
178+
model = json.load(json_file)
179+
model_protocol = model.get("metadata", dict()).get("protocol", "UNKNOWN_PROTOCOL")
180+
if model_protocol in {"json", "rest-json"}:
181+
use_smithy = True
182+
return use_smithy
183+
184+
@staticmethod
185+
def _build_service_model(models_dir: str, endpoint_rules_dir: str, c2j_model_filename) -> ServiceModel:
141186
"""Return a ServiceModel containing paths to the Service models: C2J model and endpoints (rules and tests).
142187
143188
:param models_dir (str): filepath (absolute or relative) to the dir with c2j models
@@ -153,8 +198,11 @@ def _build_service_model(endpoint_rules_dir: str, c2j_model_filename) -> Service
153198
match = SERVICE_MODEL_FILENAME_PATTERN.match(c2j_model_filename)
154199
service_id = match.group("service")
155200

201+
use_smithy = ModelUtils.is_smithy_enabled(service_id, models_dir, c2j_model_filename)
202+
156203
if os.path.exists(endpoint_rules_filepath) and os.path.exists(endpoint_tests_filepath):
157204
return ServiceModel(service_id=service_id,
158205
c2j_model=c2j_model_filename,
159206
endpoint_rule_set=endpoint_rules_filename,
160-
endpoint_tests=endpoint_tests_filename)
207+
endpoint_tests=endpoint_tests_filename,
208+
use_smithy=use_smithy)

tools/scripts/codegen/protocol_tests_gen.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED
1313

1414
from codegen.legacy_c2j_cpp_gen import LegacyC2jCppGen
15-
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel
15+
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel, ModelUtils
1616

1717
PROTOCOL_TESTS_BASE_DIR = "tools/code-generation/protocol-tests"
1818
PROTOCOL_TESTS_CLIENT_MODELS = PROTOCOL_TESTS_BASE_DIR + "/api-descriptions"
@@ -98,6 +98,7 @@ def _generate_test_clients(self, executor: ProcessPoolExecutor, max_workers: int
9898

9999
if len(failures):
100100
return -1
101+
return 0
101102

102103
def _collect_test_client_models(self) -> dict:
103104
service_models = dict()
@@ -112,8 +113,9 @@ def _collect_test_client_models(self) -> dict:
112113
if service_model_name in UNSUPPORTED_CLIENTS:
113114
continue
114115

116+
use_smithy = ModelUtils.is_smithy_enabled(service_model_name, self.client_models_dir, filename)
115117
service_models[service_model_name] = ServiceModel(service_model_name, filename,
116-
PROTOCOL_TESTS_ENDPOINT_RULES, None)
118+
PROTOCOL_TESTS_ENDPOINT_RULES, None, use_smithy)
117119
return service_models
118120

119121
def _generate_tests(self):

tools/scripts/codegen/smoke_tests_gen.py

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def generate(self, clients_to_build: set):
4141
self._copy_cpp_codegen_contents(os.path.abspath("tools/code-generation/smithy/codegen"),
4242
"cpp-codegen-smoke-tests-plugin",
4343
os.path.abspath("generated/smoke-tests"))
44+
return 0
45+
else:
46+
return -1
4447

4548
def _generate_smoke_tests(self, smithy_services: List[str], smithy_c2j_data: str):
4649
smithy_codegen_command = [

tools/scripts/run_code_generation.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,24 @@ def main():
141141

142142
with ProcessPoolExecutor(max_workers=max_workers) as executor:
143143
c2j_gen = LegacyC2jCppGen(args, model_utils.models_to_generate)
144-
c2j_gen.generate(executor, max_workers, args)
144+
if c2j_gen.generate(executor, max_workers, args) != 0:
145+
print("ERROR: Failed to generate service client(s)!")
146+
return -1
145147

146148
if args["generate_protocol_tests"]:
147149
protocol_tests_generator = ProtocolTestsGen(args)
148-
protocol_tests_generator.generate(executor, max_workers)
150+
if protocol_tests_generator.generate(executor, max_workers) != 0:
151+
print("ERROR: Failed to generate protocol test(s)!")
152+
return -1
149153

150154
# generate code using smithy for all discoverable clients
151155
# clients_to_build check is present because user can generate only defaults or partitions or protocol-tests
152156
clients_to_build = model_utils.get_clients_to_build()
153157
if args["generate_smoke_tests"] and clients_to_build:
154158
smoke_tests_gen = SmokeTestsGen(args["debug"])
155-
smoke_tests_gen.generate(clients_to_build)
159+
if smoke_tests_gen.generate(clients_to_build) != 0:
160+
print("ERROR: Failed to generate smoke test(s)!")
161+
return -1
156162

157163
return 0
158164

0 commit comments

Comments
 (0)