Skip to content

feat: SQLAlchemy 2.0 support #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 14, 2023
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ jobs:
strategy:
max-parallel: 10
matrix:
sql-alchemy: ["1.2", "1.3", "1.4"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
sql-alchemy: [ "1.2", "1.3", "1.4","2.0" ]
python-version: [ "3.7", "3.8", "3.9", "3.10" ]

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ __pycache__/
.Python
env/
.venv/
venv/
build/
develop-eggs/
dist/
Expand Down
20 changes: 18 additions & 2 deletions graphene_sqlalchemy/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@
import sqlalchemy
from sqlalchemy.orm import Session, strategies
from sqlalchemy.orm.query import QueryContext
from sqlalchemy.util import immutabledict

from .utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, is_graphene_version_less_than
from .utils import (
SQL_VERSION_HIGHER_EQUAL_THAN_1_4,
SQL_VERSION_HIGHER_EQUAL_THAN_2,
is_graphene_version_less_than,
)


def get_data_loader_impl() -> Any: # pragma: no cover
Expand Down Expand Up @@ -76,7 +81,18 @@ async def batch_load_fn(self, parents):
query_context = parent_mapper_query._compile_context()
else:
query_context = QueryContext(session.query(parent_mapper.entity))
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
if SQL_VERSION_HIGHER_EQUAL_THAN_2: # pragma: no cover
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
states,
None,
child_mapper,
None,
None, # recursion depth can be none
immutabledict(), # default value for selectinload->lazyload
)
elif SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
self.selectin_loader._load_for_path(
query_context,
parent_mapper._path_registry,
Expand Down
23 changes: 18 additions & 5 deletions graphene_sqlalchemy/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@
String,
Table,
func,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, column_property, composite, mapper, relationship
from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter
from sqlalchemy.sql.type_api import TypeEngine

from graphene_sqlalchemy.tests.utils import wrap_select_func
from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4, SQL_VERSION_HIGHER_EQUAL_THAN_2

# fmt: off
import sqlalchemy
if SQL_VERSION_HIGHER_EQUAL_THAN_2:
from sqlalchemy.sql.sqltypes import HasExpressionLookup # noqa # isort:skip
else:
from sqlalchemy.sql.sqltypes import _LookupExpressionAdapter as HasExpressionLookup # noqa # isort:skip
# fmt: on

PetKind = Enum("cat", "dog", name="pet_kind")


Expand Down Expand Up @@ -119,7 +128,7 @@ def hybrid_prop_list(self) -> List[int]:
return [1, 2, 3]

column_prop = column_property(
select([func.cast(func.count(id), Integer)]), doc="Column property"
wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property"
)

composite_prop = composite(
Expand Down Expand Up @@ -163,7 +172,11 @@ def __subclasses__(cls):

editor_table = Table("editors", Base.metadata, autoload=True)

mapper(ReflectedEditor, editor_table)
# TODO Remove when switching min sqlalchemy version to SQLAlchemy 1.4
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
Base.registry.map_imperatively(ReflectedEditor, editor_table)
else:
mapper(ReflectedEditor, editor_table)


############################################
Expand Down Expand Up @@ -337,7 +350,7 @@ class Employee(Person):
############################################


class CustomIntegerColumn(_LookupExpressionAdapter, TypeEngine):
class CustomIntegerColumn(HasExpressionLookup, TypeEngine):
"""
Custom Column Type that our converters don't recognize
Adapted from sqlalchemy.Integer
Expand Down
5 changes: 3 additions & 2 deletions graphene_sqlalchemy/tests/models_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
String,
Table,
func,
select,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import column_property, relationship

from graphene_sqlalchemy.tests.utils import wrap_select_func

PetKind = Enum("cat", "dog", name="pet_kind")


Expand Down Expand Up @@ -61,7 +62,7 @@ class Reporter(Base):
favorite_article = relationship("Article", uselist=False)

column_prop = column_property(
select([func.cast(func.count(id), Integer)]), doc="Column property"
wrap_select_func(func.cast(func.count(id), Integer)), doc="Column property"
)


Expand Down
47 changes: 32 additions & 15 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,28 @@
import sys
from typing import Dict, Tuple, Union

import graphene
import pytest
import sqlalchemy
import sqlalchemy_utils as sqa_utils
from sqlalchemy import Column, func, select, types
from graphene.relay import Node
from graphene.types.structures import Structure
from sqlalchemy import Column, func, types
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.inspection import inspect
from sqlalchemy.orm import column_property, composite

import graphene
from graphene.relay import Node
from graphene.types.structures import Structure

from .models import (
Article,
CompositeFullName,
Pet,
Reporter,
ShoppingCart,
ShoppingCartItem,
)
from .utils import wrap_select_func
from ..converter import (
convert_sqlalchemy_column,
convert_sqlalchemy_composite,
Expand All @@ -27,6 +35,7 @@
from ..fields import UnsortedSQLAlchemyConnectionField, default_connection_field_factory
from ..registry import Registry, get_global_registry
from ..types import ORMField, SQLAlchemyObjectType
from ..utils import is_sqlalchemy_version_less_than
from .models import (
Article,
CompositeFullName,
Expand Down Expand Up @@ -204,9 +213,9 @@ def prop_method() -> int | str:
return "not allowed in gql schema"

with pytest.raises(
ValueError,
match=r"Cannot convert hybrid_property Union to "
r"graphene.Union: the Union contains scalars. \.*",
ValueError,
match=r"Cannot convert hybrid_property Union to "
r"graphene.Union: the Union contains scalars. \.*",
):
get_hybrid_property_type(prop_method)

Expand Down Expand Up @@ -460,7 +469,7 @@ class TestEnum(enum.IntEnum):

def test_should_columproperty_convert():
field = get_field_from_column(
column_property(select([func.sum(func.cast(id, types.Integer))]).where(id == 1))
column_property(wrap_select_func(func.sum(func.cast(id, types.Integer))).where(id == 1))
)

assert field.type == graphene.Int
Expand All @@ -477,10 +486,18 @@ def test_should_jsontype_convert_jsonstring():
assert get_field(types.JSON).type == graphene.JSONString


@pytest.mark.skipif(
(not is_sqlalchemy_version_less_than("2.0.0b1")),
reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy",
)
def test_should_variant_int_convert_int():
assert get_field(types.Variant(types.Integer(), {})).type == graphene.Int


@pytest.mark.skipif(
(not is_sqlalchemy_version_less_than("2.0.0b1")),
reason="SQLAlchemy >=2.0 does not support this: Variant is no longer used in SQLAlchemy",
)
def test_should_variant_string_convert_string():
assert get_field(types.Variant(types.String(), {})).type == graphene.String

Expand Down Expand Up @@ -811,8 +828,8 @@ class Meta:
)

for (
hybrid_prop_name,
hybrid_prop_expected_return_type,
hybrid_prop_name,
hybrid_prop_expected_return_type,
) in shopping_cart_item_expected_types.items():
hybrid_prop_field = ShoppingCartItemType._meta.fields[hybrid_prop_name]

Expand All @@ -823,7 +840,7 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
hybrid_prop_field.description is None
hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property

###################################################
Expand Down Expand Up @@ -870,8 +887,8 @@ class Meta:
)

for (
hybrid_prop_name,
hybrid_prop_expected_return_type,
hybrid_prop_name,
hybrid_prop_expected_return_type,
) in shopping_cart_expected_types.items():
hybrid_prop_field = ShoppingCartType._meta.fields[hybrid_prop_name]

Expand All @@ -882,5 +899,5 @@ class Meta:
str(hybrid_prop_expected_return_type),
)
assert (
hybrid_prop_field.description is None
hybrid_prop_field.description is None
) # "doc" is ignored by hybrid property
13 changes: 12 additions & 1 deletion graphene_sqlalchemy/tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import inspect
import re

from sqlalchemy import select

from graphene_sqlalchemy.utils import SQL_VERSION_HIGHER_EQUAL_THAN_1_4


def to_std_dicts(value):
"""Convert nested ordered dicts to normal dicts for better comparison."""
Expand All @@ -18,8 +22,15 @@ def remove_cache_miss_stat(message):
return re.sub(r"\[generated in \d+.?\d*s\]\s", "", message)


async def eventually_await_session(session, func, *args):
def wrap_select_func(query):
# TODO remove this when we drop support for sqa < 2.0
if SQL_VERSION_HIGHER_EQUAL_THAN_1_4:
return select(query)
else:
return select([query])


async def eventually_await_session(session, func, *args):
if inspect.iscoroutinefunction(getattr(session, func)):
await getattr(session, func)(*args)
else:
Expand Down
8 changes: 7 additions & 1 deletion graphene_sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,18 @@ def is_graphene_version_less_than(version_string): # pragma: no cover

SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = False

if not is_sqlalchemy_version_less_than("1.4"):
if not is_sqlalchemy_version_less_than("1.4"): # pragma: no cover
from sqlalchemy.ext.asyncio import AsyncSession

SQL_VERSION_HIGHER_EQUAL_THAN_1_4 = True


SQL_VERSION_HIGHER_EQUAL_THAN_2 = False

if not is_sqlalchemy_version_less_than("2.0.0b1"): # pragma: no cover
SQL_VERSION_HIGHER_EQUAL_THAN_2 = True


def get_session(context):
return context.get("session")

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# To keep things simple, we only support newer versions of Graphene
"graphene>=3.0.0b7",
"promise>=2.3",
"SQLAlchemy>=1.1,<2",
"SQLAlchemy>=1.1",
"aiodataloader>=0.2.0,<1.0",
]

Expand Down
8 changes: 6 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = pre-commit,py{37,38,39,310}-sql{12,13,14}
envlist = pre-commit,py{37,38,39,310}-sql{12,13,14,20}
skipsdist = true
minversion = 3.7.0

Expand All @@ -15,6 +15,7 @@ SQLALCHEMY =
1.2: sql12
1.3: sql13
1.4: sql14
2.0: sql20

[testenv]
passenv = GITHUB_*
Expand All @@ -23,8 +24,11 @@ deps =
sql12: sqlalchemy>=1.2,<1.3
sql13: sqlalchemy>=1.3,<1.4
sql14: sqlalchemy>=1.4,<1.5
sql20: sqlalchemy>=2.0.0b3
setenv =
SQLALCHEMY_WARN_20 = 1
commands =
pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs}
python -W always -m pytest graphene_sqlalchemy --cov=graphene_sqlalchemy --cov-report=term --cov-report=xml {posargs}

[testenv:pre-commit]
basepython=python3.10
Expand Down