Skip to content

CLN: remove sqlalchemy<14 compat #45410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 10 additions & 38 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from pandas.core.base import PandasObject
import pandas.core.common as com
from pandas.core.tools.datetimes import to_datetime
from pandas.util.version import Version


class DatabaseError(OSError):
Expand All @@ -57,16 +56,6 @@ class DatabaseError(OSError):
# -- Helper functions


def _gt14() -> bool:
"""
Check if sqlalchemy.__version__ is at least 1.4.0, when several
deprecations were made.
"""
import sqlalchemy

return Version(sqlalchemy.__version__) >= Version("1.4.0")


def _convert_params(sql, params):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
Expand Down Expand Up @@ -814,10 +803,7 @@ def sql_schema(self):

def _execute_create(self):
# Inserting table into database, add to MetaData object
if _gt14():
self.table = self.table.to_metadata(self.pd_sql.meta)
else:
self.table = self.table.tometadata(self.pd_sql.meta)
self.table = self.table.to_metadata(self.pd_sql.meta)
self.table.create(bind=self.pd_sql.connectable)

def create(self):
Expand Down Expand Up @@ -986,10 +972,9 @@ def read(self, coerce_float=True, parse_dates=None, columns=None, chunksize=None
if self.index is not None:
for idx in self.index[::-1]:
cols.insert(0, self.table.c[idx])
sql_select = select(*cols) if _gt14() else select(cols)
sql_select = select(*cols)
else:
sql_select = select(self.table) if _gt14() else self.table.select()

sql_select = select(self.table)
result = self.pd_sql.execute(sql_select)
column_names = result.keys()

Expand Down Expand Up @@ -1633,19 +1618,11 @@ def check_case_sensitive(
if not name.isdigit() and not name.islower():
# check for potentially case sensitivity issues (GH7815)
# Only check when name is not a number and name is not lower case
engine = self.connectable.engine
with self.connectable.connect() as conn:
if _gt14():
from sqlalchemy import inspect
from sqlalchemy import inspect

insp = inspect(conn)
table_names = insp.get_table_names(
schema=schema or self.meta.schema
)
else:
table_names = engine.table_names(
schema=schema or self.meta.schema, connection=conn
)
with self.connectable.connect() as conn:
insp = inspect(conn)
table_names = insp.get_table_names(schema=schema or self.meta.schema)
if name not in table_names:
msg = (
f"The provided table name '{name}' is not found exactly as "
Expand Down Expand Up @@ -1749,15 +1726,10 @@ def tables(self):
return self.meta.tables

def has_table(self, name: str, schema: str | None = None):
if _gt14():
from sqlalchemy import inspect
from sqlalchemy import inspect

insp = inspect(self.connectable)
return insp.has_table(name, schema or self.meta.schema)
else:
return self.connectable.run_callable(
self.connectable.dialect.has_table, name, schema or self.meta.schema
)
insp = inspect(self.connectable)
return insp.has_table(name, schema or self.meta.schema)

def get_table(self, table_name: str, schema: str | None = None):
from sqlalchemy import (
Expand Down
93 changes: 28 additions & 65 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
SQLAlchemyEngine,
SQLDatabase,
SQLiteDatabase,
_gt14,
get_engine,
pandasSQL_builder,
read_sql_query,
Expand Down Expand Up @@ -385,10 +384,10 @@ def mysql_pymysql_engine(iris_path, types_data):
"mysql+pymysql://root@localhost:3306/pandas",
connect_args={"client_flag": pymysql.constants.CLIENT.MULTI_STATEMENTS},
)
check_target = sqlalchemy.inspect(engine) if _gt14() else engine
if not check_target.has_table("iris"):
insp = sqlalchemy.inspect(engine)
if not insp.has_table("iris"):
create_and_load_iris(engine, iris_path, "mysql")
if not check_target.has_table("types"):
if not insp.has_table("types"):
for entry in types_data:
entry.pop("DateColWithTz")
create_and_load_types(engine, types_data, "mysql")
Expand All @@ -412,10 +411,10 @@ def postgresql_psycopg2_engine(iris_path, types_data):
engine = sqlalchemy.create_engine(
"postgresql+psycopg2://postgres:postgres@localhost:5432/pandas"
)
check_target = sqlalchemy.inspect(engine) if _gt14() else engine
if not check_target.has_table("iris"):
insp = sqlalchemy.inspect(engine)
if not insp.has_table("iris"):
create_and_load_iris(engine, iris_path, "postgresql")
if not check_target.has_table("types"):
if not insp.has_table("types"):
create_and_load_types(engine, types_data, "postgresql")
yield engine
with engine.connect() as conn:
Expand Down Expand Up @@ -1450,8 +1449,7 @@ def test_query_by_select_obj(self):
)

iris = iris_table_metadata(self.flavor)
iris_select = iris if _gt14() else [iris]
name_select = select(iris_select).where(iris.c.Name == bindparam("name"))
name_select = select(iris).where(iris.c.Name == bindparam("name"))
iris_df = sql.read_sql(name_select, self.conn, params={"name": "Iris-setosa"})
all_names = set(iris_df["Name"])
assert all_names == {"Iris-setosa"}
Expand Down Expand Up @@ -1624,46 +1622,33 @@ def test_to_sql_empty(self, test_frame1):
self._to_sql_empty(test_frame1)

def test_create_table(self):
from sqlalchemy import inspect

temp_conn = self.connect()
temp_frame = DataFrame(
{"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
)

pandasSQL = sql.SQLDatabase(temp_conn)
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4

if _gt14():
from sqlalchemy import inspect

insp = inspect(temp_conn)
assert insp.has_table("temp_frame")
else:
assert temp_conn.has_table("temp_frame")
insp = inspect(temp_conn)
assert insp.has_table("temp_frame")

def test_drop_table(self):
temp_conn = self.connect()
from sqlalchemy import inspect

temp_conn = self.connect()
temp_frame = DataFrame(
{"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]}
)

pandasSQL = sql.SQLDatabase(temp_conn)
assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4

if _gt14():
from sqlalchemy import inspect

insp = inspect(temp_conn)
assert insp.has_table("temp_frame")
else:
assert temp_conn.has_table("temp_frame")
insp = inspect(temp_conn)
assert insp.has_table("temp_frame")

pandasSQL.drop_table("temp_frame")

if _gt14():
assert not insp.has_table("temp_frame")
else:
assert not temp_conn.has_table("temp_frame")
assert not insp.has_table("temp_frame")

def test_roundtrip(self, test_frame1):
self._roundtrip(test_frame1)
Expand Down Expand Up @@ -2156,14 +2141,10 @@ def bar(connection, data):
data.to_sql(name="test_foo_data", con=connection, if_exists="append")

def baz(conn):
if _gt14():
# https://github.com/sqlalchemy/sqlalchemy/commit/
# 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973
foo_data = foo(conn)
bar(conn, foo_data)
else:
foo_data = conn.run_callable(foo)
conn.run_callable(bar, foo_data)
# https://github.com/sqlalchemy/sqlalchemy/commit/
# 00b5c10846e800304caa86549ab9da373b42fa5d#r48323973
foo_data = foo(conn)
bar(conn, foo_data)

def main(connectable):
if isinstance(connectable, Engine):
Expand Down Expand Up @@ -2216,14 +2197,9 @@ def test_temporary_table(self):
)
from sqlalchemy.orm import (
Session,
sessionmaker,
declarative_base,
)

if _gt14():
from sqlalchemy.orm import declarative_base
else:
from sqlalchemy.ext.declarative import declarative_base

test_data = "Hello, World!"
expected = DataFrame({"spam": [test_data]})
Base = declarative_base()
Expand All @@ -2234,24 +2210,13 @@ class Temporary(Base):
id = Column(Integer, primary_key=True)
spam = Column(Unicode(30), nullable=False)

if _gt14():
with Session(self.conn) as session:
with session.begin():
conn = session.connection()
Temporary.__table__.create(conn)
session.add(Temporary(spam=test_data))
session.flush()
df = sql.read_sql_query(sql=select(Temporary.spam), con=conn)
else:
Session = sessionmaker()
session = Session(bind=self.conn)
with session.transaction:
with Session(self.conn) as session:
with session.begin():
conn = session.connection()
Temporary.__table__.create(conn)
session.add(Temporary(spam=test_data))
session.flush()
df = sql.read_sql_query(sql=select([Temporary.spam]), con=conn)

df = sql.read_sql_query(sql=select(Temporary.spam), con=conn)
tm.assert_frame_equal(df, expected)

# -- SQL Engine tests (in the base class for now)
Expand Down Expand Up @@ -2349,12 +2314,10 @@ def test_row_object_is_named_tuple(self):
Integer,
String,
)
from sqlalchemy.orm import sessionmaker

if _gt14():
from sqlalchemy.orm import declarative_base
else:
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import (
declarative_base,
sessionmaker,
)

BaseModel = declarative_base()

Expand Down