Skip to content

Commit e1a7c2d

Browse files
JDarDagranCloud Composer Team
authored and
Cloud Composer Team
committed
openlineage, snowflake: add OpenLineage support for Snowflake (#31696)
* Add OpenLineage support for SnowflakeOperator. Signed-off-by: Jakub Dardzinski <[email protected]> * Change how default schema is retrieved. Signed-off-by: Jakub Dardzinski <[email protected]> --------- Signed-off-by: Jakub Dardzinski <[email protected]> GitOrigin-RevId: 5b082c38a66b1a0b6b496e0d3b15a6684339e1d1
1 parent 96347d6 commit e1a7c2d

File tree

4 files changed

+169
-1
lines changed

4 files changed

+169
-1
lines changed

airflow/providers/snowflake/hooks/snowflake.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
from functools import wraps
2323
from io import StringIO
2424
from pathlib import Path
25-
from typing import Any, Callable, Iterable, Mapping, TypeVar, overload
25+
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, TypeVar, overload
26+
from urllib.parse import urlparse
2627

2728
from cryptography.hazmat.backends import default_backend
2829
from cryptography.hazmat.primitives import serialization
@@ -36,6 +37,9 @@
3637
from airflow.utils.strings import to_boolean
3738

3839
T = TypeVar("T")
40+
if TYPE_CHECKING:
41+
from airflow.providers.openlineage.extractors import OperatorLineage
42+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
3943

4044

4145
def _try_to_boolean(value: Any):
@@ -448,3 +452,68 @@ def _get_cursor(self, conn: Any, return_dictionaries: bool):
448452
finally:
449453
if cursor is not None:
450454
cursor.close()
455+
456+
def get_openlineage_database_info(self, connection) -> DatabaseInfo:
457+
from airflow.providers.openlineage.sqlparser import DatabaseInfo
458+
459+
database = self.database or self._get_field(connection.extra_dejson, "database")
460+
461+
return DatabaseInfo(
462+
scheme=self.get_openlineage_database_dialect(connection),
463+
authority=self._get_openlineage_authority(connection),
464+
information_schema_columns=[
465+
"table_schema",
466+
"table_name",
467+
"column_name",
468+
"ordinal_position",
469+
"data_type",
470+
],
471+
database=database,
472+
is_information_schema_cross_db=True,
473+
is_uppercase_names=True,
474+
)
475+
476+
def get_openlineage_database_dialect(self, _) -> str:
477+
return "snowflake"
478+
479+
def get_openlineage_default_schema(self) -> str | None:
480+
"""
481+
Attempts to get current schema.
482+
483+
Usually ``SELECT CURRENT_SCHEMA();`` should work.
484+
However, apparently you may set ``database`` without ``schema``
485+
and get results from ``SELECT CURRENT_SCHEMAS();`` but not
486+
from ``SELECT CURRENT_SCHEMA();``.
487+
It still may return nothing if no database is set in connection.
488+
"""
489+
schema = self._get_conn_params()["schema"]
490+
if not schema:
491+
current_schemas = self.get_first("SELECT PARSE_JSON(CURRENT_SCHEMAS())[0]::string;")[0]
492+
if current_schemas:
493+
_, schema = current_schemas.split(".")
494+
return schema
495+
496+
def _get_openlineage_authority(self, _) -> str:
497+
from openlineage.common.provider.snowflake import fix_snowflake_sqlalchemy_uri
498+
499+
uri = fix_snowflake_sqlalchemy_uri(self.get_uri())
500+
return urlparse(uri).hostname
501+
502+
def get_openlineage_database_specific_lineage(self, _) -> OperatorLineage | None:
503+
from openlineage.client.facet import ExternalQueryRunFacet
504+
505+
from airflow.providers.openlineage.extractors import OperatorLineage
506+
from airflow.providers.openlineage.sqlparser import SQLParser
507+
508+
connection = self.get_connection(getattr(self, self.conn_name_attr))
509+
namespace = SQLParser.create_namespace(self.get_database_info(connection))
510+
511+
if self.query_ids:
512+
return OperatorLineage(
513+
run_facets={
514+
"externalQuery": ExternalQueryRunFacet(
515+
externalQueryId=self.query_ids[0], source=namespace
516+
)
517+
}
518+
)
519+
return None

generated/provider_dependencies.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@
814814
],
815815
"cross-providers-deps": [
816816
"common.sql",
817+
"openlineage",
817818
"slack"
818819
],
819820
"excluded-python-versions": []

tests/providers/snowflake/hooks/test_snowflake.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,33 @@ def test___ensure_prefixes(self):
621621
"extra__snowflake__private_key_content",
622622
"extra__snowflake__insecure_mode",
623623
]
624+
625+
@pytest.mark.parametrize(
626+
"returned_schema,expected_schema",
627+
[([None], ""), (["DATABASE.SCHEMA"], "SCHEMA")],
628+
)
629+
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
630+
def test_get_openlineage_default_schema_with_no_schema_set(
631+
self, mock_get_first, returned_schema, expected_schema
632+
):
633+
connection_kwargs = {
634+
**BASE_CONNECTION_KWARGS,
635+
"schema": None,
636+
}
637+
with mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()):
638+
hook = SnowflakeHook(snowflake_conn_id="test_conn")
639+
mock_get_first.return_value = returned_schema
640+
assert hook.get_openlineage_default_schema() == expected_schema
641+
642+
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook.get_first")
643+
def test_get_openlineage_default_schema_with_schema_set(self, mock_get_first):
644+
with mock.patch.dict(
645+
"os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**BASE_CONNECTION_KWARGS).get_uri()
646+
):
647+
hook = SnowflakeHook(snowflake_conn_id="test_conn")
648+
assert hook.get_openlineage_default_schema() == BASE_CONNECTION_KWARGS["schema"]
649+
mock_get_first.assert_not_called()
650+
651+
hook_with_schema_param = SnowflakeHook(snowflake_conn_id="test_conn", schema="my_schema")
652+
assert hook_with_schema_param.get_openlineage_default_schema() == "my_schema"
653+
mock_get_first.assert_not_called()

tests/providers/snowflake/operators/test_snowflake_sql.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222
import pytest
2323
from databricks.sql.types import Row
24+
from openlineage.client.facet import SchemaDatasetFacet, SchemaField, SqlJobFacet
25+
from openlineage.client.run import Dataset
2426

27+
from airflow.models.connection import Connection
2528
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
29+
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
2630
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator
2731

2832
DATE = "2017-04-20"
@@ -138,3 +142,67 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
138142
return_last=return_last,
139143
split_statements=split_statement,
140144
)
145+
146+
147+
def test_execute_openlineage_events():
148+
DB_NAME = "DATABASE"
149+
DB_SCHEMA_NAME = "PUBLIC"
150+
151+
class SnowflakeHookForTests(SnowflakeHook):
152+
get_conn = MagicMock(name="conn")
153+
get_connection = MagicMock()
154+
155+
def get_first(self, *_):
156+
return [f"{DB_NAME}.{DB_SCHEMA_NAME}"]
157+
158+
dbapi_hook = SnowflakeHookForTests()
159+
160+
class SnowflakeOperatorForTest(SnowflakeOperator):
161+
def get_db_hook(self):
162+
return dbapi_hook
163+
164+
sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
165+
order_day_of_week VARCHAR(64) NOT NULL,
166+
order_placed_on TIMESTAMP NOT NULL,
167+
orders_placed INTEGER NOT NULL
168+
);
169+
FORGOT TO COMMENT"""
170+
op = SnowflakeOperatorForTest(task_id="snowflake-operator", sql=sql)
171+
rows = [
172+
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_DAY_OF_WEEK", 1, "TEXT"),
173+
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDER_PLACED_ON", 2, "TIMESTAMP_NTZ"),
174+
(DB_SCHEMA_NAME, "POPULAR_ORDERS_DAY_OF_WEEK", "ORDERS_PLACED", 3, "NUMBER"),
175+
]
176+
dbapi_hook.get_connection.return_value = Connection(
177+
conn_id="snowflake_default",
178+
conn_type="snowflake",
179+
extra={
180+
"account": "test_account",
181+
"region": "us-east",
182+
"warehouse": "snow-warehouse",
183+
"database": DB_NAME,
184+
},
185+
)
186+
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]
187+
188+
lineage = op.get_openlineage_facets_on_start()
189+
assert len(lineage.inputs) == 0
190+
assert lineage.outputs == [
191+
Dataset(
192+
namespace="snowflake://test_account.us-east.aws",
193+
name=f"{DB_NAME}.{DB_SCHEMA_NAME}.POPULAR_ORDERS_DAY_OF_WEEK",
194+
facets={
195+
"schema": SchemaDatasetFacet(
196+
fields=[
197+
SchemaField(name="ORDER_DAY_OF_WEEK", type="TEXT"),
198+
SchemaField(name="ORDER_PLACED_ON", type="TIMESTAMP_NTZ"),
199+
SchemaField(name="ORDERS_PLACED", type="NUMBER"),
200+
]
201+
)
202+
},
203+
)
204+
]
205+
206+
assert lineage.job_facets == {"sql": SqlJobFacet(query=sql)}
207+
208+
assert lineage.run_facets["extractionError"].failedTasks == 1

0 commit comments

Comments
 (0)