Skip to content

Commit c61f145

Browse files
Minor fix in @provides
1 parent 308be36 commit c61f145

File tree

3 files changed

+24
-6
lines changed

3 files changed

+24
-6
lines changed

examples/inaccessible.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
import graphene
22

3-
from graphene_federation import inaccessible, shareable, extend, requires, external
3+
from graphene_federation import inaccessible, shareable, extend, requires, external, provides
44

55
from graphene_federation import build_schema
66

77

8-
@extend(fields='x')
98
class Position(graphene.ObjectType):
109
x = graphene.Int(required=True)
1110
y = external(graphene.Int(required=True))
1211
z = inaccessible(graphene.Int(required=True))
13-
a = requires(graphene.Int(required=True), fields="x")
12+
a = provides(graphene.Int(required=True), fields="x")
1413

1514

1615
class Query(graphene.ObjectType):

graphene_federation/provides.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,20 @@ def provides(field, fields: Union[str, list[str]] = None):
3737
fields = fields.split()
3838
field._provides = fields
3939
return field
40+
41+
42+
def get_provides_fields(schema: Schema) -> []:
43+
"""
44+
Find all the extended types from the schema.
45+
They can be easily distinguished from the other type as
46+
the `@provides` decorator adds a `_provides` attribute to them.
47+
"""
48+
provides_fields = {}
49+
for type_name, type_ in schema.graphql_schema.type_map.items():
50+
if not hasattr(type_, "graphene_type"):
51+
continue
52+
for field in list(type_.graphene_type.__dict__):
53+
if getattr(getattr(type_.graphene_type, field), "_provides", False):
54+
provides_fields[type_name] = type_.graphene_type
55+
continue
56+
return provides_fields

graphene_federation/service.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from graphene import ObjectType, String, Field, Schema
1010

1111
from .extend import get_extended_types
12-
from .provides import get_provides_parent_types
12+
from .provides import get_provides_parent_types, get_provides_fields
1313

1414
from .entity import get_entities
1515
from .tag import get_tagged_fields
@@ -103,6 +103,7 @@ def get_sdl(schema: Schema) -> str:
103103
shareable_types = get_shareable_types(schema)
104104
inaccessible_types = get_inaccessible_types(schema)
105105
provides_parent_types = get_provides_parent_types(schema)
106+
provides_fields = get_provides_fields(schema)
106107
entities = get_entities(schema)
107108
shareable_fields = get_shareable_fields(schema)
108109
tagged_fields = get_tagged_fields(schema)
@@ -120,7 +121,7 @@ def get_sdl(schema: Schema) -> str:
120121
_schema_import.append('"@requires"')
121122
if entities:
122123
_schema_import.append('"@key"')
123-
if provides_parent_types:
124+
if provides_parent_types or provides_fields:
124125
_schema_import.append('"@provides"')
125126
if inaccessible_types or inaccessible_fields:
126127
_schema_import.append('"@inaccessible"')
@@ -143,7 +144,8 @@ def get_sdl(schema: Schema) -> str:
143144
# Add fields directives (@external, @provides, @requires, @shareable, @inaccessible)
144145
for entity in set(provides_parent_types.values()) | set(extended_types.values()) | set(
145146
shareable_types.values()) | set(inaccessible_types.values()) | set(
146-
entities.values()) | set(shareable_fields.values()) | set(tagged_fields.values()):
147+
entities.values()) | set(inaccessible_fields.values()) | set(shareable_fields.values()) | set(
148+
tagged_fields.values()) | set(required_fields.values()) | set(provides_fields.values()):
147149
string_schema = add_entity_fields_decorators(entity, schema, string_schema)
148150

149151
# Prepend `extend` keyword to the type definition of extended types

0 commit comments

Comments
 (0)