Skip to content

Commit

Permalink
Add --directory option for storing migrations (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
troyharvey authored May 20, 2024
1 parent 3a98f99 commit 75aeb73
Show file tree
Hide file tree
Showing 17 changed files with 607 additions and 152 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 6
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']
env:
ACTIONS_ALLOW_UNSECURE_COMMANDS: true
name: Python ${{ matrix.python-version }} tests
Expand Down
19 changes: 19 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

## Test in a dbt Project

1. Install `migro` in edit mode inside a dbt project.

pip install -e ~/Projects/migro

1. Build a distribution wheel.

python setup.py bdist_wheel
Expand All @@ -29,3 +33,18 @@
1. Upload the new version.

twine upload dist/*

## Upgrading Dependencies

1. Install pip tools

pip install pip-tools

1. Update the dependencies in pyproject and setup.py.
1. Compile the requirements file.

pip-compile -o requirements.txt pyproject.toml

1. Compile the dev requirements file.

pip-compile --extra dev -o dev-requirements.txt pyproject.toml
59 changes: 21 additions & 38 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --extra=dev --output-file=dev-requirements.txt pyproject.toml
#
asn1crypto==1.5.1
# via
# oscrypto
# snowflake-connector-python
attrs==22.2.0
# via pytest
black==22.12.0
# via migro (pyproject.toml)
# via snowflake-connector-python
bleach==5.0.1
# via readme-renderer
certifi==2022.12.7
Expand All @@ -27,9 +21,7 @@ charset-normalizer==2.1.1
# requests
# snowflake-connector-python
click==8.1.3
# via
# black
# migro (pyproject.toml)
# via migro (pyproject.toml)
commonmark==0.9.1
# via rich
coverage[toml]==7.0.1
Expand Down Expand Up @@ -64,26 +56,20 @@ markupsafe==2.1.1
# via jinja2
more-itertools==9.0.0
# via jaraco-classes
mypy-extensions==0.4.3
# via black
oscrypto==1.3.0
# via snowflake-connector-python
packaging==22.0
# via pytest
pathspec==0.10.3
# via black
# via
# pytest
# snowflake-connector-python
pkginfo==1.9.2
# via twine
platformdirs==2.6.2
# via black
pluggy==1.0.0
platformdirs==4.2.2
# via snowflake-connector-python
pluggy==1.5.0
# via pytest
psycopg2==2.9.5
# via migro (pyproject.toml)
pycparser==2.21
# via cffi
pycryptodomex==3.16.0
# via snowflake-connector-python
pygments==2.14.0
# via
# readme-renderer
Expand All @@ -92,15 +78,15 @@ pyjwt==2.6.0
# via snowflake-connector-python
pyopenssl==22.1.0
# via snowflake-connector-python
pytest==7.2.0
pytest==8.2.1
# via
# migro (pyproject.toml)
# pytest-cov
pytest-cov==4.0.0
pytest-cov==5.0.0
# via migro (pyproject.toml)
pytz==2022.7
# via snowflake-connector-python
pyyaml==6.0
pyyaml==6.0.1
# via migro (pyproject.toml)
readme-renderer==37.3
# via twine
Expand All @@ -115,34 +101,31 @@ rfc3986==2.0.0
# via twine
rich==13.0.0
# via twine
ruff==0.0.221
ruff==0.4.4
# via migro (pyproject.toml)
six==1.16.0
# via bleach
snowflake-connector-python==2.9.0
snowflake-connector-python==3.10.0
# via migro (pyproject.toml)
sqlparse==0.4.3
sortedcontainers==2.4.0
# via snowflake-connector-python
sqlparse==0.5.0
# via migro (pyproject.toml)
tomli==2.0.1
# via
# black
# coverage
# pytest
twine==4.0.2
tomlkit==0.12.5
# via snowflake-connector-python
twine==5.1.0
# via migro (pyproject.toml)
typing-extensions==4.4.0
# via
# black
# snowflake-connector-python
# via snowflake-connector-python
urllib3==1.26.13
# via
# requests
# snowflake-connector-python
# twine
webencodings==0.5.1
# via bleach
zipp==3.11.0
# via importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools
78 changes: 57 additions & 21 deletions migro/database.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import base64
import psycopg2
import psycopg2.extras
import snowflake.connector
import sqlite3
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from migro import dbt


def get_database_instance(profile=None):
db_config = dbt.get_output(profile=profile)
def get_database_instance(profile=None, target=None):
db_config = dbt.get_target_output(profile_name=profile, target=target)

if db_config["type"] == "redshift":
return RedshiftDatabase(
host=db_config["host"],
user=db_config["user"],
password=db_config["password"]
if "password" in db_config
else db_config["pass"],
password=(
db_config["password"] if "password" in db_config else db_config["pass"]
),
port=db_config["port"],
dbname=db_config["dbname"],
)
Expand All @@ -29,14 +32,16 @@ def get_database_instance(profile=None):
return SnowflakeDatabase(
account=db_config["account"],
user=db_config["user"],
password=db_config["password"],
password=db_config.get("password"),
database=db_config["database"],
warehouse=db_config["warehouse"],
private_key=db_config.get("private_key"),
private_key_passphrase=db_config.get("private_key_passphrase"),
)


@dataclass
class Database:

MIGRATIONS_TABLE_SQL: ClassVar[str]

def _get_connection(self):
Expand All @@ -58,10 +63,7 @@ def get_migrations(self):

@dataclass
class SqliteDatabase(Database):

MIGRATIONS_TABLE_SQL: ClassVar[
str
] = """
MIGRATIONS_TABLE_SQL: ClassVar[str] = """
create table if not exists migrations
(
id integer primary key,
Expand Down Expand Up @@ -95,9 +97,7 @@ class RedshiftDatabase(Database):
port: int
dbname: str

MIGRATIONS_TABLE_SQL: ClassVar[
str
] = """
MIGRATIONS_TABLE_SQL: ClassVar[str] = """
create table if not exists migrations
(
id int identity not null,
Expand Down Expand Up @@ -144,10 +144,11 @@ class SnowflakeDatabase(Database):
user: str
password: str
database: str
warehouse: str
private_key: Optional[str] = None
private_key_passphrase: Optional[str] = None

MIGRATIONS_TABLE_SQL: ClassVar[
str
] = """
MIGRATIONS_TABLE_SQL: ClassVar[str] = """
create table if not exists PUBLIC.migrations
(
id int identity not null,
Expand All @@ -162,6 +163,41 @@ def _get_connection(self):
password=self.password,
account=self.account,
database=self.database,
warehouse=self.warehouse,
private_key=self._get_private_key(),
)

def _get_private_key(self):
"""
base64 decode the private key, decrypt it, and return an instance of AuthByKeyPair
See dbt-snowflake private key code:
https://github.com/dbt-labs/dbt-snowflake/blob/87a6e808dfb025df1eeef3741ad3822635249889/dbt/adapters/snowflake/connections.py#L244
"""
if not self.private_key:
return None

if self.private_key_passphrase:
encoded_passphrase = self.private_key_passphrase.encode()
else:
encoded_passphrase = None

if self.private_key.startswith("-"):
p_key = serialization.load_pem_private_key(
data=bytes(self.private_key, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)
else:
p_key = serialization.load_der_private_key(
data=base64.b64decode(self.private_key),
password=encoded_passphrase,
backend=default_backend(),
)

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

def create_migrations_table(self):
Expand All @@ -173,9 +209,9 @@ def get_migrations(self):
cur.execute(
(
"""
SELECT ID as "id", MIGRATION as "migration", APPLIED_AT as "applied_at"
FROM public.migrations ORDER BY migration ASC
"""
SELECT ID as "id", MIGRATION as "migration", APPLIED_AT as "applied_at"
FROM PUBLIC.migrations ORDER BY migration ASC
"""
)
)
migrations = cur.fetchall()
Expand Down
76 changes: 63 additions & 13 deletions migro/dbt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,74 @@
import click
import os
import yaml
from migro import jinja


def get_output(profile=None):
def _get_profiles_path():
"""
dbt best practice is to store the profiles.yml file in the ~/.dbt directory.
But it's common to store the profiles.yml file in the root of the dbt project.
migro checks the local project directory and then checks the .dbt home directory.
Return the path to the profiles.yml file.
"""

if not os.path.isfile("profiles.yml"):
click.echo(click.style("Missing dbt profiles.yml", fg="red"))
exit()
if os.path.isfile("./profiles.yml"):
return "./profiles.yml"
elif os.path.isfile("~/.dbt/profiles.yml"):
return "~/.dbt/profiles.yml"
else:
raise Exception("Missing dbt profiles.yml")

profiles = jinja.render_jinja_template("profiles.yml")
profiles = yaml.safe_load(profiles)

for profile_name, p in profiles.items():
if profile and profile != profile_name:
def _profiles_yaml_to_dict(profiles_path):
"""
Render the profiles.yml file using jinja.
Return the rendered profiles.yml file as a dictionary.
"""

profiles_yaml: str = jinja.render_jinja_template(profiles_path)
return yaml.safe_load(profiles_yaml)


def _get_profile(profiles: dict, profile_name: str = None) -> dict:
"""
Return the target dbt profile from a list of profiles.
"""

for key, profile in profiles.items():
# Return the first profile when no profile_name is specified.
if profile_name is None:
return profile

if key == profile_name:
return profile

raise Exception(f"Profile {profile_name} not found in profiles.yml")


def _get_output(profile: dict, target: str = None) -> dict:
"""
Return the target dbt output from a profile.
"""

if not target:
target = profile.get("target")

assert target is not None, "No target specified"

for output_name, output in profile["outputs"].items():
if target != output_name:
continue
return output

raise Exception(f"Target {target} not found in profiles.yml")


target = p["target"]
def get_target_output(profile_name: str = None, target=None) -> dict:
"""
Get the target output database configuration from the dbt profiles.yml file.
"""

for output_key, output in p["outputs"].items():
if len(p["outputs"].keys()) == 1 or target == output_key:
return output
profiles_path = _get_profiles_path()
profiles = _profiles_yaml_to_dict(profiles_path)
profile = _get_profile(profiles, profile_name)
return _get_output(profile, target)
Loading

0 comments on commit 75aeb73

Please sign in to comment.