Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for batch insert with returning ids #155

Open
wants to merge 2 commits into
base: azure-1.11
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions sql_server/pyodbc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,18 +326,25 @@ def as_sql(self):
# queries and generate their own placeholders. Doing that isn't
# necessary and it should be possible to use placeholders and
# expressions in bulk inserts too.
can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) and has_fields
can_bulk = self.connection.features.has_bulk_insert and has_fields

if self.connection.features.can_return_ids_from_bulk_insert:
meta = self.query.get_meta()
qn = self.quote_name_unless_alias
result.append("OUTPUT INSERTED.%s" % qn(meta.pk.db_column or meta.pk.column))

placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows)

if self.return_id and self.connection.features.can_return_id_from_insert:
result.insert(0, 'SET NOCOUNT ON')
result.append((values_format + ';') % ', '.join(placeholder_rows[0]))
params = [param_rows[0]]
result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)')
return [(" ".join(result), tuple(chain.from_iterable(params)))]
if not can_bulk:
if self.return_id and self.connection.features.can_return_id_from_insert:
result.insert(0, 'SET NOCOUNT ON')
result.append((values_format + ';') % ', '.join(placeholder_rows[0]))
params = [param_rows[0]]
result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)')
return [(" ".join(result), tuple(chain.from_iterable(params)))]

if can_bulk:
result.insert(0, 'SET NOCOUNT ON')
result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows))
sql = [(" ".join(result), tuple(p for ps in param_rows for p in ps))]
else:
Expand Down
1 change: 1 addition & 0 deletions sql_server/pyodbc/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures):
can_introspect_autofield = True
can_introspect_small_integer_field = True
can_return_id_from_insert = True
can_return_ids_from_bulk_insert = True
can_use_chunked_reads = False
for_update_after_from = True
greatest_least_ignores_nulls = True
Expand Down
8 changes: 8 additions & 0 deletions sql_server/pyodbc/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname):
sql = "CONVERT(datetime, CONVERT(varchar, %s, 20))" % field_name
return sql, params

def fetch_returned_insert_ids(self, cursor):
"""
Given a cursor object that has just performed an INSERT...RETURNING
statement into a table that has an auto-incrementing ID, return the
list of newly created IDs.
"""
return [item[0] for item in cursor.fetchall()]

def for_update_sql(self, nowait=False, skip_locked=False):
"""
Returns the FOR UPDATE SQL clause to lock rows for an update operation.
Expand Down