Skip to content

Commit

Permalink
Add --directory option for storing migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
troyharvey committed May 19, 2024
1 parent 3a98f99 commit 16b4e2c
Show file tree
Hide file tree
Showing 16 changed files with 517 additions and 140 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 6
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11', '3.12']
env:
ACTIONS_ALLOW_UNSECURE_COMMANDS: true
name: Python ${{ matrix.python-version }} tests
Expand Down
19 changes: 19 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

## Test in a dbt Project

1. Install `migro` in edit mode inside a dbt project.

pip install -e ~/Projects/migro

1. Build a distribution wheel.

python setup.py bdist_wheel
Expand All @@ -29,3 +33,18 @@
1. Upload the new version.

twine upload dist/*

## Upgrading Dependencies

1. Install pip tools

pip install pip-tools

1. Update the dependencies in pyproject and setup.py.
1. Compile the requirements file.

pip-compile -o requirements.txt pyproject.toml

1. Compile the dev requirements file.

pip-compile --extra dev -o dev-requirements.txt pyproject.toml
59 changes: 21 additions & 38 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
#
# This file is autogenerated by pip-compile with Python 3.9
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
#
# pip-compile --extra=dev --output-file=dev-requirements.txt pyproject.toml
#
asn1crypto==1.5.1
# via
# oscrypto
# snowflake-connector-python
attrs==22.2.0
# via pytest
black==22.12.0
# via migro (pyproject.toml)
# via snowflake-connector-python
bleach==5.0.1
# via readme-renderer
certifi==2022.12.7
Expand All @@ -27,9 +21,7 @@ charset-normalizer==2.1.1
# requests
# snowflake-connector-python
click==8.1.3
# via
# black
# migro (pyproject.toml)
# via migro (pyproject.toml)
commonmark==0.9.1
# via rich
coverage[toml]==7.0.1
Expand Down Expand Up @@ -64,26 +56,20 @@ markupsafe==2.1.1
# via jinja2
more-itertools==9.0.0
# via jaraco-classes
mypy-extensions==0.4.3
# via black
oscrypto==1.3.0
# via snowflake-connector-python
packaging==22.0
# via pytest
pathspec==0.10.3
# via black
# via
# pytest
# snowflake-connector-python
pkginfo==1.9.2
# via twine
platformdirs==2.6.2
# via black
pluggy==1.0.0
platformdirs==4.2.2
# via snowflake-connector-python
pluggy==1.5.0
# via pytest
psycopg2==2.9.5
# via migro (pyproject.toml)
pycparser==2.21
# via cffi
pycryptodomex==3.16.0
# via snowflake-connector-python
pygments==2.14.0
# via
# readme-renderer
Expand All @@ -92,15 +78,15 @@ pyjwt==2.6.0
# via snowflake-connector-python
pyopenssl==22.1.0
# via snowflake-connector-python
pytest==7.2.0
pytest==8.2.1
# via
# migro (pyproject.toml)
# pytest-cov
pytest-cov==4.0.0
pytest-cov==5.0.0
# via migro (pyproject.toml)
pytz==2022.7
# via snowflake-connector-python
pyyaml==6.0
pyyaml==6.0.1
# via migro (pyproject.toml)
readme-renderer==37.3
# via twine
Expand All @@ -115,34 +101,31 @@ rfc3986==2.0.0
# via twine
rich==13.0.0
# via twine
ruff==0.0.221
ruff==0.4.4
# via migro (pyproject.toml)
six==1.16.0
# via bleach
snowflake-connector-python==2.9.0
snowflake-connector-python==3.10.0
# via migro (pyproject.toml)
sqlparse==0.4.3
sortedcontainers==2.4.0
# via snowflake-connector-python
sqlparse==0.5.0
# via migro (pyproject.toml)
tomli==2.0.1
# via
# black
# coverage
# pytest
twine==4.0.2
tomlkit==0.12.5
# via snowflake-connector-python
twine==5.1.0
# via migro (pyproject.toml)
typing-extensions==4.4.0
# via
# black
# snowflake-connector-python
# via snowflake-connector-python
urllib3==1.26.13
# via
# requests
# snowflake-connector-python
# twine
webencodings==0.5.1
# via bleach
zipp==3.11.0
# via importlib-metadata

# The following packages are considered to be unsafe in a requirements file:
# setuptools
64 changes: 54 additions & 10 deletions migro/database.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
import base64
import psycopg2
import psycopg2.extras
import snowflake.connector
import sqlite3
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from migro import dbt


def get_database_instance(profile=None):
db_config = dbt.get_output(profile=profile)
def get_database_instance(profile=None, target=None):
db_config = dbt.get_target_output(profile_name=profile, target=target)

if db_config["type"] == "redshift":
return RedshiftDatabase(
host=db_config["host"],
user=db_config["user"],
password=db_config["password"]
if "password" in db_config
else db_config["pass"],
password=(
db_config["password"] if "password" in db_config else db_config["pass"]
),
port=db_config["port"],
dbname=db_config["dbname"],
)
Expand All @@ -29,8 +32,11 @@ def get_database_instance(profile=None):
return SnowflakeDatabase(
account=db_config["account"],
user=db_config["user"],
password=db_config["password"],
password=db_config.get("password"),
database=db_config["database"],
warehouse=db_config["warehouse"],
private_key=db_config.get("private_key"),
private_key_passphrase=db_config.get("private_key_passphrase"),
)


Expand Down Expand Up @@ -144,6 +150,9 @@ class SnowflakeDatabase(Database):
user: str
password: str
database: str
warehouse: str
private_key: Optional[str] = None
private_key_passphrase: Optional[str] = None

MIGRATIONS_TABLE_SQL: ClassVar[
str
Expand All @@ -162,6 +171,41 @@ def _get_connection(self):
password=self.password,
account=self.account,
database=self.database,
warehouse=self.warehouse,
private_key=self._get_private_key(),
)

def _get_private_key(self):
"""
base64 decode the private key, decrypt it, and return an instance of AuthByKeyPair
See dbt-snowflake private key code:
https://github.com/dbt-labs/dbt-snowflake/blob/87a6e808dfb025df1eeef3741ad3822635249889/dbt/adapters/snowflake/connections.py#L244
"""
if not self.private_key:
return None

if self.private_key_passphrase:
encoded_passphrase = self.private_key_passphrase.encode()
else:
encoded_passphrase = None

Check warning on line 190 in migro/database.py

View check run for this annotation

Codecov / codecov/patch

migro/database.py#L190

Added line #L190 was not covered by tests

if self.private_key.startswith("-"):
p_key = serialization.load_pem_private_key(
data=bytes(self.private_key, "utf-8"),
password=encoded_passphrase,
backend=default_backend(),
)
else:
p_key = serialization.load_der_private_key(
data=base64.b64decode(self.private_key),
password=encoded_passphrase,
backend=default_backend(),
)

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

def create_migrations_table(self):
Expand All @@ -173,9 +217,9 @@ def get_migrations(self):
cur.execute(
(
"""
SELECT ID as "id", MIGRATION as "migration", APPLIED_AT as "applied_at"
FROM public.migrations ORDER BY migration ASC
"""
SELECT ID as "id", MIGRATION as "migration", APPLIED_AT as "applied_at"
FROM PUBLIC.migrations ORDER BY migration ASC
"""
)
)
migrations = cur.fetchall()
Expand Down
72 changes: 59 additions & 13 deletions migro/dbt.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,70 @@
import click
import os
import yaml
from migro import jinja


def get_output(profile=None):
def _get_profiles_path():
"""
dbt best practice is to store the profiles.yml file in the ~/.dbt directory.
But it's common to store the profiles.yml file in the root of the dbt project.
migro checks the local project directory and then checks the .dbt home directory.
Return the path to the profiles.yml file.
"""

if not os.path.isfile("profiles.yml"):
click.echo(click.style("Missing dbt profiles.yml", fg="red"))
exit()
if os.path.isfile("./profiles.yml"):
return "./profiles.yml"
elif os.path.isfile("~/.dbt/profiles.yml"):
return "~/.dbt/profiles.yml"
else:
raise Exception("Missing dbt profiles.yml")

profiles = jinja.render_jinja_template("profiles.yml")
profiles = yaml.safe_load(profiles)

for profile_name, p in profiles.items():
if profile and profile != profile_name:
def _profiles_yaml_to_dict(profiles_path):
"""
Render the profiles.yml file using jinja.
Return the rendered profiles.yml file as a dictionary.
"""

profiles_yaml: str = jinja.render_jinja_template(profiles_path)
return yaml.safe_load(profiles_yaml)


def _get_profile(profiles: dict, profile_name: str = None) -> dict:
"""
Return the target dbt profile from a list of profiles.
"""

for key, profile in profiles.items():
if key == profile_name:
return profile

raise Exception(f"Profile {profile_name} not found in profiles.yml")


def _get_output(profile: dict, target: str = None) -> dict:
"""
Return the target dbt output from a profile.
"""

if not target:
target = profile.get("target")

assert target is not None, "No target specified"

for output_name, output in profile["outputs"].items():
if target != output_name:
continue
return output

raise Exception(f"Target {target} not found in profiles.yml")


target = p["target"]
def get_target_output(profile_name: str = None, target=None) -> dict:
"""
Get the target output database configuration from the dbt profiles.yml file.
"""

for output_key, output in p["outputs"].items():
if len(p["outputs"].keys()) == 1 or target == output_key:
return output
profiles_path = _get_profiles_path()
profiles = _profiles_yaml_to_dict(profiles_path)
profile = _get_profile(profiles, profile_name)
return _get_output(profile, target)
Loading

0 comments on commit 16b4e2c

Please sign in to comment.