Skip to content

✨ Add support for SQLAlchemy polymorphic models #1226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from sqlalchemy import inspect
from sqlalchemy.orm import Mapper
from typing_extensions import Annotated, get_args, get_origin

# Reassign variable to make it reexported for mypy
Expand Down Expand Up @@ -64,6 +66,35 @@ def _is_union_type(t: Any) -> bool:
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)


def set_polymorphic_default_value(
self_instance: _TSQLModel,
values: Dict[str, Any],
) -> bool:
"""By default, when init a model, pydantic will set the polymorphic_on
value to field default value. But when inherit a model, the polymorphic_on
should be set to polymorphic_identity value by default."""
cls = type(self_instance)
mapper = inspect(cls)
ret = False
if isinstance(mapper, Mapper):
polymorphic_on = mapper.polymorphic_on
if polymorphic_on is not None:
polymorphic_property = mapper.get_property_by_column(polymorphic_on)
field_info = get_model_fields(cls).get(polymorphic_property.key)
if field_info:
v = values.get(polymorphic_property.key)
# if model is inherited or polymorphic_on is not explicitly set
# set the polymorphic_on by default
if mapper.inherits or v is None:
setattr(
self_instance,
polymorphic_property.key,
mapper.polymorphic_identity,
)
ret = True
return ret


@contextmanager
def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False)
Expand Down Expand Up @@ -290,6 +321,8 @@ def sqlmodel_table_construct(
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
# Override polymorphic_on default value
set_polymorphic_default_value(self_instance, values)
return self_instance

def sqlmodel_validate(
Expand Down
62 changes: 52 additions & 10 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ipaddress
import uuid
import warnings
import weakref
from datetime import date, datetime, time, timedelta
from decimal import Decimal
Expand Down Expand Up @@ -41,9 +42,10 @@
)
from sqlalchemy import Enum as sa_Enum
from sqlalchemy.orm import (
InstrumentedAttribute,
Mapped,
MappedColumn,
RelationshipProperty,
declared_attr,
registry,
relationship,
)
Expand Down Expand Up @@ -538,7 +540,33 @@ def __new__(
config_kwargs = {
key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs
}
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
is_polymorphic = False
if IS_PYDANTIC_V2:
base_fields = {}
base_annotations = {}
for base in bases[::-1]:
if issubclass(base, BaseModel):
base_fields.update(get_model_fields(base))
base_annotations.update(base.__annotations__)
if hasattr(base, "__tablename__"):
is_polymorphic = True
# use base_fields overwriting the ones from the class for inherit
# if base is a sqlalchemy model, it's attributes will be an InstrumentedAttribute
# thus pydantic will use the value of the attribute as the default value
base_annotations.update(dict_used["__annotations__"])
dict_used["__annotations__"] = base_annotations
base_fields.update(dict_used)
dict_used = base_fields
# if is_polymorphic, disable pydantic `shadows an attribute` warning
if is_polymorphic:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Field name .+ shadows an attribute in parent.+",
)
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
else:
new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs)
new_cls.__annotations__ = {
**relationship_annotations,
**pydantic_annotations,
Expand All @@ -558,9 +586,22 @@ def get_config(name: str) -> Any:

config_table = get_config("table")
if config_table is True:
# sqlalchemy mark a class as table by check if it has __tablename__ attribute
# or if __tablename__ is in __annotations__. Only set __tablename__ if it's
# a table model
if new_cls.__name__ != "SQLModel" and not hasattr(new_cls, "__tablename__"):
setattr(new_cls, "__tablename__", new_cls.__name__.lower()) # noqa: B010
# If it was passed by kwargs, ensure it's also set in config
set_config_value(model=new_cls, parameter="table", value=config_table)
for k, v in get_model_fields(new_cls).items():
original_v = getattr(new_cls, k, None)
if (
isinstance(original_v, InstrumentedAttribute)
and k not in class_dict
):
# The attribute was already set by SQLAlchemy, don't override it
# Needed for polymorphic models, see #36
continue
col = get_column_from_field(v)
setattr(new_cls, k, col)
# Set a config flag to tell FastAPI that this should be read with a field
Expand Down Expand Up @@ -594,7 +635,13 @@ def __init__(
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_is_table = any(is_table_model_class(base) for base in bases)
if is_table_model_class(cls) and not base_is_table:
polymorphic_identity = dict_.get("__mapper_args__", {}).get(
"polymorphic_identity"
)
has_polymorphic = polymorphic_identity is not None

# allow polymorphic models inherit from table models
if is_table_model_class(cls) and (not base_is_table or has_polymorphic):
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
Expand Down Expand Up @@ -702,13 +749,13 @@ def get_sqlalchemy_type(field: Any) -> Any:
raise ValueError(f"{type_} has no matching SQLAlchemy type")


def get_column_from_field(field: Any) -> Column: # type: ignore
def get_column_from_field(field: Any) -> Union[Column, MappedColumn]: # type: ignore
if IS_PYDANTIC_V2:
field_info = field
else:
field_info = field.field_info
sa_column = getattr(field_info, "sa_column", Undefined)
if isinstance(sa_column, Column):
if isinstance(sa_column, Column) or isinstance(sa_column, MappedColumn):
return sa_column
sa_type = get_sqlalchemy_type(field)
primary_key = getattr(field_info, "primary_key", Undefined)
Expand Down Expand Up @@ -772,7 +819,6 @@ def get_column_from_field(field: Any) -> Column: # type: ignore
class SQLModel(BaseModel, metaclass=SQLModelMetaclass, registry=default_registry):
# SQLAlchemy needs to set weakref(s), Pydantic will set the other slots values
__slots__ = ("__weakref__",)
__tablename__: ClassVar[Union[str, Callable[..., str]]]
__sqlmodel_relationships__: ClassVar[Dict[str, RelationshipProperty[Any]]]
__name__: ClassVar[str]
metadata: ClassVar[MetaData]
Expand Down Expand Up @@ -836,10 +882,6 @@ def __repr_args__(self) -> Sequence[Tuple[Optional[str], Any]]:
if not (isinstance(k, str) and k.startswith("_sa_"))
]

@declared_attr # type: ignore
def __tablename__(cls) -> str:
return cls.__name__.lower()

@classmethod
def model_validate(
cls: Type[_TSQLModel],
Expand Down
132 changes: 132 additions & 0 deletions tests/test_polymorphic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Optional

from sqlalchemy import ForeignKey
from sqlalchemy.orm import mapped_column
from sqlmodel import Field, Session, SQLModel, create_engine, select

from tests.conftest import needs_pydanticv2


@needs_pydanticv2
def test_polymorphic_joined_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love to see a syntax that's more focused on what the user wants and hides implementation details.
Example:

from fquery.sqlmodel import sqlmodel

@sqlmodel
class Hero:
  hero_type: str = polymorphic(default="hero")
  
@sqlmodel
class DarkHero(Hero):
   dark_power: str = field(default="dark", metadata=...)

You can then perform metaprogramming magic inside the sqlmodel decorator to compute exactly the same class definitions you have in the PR here.

Copy link
Author

@PaleNeutron PaleNeutron Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Polymorphic orm has a lot of details, and these details are not concerned at dataclass level. So I think keep the origin sqlalchemy syntax is the best choice for early version. Maybe we can improve it at next PR.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the key phrase in your comment is for early version. If sqlalchemy syntax was good enough, sqlmodel wouldn't exist :)

How other ORMs are doing it:

I think the key points are:

  • Decorator approach hides inheritance for those who don't like it. Instead it injects behavior.
  • fquery.sqlmodel already has one_to_many(). Perhaps it could be enhanced with one_to_many("hero_type")

But like you say - probably best done in two steps.

__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}

class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
sa_column=mapped_column(ForeignKey("hero.id"), primary_key=True),
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)


@needs_pydanticv2
def test_polymorphic_joined_table_with_sqlmodel_field(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}

class DarkHero(Hero):
__tablename__ = "dark_hero"
id: Optional[int] = Field(
default=None,
primary_key=True,
foreign_key="hero.id",
)
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero()
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)


@needs_pydanticv2
def test_polymorphic_single_table(clear_sqlmodel) -> None:
class Hero(SQLModel, table=True):
__tablename__ = "hero"
id: Optional[int] = Field(default=None, primary_key=True)
hero_type: str = Field(default="hero")

__mapper_args__ = {
"polymorphic_on": "hero_type",
"polymorphic_identity": "hero",
}

class DarkHero(Hero):
dark_power: str = Field(
default="dark",
sa_column=mapped_column(
nullable=False, use_existing_column=True, default="dark"
),
)

__mapper_args__ = {
"polymorphic_identity": "dark",
}

engine = create_engine("sqlite:///:memory:", echo=True)
SQLModel.metadata.create_all(engine)
with Session(engine) as db:
hero = Hero()
db.add(hero)
dark_hero = DarkHero(dark_power="pokey")
db.add(dark_hero)
db.commit()
statement = select(DarkHero)
result = db.exec(statement).all()
assert len(result) == 1
assert isinstance(result[0].dark_power, str)