Skip to content

Commit 5b32aa0

Browse files
refactor: Apply WHERE clauses to return value of upstream SQLStream.build_query
1 parent eb913b3 commit 5b32aa0

File tree

4 files changed

+22
-40
lines changed

4 files changed

+22
-40
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ dependencies = [
2929
"psycopg2-binary==2.9.10",
3030
"sqlalchemy==2.0.41",
3131
"sshtunnel==0.4.0",
32-
"singer-sdk[faker] @ git+https://github.com/meltano/sdk.git",
32+
"singer-sdk[faker] @ git+https://github.com/meltano/sdk.git@refs/pull/3050/head",
3333
]
3434

3535
[project.urls]
@@ -52,7 +52,7 @@ lint = [
5252
testing = [
5353
"hypothesis>=6.122.1",
5454
"pytest>=8",
55-
"singer-sdk[testing] @ git+https://github.com/meltano/sdk.git",
55+
"singer-sdk[testing] @ git+https://github.com/meltano/sdk.git@refs/pull/3050/head",
5656
]
5757
typing = [
5858
"mypy>=1.8.0",

tap_postgres/client.py

Lines changed: 13 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -203,42 +203,13 @@ def max_record_count(self) -> int | None:
203203
"""Return the maximum number of records to fetch in a single query."""
204204
return self.config.get("max_record_count")
205205

206-
def build_query(self, context: Context | None) -> sa.sql.Select:
206+
def build_query(self, table: sa.Table) -> sa.sql.Select:
207207
"""Build a SQLAlchemy query for the stream."""
208-
selected_column_names = self.get_selected_schema()["properties"].keys()
209-
table = self.connector.get_table(
210-
full_table_name=self.fully_qualified_name,
211-
column_names=selected_column_names,
212-
)
213-
query = table.select()
214-
215-
if self.replication_key:
216-
replication_key_col = table.columns[self.replication_key]
217-
order_by = (
218-
sa.nulls_first(replication_key_col.asc())
219-
if self.supports_nulls_first
220-
else replication_key_col.asc()
221-
)
222-
query = query.order_by(order_by)
223-
224-
start_val = self.get_starting_replication_key_value(context)
225-
if start_val:
226-
query = query.where(replication_key_col >= start_val)
227-
208+
query = super().build_query(table)
228209
stream_options = self.config.get("stream_options", {}).get(self.name, {})
229210
if clauses := stream_options.get("custom_where_clauses"):
230211
query = query.where(*(sa.text(clause.strip()) for clause in clauses))
231212

232-
if self.ABORT_AT_RECORD_COUNT is not None:
233-
# Limit record count to one greater than the abort threshold. This ensures
234-
# `MaxRecordsLimitException` exception is properly raised by caller
235-
# `Stream._sync_records()` if more records are available than can be
236-
# processed.
237-
query = query.limit(self.ABORT_AT_RECORD_COUNT + 1)
238-
239-
if self.max_record_count():
240-
query = query.limit(self.max_record_count())
241-
242213
return query
243214

244215
# Get records from stream
@@ -264,8 +235,18 @@ def get_records(self, context: Context | None) -> t.Iterable[dict[str, t.Any]]:
264235
msg = f"Stream '{self.name}' does not support partitioning."
265236
raise NotImplementedError(msg)
266237

238+
selected_column_names = self.get_selected_schema()["properties"].keys()
239+
table = self.connector.get_table(
240+
full_table_name=self.fully_qualified_name,
241+
column_names=selected_column_names,
242+
)
243+
244+
query = self.build_query(table)
245+
query = self.apply_replication_filter(query, table, context=context)
246+
query = self.apply_abort_query_limit(query)
247+
267248
with self.connector._connect() as conn:
268-
for record in conn.execute(self.build_query(context)).mappings():
249+
for record in conn.execute(query).mappings():
269250
# TODO: Standardize record mapping type
270251
# https://github.com/meltano/sdk/issues/2096
271252
transformed_record = self.post_process(dict(record))

tests/test_stream_class.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def test_build_query():
5959
table="test_table",
6060
)
6161
stream = PostgresStream(tap, catalog_entry.to_dict(), connector=DummyConnector())
62+
table = sa.Table("test_table", sa.MetaData(), sa.Column("id", sa.Integer))
6263
assert (
63-
str(stream.build_query(None).compile()).replace("\n", "")
64+
str(stream.build_query(table).compile()).replace("\n", "")
6465
== "SELECT test_table.id FROM test_table WHERE id % 2 = 0 AND id % 3 = 0"
6566
)

uv.lock

Lines changed: 5 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)