Skip to content

Commit 3abb3ed

Browse files
authored
Merge pull request #414 from KammererTob/master
Fix for #409
2 parents 6ec3320 + bb1e80d commit 3abb3ed

File tree

2 files changed

+67
-12
lines changed

2 files changed

+67
-12
lines changed

src/main/kotlin/graphql/kickstart/tools/SchemaParser.kt

+23-12
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class SchemaParser internal constructor(
7676
val inputObjects: MutableList<GraphQLInputObjectType> = mutableListOf()
7777
inputObjectDefinitions.forEach {
7878
if (inputObjects.none { io -> io.name == it.name }) {
79-
inputObjects.add(createInputObject(it, inputObjects))
79+
inputObjects.add(createInputObject(it, inputObjects, mutableSetOf()))
8080
}
8181
}
8282
val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
@@ -155,7 +155,8 @@ class SchemaParser internal constructor(
155155
return schemaGeneratorDirectiveHelper.onObject(objectType, directiveHelperParameters)
156156
}
157157

158-
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>): GraphQLInputObjectType {
158+
private fun createInputObject(definition: InputObjectTypeDefinition, inputObjects: List<GraphQLInputObjectType>,
159+
referencingInputObjects: MutableSet<String>): GraphQLInputObjectType {
159160
val extensionDefinitions = inputExtensionDefinitions.filter { it.name == definition.name }
160161

161162
val builder = GraphQLInputObjectType.newInputObject()
@@ -166,14 +167,16 @@ class SchemaParser internal constructor(
166167

167168
builder.withDirectives(*buildDirectives(definition.directives, Introspection.DirectiveLocation.INPUT_OBJECT))
168169

170+
referencingInputObjects.add(definition.name)
171+
169172
(extensionDefinitions + definition).forEach {
170173
it.inputValueDefinitions.forEach { inputDefinition ->
171174
val fieldBuilder = GraphQLInputObjectField.newInputObjectField()
172175
.name(inputDefinition.name)
173176
.definition(inputDefinition)
174177
.description(if (inputDefinition.description != null) inputDefinition.description.content else getDocumentation(inputDefinition))
175178
.defaultValue(buildDefaultValue(inputDefinition.defaultValue))
176-
.type(determineInputType(inputDefinition.type, inputObjects))
179+
.type(determineInputType(inputDefinition.type, inputObjects, referencingInputObjects))
177180
.withDirectives(*buildDirectives(inputDefinition.directives, Introspection.DirectiveLocation.INPUT_FIELD_DEFINITION))
178181
builder.field(fieldBuilder.build())
179182
}
@@ -280,7 +283,7 @@ class SchemaParser internal constructor(
280283
.name(argumentDefinition.name)
281284
.definition(argumentDefinition)
282285
.description(if (argumentDefinition.description != null) argumentDefinition.description.content else getDocumentation(argumentDefinition))
283-
.type(determineInputType(argumentDefinition.type, inputObjects))
286+
.type(determineInputType(argumentDefinition.type, inputObjects, setOf()))
284287
.apply { buildDefaultValue(argumentDefinition.defaultValue)?.let { defaultValue(it) } }
285288
.withDirectives(*buildDirectives(argumentDefinition.directives, Introspection.DirectiveLocation.ARGUMENT_DEFINITION))
286289

@@ -380,7 +383,7 @@ class SchemaParser internal constructor(
380383
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
381384
is InputObjectTypeDefinition -> {
382385
log.info("Create input object")
383-
createInputObject(typeDefinition, inputObjects)
386+
createInputObject(typeDefinition, inputObjects, mutableSetOf())
384387
}
385388
is TypeName -> {
386389
val scalarType = customScalars[typeDefinition.name]
@@ -398,16 +401,19 @@ class SchemaParser internal constructor(
398401
else -> throw SchemaError("Unknown type: $typeDefinition")
399402
}
400403

401-
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>) =
402-
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
404+
private fun determineInputType(typeDefinition: Type<*>, inputObjects: List<GraphQLInputObjectType>, referencingInputObjects: Set<String>) =
405+
determineInputType(GraphQLInputType::class, typeDefinition, permittedTypesForInputObject, inputObjects, referencingInputObjects) as GraphQLInputType
403406

404-
private fun <T : Any> determineInputType(expectedType: KClass<T>, typeDefinition: Type<*>, allowedTypeReferences: Set<String>, inputObjects: List<GraphQLInputObjectType>): GraphQLType =
407+
private fun <T : Any> determineInputType(expectedType: KClass<T>,
408+
typeDefinition: Type<*>, allowedTypeReferences: Set<String>,
409+
inputObjects: List<GraphQLInputObjectType>,
410+
referencingInputObjects: Set<String>): GraphQLType =
405411
when (typeDefinition) {
406412
is ListType -> GraphQLList(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
407413
is NonNullType -> GraphQLNonNull(determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
408414
is InputObjectTypeDefinition -> {
409415
log.info("Create input object")
410-
createInputObject(typeDefinition, inputObjects)
416+
createInputObject(typeDefinition, inputObjects, referencingInputObjects as MutableSet<String>)
411417
}
412418
is TypeName -> {
413419
val scalarType = customScalars[typeDefinition.name]
@@ -425,9 +431,14 @@ class SchemaParser internal constructor(
425431
} else {
426432
val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
427433
if (filteredDefinitions.isNotEmpty()) {
428-
val inputObject = createInputObject(filteredDefinitions[0], inputObjects)
429-
(inputObjects as MutableList).add(inputObject)
430-
inputObject
434+
val referencingInputObject = referencingInputObjects.find { it == typeDefinition.name }
435+
if (referencingInputObject != null) {
436+
GraphQLTypeReference(referencingInputObject)
437+
} else {
438+
val inputObject = createInputObject(filteredDefinitions[0], inputObjects, referencingInputObjects as MutableSet<String>)
439+
(inputObjects as MutableList).add(inputObject)
440+
inputObject
441+
}
431442
} else {
432443
// todo: handle enum type
433444
GraphQLTypeReference(typeDefinition.name)

src/test/groovy/graphql/kickstart/tools/SchemaParserSpec.groovy

+44
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,50 @@ class SchemaParserSpec extends Specification {
368368
noExceptionThrown()
369369
}
370370

371+
def "allow circular relations in input objects"() {
372+
when:
373+
SchemaParser.newParser().schemaString('''\
374+
input A {
375+
id: ID!
376+
b: B
377+
}
378+
input B {
379+
id: ID!
380+
a: A
381+
}
382+
input C {
383+
id: ID!
384+
c: C
385+
}
386+
type Query {}
387+
type Mutation {
388+
test(input: A!): Boolean
389+
testC(input: C!): Boolean
390+
}
391+
'''.stripIndent())
392+
.resolvers(new GraphQLMutationResolver() {
393+
static class A {
394+
String id;
395+
B b;
396+
}
397+
static class B {
398+
String id;
399+
A a;
400+
}
401+
static class C {
402+
String id;
403+
C c;
404+
}
405+
boolean test(A a) { return true }
406+
boolean testC(C c) { return true }
407+
}, new GraphQLQueryResolver() {})
408+
.build()
409+
.makeExecutableSchema()
410+
411+
then:
412+
noExceptionThrown()
413+
}
414+
371415
enum EnumType {
372416
TEST
373417
}

0 commit comments

Comments
 (0)