Skip to content

Commit 9df5329

Browse files
committed
Update sqlite/duckdb to overwrite an existing db schema and added tests
1 parent 321e0e8 commit 9df5329

File tree

6 files changed

+152
-26
lines changed

6 files changed

+152
-26
lines changed

dsi/backends/duckdb.py

Lines changed: 75 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,56 @@ def ingest_artifacts(self, collection, isVerbose=False):
175175
"""
176176
artifacts = collection
177177

178+
self.cur.execute("BEGIN TRANSACTION")
179+
180+
if self.list() is not None and list(artifacts.keys()) == ["dsi_relations"]:
181+
pk_list = artifacts["dsi_relations"]["primary_key"]
182+
fk_list = artifacts["dsi_relations"]["foreign_key"]
183+
pk_tables = set(t[0] for t in pk_list)
184+
fk_tables = set(t[0] for t in fk_list if t[0] != None)
185+
all_schema_tables = pk_tables.union(fk_tables)
186+
db_tables = [t[0] for t in self.list() if t[0] != "dsi_units"]
187+
188+
# check if tables from dsi_relations are all in the db
189+
if all_schema_tables.issubset(set(db_tables)):
190+
circ, _ = self.check_table_relations(all_schema_tables, artifacts["dsi_relations"])
191+
if circ:
192+
return (ValueError, f"A complex schema with a circular dependency cannot be ingested into a DuckDB backend.")
193+
194+
drop_order = all_schema_tables
195+
collect = self.process_artifacts()
196+
if "dsi_relations" in collect.keys():
197+
curr_pk_tables = set(t[0] for t in collect["dsi_relations"]["primary_key"])
198+
curr_fk_tables = set(t[0] for t in collect["dsi_relations"]["foreign_key"] if t[0] != None)
199+
curr_schema_tables = curr_pk_tables.union(curr_fk_tables)
200+
201+
# need to drop and reingest all tables in old schema and new schema
202+
all_schema_tables = all_schema_tables.union(curr_schema_tables)
203+
204+
_, ord_tables1 = self.check_table_relations(all_schema_tables, collect["dsi_relations"])
205+
drop_order = ord_tables1
206+
207+
for table in drop_order:
208+
self.cur.execute(f'DROP TABLE IF EXISTS "{table}";')
209+
try:
210+
self.con.commit()
211+
except Exception as e:
212+
self.cur.execute("ROLLBACK")
213+
self.cur.execute("CHECKPOINT")
214+
return (duckdb.Error, e)
215+
216+
#do not reingest tables not in old or new schema as they will be the same
217+
non_schema_tables = set(db_tables) - all_schema_tables
218+
for t in non_schema_tables:
219+
del collect[t]
220+
221+
collect["dsi_relations"] = artifacts["dsi_relations"]
222+
artifacts = collect
223+
224+
else:
225+
print("WARNING: Complex schemas can only be ingested if all referenced data tables are loaded into DSI.")
226+
227+
178228
table_order = artifacts.keys()
179229
if "dsi_relations" in artifacts.keys():
180230
circular, ordered_tables = self.check_table_relations(artifacts.keys(), artifacts["dsi_relations"])
@@ -184,10 +234,8 @@ def ingest_artifacts(self, collection, isVerbose=False):
184234
else:
185235
table_order = list(reversed(ordered_tables)) # ingest primary key tables first then children
186236

187-
self.cur.execute("BEGIN TRANSACTION")
188237
if self.runTable:
189-
runTable_create = "CREATE TABLE IF NOT EXISTS runTable " \
190-
"(run_id INTEGER PRIMARY KEY, run_timestamp TEXT UNIQUE);"
238+
runTable_create = "CREATE TABLE IF NOT EXISTS runTable (run_id INTEGER PRIMARY KEY, run_timestamp TEXT UNIQUE);"
191239
self.cur.execute(runTable_create)
192240

193241
sequence_run_id = "CREATE SEQUENCE IF NOT EXISTS seq_run_id START 1;"
@@ -387,13 +435,16 @@ def notebook(self, interactive=False):
387435
def read_to_artifact(self):
388436
return self.process_artifacts()
389437

390-
def process_artifacts(self):
438+
def process_artifacts(self, only_units_relations = False):
391439
"""
392440
Reads data from the DuckDB database into a nested OrderedDict.
393441
Keys are table names, and values are OrderedDicts containing table data.
394442
395443
If the database contains PK/FK relationships, they are stored in a special `dsi_relations` table.
396444
445+
`only_units_relations` : bool, default=False
446+
**USERS SHOULD IGNORE THIS FLAG.** Used internally by duckdb.py.
447+
397448
`return` : OrderedDict
398449
A nested OrderedDict containing all data from the DuckDB database.
399450
"""
@@ -404,20 +455,22 @@ def process_artifacts(self):
404455
SELECT table_name FROM information_schema.tables
405456
WHERE table_schema = 'main' AND table_type = 'BASE TABLE'
406457
""").fetchall()
407-
for item in tableList:
408-
tableName = self.duckdb_compatible_name(item[0])
409458

410-
tableInfo = self.cur.execute(f"PRAGMA table_info({tableName});").fetchdf()
411-
colDict = OrderedDict((self.duckdb_compatible_name(col), []) for col in tableInfo['name'])
459+
if only_units_relations == False:
460+
for item in tableList:
461+
tableName = self.duckdb_compatible_name(item[0])
462+
463+
tableInfo = self.cur.execute(f"PRAGMA table_info({tableName});").fetchdf()
464+
colDict = OrderedDict((self.duckdb_compatible_name(col), []) for col in tableInfo['name'])
412465

413-
data = self.cur.execute(f"SELECT * FROM {tableName};").fetchall()
414-
for row in data:
415-
for colName, val in zip(colDict.keys(), row):
416-
if val == "NULL":
417-
colDict[colName].append(None)
418-
else:
419-
colDict[colName].append(val)
420-
artifact[tableName] = colDict
466+
data = self.cur.execute(f"SELECT * FROM {tableName};").fetchall()
467+
for row in data:
468+
for colName, val in zip(colDict.keys(), row):
469+
if val == "NULL":
470+
colDict[colName].append(None)
471+
else:
472+
colDict[colName].append(val)
473+
artifact[tableName] = colDict
421474

422475
pk_list = []
423476
fkData = self.cur.execute(f"""
@@ -743,6 +796,8 @@ def list(self):
743796
SELECT table_name FROM information_schema.tables
744797
WHERE table_schema = 'main' AND table_type = 'BASE TABLE'
745798
""").fetchall()
799+
if not tableList:
800+
return None
746801
tableList = [self.duckdb_compatible_name(table[0]) for table in tableList]
747802

748803
info_list = []
@@ -839,12 +894,13 @@ def summary_helper(self, table_name):
839894
col_info = self.cur.execute(f"PRAGMA table_info({table_name})").fetchall()
840895

841896
numeric_types = {'INTEGER', 'REAL', 'FLOAT', 'NUMERIC', 'DECIMAL', 'DOUBLE', 'BIGINT'}
842-
headers = ['column', 'type', 'min', 'max', 'avg', 'std_dev']
897+
headers = ['column', 'type', 'unique', 'min', 'max', 'avg', 'std_dev']
843898
rows = []
844899

845900
for col in col_info:
846901
col_name = col[1]
847902
col_type = col[2].upper()
903+
unique_vals = self.cur.execute(f"SELECT COUNT(DISTINCT {col_name}) FROM {table_name};").fetchone()[0]
848904
is_primary = col[5] > 0
849905
display_name = f"{col_name}*" if is_primary else col_name
850906

@@ -863,7 +919,7 @@ def summary_helper(self, table_name):
863919

864920
if avg_val != None and std_dev == None:
865921
std_dev = 0
866-
rows.append([display_name, col_type, min_val, max_val, avg_val, std_dev])
922+
rows.append([display_name, col_type, unique_vals, min_val, max_val, avg_val, std_dev])
867923

868924
return headers, rows
869925

@@ -1007,7 +1063,7 @@ def visit(node):
10071063
if any(visit(node) for node in list(graph.keys())):
10081064
return True, None # Circular dependency detected
10091065

1010-
# Step 3: Order tables from least dependencies to most (if no circular dependencies)
1066+
# Order tables from least dependencies to most (if no circular dependencies)
10111067
in_degree = {table: 0 for table in tables}
10121068
for child in graph:
10131069
for parent in graph[child]:

dsi/backends/sqlite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ def ingest_artifacts(self, collection, isVerbose=False):
208208
create_stmt += f"{col[1]} {col[2]}, "
209209

210210
if fk_dict:
211-
fk_stmt = "FOREIGN KEY "
211+
fk_stmt = ""
212212
for k, v in fk_dict.items():
213213
if k not in create_stmt:
214214
msg = f"Input schema references a nonexistent column, {k}, in the foreign_key section of {table}"
215215
raise ValueError(msg)
216-
fk_stmt += f"({k}) REFERENCES {v[0]}({v[1]}), "
216+
fk_stmt += f"FOREIGN KEY ({k}) REFERENCES {v[0]}({v[1]}), "
217217
create_stmt += fk_stmt
218218
create_stmt = create_stmt[:-2] + ");"
219219

dsi/dsi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -304,10 +304,10 @@ def read(self, filenames, reader_name, table_name = None):
304304
sys.exit(f"read() ERROR: {e}")
305305
self.t.active_metadata = OrderedDict()
306306

307-
if len(table_keys) > 1:
308-
print(f"Loaded {filenames} into tables: {', '.join(table_keys)}")
309-
else:
307+
if len(table_keys) == 1:
310308
print(f"Loaded {filenames} into the table {table_keys[0]}")
309+
else:
310+
print(f"Loaded {filenames} into tables: {', '.join(table_keys)}")
311311

312312
def query(self, statement, collection = False, update = False):
313313
"""

dsi/tests/test_dsi.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import textwrap
66
from pandas import DataFrame
77
from collections import OrderedDict
8+
import hashlib
89

910
def test_list_functions():
1011
test = DSI()
@@ -888,6 +889,31 @@ def test_query_update_schema_sqlite_backend():
888889
assert data['i'].tolist() == [123,234]
889890
assert data['new_col'].tolist() == ["test1", "test1"]
890891

892+
def test_overwrite_schema_sqlite_backend():
893+
dbpath = 'data.db'
894+
if os.path.exists(dbpath):
895+
os.remove(dbpath)
896+
897+
test = DSI(filename=dbpath, backend_name= "Sqlite")
898+
test.schema(filename="examples/test/yaml1_schema.json")
899+
test.read(filenames=["examples/test/student_test1.yml", "examples/test/student_test2.yml"], reader_name='YAML1')
900+
test.write(filename="full_erd.png", writer_name="ER_Diagram")
901+
902+
test.schema(filename="examples/test/yaml1_circular_schema.json")
903+
test.write(filename="new_erd.png", writer_name="ER_Diagram")
904+
905+
def file_hash(path):
906+
sha = hashlib.sha256()
907+
with open(path, "rb") as f:
908+
sha.update(f.read())
909+
return sha.hexdigest()
910+
911+
hash1 = file_hash("full_erd.png")
912+
hash2 = file_hash("new_erd.png")
913+
914+
assert hash1 != hash2
915+
916+
891917

892918
# DUCKDB
893919
# DUCKDB
@@ -1730,4 +1756,46 @@ def test_query_update_schema_duckdb_backend():
17301756

17311757
data = test.get_table("math", collection=True, update=True)
17321758
assert data['specification'].tolist() == [123,234]
1733-
assert data['new_col'].tolist() == ["test1", "test1"]
1759+
assert data['new_col'].tolist() == ["test1", "test1"]
1760+
1761+
def test_overwrite_schema_duckdb_backend():
1762+
dbpath = 'data.db'
1763+
if os.path.exists(dbpath):
1764+
os.remove(dbpath)
1765+
1766+
test = DSI(filename=dbpath, backend_name= "DuckDB")
1767+
test.schema(filename="examples/test/yaml1_schema.json")
1768+
test.read(filenames=["examples/test/student_test1.yml", "examples/test/student_test2.yml"], reader_name='YAML1')
1769+
test.write(filename="full_erd.png", writer_name="ER_Diagram")
1770+
1771+
#loophole to assign new schema since there isnt another schema file that can be used with yaml data (circular wont work here)
1772+
new_schema = OrderedDict({'primary_key': [('address', 'i'), ('math', 'specification')], 'foreign_key': [('math', 'b'), (None, None)]})
1773+
test.read(filenames=new_schema, reader_name="Collection", table_name="dsi_relations") #loophole to assign new schema since
1774+
test.write(filename="new_erd.png", writer_name="ER_Diagram")
1775+
1776+
def file_hash(path):
1777+
sha = hashlib.sha256()
1778+
with open(path, "rb") as f:
1779+
sha.update(f.read())
1780+
return sha.hexdigest()
1781+
1782+
hash1 = file_hash("full_erd.png")
1783+
hash2 = file_hash("new_erd.png")
1784+
1785+
assert hash1 != hash2
1786+
1787+
def test_fail_overwrite_schema_duckdb_backend():
1788+
dbpath = 'data.db'
1789+
if os.path.exists(dbpath):
1790+
os.remove(dbpath)
1791+
1792+
test = DSI(filename=dbpath, backend_name= "DuckDB")
1793+
test.schema(filename="examples/test/yaml1_schema.json")
1794+
test.read(filenames=["examples/test/student_test1.yml", "examples/test/student_test2.yml"], reader_name='YAML1')
1795+
test.write(filename="full_erd.png", writer_name="ER_Diagram")
1796+
1797+
try:
1798+
test.schema(filename="examples/test/yaml1_circular_schema.json") # should not allow circular dependency overwrite
1799+
assert False
1800+
except:
1801+
assert True

examples/test/coreterminal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# a.load_module('plugin', 'Schema', 'reader', filename="example_schema.json")
1111
a.load_module('plugin', 'Schema', 'reader', filename="yaml1_schema.json")
12+
# a.load_module('plugin', 'Schema', 'reader', filename="yaml1_circular_schema.json")
1213

1314
a.load_module('plugin', 'YAML1', 'reader', filenames=["student_test1.yml", "student_test2.yml"])
1415
# a.load_module('plugin', 'TOML1', 'reader', filenames=["results.toml", "results1.toml"], target_table_prefix = "results")

examples/test/dsi_example.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
# test.list_writers()
1111

1212
''' Example uses of loading DSI readers '''
13-
# test.schema(filename="yaml1_schema.json") # must be loaded first
13+
# test.schema(filename="yaml1_circular_schema.json") # must be loaded first
14+
test.schema(filename="yaml1_schema.json") # must be loaded first
1415
# test.schema(filename="example_schema.json") # must be loaded first
1516

1617
test.read(filenames=["student_test1.yml", "student_test2.yml"], reader_name='YAML1')

0 commit comments

Comments
 (0)