Skip to content

Commit 5d59485

Browse files
committed
add correct insert job
1 parent 1daad9b commit 5d59485

File tree

2 files changed

+33
-16
lines changed

2 files changed

+33
-16
lines changed

dlt/destinations/job_client_impl.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,29 @@ def is_sql_job(file_path: str) -> bool:
115115

116116
class ModelLoadJob(RunnableLoadJob):
117117
"""
118-
A job to insert rows into a table from a model file which contains a list of select statements
118+
A job to insert rows into a table from a model file which contains a single select statement
119119
"""
120120

121121
def __init__(self, file_path: str) -> None:
122122
super().__init__(file_path)
123123
self._job_client: "SqlJobClientBase" = None
124+
self._sql_client = self._job_client.sql_client
124125

125126
def run(self) -> None:
126127
with FileStorage.open_zipsafe_ro(self._file_path, "r", encoding="utf-8") as f:
127-
sql = f.read()
128-
self._sql_client = self._job_client.sql_client
129-
self._sql_client.execute_sql(sql)
128+
select_statement = f.read()
129+
130+
insert_statement = self._insert_statement_from_select_statement(select_statement)
131+
self._sql_client.execute_sql(insert_statement)
132+
133+
def _insert_statement_from_select_statement(self, select_statement: str) -> str:
134+
"""
135+
NOTE: Here we generate an insert statement from a select statement, this is the duckdb
136+
dialect, we may be able to transpile with sqlglot for each destination or we need
137+
to subclass and override this method.
138+
"""
139+
name = self._sql_client.make_qualified_table_name(self._load_table["name"])
140+
return f"INSERT INTO {name} {select_statement};"
130141

131142
@staticmethod
132143
def is_model_job(file_path: str) -> bool:

tests/load/test_sql_resource.py renamed to tests/load/test_model_item_format.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,44 +7,50 @@
77
from dlt.common.destination.dataset import SupportsReadableDataset
88

99
from tests.pipeline.utils import load_table_counts
10-
1110
from dlt.extract.hints import make_hints
1211

1312

14-
def test_sql_job() -> None:
13+
def test_simple_model_jobs() -> None:
1514
# populate a table with 10 items and retrieve dataset
1615
pipeline = dlt.pipeline(
1716
pipeline_name="example_pipeline", destination="duckdb", dataset_name="example_dataset"
1817
)
1918
pipeline.run([{"a": i} for i in range(10)], table_name="example_table")
2019
dataset = pipeline.dataset()
2120

21+
example_table_columns = dataset.schema.tables["example_table"]["columns"]
22+
2223
# create a resource that generates sql statements to create 2 new tables
24+
# we also need to supply all hints so the table can be created
2325
@dlt.resource()
2426
def copied_table() -> Any:
2527
query = dataset["example_table"].limit(5).query()
2628
yield dlt.mark.with_hints(
27-
f"CREATE OR REPLACE TABLE copied_table AS {query}",
28-
make_hints(file_format="sql"),
29+
query, hints=make_hints(columns=example_table_columns), data_item_format="model"
2930
)
3031

32+
@dlt.resource()
33+
def copied_table_2() -> Any:
3134
query = dataset["example_table"].limit(7).query()
3235
yield dlt.mark.with_hints(
33-
f"CREATE OR REPLACE TABLE copied_table2 AS {query}",
34-
make_hints(file_format="sql"),
36+
query, hints=make_hints(columns=example_table_columns), data_item_format="model"
3537
)
3638

3739
# run sql jobs
38-
pipeline.run(copied_table())
40+
pipeline.run([copied_table(), copied_table_2()])
3941

4042
# the two tables where created
41-
assert load_table_counts(pipeline, "example_table", "copied_table", "copied_table2") == {
42-
"example_table": 10,
43+
assert load_table_counts(pipeline, "copied_table", "copied_table_2", "example_table") == {
4344
"copied_table": 5,
44-
"copied_table2": 7,
45+
"copied_table_2": 7,
46+
"example_table": 10,
4547
}
4648

4749
# we have a table entry for the main table "copied_table"
4850
assert "copied_table" in pipeline.default_schema.tables
49-
# but no columns, it's up to the user to provide a schema
50-
assert len(pipeline.default_schema.tables["copied_table"]["columns"]) == 0
51+
# and we only have the three columns from the original table
52+
assert set(pipeline.default_schema.tables["copied_table"]["columns"].keys()) == {
53+
"a",
54+
"_dlt_id",
55+
"_dlt_load_id",
56+
}

0 commit comments

Comments
 (0)