Skip to content

Commit 28249a8

Browse files
committed
ORMfield.batching and SQLAlchemyObjectType.batching
1 parent 631513f commit 28249a8

File tree

6 files changed

+288
-120
lines changed

6 files changed

+288
-120
lines changed

graphene_sqlalchemy/converter.py

+44-20
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,28 @@
33
from singledispatch import singledispatch
44
from sqlalchemy import types
55
from sqlalchemy.dialects import postgresql
6-
from sqlalchemy.orm import interfaces
6+
from sqlalchemy.orm import interfaces, strategies
77

88
from graphene import (ID, Boolean, Dynamic, Enum, Field, Float, Int, List,
99
String)
1010
from graphene.types.json import JSONString
1111

12+
from .batching import get_batch_resolver
1213
from .enums import enum_for_sa_enum
14+
from .fields import (BatchSQLAlchemyConnectionField,
15+
default_connection_field_factory)
1316
from .registry import get_global_registry
17+
from .resolvers import get_attr_resolver, get_custom_resolver
1418

1519
try:
1620
from sqlalchemy_utils import ChoiceType, JSONType, ScalarListType, TSVectorType
1721
except ImportError:
1822
ChoiceType = JSONType = ScalarListType = TSVectorType = object
1923

2024

25+
is_selectin_available = getattr(strategies, 'SelectInLoader', None)
26+
27+
2128
def get_column_doc(column):
2229
return getattr(column, "doc", None)
2330

@@ -26,29 +33,46 @@ def is_column_nullable(column):
2633
return bool(getattr(column, "nullable", True))
2734

2835

29-
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
30-
direction = relationship_prop.direction
31-
model = relationship_prop.mapper.entity
32-
36+
def convert_sqlalchemy_relationship(relationship_prop, obj_type, connection_field_factory, batching,
37+
attr_name, orm_field_name, **field_kwargs):
38+
"""
39+
:param sqlalchemy.RelationshipProperty relationship_prop:
40+
:param Registry registry:
41+
:type function|None connection_field_factory:
42+
:type bool batching:
43+
:param SQLAlchemyObjectType obj_type:
44+
:param str orm_field_name:
45+
:rtype: Dynamic
46+
"""
3347
def dynamic_type():
34-
_type = registry.get_type_for_model(model)
48+
direction = relationship_prop.direction
49+
model = relationship_prop.mapper.entity
50+
type_ = obj_type._meta.registry.get_type_for_model(model)
3551

36-
if not _type:
52+
batching_ = batching if is_selectin_available else False
53+
connection_field_factory_ = connection_field_factory
54+
55+
if not type_:
3756
return None
57+
3858
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
39-
return Field(
40-
_type,
41-
resolver=resolver,
42-
**field_kwargs
43-
)
44-
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
45-
if _type._meta.connection:
46-
# TODO Add a way to override connection_field_factory
47-
return connection_field_factory(relationship_prop, registry, **field_kwargs)
48-
return Field(
49-
List(_type),
50-
**field_kwargs
51-
)
59+
resolver = get_custom_resolver(obj_type, orm_field_name)
60+
if resolver is None:
61+
resolver = get_batch_resolver(relationship_prop) if batching_ else \
62+
get_attr_resolver(obj_type, relationship_prop.key)
63+
64+
return Field(type_, resolver=resolver, **field_kwargs)
65+
66+
if direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
67+
if not type_._meta.connection:
68+
return Field(List(type_), **field_kwargs)
69+
70+
if connection_field_factory_ is None:
71+
connection_field_factory_ = BatchSQLAlchemyConnectionField.from_relationship if batching_ else \
72+
default_connection_field_factory
73+
74+
# TODO Allow override of connection_field_factory and resolver via ORMField
75+
return connection_field_factory_(relationship_prop, obj_type._meta.registry, **field_kwargs)
5276

5377
return Dynamic(dynamic_type)
5478

graphene_sqlalchemy/resolver.py

Whitespace-only changes.

graphene_sqlalchemy/resolvers.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from graphene.utils.get_unbound_function import get_unbound_function
2+
3+
4+
def get_custom_resolver(obj_type, orm_field_name):
5+
"""
6+
Since `graphene` will call `resolve_<field_name>` on a field only if it
7+
does not have a `resolver`, we need to re-implement that logic here so
8+
users are able to override the default resolvers that we provide.
9+
"""
10+
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
11+
if resolver:
12+
return get_unbound_function(resolver)
13+
14+
return None
15+
16+
17+
def get_attr_resolver(obj_type, model_attr):
18+
"""
19+
In order to support field renaming via `ORMField.model_attr`,
20+
we need to define resolver functions for each field.
21+
22+
:param SQLAlchemyObjectType obj_type:
23+
:param str model_attr: the name of the SQLAlchemy attribute
24+
:rtype: Callable
25+
"""
26+
return lambda root, _info: getattr(root, model_attr, None)

graphene_sqlalchemy/tests/test_batching.py

+184-5
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import graphene
77
from graphene import relay
88

9-
from ..fields import BatchSQLAlchemyConnectionField
10-
from ..types import SQLAlchemyObjectType
9+
from ..fields import (BatchSQLAlchemyConnectionField,
10+
default_connection_field_factory)
11+
from ..types import ORMField, SQLAlchemyObjectType
1112
from .models import Article, HairKind, Pet, Reporter
1213
from .utils import is_sqlalchemy_version_less_than, to_std_dicts
1314

@@ -43,19 +44,19 @@ class ReporterType(SQLAlchemyObjectType):
4344
class Meta:
4445
model = Reporter
4546
interfaces = (relay.Node,)
46-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
47+
batching = True
4748

4849
class ArticleType(SQLAlchemyObjectType):
4950
class Meta:
5051
model = Article
5152
interfaces = (relay.Node,)
52-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
53+
batching = True
5354

5455
class PetType(SQLAlchemyObjectType):
5556
class Meta:
5657
model = Pet
5758
interfaces = (relay.Node,)
58-
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
59+
batching = True
5960

6061
class Query(graphene.ObjectType):
6162
articles = graphene.Field(graphene.List(ArticleType))
@@ -513,3 +514,181 @@ def test_many_to_many(session_factory):
513514
},
514515
],
515516
}
517+
518+
519+
def test_disable_batching_via_ormfield(session_factory):
520+
session = session_factory()
521+
reporter_1 = Reporter(first_name='Reporter_1')
522+
session.add(reporter_1)
523+
reporter_2 = Reporter(first_name='Reporter_2')
524+
session.add(reporter_2)
525+
session.commit()
526+
session.close()
527+
528+
class ReporterType(SQLAlchemyObjectType):
529+
class Meta:
530+
model = Reporter
531+
interfaces = (relay.Node,)
532+
batching = True
533+
534+
favorite_article = ORMField(batching=False)
535+
articles = ORMField(batching=False)
536+
537+
class ArticleType(SQLAlchemyObjectType):
538+
class Meta:
539+
model = Article
540+
interfaces = (relay.Node,)
541+
542+
class Query(graphene.ObjectType):
543+
reporters = graphene.Field(graphene.List(ReporterType))
544+
545+
def resolve_reporters(self, info):
546+
return info.context.get('session').query(Reporter).all()
547+
548+
schema = graphene.Schema(query=Query)
549+
550+
# Test one-to-one and many-to-one relationships
551+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
552+
# Starts new session to fully reset the engine / connection logging level
553+
session = session_factory()
554+
schema.execute("""
555+
query {
556+
reporters {
557+
favoriteArticle {
558+
headline
559+
}
560+
}
561+
}
562+
""", context_value={"session": session})
563+
messages = sqlalchemy_logging_handler.messages
564+
565+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
566+
assert len(select_statements) == 2
567+
568+
# Test one-to-many and many-to-many relationships
569+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
570+
# Starts new session to fully reset the engine / connection logging level
571+
session = session_factory()
572+
schema.execute("""
573+
query {
574+
reporters {
575+
articles {
576+
edges {
577+
node {
578+
headline
579+
}
580+
}
581+
}
582+
}
583+
}
584+
""", context_value={"session": session})
585+
messages = sqlalchemy_logging_handler.messages
586+
587+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
588+
assert len(select_statements) == 2
589+
590+
591+
def test_connection_factory_field_overrides_batching_is_false(session_factory):
592+
session = session_factory()
593+
reporter_1 = Reporter(first_name='Reporter_1')
594+
session.add(reporter_1)
595+
reporter_2 = Reporter(first_name='Reporter_2')
596+
session.add(reporter_2)
597+
session.commit()
598+
session.close()
599+
600+
class ReporterType(SQLAlchemyObjectType):
601+
class Meta:
602+
model = Reporter
603+
interfaces = (relay.Node,)
604+
batching = False
605+
connection_field_factory = BatchSQLAlchemyConnectionField.from_relationship
606+
607+
articles = ORMField(batching=False)
608+
609+
class ArticleType(SQLAlchemyObjectType):
610+
class Meta:
611+
model = Article
612+
interfaces = (relay.Node,)
613+
614+
class Query(graphene.ObjectType):
615+
reporters = graphene.Field(graphene.List(ReporterType))
616+
617+
def resolve_reporters(self, info):
618+
return info.context.get('session').query(Reporter).all()
619+
620+
schema = graphene.Schema(query=Query)
621+
622+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
623+
# Starts new session to fully reset the engine / connection logging level
624+
session = session_factory()
625+
schema.execute("""
626+
query {
627+
reporters {
628+
articles {
629+
edges {
630+
node {
631+
headline
632+
}
633+
}
634+
}
635+
}
636+
}
637+
""", context_value={"session": session})
638+
messages = sqlalchemy_logging_handler.messages
639+
640+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
641+
assert len(select_statements) == 1
642+
643+
644+
def test_connection_factory_field_overrides_batching_is_true(session_factory):
645+
session = session_factory()
646+
reporter_1 = Reporter(first_name='Reporter_1')
647+
session.add(reporter_1)
648+
reporter_2 = Reporter(first_name='Reporter_2')
649+
session.add(reporter_2)
650+
session.commit()
651+
session.close()
652+
653+
class ReporterType(SQLAlchemyObjectType):
654+
class Meta:
655+
model = Reporter
656+
interfaces = (relay.Node,)
657+
batching = True
658+
connection_field_factory = default_connection_field_factory
659+
660+
articles = ORMField(batching=True)
661+
662+
class ArticleType(SQLAlchemyObjectType):
663+
class Meta:
664+
model = Article
665+
interfaces = (relay.Node,)
666+
667+
class Query(graphene.ObjectType):
668+
reporters = graphene.Field(graphene.List(ReporterType))
669+
670+
def resolve_reporters(self, info):
671+
return info.context.get('session').query(Reporter).all()
672+
673+
schema = graphene.Schema(query=Query)
674+
675+
with mock_sqlalchemy_logging_handler() as sqlalchemy_logging_handler:
676+
# Starts new session to fully reset the engine / connection logging level
677+
session = session_factory()
678+
schema.execute("""
679+
query {
680+
reporters {
681+
articles {
682+
edges {
683+
node {
684+
headline
685+
}
686+
}
687+
}
688+
}
689+
}
690+
""", context_value={"session": session})
691+
messages = sqlalchemy_logging_handler.messages
692+
693+
select_statements = [message for message in messages if 'SELECT' in message and 'FROM articles' in message]
694+
assert len(select_statements) == 2

0 commit comments

Comments
 (0)