Skip to content

Commit b1b9199

Browse files
authored
Minor code simplifications (#146)
1 parent 6f51c70 commit b1b9199

File tree

8 files changed

+96
-101
lines changed

8 files changed

+96
-101
lines changed

src/graphql/language/visitor.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,15 @@ def __init_subclass__(cls) -> None:
183183
kind: Optional[str] = None
184184
else:
185185
attr, kind = attr_kind
186-
if attr in ("enter", "leave"):
187-
if kind:
188-
name = snake_to_camel(kind) + "Node"
189-
node_cls = getattr(ast, name, None)
190-
if (
191-
not node_cls
192-
or not isinstance(node_cls, type)
193-
or not issubclass(node_cls, Node)
194-
):
195-
raise TypeError(f"Invalid AST node kind: {kind}.")
186+
if attr in ("enter", "leave") and kind:
187+
name = snake_to_camel(kind) + "Node"
188+
node_cls = getattr(ast, name, None)
189+
if (
190+
not node_cls
191+
or not isinstance(node_cls, type)
192+
or not issubclass(node_cls, Node)
193+
):
194+
raise TypeError(f"Invalid AST node kind: {kind}.")
196195

197196
def get_visit_fn(self, kind: str, is_leaving: bool = False) -> Callable:
198197
"""Get the visit function for the given node kind and direction."""
@@ -256,22 +255,18 @@ def visit(root: Node, visitor: Visitor) -> Any:
256255
node: Any = parent
257256
parent = ancestors_pop() if ancestors else None
258257
if is_edited:
259-
if in_array:
260-
node = node[:]
261-
else:
262-
node = copy(node)
258+
node = node[:] if in_array else copy(node)
263259
edit_offset = 0
264260
for edit_key, edit_value in edits:
265261
if in_array:
266262
edit_key -= edit_offset
267263
if in_array and (edit_value is REMOVE or edit_value is Ellipsis):
268264
node.pop(edit_key)
269265
edit_offset += 1
266+
elif isinstance(node, list):
267+
node[edit_key] = edit_value
270268
else:
271-
if isinstance(node, list):
272-
node[edit_key] = edit_value
273-
else:
274-
setattr(node, edit_key, edit_value)
269+
setattr(node, edit_key, edit_value)
275270

276271
idx = stack.idx
277272
keys = stack.keys

src/graphql/pyutils/suggestion_list.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def measure(self, option: str, threshold: int) -> Optional[int]:
7373
return None
7474

7575
rows = self._rows
76-
for j in range(0, b_len + 1):
76+
for j in range(b_len + 1):
7777
rows[0][j] = j
7878

7979
for i in range(1, a_len + 1):

src/graphql/subscription/map_async_iterator.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ async def __anext__(self) -> Any:
3333
if not isasyncgen(self.iterator):
3434
raise StopAsyncIteration
3535
value = await self.iterator.__anext__()
36-
result = self.callback(value)
37-
3836
else:
3937
aclose = ensure_future(self._close_event.wait())
4038
anext = ensure_future(self.iterator.__anext__())
@@ -61,7 +59,8 @@ async def __anext__(self) -> Any:
6159
raise error
6260

6361
value = anext.result()
64-
result = self.callback(value)
62+
63+
result = self.callback(value)
6564

6665
return await result if isawaitable(result) else result
6766

@@ -72,23 +71,24 @@ async def athrow(
7271
traceback: Optional[TracebackType] = None,
7372
) -> None:
7473
"""Throw an exception into the asynchronous iterator."""
75-
if not self.is_closed:
76-
athrow = getattr(self.iterator, "athrow", None)
77-
if athrow:
78-
await athrow(type_, value, traceback)
79-
else:
80-
await self.aclose()
81-
if value is None:
82-
if traceback is None:
83-
raise type_
84-
value = (
85-
type_
86-
if isinstance(value, BaseException)
87-
else cast(Type[BaseException], type_)()
88-
)
89-
if traceback is not None:
90-
value = value.with_traceback(traceback)
91-
raise value
74+
if self.is_closed:
75+
return
76+
athrow = getattr(self.iterator, "athrow", None)
77+
if athrow:
78+
await athrow(type_, value, traceback)
79+
else:
80+
await self.aclose()
81+
if value is None:
82+
if traceback is None:
83+
raise type_
84+
value = (
85+
type_
86+
if isinstance(value, BaseException)
87+
else cast(Type[BaseException], type_)()
88+
)
89+
if traceback is not None:
90+
value = value.with_traceback(traceback)
91+
raise value
9292

9393
async def aclose(self) -> None:
9494
"""Close the iterator."""

src/graphql/utilities/print_schema.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,7 @@ def is_schema_of_common_names(schema: GraphQLSchema) -> bool:
108108
return False
109109

110110
subscription_type = schema.subscription_type
111-
if subscription_type and subscription_type.name != "Subscription":
112-
return False
113-
114-
return True
111+
return not subscription_type or subscription_type.name == "Subscription"
115112

116113

117114
def print_type(type_: GraphQLNamedType) -> str:

src/graphql/utilities/strip_ignored_characters.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,10 @@ def strip_ignored_characters(source: Union[str, Source]) -> str:
7979
# Also prevent case of non-punctuator token following by spread resulting
8080
# in invalid token (e.g.`1...` is invalid Float token).
8181
is_non_punctuator = not is_punctuator_token_kind(current_token.kind)
82-
if was_last_added_token_non_punctuator:
83-
if is_non_punctuator or current_token.kind == TokenKind.SPREAD:
84-
stripped_body += " "
82+
if was_last_added_token_non_punctuator and (
83+
is_non_punctuator or current_token.kind == TokenKind.SPREAD
84+
):
85+
stripped_body += " "
8586

8687
token_body = body[current_token.start : current_token.end]
8788
if token_kind == TokenKind.BLOCK_STRING:

src/graphql/validation/rules/lone_anonymous_operation.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ def __init__(self, context: ASTValidationContext):
2020

2121
def enter_document(self, node: DocumentNode, *_args: Any) -> None:
2222
self.operation_count = sum(
23-
1
23+
isinstance(definition, OperationDefinitionNode)
2424
for definition in node.definitions
25-
if isinstance(definition, OperationDefinitionNode)
2625
)
2726

2827
def enter_operation_definition(

src/graphql/validation/rules/single_field_subscriptions.py

+50-49
Original file line numberDiff line numberDiff line change
@@ -23,62 +23,63 @@ class SingleFieldSubscriptionsRule(ValidationRule):
2323
def enter_operation_definition(
2424
self, node: OperationDefinitionNode, *_args: Any
2525
) -> None:
26-
if node.operation == OperationType.SUBSCRIPTION:
27-
schema = self.context.schema
28-
subscription_type = schema.subscription_type
29-
if subscription_type:
30-
operation_name = node.name.value if node.name else None
31-
variable_values: Dict[str, Any] = {}
32-
document = self.context.document
33-
fragments: Dict[str, FragmentDefinitionNode] = {
34-
definition.name.value: definition
35-
for definition in document.definitions
36-
if isinstance(definition, FragmentDefinitionNode)
37-
}
38-
fields = collect_fields(
39-
schema,
40-
fragments,
41-
variable_values,
42-
subscription_type,
43-
node.selection_set,
44-
{},
45-
set(),
46-
)
47-
if len(fields) > 1:
48-
field_selection_lists = list(fields.values())
49-
extra_field_selection_lists = field_selection_lists[1:]
50-
extra_field_selection = [
51-
field
52-
for fields in extra_field_selection_lists
53-
for field in (
54-
fields
55-
if isinstance(fields, list)
56-
else [cast(FieldNode, fields)]
26+
if node.operation != OperationType.SUBSCRIPTION:
27+
return
28+
schema = self.context.schema
29+
subscription_type = schema.subscription_type
30+
if subscription_type:
31+
operation_name = node.name.value if node.name else None
32+
variable_values: Dict[str, Any] = {}
33+
document = self.context.document
34+
fragments: Dict[str, FragmentDefinitionNode] = {
35+
definition.name.value: definition
36+
for definition in document.definitions
37+
if isinstance(definition, FragmentDefinitionNode)
38+
}
39+
fields = collect_fields(
40+
schema,
41+
fragments,
42+
variable_values,
43+
subscription_type,
44+
node.selection_set,
45+
{},
46+
set(),
47+
)
48+
if len(fields) > 1:
49+
field_selection_lists = list(fields.values())
50+
extra_field_selection_lists = field_selection_lists[1:]
51+
extra_field_selection = [
52+
field
53+
for fields in extra_field_selection_lists
54+
for field in (
55+
fields
56+
if isinstance(fields, list)
57+
else [cast(FieldNode, fields)]
58+
)
59+
]
60+
self.report_error(
61+
GraphQLError(
62+
(
63+
"Anonymous Subscription"
64+
if operation_name is None
65+
else f"Subscription '{operation_name}'"
5766
)
58-
]
67+
+ " must select only one top level field.",
68+
extra_field_selection,
69+
)
70+
)
71+
for field_nodes in fields.values():
72+
field = field_nodes[0]
73+
field_name = field.name.value
74+
if field_name.startswith("__"):
5975
self.report_error(
6076
GraphQLError(
6177
(
6278
"Anonymous Subscription"
6379
if operation_name is None
6480
else f"Subscription '{operation_name}'"
6581
)
66-
+ " must select only one top level field.",
67-
extra_field_selection,
82+
+ " must not select an introspection top level field.",
83+
field_nodes,
6884
)
6985
)
70-
for field_nodes in fields.values():
71-
field = field_nodes[0]
72-
field_name = field.name.value
73-
if field_name.startswith("__"):
74-
self.report_error(
75-
GraphQLError(
76-
(
77-
"Anonymous Subscription"
78-
if operation_name is None
79-
else f"Subscription '{operation_name}'"
80-
)
81-
+ " must not select an introspection top level field.",
82-
field_nodes,
83-
)
84-
)

src/graphql/validation/validation_context.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ def report_error(self, error: GraphQLError) -> None:
9797
def get_fragment(self, name: str) -> Optional[FragmentDefinitionNode]:
9898
fragments = self._fragments
9999
if fragments is None:
100-
fragments = {}
101-
for statement in self.document.definitions:
102-
if isinstance(statement, FragmentDefinitionNode):
103-
fragments[statement.name.value] = statement
100+
fragments = {
101+
statement.name.value: statement
102+
for statement in self.document.definitions
103+
if isinstance(statement, FragmentDefinitionNode)
104+
}
105+
104106
self._fragments = fragments
105107
return fragments.get(name)
106108

0 commit comments

Comments
 (0)