Skip to content

Commit

Permalink
SQL CREATE TABLE in python (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
coilysiren committed Oct 22, 2023
1 parent f10b2a4 commit e8bec6e
Show file tree
Hide file tree
Showing 6 changed files with 174 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/config.yml
Expand Up @@ -28,4 +28,4 @@ jobs:
uses: actions/checkout@v3

- run: pip install invoke pyyaml
- run: invoke test ${{ matrix.language }} any any
- run: invoke test ${{ matrix.language }} any any --snippets
6 changes: 6 additions & 0 deletions data/sql_input_1.sql
@@ -0,0 +1,6 @@
-- https://cratedb.com/docs/sql-99/en/latest/chapters/01.html
-- https://www.postgresql.org/docs/16/sql-createtable.html
-- https://www.postgresql.org/docs/16/sql-select.html
CREATE TABLE city ();
CREATE TABLE town ();
SELECT table_name FROM information_schema.tables WHERE table_schema = 'public';
3 changes: 3 additions & 0 deletions data/sql_output_1.json
@@ -0,0 +1,3 @@
{
"table_name": ["city", "town"]
}
80 changes: 77 additions & 3 deletions snippets/python/sql_test.py
Expand Up @@ -2,7 +2,81 @@
import json


def run_sql(input_sql: list[str]) -> list[str]:
output = {"table_name": ["city"]}
return [json.dumps(output)]
class SQL:
data: dict = {}

def __init__(self) -> None:
self.data = {}

def information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.data.values()]

def create_table(self, *args, table_schema="public") -> dict:
table_name = args[2]
if not self.data.get(table_name):
self.data[table_name] = {
"metadata": {
"table_name": table_name,
"table_schema": table_schema,
},
}
return {}

create_table.sql = "CREATE TABLE"

def select(self, *args) -> dict:
output = {}

from_index = None
where_index = None
for i, arg in enumerate(args):
if arg == "FROM":
from_index = i
if arg == "WHERE":
where_index = i

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
if from_value == "information_schema.tables":
target = self.information_schema_tables()

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on

return output

select.sql = "SELECT"

sql_map = {
create_table.sql: create_table,
select.sql: select,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}

for line in input_sql:
if not line.startswith("--"):
words = line.split(" ")
for i in reversed(range(len(words))):
key = " ".join(words[:i])
if func := self.sql_map.get(key):
output = func(self, *words)
break

return [json.dumps(output)]

82 changes: 78 additions & 4 deletions src/python/sql_test.py
Expand Up @@ -12,14 +12,88 @@
import json


def run_sql(input_sql: list[str]) -> list[str]:
output = {"table_name": ["city"]}
return [json.dumps(output)]
class SQL:
data: dict = {}

def __init__(self) -> None:
self.data = {}

def information_schema_tables(self) -> list[dict]:
return [data["metadata"] for data in self.data.values()]

def create_table(self, *args, table_schema="public") -> dict:
table_name = args[2]
if not self.data.get(table_name):
self.data[table_name] = {
"metadata": {
"table_name": table_name,
"table_schema": table_schema,
},
}
return {}

create_table.sql = "CREATE TABLE"

def select(self, *args) -> dict:
output = {}

from_index = None
where_index = None
for i, arg in enumerate(args):
if arg == "FROM":
from_index = i
if arg == "WHERE":
where_index = i

# get select keys by getting the slice of args before FROM
select_keys = " ".join(args[1:from_index]).split(",")

# get where keys by getting the slice of args after WHERE
from_value = args[from_index + 1]

# consider "information_schema.tables" a special case until
# we figure out why its so different from the others
if from_value == "information_schema.tables":
target = self.information_schema_tables()

# fmt: off
output = {
key: [
value for data in target
for key, value in data.items()
if key in select_keys
]
for key in select_keys
}
# fmt: on

return output

select.sql = "SELECT"

sql_map = {
create_table.sql: create_table,
select.sql: select,
}

def run(self, input_sql: list[str]) -> list[str]:
output = {}

for line in input_sql:
if not line.startswith("--"):
words = line.split(" ")
for i in reversed(range(len(words))):
key = " ".join(words[:i])
if func := self.sql_map.get(key):
output = func(self, *words)
break

return [json.dumps(output)]


######################
# business logic end #
######################

if __name__ == "__main__":
helpers.run(run_sql)
helpers.run(SQL().run)
18 changes: 9 additions & 9 deletions tasks.py
@@ -1,4 +1,5 @@
# builtin packages
import unittest
import filecmp
import glob
import os
Expand Down Expand Up @@ -163,6 +164,8 @@ def generate(self, language, config, script_path, input_file_path):
docker_run_test_list = [
"docker",
"run",
"--rm",
f"--name={language}",
f"--volume={self.base_directory}:/workdir",
"-w=/workdir",
]
Expand Down Expand Up @@ -297,13 +300,9 @@ def run_tests(self, input_script):
prepared_file_data = json.load(reader)
with open(ctx.script_output_file_path, "r", encoding="utf-8") as reader:
script_output_file_data = json.load(reader)
if prepared_file_data == script_output_file_data:
self.set_success_status(True)
print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded")
else:
self.set_success_status(False)
print(f"\t🔴 {ctx.script_relative_path} on {ctx.input_file_path} failed, reason:")
print(f"\t\t output file {ctx.script_output_file_name} has does not match the prepared file")
unittest.TestCase().assertDictEqual(prepared_file_data, script_output_file_data)
self.set_success_status(True)
print(f"\t🟢 {ctx.script_relative_path} on {ctx.input_file_path} succeeded")
continue

# check if the output file matches the prepared file
Expand Down Expand Up @@ -392,12 +391,13 @@ def show_results(self):


@invoke.task
def test(ctx: invoke.Context, language, input_script, input_data_index):
def test(ctx: invoke.Context, language, input_script, input_data_index, snippets=False):
# language is the programming language to run scripts in
# input_script is the name of a script you want to run
runner = TestRunner(ctx, language, input_data_index)
runner.run_tests(input_script)
runner.generate_snippets(input_script)
if snippets:
runner.generate_snippets(input_script)
runner.show_results()


Expand Down

0 comments on commit e8bec6e

Please sign in to comment.