Skip to content
This repository was archived by the owner on Apr 8, 2024. It is now read-only.

feat: Do a single connection for whole model #557

Open
wants to merge 1 commit into
base: act1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions adapter/src/dbt/adapters/fal/adapter_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

import pandas as pd
import sqlalchemy
from contextlib import contextmanager
from dbt.adapters.base import BaseAdapter, BaseRelation, RelationType
from dbt.adapters.base.connections import AdapterResponse, Connection
from dbt.adapters.base.connections import AdapterResponse
from dbt.config import RuntimeConfig
from dbt.parser.manifest import ManifestLoader

Expand All @@ -17,11 +16,13 @@
}


def _get_alchemy_engine(adapter: BaseAdapter, connection: Connection) -> Any:
def _get_alchemy_engine(adapter: BaseAdapter) -> Any:
# The following code heavily depends on the implementation
# details of the known adapters, hence it can't work for
# arbitrary ones.
adapter_type = adapter.type()
connection = adapter.connections.get_if_exists()
assert connection, "Connection should be present"

sqlalchemy_kwargs = {}
format_url = lambda url: url
Expand Down Expand Up @@ -80,30 +81,29 @@ def write_df_to_relation(
return support_duckdb.write_df_to_relation(adapter, dataframe, relation)

else:
with new_connection(adapter, "fal:write_df_to_relation") as connection:
# TODO: this should probably live in the materialization macro.
temp_relation = relation.replace_path(
identifier=f"__dbt_fal_temp_{relation.identifier}"
)
drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = _get_alchemy_engine(adapter, connection)

# TODO: probably worth handling errors here an returning
# a proper adapter response.
rows_affected = dataframe.to_sql(
con=alchemy_engine,
name=temp_relation.identifier,
schema=temp_relation.schema,
if_exists=if_exists,
index=False,
)
adapter.cache.add(temp_relation)
drop_relation_if_it_exists(adapter, relation)
adapter.rename_relation(temp_relation, relation)
adapter.commit_if_has_connection()
# TODO: this should probably live in the materialization macro.
temp_relation = relation.replace_path(
identifier=f"__dbt_fal_temp_{relation.identifier}"
)
drop_relation_if_it_exists(adapter, temp_relation)

alchemy_engine = _get_alchemy_engine(adapter)

# TODO: probably worth handling errors here an returning
# a proper adapter response.
rows_affected = dataframe.to_sql(
con=alchemy_engine,
name=temp_relation.identifier,
schema=temp_relation.schema,
if_exists=if_exists,
index=False,
)
adapter.cache.add(temp_relation)
drop_relation_if_it_exists(adapter, relation)
adapter.rename_relation(temp_relation, relation)
adapter.commit_if_has_connection()

return AdapterResponse("OK", rows_affected=rows_affected)
return AdapterResponse("OK", rows_affected=rows_affected)


def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.DataFrame:
Expand All @@ -120,13 +120,12 @@ def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.Data
return support_duckdb.read_relation_as_df(adapter, relation)

else:
with new_connection(adapter, "fal:read_relation_as_df") as connection:
alchemy_engine = _get_alchemy_engine(adapter, connection)
return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.identifier,
schema=relation.schema,
)
alchemy_engine = _get_alchemy_engine(adapter)
return pd.read_sql_table(
con=alchemy_engine,
table_name=relation.identifier,
schema=relation.schema,
)


def prepare_for_adapter(adapter: BaseAdapter, function: Any) -> Any:
Expand Down Expand Up @@ -163,11 +162,4 @@ def reconstruct_adapter(config: RuntimeConfig) -> BaseAdapter:

def reload_adapter_cache(adapter: BaseAdapter, config: RuntimeConfig) -> None:
manifest = ManifestLoader.get_full_manifest(config)
with new_connection(adapter, "fal:reload_adapter_cache"):
adapter.set_relations_cache(manifest, True)


@contextmanager
def new_connection(adapter: BaseAdapter, connection_name: str) -> Connection:
with adapter.connection_named(connection_name):
yield adapter.connections.get_thread_connection()
adapter.set_relations_cache(manifest, True)
10 changes: 6 additions & 4 deletions adapter/src/dbt/adapters/fal/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ def _run_with_adapter(code: str, adapter: BaseAdapter) -> Any:
# main symbol is defined during dbt-fal's compilation
# and acts as an entrypoint for us to run the model.
main = retrieve_symbol(code, "main")
return main(
read_df=prepare_for_adapter(adapter, read_relation_as_df),
write_df=prepare_for_adapter(adapter, write_df_to_relation),
)

with adapter.connection_named("fal:model"):
return main(
read_df=prepare_for_adapter(adapter, read_relation_as_df),
write_df=prepare_for_adapter(adapter, write_df_to_relation),
)


def _isolated_runner(code: str, config: RuntimeConfig) -> Any:
Expand Down
7 changes: 3 additions & 4 deletions adapter/src/dbt/adapters/fal/support/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from dbt.adapters.base import BaseAdapter, BaseRelation
from dbt.adapters.base.connections import AdapterResponse
from dbt.adapters.fal.adapter_support import new_connection
import pandas as pd
from dbt.adapters.sql import SQLAdapter
import duckdb


def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.DataFrame:
db_path = adapter.config.credentials.path

con = duckdb.connect(database=db_path)

df = con.execute(f"SELECT * FROM {relation.identifier}").fetchdf()
return df

Expand All @@ -19,9 +18,9 @@ def write_df_to_relation(
data: pd.DataFrame,
relation: BaseRelation,
) -> AdapterResponse:

db_path = adapter.config.credentials.path
db_path = adapter.config.credentials.path
con = duckdb.connect(database=db_path)

rows_affected = con.execute(
f"CREATE OR REPLACE TABLE {relation.identifier} AS SELECT * FROM data;"
).fetchall()[0][0]
Expand Down
26 changes: 14 additions & 12 deletions adapter/src/dbt/adapters/fal/support/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,23 @@ def read_relation_as_df(adapter: BaseAdapter, relation: BaseRelation) -> pd.Data

assert adapter.type() == "snowflake"

with new_connection(adapter, "fal-snowflake:read_relation_as_df") as conn:
cur = conn.handle.cursor()
cur.execute(sql)
df: pd.DataFrame = cur.fetch_pandas_all()
connection = adapter.connections.get_if_exists()
assert connection, "Connection should be present"

# HACK: manually parse ARRAY and VARIANT since they are returned as strings right now
# Related issue: https://github.com/snowflakedb/snowflake-connector-python/issues/544
for desc in cur.description:
# 5=VARIANT, 10=ARRAY -- https://docs.snowflake.com/en/user-guide/python-connector-api.html#type-codes
if desc.type_code in [5, 10]:
import json
cur = connection.handle.cursor()
cur.execute(sql)
df: pd.DataFrame = cur.fetch_pandas_all()

df[desc.name] = df[desc.name].map(lambda v: json.loads(v))
# HACK: manually parse ARRAY and VARIANT since they are returned as strings right now
# Related issue: https://github.com/snowflakedb/snowflake-connector-python/issues/544
for desc in cur.description:
# 5=VARIANT, 10=ARRAY -- https://docs.snowflake.com/en/user-guide/python-connector-api.html#type-codes
if desc.type_code in [5, 10]:
import json

return df
df[desc.name] = df[desc.name].map(lambda v: json.loads(v))

return df


def write_df_to_relation(
Expand Down