diff --git a/sql_server/pyodbc/compiler.py b/sql_server/pyodbc/compiler.py index 8f70cd69..bedf20d5 100644 --- a/sql_server/pyodbc/compiler.py +++ b/sql_server/pyodbc/compiler.py @@ -326,18 +326,25 @@ def as_sql(self): # queries and generate their own placeholders. Doing that isn't # necessary and it should be possible to use placeholders and # expressions in bulk inserts too. - can_bulk = (not self.return_id and self.connection.features.has_bulk_insert) and has_fields + can_bulk = self.connection.features.has_bulk_insert and has_fields + + if self.connection.features.can_return_ids_from_bulk_insert: + meta = self.query.get_meta() + qn = self.quote_name_unless_alias + result.append("OUTPUT INSERTED.%s" % qn(meta.pk.db_column or meta.pk.column)) placeholder_rows, param_rows = self.assemble_as_sql(fields, value_rows) - if self.return_id and self.connection.features.can_return_id_from_insert: - result.insert(0, 'SET NOCOUNT ON') - result.append((values_format + ';') % ', '.join(placeholder_rows[0])) - params = [param_rows[0]] - result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)') - return [(" ".join(result), tuple(chain.from_iterable(params)))] + if not can_bulk: + if self.return_id and self.connection.features.can_return_id_from_insert: + result.insert(0, 'SET NOCOUNT ON') + result.append((values_format + ';') % ', '.join(placeholder_rows[0])) + params = [param_rows[0]] + result.append('SELECT CAST(SCOPE_IDENTITY() AS bigint)') + return [(" ".join(result), tuple(chain.from_iterable(params)))] if can_bulk: + result.insert(0, 'SET NOCOUNT ON') result.append(self.connection.ops.bulk_insert_sql(fields, placeholder_rows)) sql = [(" ".join(result), tuple(p for ps in param_rows for p in ps))] else: diff --git a/sql_server/pyodbc/features.py b/sql_server/pyodbc/features.py index 18945f4d..b84b30e9 100644 --- a/sql_server/pyodbc/features.py +++ b/sql_server/pyodbc/features.py @@ -6,6 +6,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): can_introspect_autofield = True can_introspect_small_integer_field = True can_return_id_from_insert = True + can_return_ids_from_bulk_insert = True can_use_chunked_reads = False for_update_after_from = True greatest_least_ignores_nulls = True diff --git a/sql_server/pyodbc/operations.py b/sql_server/pyodbc/operations.py index 749e5b72..3e7b5d66 100644 --- a/sql_server/pyodbc/operations.py +++ b/sql_server/pyodbc/operations.py @@ -221,6 +221,14 @@ def datetime_trunc_sql(self, lookup_type, field_name, tzname): sql = "CONVERT(datetime, CONVERT(varchar, %s, 20))" % field_name return sql, params + def fetch_returned_insert_ids(self, cursor): + """ + Given a cursor object that has just performed an INSERT...RETURNING + statement into a table that has an auto-incrementing ID, return the + list of newly created IDs. + """ + return [item[0] for item in cursor.fetchall()] + def for_update_sql(self, nowait=False, skip_locked=False): """ Returns the FOR UPDATE SQL clause to lock rows for an update operation.