diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000000..9909382d46 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,31 @@ +name: Python Code Formatting (Black) + +on: + push: + branches: + - master + +jobs: + format: + runs-on: ubuntu-latest + + container: + image: python:3.7.4-alpine + + steps: + - uses: actions/checkout@v1 + - name: Install Black + run: apk add gcc musl-dev && pip install black + - name: Run Black + run: black redash tests migrations/versions + - name: Commit formatted code + uses: EndBug/add-and-commit@v2.1.0 + with: + author_name: Black + author_email: team@redash.io + message: "Autoformatted Python code with Black" + path: "." + pattern: "*.py" + force: false + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8bd64799b9..e9c28e6bc6 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -46,8 +46,8 @@ When creating a new bug report, please make sure to: If you would like to suggest an enhancement or ask for a new feature: -- Please check [the roadmap](https://trello.com/b/b2LUHU7A/redash-roadmap) for existing Trello card for what you want to suggest/ask. If there is, feel free to upvote it to signal interest or add your comments. -- If there is no existing card, open a thread in [the forum](https://discuss.redash.io/c/feature-requests) to start a discussion about what you want to suggest. Try to provide as much details and context as possible and include information about *the problem you want to solve* rather only *your proposed solution*. +- Please check [the forum](https://discuss.redash.io/c/feature-requests/5) for existing threads about what you want to suggest/ask. If there is, feel free to upvote it to signal interest or add your comments. +- If there is no open thread, you're welcome to start one to have a discussion about what you want to suggest. Try to provide as much details and context as possible and include information about *the problem you want to solve* rather only *your proposed solution*. ### Pull Requests @@ -55,9 +55,9 @@ If you would like to suggest an enhancement or ask for a new feature: - Include screenshots and animated GIFs in your pull request whenever possible. - Please add [documentation](#documentation) for new features or changes in functionality along with the code. - Please follow existing code style: - - Python: we use PEP8 for Python. - - Javascript: we use Airbnb's style guides for [JavaScript](https://github.com/airbnb/javascript#naming-conventions) and [React](https://github.com/airbnb/javascript/blob/master/react) (currently we don't follow Airbnb's convention for naming files, but we're gradually fixing this). To make it automatic and easy, we recommend using [Prettier](https://github.com/prettier/prettier). - + - Python: we use [Black](https://github.com/psf/black) to auto format the code. + - Javascript: we use [Prettier](https://github.com/prettier/prettier) to auto-format the code. + ### Documentation The project's documentation can be found at [https://redash.io/help/](https://redash.io/help/). The [documentation sources](https://github.com/getredash/website/tree/master/src/pages/kb) are hosted on GitHub. To contribute edits / new pages, you can use GitHub's interface. Click the "Edit on GitHub" link on the documentation page to quickly open the edit interface. @@ -66,9 +66,9 @@ The project's documentation can be found at [https://redash.io/help/](https://re ### Release Method -We publish a stable release every ~2 months, although the goal is to get to a stable release every month. You can see the change log on [GitHub releases page](https://github.com/getredash/redash/releases). +We publish a stable release every ~3-4 months, although the goal is to get to a stable release every month. -Every build of the master branch updates the latest *RC release*. These releases are usually stable, but might contain regressions and therefore recommended for "advanced users" only. +Every build of the master branch updates the *redash/redash:preview* Docker Image. These releases are usually stable, but might contain regressions and therefore recommended for "advanced users" only. When we release a new stable release, we also update the *latest* Docker image tag, the EC2 AMIs and GCE images. diff --git a/migrations/versions/0f740a081d20_inline_tags.py b/migrations/versions/0f740a081d20_inline_tags.py index f38f1d094e..815a249bd4 100644 --- a/migrations/versions/0f740a081d20_inline_tags.py +++ b/migrations/versions/0f740a081d20_inline_tags.py @@ -14,20 +14,20 @@ # revision identifiers, used by Alembic. -revision = '0f740a081d20' -down_revision = 'a92d92aa678e' +revision = "0f740a081d20" +down_revision = "a92d92aa678e" branch_labels = None depends_on = None def upgrade(): - tags_regex = re.compile('^([\w\s]+):|#([\w-]+)', re.I | re.U) + tags_regex = re.compile("^([\w\s]+):|#([\w-]+)", re.I | re.U) connection = op.get_bind() dashboards = connection.execute("SELECT id, name FROM dashboards") update_query = text("UPDATE dashboards SET tags = :tags WHERE id = :id") - + for dashboard in dashboards: tags = compact(flatten(tags_regex.findall(dashboard[1]))) if tags: diff --git a/migrations/versions/1daa601d3ae5_add_columns_for_disabled_users.py b/migrations/versions/1daa601d3ae5_add_columns_for_disabled_users.py index a4277b2ab8..e7dce4e13f 100644 --- a/migrations/versions/1daa601d3ae5_add_columns_for_disabled_users.py +++ b/migrations/versions/1daa601d3ae5_add_columns_for_disabled_users.py @@ -10,18 +10,15 @@ # revision identifiers, used by Alembic. -revision = '1daa601d3ae5' -down_revision = '969126bd800f' +revision = "1daa601d3ae5" +down_revision = "969126bd800f" branch_labels = None depends_on = None def upgrade(): - op.add_column( - 'users', - sa.Column('disabled_at', sa.DateTime(True), nullable=True) - ) + op.add_column("users", sa.Column("disabled_at", sa.DateTime(True), nullable=True)) def downgrade(): - op.drop_column('users', 'disabled_at') + op.drop_column("users", "disabled_at") diff --git a/migrations/versions/5ec5c84ba61e_.py b/migrations/versions/5ec5c84ba61e_.py index ef4e54fd38..d1cbe9184a 100644 --- a/migrations/versions/5ec5c84ba61e_.py +++ b/migrations/versions/5ec5c84ba61e_.py @@ -12,24 +12,28 @@ # revision identifiers, used by Alembic. -revision = '5ec5c84ba61e' -down_revision = '7671dca4e604' +revision = "5ec5c84ba61e" +down_revision = "7671dca4e604" branch_labels = None depends_on = None def upgrade(): conn = op.get_bind() - op.add_column('queries', sa.Column('search_vector', su.TSVectorType())) - op.create_index('ix_queries_search_vector', 'queries', ['search_vector'], - unique=False, postgresql_using='gin') - ss.sync_trigger(conn, 'queries', 'search_vector', - ['name', 'description', 'query']) + op.add_column("queries", sa.Column("search_vector", su.TSVectorType())) + op.create_index( + "ix_queries_search_vector", + "queries", + ["search_vector"], + unique=False, + postgresql_using="gin", + ) + ss.sync_trigger(conn, "queries", "search_vector", ["name", "description", "query"]) def downgrade(): conn = op.get_bind() - ss.drop_trigger(conn, 'queries', 'search_vector') - op.drop_index('ix_queries_search_vector', table_name='queries') - op.drop_column('queries', 'search_vector') + ss.drop_trigger(conn, "queries", "search_vector") + op.drop_index("ix_queries_search_vector", table_name="queries") + op.drop_column("queries", "search_vector") diff --git a/migrations/versions/640888ce445d_.py b/migrations/versions/640888ce445d_.py index e33eee8d5f..0a9edb9b40 100644 --- a/migrations/versions/640888ce445d_.py +++ b/migrations/versions/640888ce445d_.py @@ -15,93 +15,110 @@ # revision identifiers, used by Alembic. -revision = '640888ce445d' -down_revision = '71477dadd6ef' +revision = "640888ce445d" +down_revision = "71477dadd6ef" branch_labels = None depends_on = None def upgrade(): # Copy "schedule" column into "old_schedule" column - op.add_column('queries', sa.Column('old_schedule', sa.String(length=10), nullable=True)) + op.add_column( + "queries", sa.Column("old_schedule", sa.String(length=10), nullable=True) + ) queries = table( - 'queries', - sa.Column('schedule', sa.String(length=10)), - sa.Column('old_schedule', sa.String(length=10))) + "queries", + sa.Column("schedule", sa.String(length=10)), + sa.Column("old_schedule", sa.String(length=10)), + ) - op.execute( - queries - .update() - .values({'old_schedule': queries.c.schedule})) + op.execute(queries.update().values({"old_schedule": queries.c.schedule})) # Recreate "schedule" column as a dict type - op.drop_column('queries', 'schedule') - op.add_column('queries', sa.Column('schedule', MutableDict.as_mutable(PseudoJSON), nullable=False, server_default=json.dumps({}))) + op.drop_column("queries", "schedule") + op.add_column( + "queries", + sa.Column( + "schedule", + MutableDict.as_mutable(PseudoJSON), + nullable=False, + server_default=json.dumps({}), + ), + ) # Move over values from old_schedule queries = table( - 'queries', - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('schedule', MutableDict.as_mutable(PseudoJSON)), - sa.Column('old_schedule', sa.String(length=10))) + "queries", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("schedule", MutableDict.as_mutable(PseudoJSON)), + sa.Column("old_schedule", sa.String(length=10)), + ) conn = op.get_bind() for query in conn.execute(queries.select()): schedule_json = { - 'interval': None, - 'until': None, - 'day_of_week': None, - 'time': None + "interval": None, + "until": None, + "day_of_week": None, + "time": None, } if query.old_schedule is not None: if ":" in query.old_schedule: - schedule_json['interval'] = 86400 - schedule_json['time'] = query.old_schedule + schedule_json["interval"] = 86400 + schedule_json["time"] = query.old_schedule else: - schedule_json['interval'] = query.old_schedule + schedule_json["interval"] = query.old_schedule conn.execute( - queries - .update() - .where(queries.c.id == query.id) - .values(schedule=MutableDict(schedule_json))) + queries.update() + .where(queries.c.id == query.id) + .values(schedule=MutableDict(schedule_json)) + ) + + op.drop_column("queries", "old_schedule") - op.drop_column('queries', 'old_schedule') def downgrade(): - op.add_column('queries', sa.Column('old_schedule', MutableDict.as_mutable(PseudoJSON), nullable=False, server_default=json.dumps({}))) + op.add_column( + "queries", + sa.Column( + "old_schedule", + MutableDict.as_mutable(PseudoJSON), + nullable=False, + server_default=json.dumps({}), + ), + ) queries = table( - 'queries', - sa.Column('schedule', MutableDict.as_mutable(PseudoJSON)), - sa.Column('old_schedule', MutableDict.as_mutable(PseudoJSON))) + "queries", + sa.Column("schedule", MutableDict.as_mutable(PseudoJSON)), + sa.Column("old_schedule", MutableDict.as_mutable(PseudoJSON)), + ) - op.execute( - queries - .update() - .values({'old_schedule': queries.c.schedule})) + op.execute(queries.update().values({"old_schedule": queries.c.schedule})) - op.drop_column('queries', 'schedule') - op.add_column('queries', sa.Column('schedule', sa.String(length=10), nullable=True)) + op.drop_column("queries", "schedule") + op.add_column("queries", sa.Column("schedule", sa.String(length=10), nullable=True)) queries = table( - 'queries', - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('schedule', sa.String(length=10)), - sa.Column('old_schedule', MutableDict.as_mutable(PseudoJSON))) + "queries", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("schedule", sa.String(length=10)), + sa.Column("old_schedule", MutableDict.as_mutable(PseudoJSON)), + ) conn = op.get_bind() for query in conn.execute(queries.select()): - scheduleValue = query.old_schedule['interval'] + scheduleValue = query.old_schedule["interval"] if scheduleValue <= 86400: - scheduleValue = query.old_schedule['time'] + scheduleValue = query.old_schedule["time"] conn.execute( - queries - .update() - .where(queries.c.id == query.id) - .values(schedule=scheduleValue)) + queries.update() + .where(queries.c.id == query.id) + .values(schedule=scheduleValue) + ) - op.drop_column('queries', 'old_schedule') + op.drop_column("queries", "old_schedule") diff --git a/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py index fe0ae00082..e14a893299 100644 --- a/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py +++ b/migrations/versions/65fc9ede4746_add_is_draft_status_to_queries_and_.py @@ -12,7 +12,7 @@ # revision identifiers, used by Alembic. from sqlalchemy.exc import ProgrammingError -revision = '65fc9ede4746' +revision = "65fc9ede4746" down_revision = None branch_labels = None depends_on = None @@ -20,18 +20,28 @@ def upgrade(): try: - op.add_column('queries', sa.Column('is_draft', sa.Boolean, default=True, index=True)) - op.add_column('dashboards', sa.Column('is_draft', sa.Boolean, default=True, index=True)) + op.add_column( + "queries", sa.Column("is_draft", sa.Boolean, default=True, index=True) + ) + op.add_column( + "dashboards", sa.Column("is_draft", sa.Boolean, default=True, index=True) + ) op.execute("UPDATE queries SET is_draft = (name = 'New Query')") op.execute("UPDATE dashboards SET is_draft = false") except ProgrammingError as e: # The columns might exist if you ran the old migrations. if 'column "is_draft" of relation "queries" already exists' in str(e): - print("Can't run this migration as you already have is_draft columns, please run:") - print("./manage.py db stamp {} # you might need to alter the command to match your environment.".format(revision)) + print( + "Can't run this migration as you already have is_draft columns, please run:" + ) + print( + "./manage.py db stamp {} # you might need to alter the command to match your environment.".format( + revision + ) + ) exit() def downgrade(): - op.drop_column('queries', 'is_draft') - op.drop_column('dashboards', 'is_draft') + op.drop_column("queries", "is_draft") + op.drop_column("dashboards", "is_draft") diff --git a/migrations/versions/6b5be7e0a0ef_.py b/migrations/versions/6b5be7e0a0ef_.py index 9be2b96181..d7a727fe04 100644 --- a/migrations/versions/6b5be7e0a0ef_.py +++ b/migrations/versions/6b5be7e0a0ef_.py @@ -11,8 +11,8 @@ # revision identifiers, used by Alembic. -revision = '6b5be7e0a0ef' -down_revision = '5ec5c84ba61e' +revision = "6b5be7e0a0ef" +down_revision = "5ec5c84ba61e" branch_labels = None depends_on = None @@ -23,7 +23,7 @@ def upgrade(): conn = op.get_bind() metadata = sa.MetaData(bind=conn) - queries = sa.Table('queries', metadata, autoload=True) + queries = sa.Table("queries", metadata, autoload=True) @ss.vectorizer(queries.c.id) def integer_vectorizer(column): @@ -31,18 +31,22 @@ def integer_vectorizer(column): ss.sync_trigger( conn, - 'queries', - 'search_vector', - ['id', 'name', 'description', 'query'], - metadata=metadata + "queries", + "search_vector", + ["id", "name", "description", "query"], + metadata=metadata, ) def downgrade(): conn = op.get_bind() - ss.drop_trigger(conn, 'queries', 'search_vector') - op.drop_index('ix_queries_search_vector', table_name='queries') - op.create_index('ix_queries_search_vector', 'queries', ['search_vector'], - unique=False, postgresql_using='gin') - ss.sync_trigger(conn, 'queries', 'search_vector', - ['name', 'description', 'query']) + ss.drop_trigger(conn, "queries", "search_vector") + op.drop_index("ix_queries_search_vector", table_name="queries") + op.create_index( + "ix_queries_search_vector", + "queries", + ["search_vector"], + unique=False, + postgresql_using="gin", + ) + ss.sync_trigger(conn, "queries", "search_vector", ["name", "description", "query"]) diff --git a/migrations/versions/71477dadd6ef_favorites_unique_constraint.py b/migrations/versions/71477dadd6ef_favorites_unique_constraint.py index 15c6587782..6bc5d99665 100644 --- a/migrations/versions/71477dadd6ef_favorites_unique_constraint.py +++ b/migrations/versions/71477dadd6ef_favorites_unique_constraint.py @@ -10,15 +10,17 @@ # revision identifiers, used by Alembic. -revision = '71477dadd6ef' -down_revision = '0f740a081d20' +revision = "71477dadd6ef" +down_revision = "0f740a081d20" branch_labels = None depends_on = None def upgrade(): - op.create_unique_constraint('unique_favorite', 'favorites', ['object_type', 'object_id', 'user_id']) + op.create_unique_constraint( + "unique_favorite", "favorites", ["object_type", "object_id", "user_id"] + ) def downgrade(): - op.drop_constraint('unique_favorite', 'favorites', type_='unique') + op.drop_constraint("unique_favorite", "favorites", type_="unique") diff --git a/migrations/versions/73beceabb948_bring_back_null_schedule.py b/migrations/versions/73beceabb948_bring_back_null_schedule.py index 189282ef43..b510639dd2 100644 --- a/migrations/versions/73beceabb948_bring_back_null_schedule.py +++ b/migrations/versions/73beceabb948_bring_back_null_schedule.py @@ -13,8 +13,8 @@ from redash.models import MutableDict, PseudoJSON # revision identifiers, used by Alembic. -revision = '73beceabb948' -down_revision = 'e7f8a917aa8e' +revision = "73beceabb948" +down_revision = "e7f8a917aa8e" branch_labels = None depends_on = None @@ -26,30 +26,32 @@ def is_empty_schedule(schedule): if schedule == {}: return True - if schedule.get('interval') is None and schedule.get('until') is None and schedule.get('day_of_week') is None and schedule.get('time') is None: + if ( + schedule.get("interval") is None + and schedule.get("until") is None + and schedule.get("day_of_week") is None + and schedule.get("time") is None + ): return True return False def upgrade(): - op.alter_column('queries', 'schedule', - nullable=True, - server_default=None) + op.alter_column("queries", "schedule", nullable=True, server_default=None) queries = table( - 'queries', - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('schedule', MutableDict.as_mutable(PseudoJSON))) + "queries", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column("schedule", MutableDict.as_mutable(PseudoJSON)), + ) conn = op.get_bind() for query in conn.execute(queries.select()): if is_empty_schedule(query.schedule): conn.execute( - queries - .update() - .where(queries.c.id == query.id) - .values(schedule=None)) + queries.update().where(queries.c.id == query.id).values(schedule=None) + ) def downgrade(): diff --git a/migrations/versions/7671dca4e604_.py b/migrations/versions/7671dca4e604_.py index ee981b1e39..d73887fbaa 100644 --- a/migrations/versions/7671dca4e604_.py +++ b/migrations/versions/7671dca4e604_.py @@ -10,16 +10,18 @@ # revision identifiers, used by Alembic. -revision = '7671dca4e604' -down_revision = 'd1eae8b9893e' +revision = "7671dca4e604" +down_revision = "d1eae8b9893e" branch_labels = None depends_on = None def upgrade(): - op.add_column('users', sa.Column('profile_image_url', sa.String(), - nullable=True, server_default=None)) + op.add_column( + "users", + sa.Column("profile_image_url", sa.String(), nullable=True, server_default=None), + ) def downgrade(): - op.drop_column('users', 'profile_image_url') + op.drop_column("users", "profile_image_url") diff --git a/migrations/versions/969126bd800f_.py b/migrations/versions/969126bd800f_.py index a62bd840d3..17eec1153d 100644 --- a/migrations/versions/969126bd800f_.py +++ b/migrations/versions/969126bd800f_.py @@ -14,8 +14,8 @@ # revision identifiers, used by Alembic. -revision = '969126bd800f' -down_revision = '6b5be7e0a0ef' +revision = "969126bd800f" +down_revision = "6b5be7e0a0ef" branch_labels = None depends_on = None @@ -26,17 +26,18 @@ def upgrade(): print("Updating dashboards position data:") dashboard_result = db.session.execute("SELECT id, layout FROM dashboards") for dashboard in dashboard_result: - print(" Updating dashboard: {}".format(dashboard['id'])) - layout = simplejson.loads(dashboard['layout']) + print(" Updating dashboard: {}".format(dashboard["id"])) + layout = simplejson.loads(dashboard["layout"]) print(" Building widgets map:") widgets = {} widget_result = db.session.execute( - "SELECT id, options, width FROM widgets WHERE dashboard_id=:dashboard_id", - {"dashboard_id" : dashboard['id']}) + "SELECT id, options, width FROM widgets WHERE dashboard_id=:dashboard_id", + {"dashboard_id": dashboard["id"]}, + ) for w in widget_result: - print(" Widget: {}".format(w['id'])) - widgets[w['id']] = w + print(" Widget: {}".format(w["id"])) + widgets[w["id"]] = w widget_result.close() print(" Iterating over layout:") @@ -52,25 +53,32 @@ def upgrade(): if widget is None: continue - options = simplejson.loads(widget['options']) or {} - options['position'] = { + options = simplejson.loads(widget["options"]) or {} + options["position"] = { "row": row_index, "col": column_index * column_size, - "sizeX": column_size * widget.width + "sizeX": column_size * widget.width, } db.session.execute( "UPDATE widgets SET options=:options WHERE id=:id", - {"options" : simplejson.dumps(options), "id" : widget_id}) + {"options": simplejson.dumps(options), "id": widget_id}, + ) dashboard_result.close() db.session.commit() # Remove legacy columns no longer in use. - op.drop_column('widgets', 'type') - op.drop_column('widgets', 'query_id') + op.drop_column("widgets", "type") + op.drop_column("widgets", "query_id") def downgrade(): - op.add_column('widgets', sa.Column('query_id', sa.INTEGER(), autoincrement=False, nullable=True)) - op.add_column('widgets', sa.Column('type', sa.VARCHAR(length=100), autoincrement=False, nullable=True)) + op.add_column( + "widgets", + sa.Column("query_id", sa.INTEGER(), autoincrement=False, nullable=True), + ) + op.add_column( + "widgets", + sa.Column("type", sa.VARCHAR(length=100), autoincrement=False, nullable=True), + ) diff --git a/migrations/versions/98af61feea92_add_encrypted_options_to_data_sources.py b/migrations/versions/98af61feea92_add_encrypted_options_to_data_sources.py index 2ca5e9cd75..23670adfee 100644 --- a/migrations/versions/98af61feea92_add_encrypted_options_to_data_sources.py +++ b/migrations/versions/98af61feea92_add_encrypted_options_to_data_sources.py @@ -13,36 +13,52 @@ from redash import settings from redash.utils.configuration import ConfigurationContainer -from redash.models.types import EncryptedConfiguration, Configuration, MutableDict, MutableList, PseudoJSON +from redash.models.types import ( + EncryptedConfiguration, + Configuration, + MutableDict, + MutableList, + PseudoJSON, +) # revision identifiers, used by Alembic. -revision = '98af61feea92' -down_revision = '73beceabb948' +revision = "98af61feea92" +down_revision = "73beceabb948" branch_labels = None depends_on = None def upgrade(): - op.add_column('data_sources', sa.Column('encrypted_options', postgresql.BYTEA(), nullable=True)) + op.add_column( + "data_sources", + sa.Column("encrypted_options", postgresql.BYTEA(), nullable=True), + ) # copy values data_sources = table( - 'data_sources', - sa.Column('id', sa.Integer, primary_key=True), - sa.Column('encrypted_options', ConfigurationContainer.as_mutable(EncryptedConfiguration(sa.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine))), - sa.Column('options', ConfigurationContainer.as_mutable(Configuration))) + "data_sources", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column( + "encrypted_options", + ConfigurationContainer.as_mutable( + EncryptedConfiguration( + sa.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine + ) + ), + ), + sa.Column("options", ConfigurationContainer.as_mutable(Configuration)), + ) conn = op.get_bind() for ds in conn.execute(data_sources.select()): conn.execute( - data_sources - .update() + data_sources.update() .where(data_sources.c.id == ds.id) - .values(encrypted_options=ds.options)) + .values(encrypted_options=ds.options) + ) - op.drop_column('data_sources', 'options') - op.alter_column('data_sources', 'encrypted_options', - nullable=False) + op.drop_column("data_sources", "options") + op.alter_column("data_sources", "encrypted_options", nullable=False) def downgrade(): diff --git a/migrations/versions/a92d92aa678e_inline_tags.py b/migrations/versions/a92d92aa678e_inline_tags.py index 8e9465b316..f79924dc62 100644 --- a/migrations/versions/a92d92aa678e_inline_tags.py +++ b/migrations/versions/a92d92aa678e_inline_tags.py @@ -13,17 +13,21 @@ from redash import models # revision identifiers, used by Alembic. -revision = 'a92d92aa678e' -down_revision = 'e7004224f284' +revision = "a92d92aa678e" +down_revision = "e7004224f284" branch_labels = None depends_on = None def upgrade(): - op.add_column('dashboards', sa.Column('tags', postgresql.ARRAY(sa.Unicode()), nullable=True)) - op.add_column('queries', sa.Column('tags', postgresql.ARRAY(sa.Unicode()), nullable=True)) + op.add_column( + "dashboards", sa.Column("tags", postgresql.ARRAY(sa.Unicode()), nullable=True) + ) + op.add_column( + "queries", sa.Column("tags", postgresql.ARRAY(sa.Unicode()), nullable=True) + ) def downgrade(): - op.drop_column('queries', 'tags') - op.drop_column('dashboards', 'tags') + op.drop_column("queries", "tags") + op.drop_column("dashboards", "tags") diff --git a/migrations/versions/d1eae8b9893e_.py b/migrations/versions/d1eae8b9893e_.py index 9d7d5fc5da..3badf402fd 100644 --- a/migrations/versions/d1eae8b9893e_.py +++ b/migrations/versions/d1eae8b9893e_.py @@ -10,16 +10,20 @@ # revision identifiers, used by Alembic. -revision = 'd1eae8b9893e' -down_revision = '65fc9ede4746' +revision = "d1eae8b9893e" +down_revision = "65fc9ede4746" branch_labels = None depends_on = None def upgrade(): - op.add_column('queries', sa.Column('schedule_failures', sa.Integer(), - nullable=False, server_default='0')) + op.add_column( + "queries", + sa.Column( + "schedule_failures", sa.Integer(), nullable=False, server_default="0" + ), + ) def downgrade(): - op.drop_column('queries', 'schedule_failures') + op.drop_column("queries", "schedule_failures") diff --git a/migrations/versions/d4c798575877_create_favorites.py b/migrations/versions/d4c798575877_create_favorites.py index f71271ae1c..4693f5cb1b 100644 --- a/migrations/versions/d4c798575877_create_favorites.py +++ b/migrations/versions/d4c798575877_create_favorites.py @@ -10,24 +10,25 @@ # revision identifiers, used by Alembic. -revision = 'd4c798575877' -down_revision = '1daa601d3ae5' +revision = "d4c798575877" +down_revision = "1daa601d3ae5" branch_labels = None depends_on = None def upgrade(): - op.create_table('favorites', - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('id', sa.Integer(), nullable=False), - sa.Column('object_type', sa.Unicode(length=255), nullable=False), - sa.Column('object_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), - sa.PrimaryKeyConstraint('id') + op.create_table( + "favorites", + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("object_type", sa.Unicode(length=255), nullable=False), + sa.Column("object_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.PrimaryKeyConstraint("id"), ) def downgrade(): - op.drop_table('favorites') + op.drop_table("favorites") diff --git a/migrations/versions/e5c7a4e2df4d_remove_query_tracker_keys.py b/migrations/versions/e5c7a4e2df4d_remove_query_tracker_keys.py index 0bffe8db54..9f5e5acb31 100644 --- a/migrations/versions/e5c7a4e2df4d_remove_query_tracker_keys.py +++ b/migrations/versions/e5c7a4e2df4d_remove_query_tracker_keys.py @@ -12,15 +12,15 @@ # revision identifiers, used by Alembic. -revision = 'e5c7a4e2df4d' -down_revision = '98af61feea92' +revision = "e5c7a4e2df4d" +down_revision = "98af61feea92" branch_labels = None depends_on = None -DONE_LIST = 'query_task_trackers:done' -WAITING_LIST = 'query_task_trackers:waiting' -IN_PROGRESS_LIST = 'query_task_trackers:in_progress' +DONE_LIST = "query_task_trackers:done" +WAITING_LIST = "query_task_trackers:waiting" +IN_PROGRESS_LIST = "query_task_trackers:in_progress" def prune(list_name, keep_count, max_keys=100): diff --git a/migrations/versions/e7004224f284_add_org_id_to_favorites.py b/migrations/versions/e7004224f284_add_org_id_to_favorites.py index d220802195..1d7bce148e 100644 --- a/migrations/versions/e7004224f284_add_org_id_to_favorites.py +++ b/migrations/versions/e7004224f284_add_org_id_to_favorites.py @@ -10,17 +10,17 @@ # revision identifiers, used by Alembic. -revision = 'e7004224f284' -down_revision = 'd4c798575877' +revision = "e7004224f284" +down_revision = "d4c798575877" branch_labels = None depends_on = None def upgrade(): - op.add_column('favorites', sa.Column('org_id', sa.Integer(), nullable=False)) - op.create_foreign_key(None, 'favorites', 'organizations', ['org_id'], ['id']) + op.add_column("favorites", sa.Column("org_id", sa.Integer(), nullable=False)) + op.create_foreign_key(None, "favorites", "organizations", ["org_id"], ["id"]) def downgrade(): - op.drop_constraint(None, 'favorites', type_='foreignkey') - op.drop_column('favorites', 'org_id') + op.drop_constraint(None, "favorites", type_="foreignkey") + op.drop_column("favorites", "org_id") diff --git a/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py b/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py index cd47bb774c..a5a827091c 100644 --- a/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py +++ b/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py @@ -10,15 +10,23 @@ from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. -revision = 'e7f8a917aa8e' -down_revision = '640888ce445d' +revision = "e7f8a917aa8e" +down_revision = "640888ce445d" branch_labels = None depends_on = None def upgrade(): - op.add_column('users', sa.Column('details', postgresql.JSON(astext_type=sa.Text()), server_default='{}', nullable=True)) + op.add_column( + "users", + sa.Column( + "details", + postgresql.JSON(astext_type=sa.Text()), + server_default="{}", + nullable=True, + ), + ) def downgrade(): - op.drop_column('users', 'details') + op.drop_column("users", "details") diff --git a/redash/__init__.py b/redash/__init__.py index 87ac101c34..32bd7a4e44 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -15,12 +15,13 @@ from .query_runner import import_query_runners from .destinations import import_destinations -__version__ = '9.0.0-alpha' +__version__ = "9.0.0-alpha" if os.environ.get("REMOTE_DEBUG"): import ptvsd - ptvsd.enable_attach(address=('0.0.0.0', 5678)) + + ptvsd.enable_attach(address=("0.0.0.0", 5678)) def setup_logging(): @@ -32,7 +33,12 @@ def setup_logging(): # Make noisy libraries less noisy if settings.LOG_LEVEL != "DEBUG": - for name in ["passlib", "requests.packages.urllib3", "snowflake.connector", "apiclient"]: + for name in [ + "passlib", + "requests.packages.urllib3", + "snowflake.connector", + "apiclient", + ]: logging.getLogger(name).setLevel("ERROR") @@ -42,7 +48,9 @@ def setup_logging(): rq_redis_connection = redis.from_url(settings.RQ_REDIS_URL) mail = Mail() migrate = Migrate() -statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX) +statsd_client = StatsClient( + host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX +) limiter = Limiter(key_func=get_ipaddr, storage_uri=settings.LIMITER_STORAGE) import_query_runners(settings.QUERY_RUNNERS) diff --git a/redash/app.py b/redash/app.py index 8bca2de5f1..20bcf20927 100644 --- a/redash/app.py +++ b/redash/app.py @@ -6,17 +6,20 @@ class Redash(Flask): """A custom Flask app for Redash""" + def __init__(self, *args, **kwargs): - kwargs.update({ - 'template_folder': settings.STATIC_ASSETS_PATH, - 'static_folder': settings.STATIC_ASSETS_PATH, - 'static_url_path': '/static', - }) + kwargs.update( + { + "template_folder": settings.STATIC_ASSETS_PATH, + "static_folder": settings.STATIC_ASSETS_PATH, + "static_url_path": "/static", + } + ) super(Redash, self).__init__(__name__, *args, **kwargs) # Make sure we get the right referral address even behind proxies like nginx. self.wsgi_app = ProxyFix(self.wsgi_app, settings.PROXIES_COUNT) # Configure Redash using our settings - self.config.from_object('redash.settings') + self.config.from_object("redash.settings") def create_app(): diff --git a/redash/authentication/__init__.py b/redash/authentication/__init__.py index d0058694d0..7aeea8f332 100644 --- a/redash/authentication/__init__.py +++ b/redash/authentication/__init__.py @@ -15,16 +15,18 @@ from werkzeug.exceptions import Unauthorized login_manager = LoginManager() -logger = logging.getLogger('authentication') +logger = logging.getLogger("authentication") def get_login_url(external=False, next="/"): if settings.MULTI_ORG and current_org == None: - login_url = '/' + login_url = "/" elif settings.MULTI_ORG: - login_url = url_for('redash.login', org_slug=current_org.slug, next=next, _external=external) + login_url = url_for( + "redash.login", org_slug=current_org.slug, next=next, _external=external + ) else: - login_url = url_for('redash.login', next=next, _external=external) + login_url = url_for("redash.login", next=next, _external=external) return login_url @@ -60,24 +62,28 @@ def load_user(user_id_with_identity): def request_loader(request): user = None - if settings.AUTH_TYPE == 'hmac': + if settings.AUTH_TYPE == "hmac": user = hmac_load_user_from_request(request) - elif settings.AUTH_TYPE == 'api_key': + elif settings.AUTH_TYPE == "api_key": user = api_key_load_user_from_request(request) else: - logger.warning("Unknown authentication type ({}). Using default (HMAC).".format(settings.AUTH_TYPE)) + logger.warning( + "Unknown authentication type ({}). Using default (HMAC).".format( + settings.AUTH_TYPE + ) + ) user = hmac_load_user_from_request(request) - if org_settings['auth_jwt_login_enabled'] and user is None: + if org_settings["auth_jwt_login_enabled"] and user is None: user = jwt_token_load_user_from_request(request) return user def hmac_load_user_from_request(request): - signature = request.args.get('signature') - expires = float(request.args.get('expires') or 0) - query_id = request.view_args.get('query_id', None) - user_id = request.args.get('user_id', None) + signature = request.args.get("signature") + expires = float(request.args.get("expires") or 0) + query_id = request.view_args.get("query_id", None) + user_id = request.args.get("user_id", None) # TODO: 3600 should be a setting if signature and time.time() < expires <= time.time() + 3600: @@ -93,7 +99,12 @@ def hmac_load_user_from_request(request): calculated_signature = sign(query.api_key, request.path, expires) if query.api_key and signature == calculated_signature: - return models.ApiUser(query.api_key, query.org, list(query.groups.keys()), name="ApiKey: Query {}".format(query.id)) + return models.ApiUser( + query.api_key, + query.org, + list(query.groups.keys()), + name="ApiKey: Query {}".format(query.id), + ) return None @@ -118,22 +129,27 @@ def get_user_from_api_key(api_key, query_id): if query_id: query = models.Query.get_by_id_and_org(query_id, org) if query and query.api_key == api_key: - user = models.ApiUser(api_key, query.org, list(query.groups.keys()), name="ApiKey: Query {}".format(query.id)) + user = models.ApiUser( + api_key, + query.org, + list(query.groups.keys()), + name="ApiKey: Query {}".format(query.id), + ) return user def get_api_key_from_request(request): - api_key = request.args.get('api_key', None) + api_key = request.args.get("api_key", None) if api_key is not None: return api_key - if request.headers.get('Authorization'): - auth_header = request.headers.get('Authorization') - api_key = auth_header.replace('Key ', '', 1) - elif request.view_args is not None and request.view_args.get('token'): - api_key = request.view_args['token'] + if request.headers.get("Authorization"): + auth_header = request.headers.get("Authorization") + api_key = auth_header.replace("Key ", "", 1) + elif request.view_args is not None and request.view_args.get("token"): + api_key = request.view_args["token"] return api_key @@ -141,7 +157,7 @@ def get_api_key_from_request(request): def api_key_load_user_from_request(request): api_key = get_api_key_from_request(request) if request.view_args is not None: - query_id = request.view_args.get('query_id', None) + query_id = request.view_args.get("query_id", None) user = get_user_from_api_key(api_key, query_id) else: user = None @@ -154,44 +170,44 @@ def jwt_token_load_user_from_request(request): payload = None - if org_settings['auth_jwt_auth_cookie_name']: - jwt_token = request.cookies.get(org_settings['auth_jwt_auth_cookie_name'], None) - elif org_settings['auth_jwt_auth_header_name']: - jwt_token = request.headers.get(org_settings['auth_jwt_auth_header_name'], None) + if org_settings["auth_jwt_auth_cookie_name"]: + jwt_token = request.cookies.get(org_settings["auth_jwt_auth_cookie_name"], None) + elif org_settings["auth_jwt_auth_header_name"]: + jwt_token = request.headers.get(org_settings["auth_jwt_auth_header_name"], None) else: return None if jwt_token: payload, token_is_valid = jwt_auth.verify_jwt_token( jwt_token, - expected_issuer=org_settings['auth_jwt_auth_issuer'], - expected_audience=org_settings['auth_jwt_auth_audience'], - algorithms=org_settings['auth_jwt_auth_algorithms'], - public_certs_url=org_settings['auth_jwt_auth_public_certs_url'], + expected_issuer=org_settings["auth_jwt_auth_issuer"], + expected_audience=org_settings["auth_jwt_auth_audience"], + algorithms=org_settings["auth_jwt_auth_algorithms"], + public_certs_url=org_settings["auth_jwt_auth_public_certs_url"], ) if not token_is_valid: - raise Unauthorized('Invalid JWT token') + raise Unauthorized("Invalid JWT token") if not payload: return try: - user = models.User.get_by_email_and_org(payload['email'], org) + user = models.User.get_by_email_and_org(payload["email"], org) except models.NoResultFound: - user = create_and_login_user(current_org, payload['email'], payload['email']) + user = create_and_login_user(current_org, payload["email"], payload["email"]) return user def log_user_logged_in(app, user): event = { - 'org_id': user.org_id, - 'user_id': user.id, - 'action': 'login', - 'object_type': 'redash', - 'timestamp': int(time.time()), - 'user_agent': request.user_agent.string, - 'ip': request.remote_addr + "org_id": user.org_id, + "user_id": user.id, + "action": "login", + "object_type": "redash", + "timestamp": int(time.time()), + "user_agent": request.user_agent.string, + "ip": request.remote_addr, } record_event.delay(event) @@ -199,8 +215,10 @@ def log_user_logged_in(app, user): @login_manager.unauthorized_handler def redirect_to_login(): - if request.is_xhr or '/api/' in request.path: - response = jsonify({'message': "Couldn't find resource. Please login and try again."}) + if request.is_xhr or "/api/" in request.path: + response = jsonify( + {"message": "Couldn't find resource. Please login and try again."} + ) response.status_code = 404 return response @@ -213,17 +231,22 @@ def logout_and_redirect_to_index(): logout_user() if settings.MULTI_ORG and current_org == None: - index_url = '/' + index_url = "/" elif settings.MULTI_ORG: - index_url = url_for('redash.index', org_slug=current_org.slug, _external=False) + index_url = url_for("redash.index", org_slug=current_org.slug, _external=False) else: - index_url = url_for('redash.index', _external=False) + index_url = url_for("redash.index", _external=False) return redirect(index_url) def init_app(app): - from redash.authentication import google_oauth, saml_auth, remote_user_auth, ldap_auth + from redash.authentication import ( + google_oauth, + saml_auth, + remote_user_auth, + ldap_auth, + ) login_manager.init_app(app) login_manager.anonymous_user = models.AnonymousUser @@ -251,8 +274,14 @@ def create_and_login_user(org, name, email, picture=None): models.db.session.commit() except NoResultFound: logger.debug("Creating user object (%r)", name) - user_object = models.User(org=org, name=name, email=email, is_invitation_pending=False, - _profile_image_url=picture, group_ids=[org.default_group.id]) + user_object = models.User( + org=org, + name=name, + email=email, + is_invitation_pending=False, + _profile_image_url=picture, + group_ids=[org.default_group.id], + ) models.db.session.add(user_object) models.db.session.commit() @@ -263,18 +292,18 @@ def create_and_login_user(org, name, email, picture=None): def get_next_path(unsafe_next_path): if not unsafe_next_path: - return '' + return "" # Preventing open redirection attacks parts = list(urlsplit(unsafe_next_path)) - parts[0] = '' # clear scheme - parts[1] = '' # clear netloc + parts[0] = "" # clear scheme + parts[1] = "" # clear netloc safe_next_path = urlunsplit(parts) # If the original path was a URL, we might end up with an empty - # safe url, which will redirect to the login page. Changing to + # safe url, which will redirect to the login page. Changing to # relative root to redirect to the app root after login. if not safe_next_path: - safe_next_path = './' + safe_next_path = "./" return safe_next_path diff --git a/redash/authentication/account.py b/redash/authentication/account.py index c20b60aab2..c826a71aa4 100644 --- a/redash/authentication/account.py +++ b/redash/authentication/account.py @@ -4,6 +4,7 @@ from redash import settings from redash.tasks import send_mail from redash.utils import base_url + # noinspection PyUnresolvedReferences from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature @@ -42,12 +43,9 @@ def validate_token(token): def send_verify_email(user, org): - context = { - 'user': user, - 'verify_url': verify_link_for_user(user), - } - html_content = render_template('emails/verify.html', **context) - text_content = render_template('emails/verify.txt', **context) + context = {"user": user, "verify_url": verify_link_for_user(user)} + html_content = render_template("emails/verify.html", **context) + text_content = render_template("emails/verify.txt", **context) subject = "{}, please verify your email address".format(user.name) send_mail.delay([user.email], subject, html_content, text_content) @@ -55,8 +53,8 @@ def send_verify_email(user, org): def send_invite_email(inviter, invited, invite_url, org): context = dict(inviter=inviter, invited=invited, org=org, invite_url=invite_url) - html_content = render_template('emails/invite.html', **context) - text_content = render_template('emails/invite.txt', **context) + html_content = render_template("emails/invite.html", **context) + text_content = render_template("emails/invite.txt", **context) subject = "{} invited you to join Redash".format(inviter.name) send_mail.delay([invited.email], subject, html_content, text_content) @@ -65,17 +63,17 @@ def send_invite_email(inviter, invited, invite_url, org): def send_password_reset_email(user): reset_link = reset_link_for_user(user) context = dict(user=user, reset_link=reset_link) - html_content = render_template('emails/reset.html', **context) - text_content = render_template('emails/reset.txt', **context) + html_content = render_template("emails/reset.html", **context) + text_content = render_template("emails/reset.txt", **context) subject = "Reset your password" send_mail.delay([user.email], subject, html_content, text_content) return reset_link - + def send_user_disabled_email(user): - html_content = render_template('emails/reset_disabled.html', user=user) - text_content = render_template('emails/reset_disabled.txt', user=user) + html_content = render_template("emails/reset_disabled.html", user=user) + text_content = render_template("emails/reset_disabled.txt", user=user) subject = "Your Redash account is disabled" send_mail.delay([user.email], subject, html_content, text_content) diff --git a/redash/authentication/google_oauth.py b/redash/authentication/google_oauth.py index f0de8b2425..59d49ef90e 100644 --- a/redash/authentication/google_oauth.py +++ b/redash/authentication/google_oauth.py @@ -4,35 +4,43 @@ from flask_oauthlib.client import OAuth from redash import models, settings -from redash.authentication import create_and_login_user, logout_and_redirect_to_index, get_next_path +from redash.authentication import ( + create_and_login_user, + logout_and_redirect_to_index, + get_next_path, +) from redash.authentication.org_resolving import current_org -logger = logging.getLogger('google_oauth') +logger = logging.getLogger("google_oauth") oauth = OAuth() -blueprint = Blueprint('google_oauth', __name__) +blueprint = Blueprint("google_oauth", __name__) def google_remote_app(): - if 'google' not in oauth.remote_apps: - oauth.remote_app('google', - base_url='https://www.google.com/accounts/', - authorize_url='https://accounts.google.com/o/oauth2/auth?prompt=select_account+consent', - request_token_url=None, - request_token_params={ - 'scope': 'https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile', - }, - access_token_url='https://accounts.google.com/o/oauth2/token', - access_token_method='POST', - consumer_key=settings.GOOGLE_CLIENT_ID, - consumer_secret=settings.GOOGLE_CLIENT_SECRET) + if "google" not in oauth.remote_apps: + oauth.remote_app( + "google", + base_url="https://www.google.com/accounts/", + authorize_url="https://accounts.google.com/o/oauth2/auth?prompt=select_account+consent", + request_token_url=None, + request_token_params={ + "scope": "https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" + }, + access_token_url="https://accounts.google.com/o/oauth2/token", + access_token_method="POST", + consumer_key=settings.GOOGLE_CLIENT_ID, + consumer_secret=settings.GOOGLE_CLIENT_SECRET, + ) return oauth.google def get_user_profile(access_token): - headers = {'Authorization': 'OAuth {}'.format(access_token)} - response = requests.get('https://www.googleapis.com/oauth2/v1/userinfo', headers=headers) + headers = {"Authorization": "OAuth {}".format(access_token)} + response = requests.get( + "https://www.googleapis.com/oauth2/v1/userinfo", headers=headers + ) if response.status_code == 401: logger.warning("Failed getting user profile (response code 401).") @@ -45,8 +53,8 @@ def verify_profile(org, profile): if org.is_public: return True - email = profile['email'] - domain = email.split('@')[-1] + email = profile["email"] + domain = email.split("@")[-1] if domain in org.google_apps_domains: return True @@ -57,52 +65,60 @@ def verify_profile(org, profile): return False -@blueprint.route('//oauth/google', endpoint="authorize_org") +@blueprint.route("//oauth/google", endpoint="authorize_org") def org_login(org_slug): - session['org_slug'] = current_org.slug - return redirect(url_for(".authorize", next=request.args.get('next', None))) + session["org_slug"] = current_org.slug + return redirect(url_for(".authorize", next=request.args.get("next", None))) -@blueprint.route('/oauth/google', endpoint="authorize") +@blueprint.route("/oauth/google", endpoint="authorize") def login(): - callback = url_for('.callback', _external=True) - next_path = request.args.get('next', url_for("redash.index", org_slug=session.get('org_slug'))) + callback = url_for(".callback", _external=True) + next_path = request.args.get( + "next", url_for("redash.index", org_slug=session.get("org_slug")) + ) logger.debug("Callback url: %s", callback) logger.debug("Next is: %s", next_path) return google_remote_app().authorize(callback=callback, state=next_path) -@blueprint.route('/oauth/google_callback', endpoint="callback") +@blueprint.route("/oauth/google_callback", endpoint="callback") def authorized(): resp = google_remote_app().authorized_response() - access_token = resp['access_token'] + access_token = resp["access_token"] if access_token is None: logger.warning("Access token missing in call back request.") flash("Validation error. Please retry.") - return redirect(url_for('redash.login')) + return redirect(url_for("redash.login")) profile = get_user_profile(access_token) if profile is None: flash("Validation error. Please retry.") - return redirect(url_for('redash.login')) + return redirect(url_for("redash.login")) - if 'org_slug' in session: - org = models.Organization.get_by_slug(session.pop('org_slug')) + if "org_slug" in session: + org = models.Organization.get_by_slug(session.pop("org_slug")) else: org = current_org if not verify_profile(org, profile): - logger.warning("User tried to login with unauthorized domain name: %s (org: %s)", profile['email'], org) - flash("Your Google Apps account ({}) isn't allowed.".format(profile['email'])) - return redirect(url_for('redash.login', org_slug=org.slug)) - - picture_url = "%s?sz=40" % profile['picture'] - user = create_and_login_user(org, profile['name'], profile['email'], picture_url) + logger.warning( + "User tried to login with unauthorized domain name: %s (org: %s)", + profile["email"], + org, + ) + flash("Your Google Apps account ({}) isn't allowed.".format(profile["email"])) + return redirect(url_for("redash.login", org_slug=org.slug)) + + picture_url = "%s?sz=40" % profile["picture"] + user = create_and_login_user(org, profile["name"], profile["email"], picture_url) if user is None: return logout_and_redirect_to_index() - unsafe_next_path = request.args.get('state') or url_for("redash.index", org_slug=org.slug) + unsafe_next_path = request.args.get("state") or url_for( + "redash.index", org_slug=org.slug + ) next_path = get_next_path(unsafe_next_path) return redirect(next_path) diff --git a/redash/authentication/jwt_auth.py b/redash/authentication/jwt_auth.py index 355591cb93..81904235ae 100644 --- a/redash/authentication/jwt_auth.py +++ b/redash/authentication/jwt_auth.py @@ -3,7 +3,7 @@ import requests import simplejson -logger = logging.getLogger('jwt_auth') +logger = logging.getLogger("jwt_auth") def get_public_keys(url): @@ -18,10 +18,12 @@ def get_public_keys(url): r = requests.get(url) r.raise_for_status() data = r.json() - if 'keys' in data: + if "keys" in data: public_keys = [] - for key_dict in data['keys']: - public_key = jwt.algorithms.RSAAlgorithm.from_jwk(simplejson.dumps(key_dict)) + for key_dict in data["keys"]: + public_key = jwt.algorithms.RSAAlgorithm.from_jwk( + simplejson.dumps(key_dict) + ) public_keys.append(public_key) get_public_keys.key_cache[url] = public_keys @@ -34,13 +36,15 @@ def get_public_keys(url): get_public_keys.key_cache = {} -def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url): +def verify_jwt_token( + jwt_token, expected_issuer, expected_audience, algorithms, public_certs_url +): # https://developers.cloudflare.com/access/setting-up-access/validate-jwt-tokens/ # https://cloud.google.com/iap/docs/signed-headers-howto # Loop through the keys since we can't pass the key set to the decoder keys = get_public_keys(public_certs_url) - key_id = jwt.get_unverified_header(jwt_token).get('kid', '') + key_id = jwt.get_unverified_header(jwt_token).get("kid", "") if key_id and isinstance(keys, dict): keys = [keys.get(key_id)] @@ -50,14 +54,11 @@ def verify_jwt_token(jwt_token, expected_issuer, expected_audience, algorithms, try: # decode returns the claims which has the email if you need it payload = jwt.decode( - jwt_token, - key=key, - audience=expected_audience, - algorithms=algorithms + jwt_token, key=key, audience=expected_audience, algorithms=algorithms ) - issuer = payload['iss'] + issuer = payload["iss"] if issuer != expected_issuer: - raise Exception('Wrong issuer: {}'.format(issuer)) + raise Exception("Wrong issuer: {}".format(issuer)) valid_token = True break except Exception as e: diff --git a/redash/authentication/ldap_auth.py b/redash/authentication/ldap_auth.py index b57056e298..e102b3f516 100644 --- a/redash/authentication/ldap_auth.py +++ b/redash/authentication/ldap_auth.py @@ -10,54 +10,62 @@ from ldap3 import Server, Connection except ImportError: if settings.LDAP_LOGIN_ENABLED: - sys.exit("The ldap3 library was not found. This is required to use LDAP authentication (see requirements.txt).") + sys.exit( + "The ldap3 library was not found. This is required to use LDAP authentication (see requirements.txt)." + ) -from redash.authentication import create_and_login_user, logout_and_redirect_to_index, get_next_path +from redash.authentication import ( + create_and_login_user, + logout_and_redirect_to_index, + get_next_path, +) from redash.authentication.org_resolving import current_org from redash.handlers.base import org_scoped_rule -logger = logging.getLogger('ldap_auth') +logger = logging.getLogger("ldap_auth") -blueprint = Blueprint('ldap_auth', __name__) +blueprint = Blueprint("ldap_auth", __name__) -@blueprint.route(org_scoped_rule("/ldap/login"), methods=['GET', 'POST']) +@blueprint.route(org_scoped_rule("/ldap/login"), methods=["GET", "POST"]) def login(org_slug=None): index_url = url_for("redash.index", org_slug=org_slug) - unsafe_next_path = request.args.get('next', index_url) + unsafe_next_path = request.args.get("next", index_url) next_path = get_next_path(unsafe_next_path) if not settings.LDAP_LOGIN_ENABLED: logger.error("Cannot use LDAP for login without being enabled in settings") - return redirect(url_for('redash.index', next=next_path)) + return redirect(url_for("redash.index", next=next_path)) if current_user.is_authenticated: return redirect(next_path) - if request.method == 'POST': - ldap_user = auth_ldap_user(request.form['email'], request.form['password']) + if request.method == "POST": + ldap_user = auth_ldap_user(request.form["email"], request.form["password"]) if ldap_user is not None: user = create_and_login_user( current_org, ldap_user[settings.LDAP_DISPLAY_NAME_KEY][0], - ldap_user[settings.LDAP_EMAIL_KEY][0] + ldap_user[settings.LDAP_EMAIL_KEY][0], ) if user is None: return logout_and_redirect_to_index() - return redirect(next_path or url_for('redash.index')) + return redirect(next_path or url_for("redash.index")) else: flash("Incorrect credentials.") - return render_template("login.html", - org_slug=org_slug, - next=next_path, - email=request.form.get('email', ''), - show_password_login=True, - username_prompt=settings.LDAP_CUSTOM_USERNAME_PROMPT, - hide_forgot_password=True) + return render_template( + "login.html", + org_slug=org_slug, + next=next_path, + email=request.form.get("email", ""), + show_password_login=True, + username_prompt=settings.LDAP_CUSTOM_USERNAME_PROMPT, + hide_forgot_password=True, + ) def auth_ldap_user(username, password): @@ -68,12 +76,16 @@ def auth_ldap_user(username, password): settings.LDAP_BIND_DN, password=settings.LDAP_BIND_DN_PASSWORD, authentication=settings.LDAP_AUTH_METHOD, - auto_bind=True + auto_bind=True, ) else: conn = Connection(server, auto_bind=True) - conn.search(settings.LDAP_SEARCH_DN, settings.LDAP_SEARCH_TEMPLATE % {"username": username}, attributes=[settings.LDAP_DISPLAY_NAME_KEY, settings.LDAP_EMAIL_KEY]) + conn.search( + settings.LDAP_SEARCH_DN, + settings.LDAP_SEARCH_TEMPLATE % {"username": username}, + attributes=[settings.LDAP_DISPLAY_NAME_KEY, settings.LDAP_EMAIL_KEY], + ) if len(conn.entries) == 0: return None diff --git a/redash/authentication/org_resolving.py b/redash/authentication/org_resolving.py index 0eacaad5f6..5aa54cdd1b 100644 --- a/redash/authentication/org_resolving.py +++ b/redash/authentication/org_resolving.py @@ -7,13 +7,13 @@ def _get_current_org(): - if 'org' in g: + if "org" in g: return g.org if request.view_args is None: - slug = g.get('org_slug', 'default') + slug = g.get("org_slug", "default") else: - slug = request.view_args.get('org_slug', g.get('org_slug', 'default')) + slug = request.view_args.get("org_slug", g.get("org_slug", "default")) g.org = Organization.get_by_slug(slug) logging.debug("Current organization: %s (slug: %s)", g.org, slug) diff --git a/redash/authentication/remote_user_auth.py b/redash/authentication/remote_user_auth.py index 8402aeb23d..7cba295ccd 100644 --- a/redash/authentication/remote_user_auth.py +++ b/redash/authentication/remote_user_auth.py @@ -1,23 +1,29 @@ import logging from flask import redirect, url_for, Blueprint, request -from redash.authentication import create_and_login_user, logout_and_redirect_to_index, get_next_path +from redash.authentication import ( + create_and_login_user, + logout_and_redirect_to_index, + get_next_path, +) from redash.authentication.org_resolving import current_org from redash.handlers.base import org_scoped_rule from redash import settings -logger = logging.getLogger('remote_user_auth') +logger = logging.getLogger("remote_user_auth") -blueprint = Blueprint('remote_user_auth', __name__) +blueprint = Blueprint("remote_user_auth", __name__) @blueprint.route(org_scoped_rule("/remote_user/login")) def login(org_slug=None): - unsafe_next_path = request.args.get('next') + unsafe_next_path = request.args.get("next") next_path = get_next_path(unsafe_next_path) if not settings.REMOTE_USER_LOGIN_ENABLED: - logger.error("Cannot use remote user for login without being enabled in settings") - return redirect(url_for('redash.index', next=next_path, org_slug=org_slug)) + logger.error( + "Cannot use remote user for login without being enabled in settings" + ) + return redirect(url_for("redash.index", next=next_path, org_slug=org_slug)) email = request.headers.get(settings.REMOTE_USER_HEADER) @@ -25,12 +31,16 @@ def login(org_slug=None): # falsey value. Special case that here so it Just Works for more installs. # '(null)' should never really be a value that anyone wants to legitimately # use as a redash user email. - if email == '(null)': + if email == "(null)": email = None if not email: - logger.error("Cannot use remote user for login when it's not provided in the request (looked in headers['" + settings.REMOTE_USER_HEADER + "'])") - return redirect(url_for('redash.index', next=next_path, org_slug=org_slug)) + logger.error( + "Cannot use remote user for login when it's not provided in the request (looked in headers['" + + settings.REMOTE_USER_HEADER + + "'])" + ) + return redirect(url_for("redash.index", next=next_path, org_slug=org_slug)) logger.info("Logging in " + email + " via remote user") @@ -38,4 +48,4 @@ def login(org_slug=None): if user is None: return logout_and_redirect_to_index() - return redirect(next_path or url_for('redash.index', org_slug=org_slug), code=302) + return redirect(next_path or url_for("redash.index", org_slug=org_slug), code=302) diff --git a/redash/authentication/saml_auth.py b/redash/authentication/saml_auth.py index fd15f7d053..6aa67e7bad 100644 --- a/redash/authentication/saml_auth.py +++ b/redash/authentication/saml_auth.py @@ -8,8 +8,8 @@ from saml2.config import Config as Saml2Config from saml2.saml import NAMEID_FORMAT_TRANSIENT -logger = logging.getLogger('saml_auth') -blueprint = Blueprint('saml_auth', __name__) +logger = logging.getLogger("saml_auth") +blueprint = Blueprint("saml_auth", __name__) def get_saml_client(org): @@ -23,34 +23,30 @@ def get_saml_client(org): acs_url = url_for("saml_auth.idp_initiated", org_slug=org.slug, _external=True) saml_settings = { - 'metadata': { - "remote": [{ - "url": metadata_url - }] - }, - 'service': { - 'sp': { - 'endpoints': { - 'assertion_consumer_service': [ + "metadata": {"remote": [{"url": metadata_url}]}, + "service": { + "sp": { + "endpoints": { + "assertion_consumer_service": [ (acs_url, BINDING_HTTP_REDIRECT), - (acs_url, BINDING_HTTP_POST) - ], + (acs_url, BINDING_HTTP_POST), + ] }, # Don't verify that the incoming requests originate from us via # the built-in cache for authn request ids in pysaml2 - 'allow_unsolicited': True, + "allow_unsolicited": True, # Don't sign authn requests, since signed requests only make # sense in a situation where you control both the SP and IdP - 'authn_requests_signed': False, - 'logout_requests_signed': True, - 'want_assertions_signed': True, - 'want_response_signed': False, - }, + "authn_requests_signed": False, + "logout_requests_signed": True, + "want_assertions_signed": True, + "want_response_signed": False, + } }, } if entity_id is not None and entity_id != "": - saml_settings['entityid'] = entity_id + saml_settings["entityid"] = entity_id sp_config = Saml2Config() sp_config.load(saml_settings) @@ -60,26 +56,29 @@ def get_saml_client(org): return saml_client -@blueprint.route(org_scoped_rule('/saml/callback'), methods=['POST']) +@blueprint.route(org_scoped_rule("/saml/callback"), methods=["POST"]) def idp_initiated(org_slug=None): if not current_org.get_setting("auth_saml_enabled"): logger.error("SAML Login is not enabled") - return redirect(url_for('redash.index', org_slug=org_slug)) + return redirect(url_for("redash.index", org_slug=org_slug)) saml_client = get_saml_client(current_org) try: authn_response = saml_client.parse_authn_request_response( - request.form['SAMLResponse'], - entity.BINDING_HTTP_POST) + request.form["SAMLResponse"], entity.BINDING_HTTP_POST + ) except Exception: - logger.error('Failed to parse SAML response', exc_info=True) - flash('SAML login failed. Please try again later.') - return redirect(url_for('redash.login', org_slug=org_slug)) + logger.error("Failed to parse SAML response", exc_info=True) + flash("SAML login failed. Please try again later.") + return redirect(url_for("redash.login", org_slug=org_slug)) authn_response.get_identity() user_info = authn_response.get_subject() email = user_info.text - name = "%s %s" % (authn_response.ava['FirstName'][0], authn_response.ava['LastName'][0]) + name = "%s %s" % ( + authn_response.ava["FirstName"][0], + authn_response.ava["LastName"][0], + ) # This is what as known as "Just In Time (JIT) provisioning". # What that means is that, if a user in a SAML assertion @@ -88,11 +87,11 @@ def idp_initiated(org_slug=None): if user is None: return logout_and_redirect_to_index() - if 'RedashGroups' in authn_response.ava: - group_names = authn_response.ava.get('RedashGroups') + if "RedashGroups" in authn_response.ava: + group_names = authn_response.ava.get("RedashGroups") user.update_group_assignments(group_names) - url = url_for('redash.index', org_slug=org_slug) + url = url_for("redash.index", org_slug=org_slug) return redirect(url) @@ -101,10 +100,10 @@ def idp_initiated(org_slug=None): def sp_initiated(org_slug=None): if not current_org.get_setting("auth_saml_enabled"): logger.error("SAML Login is not enabled") - return redirect(url_for('redash.index', org_slug=org_slug)) + return redirect(url_for("redash.index", org_slug=org_slug)) saml_client = get_saml_client(current_org) - nameid_format = current_org.get_setting('auth_saml_nameid_format') + nameid_format = current_org.get_setting("auth_saml_nameid_format") if nameid_format is None or nameid_format == "": nameid_format = NAMEID_FORMAT_TRANSIENT @@ -112,8 +111,8 @@ def sp_initiated(org_slug=None): redirect_url = None # Select the IdP URL to send the AuthN request to - for key, value in info['headers']: - if key == 'Location': + for key, value in info["headers"]: + if key == "Location": redirect_url = value response = redirect(redirect_url, code=302) @@ -124,6 +123,6 @@ def sp_initiated(org_slug=None): # http://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf # We set those headers here as a "belt and suspenders" approach, # since enterprise environments don't always conform to RFCs - response.headers['Cache-Control'] = 'no-cache, no-store' - response.headers['Pragma'] = 'no-cache' + response.headers["Cache-Control"] = "no-cache, no-store" + response.headers["Pragma"] = "no-cache" return response diff --git a/redash/cli/__init__.py b/redash/cli/__init__.py index d3bc522f8a..c61d8b8306 100644 --- a/redash/cli/__init__.py +++ b/redash/cli/__init__.py @@ -1,5 +1,3 @@ - - import click import simplejson from flask import current_app @@ -17,10 +15,9 @@ def create(group): @app.shell_context_processor def shell_context(): from redash import models, settings - return { - 'models': models, - 'settings': settings, - } + + return {"models": models, "settings": settings} + return app @@ -58,7 +55,7 @@ def check_settings(): @manager.command() -@click.argument('email', default=settings.MAIL_DEFAULT_SENDER, required=False) +@click.argument("email", default=settings.MAIL_DEFAULT_SENDER, required=False) def send_test_mail(email=None): """ Send test message to EMAIL (default: the address you defined in MAIL_DEFAULT_SENDER) @@ -69,8 +66,11 @@ def send_test_mail(email=None): if email is None: email = settings.MAIL_DEFAULT_SENDER - mail.send(Message(subject="Test Message from Redash", recipients=[email], - body="Test message.")) + mail.send( + Message( + subject="Test Message from Redash", recipients=[email], body="Test message." + ) + ) @manager.command() @@ -79,13 +79,14 @@ def ipython(): import sys import IPython from flask.globals import _app_ctx_stack + app = _app_ctx_stack.top.app - banner = 'Python %s on %s\nIPython: %s\nRedash version: %s\n' % ( + banner = "Python %s on %s\nIPython: %s\nRedash version: %s\n" % ( sys.version, sys.platform, IPython.__version__, - __version__ + __version__, ) ctx = {} diff --git a/redash/cli/data_sources.py b/redash/cli/data_sources.py index 90d33fae7e..8049398659 100644 --- a/redash/cli/data_sources.py +++ b/redash/cli/data_sources.py @@ -1,4 +1,3 @@ - from sys import exit import click @@ -7,32 +6,39 @@ from sqlalchemy.orm.exc import NoResultFound from redash import models -from redash.query_runner import (get_configuration_schema_for_query_runner_type, - query_runners) +from redash.query_runner import ( + get_configuration_schema_for_query_runner_type, + query_runners, +) from redash.utils import json_loads from redash.utils.configuration import ConfigurationContainer manager = AppGroup(help="Data sources management commands.") -@manager.command(name='list') -@click.option('--org', 'organization', default=None, - help="The organization the user belongs to (leave blank for " - "all organizations).") +@manager.command(name="list") +@click.option( + "--org", + "organization", + default=None, + help="The organization the user belongs to (leave blank for " "all organizations).", +) def list_command(organization=None): """List currently configured data sources.""" if organization: org = models.Organization.get_by_slug(organization) - data_sources = models.DataSource.query.filter( - models.DataSource.org == org) + data_sources = models.DataSource.query.filter(models.DataSource.org == org) else: data_sources = models.DataSource.query for i, ds in enumerate(data_sources.order_by(models.DataSource.name)): if i > 0: print("-" * 20) - print("Id: {}\nName: {}\nType: {}\nOptions: {}".format( - ds.id, ds.name, ds.type, ds.options.to_json())) + print( + "Id: {}\nName: {}\nType: {}\nOptions: {}".format( + ds.id, ds.name, ds.type, ds.options.to_json() + ) + ) @manager.command() @@ -46,26 +52,33 @@ def list_types(): def validate_data_source_type(type): if type not in query_runners.keys(): - print("Error: the type \"{}\" is not supported (supported types: {})." - .format(type, ", ".join(query_runners.keys()))) + print( + 'Error: the type "{}" is not supported (supported types: {}).'.format( + type, ", ".join(query_runners.keys()) + ) + ) print("OJNK") exit(1) @manager.command() -@click.argument('name') -@click.option('--org', 'organization', default='default', - help="The organization the user belongs to " - "(leave blank for 'default').") -def test(name, organization='default'): +@click.argument("name") +@click.option( + "--org", + "organization", + default="default", + help="The organization the user belongs to " "(leave blank for 'default').", +) +def test(name, organization="default"): """Test connection to data source by issuing a trivial query.""" try: org = models.Organization.get_by_slug(organization) data_source = models.DataSource.query.filter( - models.DataSource.name == name, - models.DataSource.org == org).one() - print("Testing connection to data source: {} (id={})".format( - name, data_source.id)) + models.DataSource.name == name, models.DataSource.org == org + ).one() + print( + "Testing connection to data source: {} (id={})".format(name, data_source.id) + ) try: data_source.query_runner.test_connection() except Exception as e: @@ -79,15 +92,16 @@ def test(name, organization='default'): @manager.command() -@click.argument('name', default=None, required=False) -@click.option('--type', default=None, - help="new type for the data source") -@click.option('--options', default=None, - help="updated options for the data source") -@click.option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for " - "'default').") -def new(name=None, type=None, options=None, organization='default'): +@click.argument("name", default=None, required=False) +@click.option("--type", default=None, help="new type for the data source") +@click.option("--options", default=None, help="updated options for the data source") +@click.option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for " "'default').", +) +def new(name=None, type=None, options=None, organization="default"): """Create new data source.""" if name is None: @@ -100,8 +114,7 @@ def new(name=None, type=None, options=None, organization='default'): idx = 0 while idx < 1 or idx > len(list(query_runners.keys())): - idx = click.prompt("[{}-{}]".format(1, len(query_runners.keys())), - type=int) + idx = click.prompt("[{}-{}]".format(1, len(query_runners.keys())), type=int) type = list(query_runners.keys())[idx - 1] else: @@ -111,28 +124,28 @@ def new(name=None, type=None, options=None, organization='default'): schema = query_runner.configuration_schema() if options is None: - types = { - 'string': text_type, - 'number': int, - 'boolean': bool - } + types = {"string": text_type, "number": int, "boolean": bool} options_obj = {} - for k, prop in schema['properties'].items(): - required = k in schema.get('required', []) + for k, prop in schema["properties"].items(): + required = k in schema.get("required", []) default_value = "<>" if required: default_value = None - prompt = prop.get('title', k.capitalize()) + prompt = prop.get("title", k.capitalize()) if required: prompt = "{} (required)".format(prompt) else: prompt = "{} (optional)".format(prompt) - value = click.prompt(prompt, default=default_value, - type=types[prop['type']], show_default=False) + value = click.prompt( + prompt, + default=default_value, + type=types[prop["type"]], + show_default=False, + ) if value != default_value: options_obj[k] = value @@ -144,28 +157,37 @@ def new(name=None, type=None, options=None, organization='default'): print("Error: invalid configuration.") exit() - print("Creating {} data source ({}) with options:\n{}".format( - type, name, options.to_json())) + print( + "Creating {} data source ({}) with options:\n{}".format( + type, name, options.to_json() + ) + ) data_source = models.DataSource.create_with_group( - name=name, type=type, options=options, - org=models.Organization.get_by_slug(organization)) + name=name, + type=type, + options=options, + org=models.Organization.get_by_slug(organization), + ) models.db.session.commit() print("Id: {}".format(data_source.id)) @manager.command() -@click.argument('name') -@click.option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for " - "'default').") -def delete(name, organization='default'): +@click.argument("name") +@click.option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for " "'default').", +) +def delete(name, organization="default"): """Delete data source by name.""" try: org = models.Organization.get_by_slug(organization) data_source = models.DataSource.query.filter( - models.DataSource.name == name, - models.DataSource.org == org).one() + models.DataSource.name == name, models.DataSource.org == org + ).one() print("Deleting data source: {} (id={})".format(name, data_source.id)) models.db.session.delete(data_source) models.db.session.commit() @@ -182,31 +204,30 @@ def update_attr(obj, attr, new_value): @manager.command() -@click.argument('name') -@click.option('--name', 'new_name', default=None, - help="new name for the data source") -@click.option('--options', default=None, - help="updated options for the data source") -@click.option('--type', default=None, - help="new type for the data source") -@click.option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for " - "'default').") -def edit(name, new_name=None, options=None, type=None, organization='default'): +@click.argument("name") +@click.option("--name", "new_name", default=None, help="new name for the data source") +@click.option("--options", default=None, help="updated options for the data source") +@click.option("--type", default=None, help="new type for the data source") +@click.option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for " "'default').", +) +def edit(name, new_name=None, options=None, type=None, organization="default"): """Edit data source settings (name, options, type).""" try: if type is not None: validate_data_source_type(type) org = models.Organization.get_by_slug(organization) data_source = models.DataSource.query.filter( - models.DataSource.name == name, - models.DataSource.org == org).one() + models.DataSource.name == name, models.DataSource.org == org + ).one() update_attr(data_source, "name", new_name) update_attr(data_source, "type", type) if options is not None: - schema = get_configuration_schema_for_query_runner_type( - data_source.type) + schema = get_configuration_schema_for_query_runner_type(data_source.type) options = json_loads(options) data_source.options.set_schema(schema) data_source.options.update(options) diff --git a/redash/cli/database.py b/redash/cli/database.py index b8946f0aef..ce55b73c6f 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -19,7 +19,7 @@ def _wait_for_db_connection(db): retried = False while not retried: try: - db.engine.execute('SELECT 1;') + db.engine.execute("SELECT 1;") return except DatabaseError: time.sleep(30) @@ -51,10 +51,9 @@ def drop_tables(): @manager.command() -@argument('old_secret') -@argument('new_secret') -@option('--show-sql/--no-show-sql', default=False, - help="show sql for debug") +@argument("old_secret") +@argument("new_secret") +@option("--show-sql/--no-show-sql", default=False, help="show sql for debug") def reencrypt(old_secret, new_secret, show_sql): """Reencrypt data encrypted by OLD_SECRET with NEW_SECRET.""" from redash.models import db @@ -63,26 +62,39 @@ def reencrypt(old_secret, new_secret, show_sql): if show_sql: import logging + logging.basicConfig() - logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) - - table_for_select = sqlalchemy.Table('data_sources', sqlalchemy.MetaData(), - Column('id', db.Integer, primary_key=True), - Column('encrypted_options', - ConfigurationContainer.as_mutable( - EncryptedConfiguration( - db.Text, old_secret, FernetEngine)))) - table_for_update = sqlalchemy.Table('data_sources', sqlalchemy.MetaData(), - Column('id', db.Integer, primary_key=True), - Column('encrypted_options', - ConfigurationContainer.as_mutable( - EncryptedConfiguration( - db.Text, new_secret, FernetEngine)))) + logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) + + table_for_select = sqlalchemy.Table( + "data_sources", + sqlalchemy.MetaData(), + Column("id", db.Integer, primary_key=True), + Column( + "encrypted_options", + ConfigurationContainer.as_mutable( + EncryptedConfiguration(db.Text, old_secret, FernetEngine) + ), + ), + ) + table_for_update = sqlalchemy.Table( + "data_sources", + sqlalchemy.MetaData(), + Column("id", db.Integer, primary_key=True), + Column( + "encrypted_options", + ConfigurationContainer.as_mutable( + EncryptedConfiguration(db.Text, new_secret, FernetEngine) + ), + ), + ) update = table_for_update.update() data_sources = db.session.execute(select([table_for_select])) for ds in data_sources: - stmt = update.where(table_for_update.c.id == ds['id']).values(encrypted_options=ds['encrypted_options']) + stmt = update.where(table_for_update.c.id == ds["id"]).values( + encrypted_options=ds["encrypted_options"] + ) db.session.execute(stmt) data_sources.close() diff --git a/redash/cli/groups.py b/redash/cli/groups.py index 195d3c466a..1770057feb 100644 --- a/redash/cli/groups.py +++ b/redash/cli/groups.py @@ -1,4 +1,3 @@ - from sys import exit from sqlalchemy.orm.exc import NoResultFound @@ -11,17 +10,23 @@ @manager.command() -@argument('name') -@option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for " - "'default').") -@option('--permissions', default=None, - help="Comma separated list of permissions ('create_dashboard'," - " 'create_query', 'edit_dashboard', 'edit_query', " - "'view_query', 'view_source', 'execute_query', 'list_users'," - " 'schedule_query', 'list_dashboards', 'list_alerts'," - " 'list_data_sources') (leave blank for default).") -def create(name, permissions=None, organization='default'): +@argument("name") +@option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for " "'default').", +) +@option( + "--permissions", + default=None, + help="Comma separated list of permissions ('create_dashboard'," + " 'create_query', 'edit_dashboard', 'edit_query', " + "'view_query', 'view_source', 'execute_query', 'list_users'," + " 'schedule_query', 'list_dashboards', 'list_alerts'," + " 'list_data_sources') (leave blank for default).", +) +def create(name, permissions=None, organization="default"): print("Creating group (%s)..." % (name)) org = models.Organization.get_by_slug(organization) @@ -31,9 +36,7 @@ def create(name, permissions=None, organization='default'): print("permissions: [%s]" % ",".join(permissions)) try: - models.db.session.add(models.Group( - name=name, org=org, - permissions=permissions)) + models.db.session.add(models.Group(name=name, org=org, permissions=permissions)) models.db.session.commit() except Exception as e: print("Failed create group: %s" % e) @@ -41,13 +44,16 @@ def create(name, permissions=None, organization='default'): @manager.command() -@argument('group_id') -@option('--permissions', default=None, - help="Comma separated list of permissions ('create_dashboard'," - " 'create_query', 'edit_dashboard', 'edit_query'," - " 'view_query', 'view_source', 'execute_query', 'list_users'," - " 'schedule_query', 'list_dashboards', 'list_alerts'," - " 'list_data_sources') (leave blank for default).") +@argument("group_id") +@option( + "--permissions", + default=None, + help="Comma separated list of permissions ('create_dashboard'," + " 'create_query', 'edit_dashboard', 'edit_query'," + " 'view_query', 'view_source', 'execute_query', 'list_users'," + " 'schedule_query', 'list_dashboards', 'list_alerts'," + " 'list_data_sources') (leave blank for default).", +) def change_permissions(group_id, permissions=None): print("Change permissions of group %s ..." % group_id) @@ -58,8 +64,10 @@ def change_permissions(group_id, permissions=None): exit(1) permissions = extract_permissions_string(permissions) - print("current permissions [%s] will be modify to [%s]" % ( - ",".join(group.permissions), ",".join(permissions))) + print( + "current permissions [%s] will be modify to [%s]" + % (",".join(group.permissions), ",".join(permissions)) + ) group.permissions = permissions @@ -75,14 +83,18 @@ def extract_permissions_string(permissions): if permissions is None: permissions = models.Group.DEFAULT_PERMISSIONS else: - permissions = permissions.split(',') + permissions = permissions.split(",") permissions = [p.strip() for p in permissions] return permissions -@manager.command(name='list') -@option('--org', 'organization', default=None, - help="The organization to limit to (leave blank for all).") +@manager.command(name="list") +@option( + "--org", + "organization", + default=None, + help="The organization to limit to (leave blank for all).", +) def list_command(organization=None): """List all groups""" if organization: @@ -95,8 +107,15 @@ def list_command(organization=None): if i > 0: print("-" * 20) - print("Id: {}\nName: {}\nType: {}\nOrganization: {}\nPermissions: [{}]".format( - group.id, group.name, group.type, group.org.slug, ",".join(group.permissions))) + print( + "Id: {}\nName: {}\nType: {}\nOrganization: {}\nPermissions: [{}]".format( + group.id, + group.name, + group.type, + group.org.slug, + ",".join(group.permissions), + ) + ) members = models.Group.members(group.id) user_names = [m.name for m in members] diff --git a/redash/cli/organization.py b/redash/cli/organization.py index ad45a3fc38..45c73551fc 100644 --- a/redash/cli/organization.py +++ b/redash/cli/organization.py @@ -1,4 +1,3 @@ - from click import argument from flask.cli import AppGroup @@ -8,28 +7,34 @@ @manager.command() -@argument('domains') +@argument("domains") def set_google_apps_domains(domains): """ Sets the allowable domains to the comma separated list DOMAINS. """ organization = models.Organization.query.first() k = models.Organization.SETTING_GOOGLE_APPS_DOMAINS - organization.settings[k] = domains.split(',') + organization.settings[k] = domains.split(",") models.db.session.add(organization) models.db.session.commit() - print("Updated list of allowed domains to: {}".format( - organization.google_apps_domains)) + print( + "Updated list of allowed domains to: {}".format( + organization.google_apps_domains + ) + ) @manager.command() def show_google_apps_domains(): organization = models.Organization.query.first() - print("Current list of Google Apps domains: {}".format( - ', '.join(organization.google_apps_domains))) + print( + "Current list of Google Apps domains: {}".format( + ", ".join(organization.google_apps_domains) + ) + ) -@manager.command(name='list') +@manager.command(name="list") def list_command(): """List all organizations""" orgs = models.Organization.query diff --git a/redash/cli/queries.py b/redash/cli/queries.py index 00810b749e..f71bdbabe6 100644 --- a/redash/cli/queries.py +++ b/redash/cli/queries.py @@ -6,8 +6,8 @@ @manager.command() -@argument('query_id') -@argument('tag') +@argument("query_id") +@argument("tag") def add_tag(query_id, tag): from redash import models @@ -32,8 +32,8 @@ def add_tag(query_id, tag): @manager.command() -@argument('query_id') -@argument('tag') +@argument("query_id") +@argument("tag") def remove_tag(query_id, tag): from redash import models diff --git a/redash/cli/rq.py b/redash/cli/rq.py index 00a82ed102..dab96f55c9 100644 --- a/redash/cli/rq.py +++ b/redash/cli/rq.py @@ -9,7 +9,11 @@ from sqlalchemy.orm import configure_mappers from redash import rq_redis_connection -from redash.schedule import rq_scheduler, schedule_periodic_jobs, periodic_job_definitions +from redash.schedule import ( + rq_scheduler, + schedule_periodic_jobs, + periodic_job_definitions, +) manager = AppGroup(help="RQ management commands.") @@ -22,15 +26,15 @@ def scheduler(): @manager.command() -@argument('queues', nargs=-1) +@argument("queues", nargs=-1) def worker(queues): - # Configure any SQLAlchemy mappers loaded until now so that the mapping configuration - # will already be available to the forked work horses and they won't need + # Configure any SQLAlchemy mappers loaded until now so that the mapping configuration + # will already be available to the forked work horses and they won't need # to spend valuable time re-doing that on every fork. configure_mappers() if not queues: - queues = ['periodic', 'emails', 'default', 'schemas'] + queues = ["periodic", "emails", "default", "schemas"] with Connection(rq_redis_connection): w = Worker(queues, log_job_description=False) diff --git a/redash/cli/users.py b/redash/cli/users.py index 72b6a5544e..af7ba99284 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -1,4 +1,3 @@ - from sys import exit from click import BOOL, argument, option, prompt @@ -15,8 +14,8 @@ def build_groups(org, groups, is_admin): if isinstance(groups, string_types): - groups = groups.split(',') - groups.remove('') # in case it was empty string + groups = groups.split(",") + groups.remove("") # in case it was empty string groups = [int(g) for g in groups] if groups is None: @@ -29,11 +28,14 @@ def build_groups(org, groups, is_admin): @manager.command() -@argument('email') -@option('--org', 'organization', default='default', - help="the organization the user belongs to, (leave blank for " - "'default').") -def grant_admin(email, organization='default'): +@argument("email") +@option( + "--org", + "organization", + default="default", + help="the organization the user belongs to, (leave blank for " "'default').", +) +def grant_admin(email, organization="default"): """ Grant admin access to user EMAIL. """ @@ -54,28 +56,47 @@ def grant_admin(email, organization='default'): @manager.command() -@argument('email') -@argument('name') -@option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for " - "'default').") -@option('--admin', 'is_admin', is_flag=True, default=False, - help="set user as admin") -@option('--google', 'google_auth', is_flag=True, - default=False, help="user uses Google Auth to login") -@option('--password', 'password', default=None, - help="Password for users who don't use Google Auth " - "(leave blank for prompt).") -@option('--groups', 'groups', default=None, - help="Comma separated list of groups (leave blank for " - "default).") -def create(email, name, groups, is_admin=False, google_auth=False, - password=None, organization='default'): +@argument("email") +@argument("name") +@option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for " "'default').", +) +@option("--admin", "is_admin", is_flag=True, default=False, help="set user as admin") +@option( + "--google", + "google_auth", + is_flag=True, + default=False, + help="user uses Google Auth to login", +) +@option( + "--password", + "password", + default=None, + help="Password for users who don't use Google Auth " "(leave blank for prompt).", +) +@option( + "--groups", + "groups", + default=None, + help="Comma separated list of groups (leave blank for " "default).", +) +def create( + email, + name, + groups, + is_admin=False, + google_auth=False, + password=None, + organization="default", +): """ Create user EMAIL with display name NAME. """ - print("Creating user (%s, %s) in organization %s..." % (email, name, - organization)) + print("Creating user (%s, %s) in organization %s..." % (email, name, organization)) print("Admin: %r" % is_admin) print("Login with Google Auth: %r\n" % google_auth) @@ -84,8 +105,7 @@ def create(email, name, groups, is_admin=False, google_auth=False, user = models.User(org=org, email=email, name=name, group_ids=groups) if not password and not google_auth: - password = prompt("Password", hide_input=True, - confirmation_prompt=True) + password = prompt("Password", hide_input=True, confirmation_prompt=True) if not google_auth: user.hash_password(password) @@ -98,20 +118,36 @@ def create(email, name, groups, is_admin=False, google_auth=False, @manager.command() -@argument('email') -@argument('name') -@option('--org', 'organization', default='default', - help="The organization the root user belongs to (leave blank for 'default').") -@option('--google', 'google_auth', is_flag=True, - default=False, help="user uses Google Auth to login") -@option('--password', 'password', default=None, - help="Password for root user who don't use Google Auth " - "(leave blank for prompt).") -def create_root(email, name, google_auth=False, password=None, organization='default'): +@argument("email") +@argument("name") +@option( + "--org", + "organization", + default="default", + help="The organization the root user belongs to (leave blank for 'default').", +) +@option( + "--google", + "google_auth", + is_flag=True, + default=False, + help="user uses Google Auth to login", +) +@option( + "--password", + "password", + default=None, + help="Password for root user who don't use Google Auth " + "(leave blank for prompt).", +) +def create_root(email, name, google_auth=False, password=None, organization="default"): """ Create root user. """ - print("Creating root user (%s, %s) in organization %s..." % (email, name, organization)) + print( + "Creating root user (%s, %s) in organization %s..." + % (email, name, organization) + ) print("Login with Google Auth: %r\n" % google_auth) user = models.User.query.filter(models.User.email == email).first() @@ -119,21 +155,35 @@ def create_root(email, name, google_auth=False, password=None, organization='def print("User [%s] is already exists." % email) exit(1) - slug = 'default' - default_org = models.Organization.query.filter(models.Organization.slug == slug).first() + slug = "default" + default_org = models.Organization.query.filter( + models.Organization.slug == slug + ).first() if default_org is None: default_org = models.Organization(name=organization, slug=slug, settings={}) - admin_group = models.Group(name='admin', permissions=['admin', 'super_admin'], - org=default_org, type=models.Group.BUILTIN_GROUP) - default_group = models.Group(name='default', permissions=models.Group.DEFAULT_PERMISSIONS, - org=default_org, type=models.Group.BUILTIN_GROUP) + admin_group = models.Group( + name="admin", + permissions=["admin", "super_admin"], + org=default_org, + type=models.Group.BUILTIN_GROUP, + ) + default_group = models.Group( + name="default", + permissions=models.Group.DEFAULT_PERMISSIONS, + org=default_org, + type=models.Group.BUILTIN_GROUP, + ) models.db.session.add_all([default_org, admin_group, default_group]) models.db.session.commit() - user = models.User(org=default_org, email=email, name=name, - group_ids=[admin_group.id, default_group.id]) + user = models.User( + org=default_org, + email=email, + name=name, + group_ids=[admin_group.id, default_group.id], + ) if not google_auth: user.hash_password(password) @@ -146,10 +196,13 @@ def create_root(email, name, google_auth=False, password=None, organization='def @manager.command() -@argument('email') -@option('--org', 'organization', default=None, - help="The organization the user belongs to (leave blank for all" - " organizations).") +@argument("email") +@option( + "--org", + "organization", + default=None, + help="The organization the user belongs to (leave blank for all" " organizations).", +) def delete(email, organization=None): """ Delete user EMAIL. @@ -157,22 +210,25 @@ def delete(email, organization=None): if organization: org = models.Organization.get_by_slug(organization) deleted_count = models.User.query.filter( - models.User.email == email, - models.User.org == org.id, + models.User.email == email, models.User.org == org.id ).delete() else: deleted_count = models.User.query.filter(models.User.email == email).delete( - synchronize_session=False) + synchronize_session=False + ) models.db.session.commit() print("Deleted %d users." % deleted_count) @manager.command() -@argument('email') -@argument('password') -@option('--org', 'organization', default=None, - help="The organization the user belongs to (leave blank for all " - "organizations).") +@argument("email") +@argument("password") +@option( + "--org", + "organization", + default=None, + help="The organization the user belongs to (leave blank for all " "organizations).", +) def password(email, password, organization=None): """ Resets password for EMAIL to PASSWORD. @@ -180,8 +236,7 @@ def password(email, password, organization=None): if organization: org = models.Organization.get_by_slug(organization) user = models.User.query.filter( - models.User.email == email, - models.User.org == org, + models.User.email == email, models.User.org == org ).first() else: user = models.User.query.filter(models.User.email == email).first() @@ -197,17 +252,23 @@ def password(email, password, organization=None): @manager.command() -@argument('email') -@argument('name') -@argument('inviter_email') -@option('--org', 'organization', default='default', - help="The organization the user belongs to (leave blank for 'default')") -@option('--admin', 'is_admin', type=BOOL, default=False, - help="set user as admin") -@option('--groups', 'groups', default=None, - help="Comma seperated list of groups (leave blank for default).") -def invite(email, name, inviter_email, groups, is_admin=False, - organization='default'): +@argument("email") +@argument("name") +@argument("inviter_email") +@option( + "--org", + "organization", + default="default", + help="The organization the user belongs to (leave blank for 'default')", +) +@option("--admin", "is_admin", type=BOOL, default=False, help="set user as admin") +@option( + "--groups", + "groups", + default=None, + help="Comma seperated list of groups (leave blank for default).", +) +def invite(email, name, inviter_email, groups, is_admin=False, organization="default"): """ Sends an invitation to the given NAME and EMAIL from INVITER_EMAIL. """ @@ -230,10 +291,13 @@ def invite(email, name, inviter_email, groups, is_admin=False, print("The inviter [%s] was not found." % inviter_email) -@manager.command(name='list') -@option('--org', 'organization', default=None, - help="The organization the user belongs to (leave blank for all" - " organizations)") +@manager.command(name="list") +@option( + "--org", + "organization", + default=None, + help="The organization the user belongs to (leave blank for all" " organizations)", +) def list_command(organization=None): """List all users""" if organization: @@ -245,8 +309,11 @@ def list_command(organization=None): if i > 0: print("-" * 20) - print("Id: {}\nName: {}\nEmail: {}\nOrganization: {}\nActive: {}".format( - user.id, user.name, user.email, user.org.name, not(user.is_disabled))) + print( + "Id: {}\nName: {}\nEmail: {}\nOrganization: {}\nActive: {}".format( + user.id, user.name, user.email, user.org.name, not (user.is_disabled) + ) + ) groups = models.Group.query.filter(models.Group.id.in_(user.group_ids)).all() group_names = [group.name for group in groups] diff --git a/redash/destinations/__init__.py b/redash/destinations/__init__.py index 509974708b..a84d735904 100644 --- a/redash/destinations/__init__.py +++ b/redash/destinations/__init__.py @@ -2,12 +2,7 @@ logger = logging.getLogger(__name__) -__all__ = [ - 'BaseDestination', - 'register', - 'get_destination', - 'import_destinations' -] +__all__ = ["BaseDestination", "register", "get_destination", "import_destinations"] class BaseDestination(object): @@ -26,7 +21,7 @@ def type(cls): @classmethod def icon(cls): - return 'fa-bullseye' + return "fa-bullseye" @classmethod def enabled(cls): @@ -42,10 +37,10 @@ def notify(self, alert, query, user, new_state, app, host, options): @classmethod def to_dict(cls): return { - 'name': cls.name(), - 'type': cls.type(), - 'icon': cls.icon(), - 'configuration_schema': cls.configuration_schema() + "name": cls.name(), + "type": cls.type(), + "icon": cls.icon(), + "configuration_schema": cls.configuration_schema(), } @@ -55,10 +50,17 @@ def to_dict(cls): def register(destination_class): global destinations if destination_class.enabled(): - logger.debug("Registering %s (%s) destinations.", destination_class.name(), destination_class.type()) + logger.debug( + "Registering %s (%s) destinations.", + destination_class.name(), + destination_class.type(), + ) destinations[destination_class.type()] = destination_class else: - logger.warning("%s destination enabled but not supported, not registering. Either disable or install missing dependencies.", destination_class.name()) + logger.warning( + "%s destination enabled but not supported, not registering. Either disable or install missing dependencies.", + destination_class.name(), + ) def get_destination(destination_type, configuration): diff --git a/redash/destinations/chatwork.py b/redash/destinations/chatwork.py index 24522a0ce4..d3dda8288a 100644 --- a/redash/destinations/chatwork.py +++ b/redash/destinations/chatwork.py @@ -5,64 +5,72 @@ class ChatWork(BaseDestination): - ALERTS_DEFAULT_MESSAGE_TEMPLATE = '{alert_name} changed state to {new_state}.\\n{alert_url}\\n{query_url}' + ALERTS_DEFAULT_MESSAGE_TEMPLATE = ( + "{alert_name} changed state to {new_state}.\\n{alert_url}\\n{query_url}" + ) @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'api_token': { - 'type': 'string', - 'title': 'API Token' + "type": "object", + "properties": { + "api_token": {"type": "string", "title": "API Token"}, + "room_id": {"type": "string", "title": "Room ID"}, + "message_template": { + "type": "string", + "default": ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE, + "title": "Message Template", }, - 'room_id': { - 'type': 'string', - 'title': 'Room ID' - }, - 'message_template': { - 'type': 'string', - 'default': ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE, - 'title': 'Message Template' - } }, - 'required': ['message_template', 'api_token', 'room_id'] + "required": ["message_template", "api_token", "room_id"], } @classmethod def icon(cls): - return 'fa-comment' + return "fa-comment" def notify(self, alert, query, user, new_state, app, host, options): try: # Documentation: http://developer.chatwork.com/ja/endpoint_rooms.html#POST-rooms-room_id-messages - url = 'https://api.chatwork.com/v2/rooms/{room_id}/messages'.format(room_id=options.get('room_id')) + url = "https://api.chatwork.com/v2/rooms/{room_id}/messages".format( + room_id=options.get("room_id") + ) - message = '' + message = "" if alert.custom_subject: - message = alert.custom_subject + '\n' + message = alert.custom_subject + "\n" if alert.custom_body: message += alert.custom_body else: - alert_url = '{host}/alerts/{alert_id}'.format(host=host, alert_id=alert.id) - query_url = '{host}/queries/{query_id}'.format(host=host, query_id=query.id) - message_template = options.get('message_template', ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE) - message += message_template.replace('\\n', '\n').format( + alert_url = "{host}/alerts/{alert_id}".format( + host=host, alert_id=alert.id + ) + query_url = "{host}/queries/{query_id}".format( + host=host, query_id=query.id + ) + message_template = options.get( + "message_template", ChatWork.ALERTS_DEFAULT_MESSAGE_TEMPLATE + ) + message += message_template.replace("\\n", "\n").format( alert_name=alert.name, new_state=new_state.upper(), alert_url=alert_url, - query_url=query_url + query_url=query_url, ) - headers = {'X-ChatWorkToken': options.get('api_token')} - payload = {'body': message} + headers = {"X-ChatWorkToken": options.get("api_token")} + payload = {"body": message} resp = requests.post(url, headers=headers, data=payload, timeout=5.0) logging.warning(resp.text) if resp.status_code != 200: - logging.error('ChatWork send ERROR. status_code => {status}'.format(status=resp.status_code)) + logging.error( + "ChatWork send ERROR. status_code => {status}".format( + status=resp.status_code + ) + ) except Exception: - logging.exception('ChatWork send ERROR.') + logging.exception("ChatWork send ERROR.") register(ChatWork) diff --git a/redash/destinations/email.py b/redash/destinations/email.py index 378eae7e65..b345bd42df 100644 --- a/redash/destinations/email.py +++ b/redash/destinations/email.py @@ -6,31 +6,30 @@ class Email(BaseDestination): - @classmethod def configuration_schema(cls): return { "type": "object", "properties": { - "addresses": { - "type": "string" - }, + "addresses": {"type": "string"}, "subject_template": { "type": "string", "default": settings.ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE, - "title": "Subject Template" - } + "title": "Subject Template", + }, }, "required": ["addresses"], - "extra_options": ["subject_template"] + "extra_options": ["subject_template"], } @classmethod def icon(cls): - return 'fa-envelope' + return "fa-envelope" def notify(self, alert, query, user, new_state, app, host, options): - recipients = [email for email in options.get('addresses', '').split(',') if email] + recipients = [ + email for email in options.get("addresses", "").split(",") if email + ] if not recipients: logging.warning("No emails given. Skipping send.") @@ -41,23 +40,23 @@ def notify(self, alert, query, user, new_state, app, host, options): html = """ Check alert / check query
. - """.format(host=host, alert_id=alert.id, query_id=query.id) + """.format( + host=host, alert_id=alert.id, query_id=query.id + ) logging.debug("Notifying: %s", recipients) try: - alert_name = alert.name.encode('utf-8', 'ignore') + alert_name = alert.name.encode("utf-8", "ignore") state = new_state.upper() if alert.custom_subject: subject = alert.custom_subject else: - subject_template = options.get('subject_template', settings.ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE) + subject_template = options.get( + "subject_template", settings.ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE + ) subject = subject_template.format(alert_name=alert_name, state=state) - message = Message( - recipients=recipients, - subject=subject, - html=html - ) + message = Message(recipients=recipients, subject=subject, html=html) mail.send(message) except Exception: logging.exception("Mail send error.") diff --git a/redash/destinations/hangoutschat.py b/redash/destinations/hangoutschat.py index 0a3063a435..bc52f3de69 100644 --- a/redash/destinations/hangoutschat.py +++ b/redash/destinations/hangoutschat.py @@ -21,28 +21,30 @@ def configuration_schema(cls): "properties": { "url": { "type": "string", - "title": "Webhook URL (get it from the room settings)" + "title": "Webhook URL (get it from the room settings)", }, "icon_url": { "type": "string", - "title": "Icon URL (32x32 or multiple, png format)" - } + "title": "Icon URL (32x32 or multiple, png format)", + }, }, - "required": ["url"] + "required": ["url"], } @classmethod def icon(cls): - return 'fa-bolt' + return "fa-bolt" def notify(self, alert, query, user, new_state, app, host, options): try: if new_state == "triggered": - message = "Triggered" + message = 'Triggered' elif new_state == "ok": - message = "Went back to normal" + message = 'Went back to normal' else: - message = "Unable to determine status. Check Query and Alert configuration." + message = ( + "Unable to determine status. Check Query and Alert configuration." + ) if alert.custom_subject: title = alert.custom_subject @@ -52,59 +54,53 @@ def notify(self, alert, query, user, new_state, app, host, options): data = { "cards": [ { - "header": { - "title": title - }, + "header": {"title": title}, "sections": [ - { - "widgets": [ - { - "textParagraph": { - "text": message - } - } - ] - } - ] + {"widgets": [{"textParagraph": {"text": message}}]} + ], } ] } if alert.custom_body: - data["cards"][0]["sections"].append({ - "widgets": [ - { - "textParagraph": { - "text": alert.custom_body - } - } - ] - }) + data["cards"][0]["sections"].append( + {"widgets": [{"textParagraph": {"text": alert.custom_body}}]} + ) if options.get("icon_url"): data["cards"][0]["header"]["imageUrl"] = options.get("icon_url") # Hangouts Chat will create a blank card if an invalid URL (no hostname) is posted. if host: - data["cards"][0]["sections"][0]["widgets"].append({ - "buttons": [ - { - "textButton": { - "text": "OPEN QUERY", - "onClick": { - "openLink": { - "url": "{host}/queries/{query_id}".format(host=host, query_id=query.id) - } + data["cards"][0]["sections"][0]["widgets"].append( + { + "buttons": [ + { + "textButton": { + "text": "OPEN QUERY", + "onClick": { + "openLink": { + "url": "{host}/queries/{query_id}".format( + host=host, query_id=query.id + ) + } + }, } } - } - ] - }) + ] + } + ) headers = {"Content-Type": "application/json; charset=UTF-8"} - resp = requests.post(options.get("url"), data=json_dumps(data), headers=headers, timeout=5.0) + resp = requests.post( + options.get("url"), data=json_dumps(data), headers=headers, timeout=5.0 + ) if resp.status_code != 200: - logging.error("webhook send ERROR. status_code => {status}".format(status=resp.status_code)) + logging.error( + "webhook send ERROR. status_code => {status}".format( + status=resp.status_code + ) + ) except Exception: logging.exception("webhook send ERROR.") diff --git a/redash/destinations/hipchat.py b/redash/destinations/hipchat.py index 8d41e97d9a..add7ee1a60 100644 --- a/redash/destinations/hipchat.py +++ b/redash/destinations/hipchat.py @@ -7,9 +7,9 @@ colors = { - Alert.OK_STATE: 'green', - Alert.TRIGGERED_STATE: 'red', - Alert.UNKNOWN_STATE: 'yellow' + Alert.OK_STATE: "green", + Alert.TRIGGERED_STATE: "red", + Alert.UNKNOWN_STATE: "yellow", } @@ -22,35 +22,38 @@ def configuration_schema(cls): "properties": { "url": { "type": "string", - "title": "HipChat Notification URL (get it from the Integrations page)" - }, + "title": "HipChat Notification URL (get it from the Integrations page)", + } }, - "required": ["url"] + "required": ["url"], } @classmethod def icon(cls): - return 'fa-comment-o' + return "fa-comment-o" def notify(self, alert, query, user, new_state, app, host, options): try: - alert_url = '{host}/alerts/{alert_id}'.format(host=host, alert_id=alert.id) - query_url = '{host}/queries/{query_id}'.format(host=host, query_id=query.id) + alert_url = "{host}/alerts/{alert_id}".format(host=host, alert_id=alert.id) + query_url = "{host}/queries/{query_id}".format(host=host, query_id=query.id) message = '{alert_name} changed state to {new_state} (based on this query).'.format( - alert_name=alert.name, new_state=new_state.upper(), + alert_name=alert.name, + new_state=new_state.upper(), alert_url=alert_url, - query_url=query_url) + query_url=query_url, + ) - data = { - 'message': message, - 'color': colors.get(new_state, 'green') - } - headers = {'Content-Type': 'application/json'} - response = requests.post(options['url'], data=json_dumps(data), headers=headers, timeout=5.0) + data = {"message": message, "color": colors.get(new_state, "green")} + headers = {"Content-Type": "application/json"} + response = requests.post( + options["url"], data=json_dumps(data), headers=headers, timeout=5.0 + ) if response.status_code != 204: - logging.error('Bad status code received from HipChat: %d', response.status_code) + logging.error( + "Bad status code received from HipChat: %d", response.status_code + ) except Exception: logging.exception("HipChat Send ERROR.") diff --git a/redash/destinations/mattermost.py b/redash/destinations/mattermost.py index 2765f4fe55..5d601ff6ff 100644 --- a/redash/destinations/mattermost.py +++ b/redash/destinations/mattermost.py @@ -9,30 +9,18 @@ class Mattermost(BaseDestination): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'title': 'Mattermost Webhook URL' - }, - 'username': { - 'type': 'string', - 'title': 'Username' - }, - 'icon_url': { - 'type': 'string', - 'title': 'Icon (URL)' - }, - 'channel': { - 'type': 'string', - 'title': 'Channel' - } - } + "type": "object", + "properties": { + "url": {"type": "string", "title": "Mattermost Webhook URL"}, + "username": {"type": "string", "title": "Username"}, + "icon_url": {"type": "string", "title": "Icon (URL)"}, + "channel": {"type": "string", "title": "Channel"}, + }, } @classmethod def icon(cls): - return 'fa-bolt' + return "fa-bolt" def notify(self, alert, query, user, new_state, app, host, options): if alert.custom_subject: @@ -41,27 +29,32 @@ def notify(self, alert, query, user, new_state, app, host, options): text = "#### " + alert.name + " just triggered" else: text = "#### " + alert.name + " went back to normal" - payload = {'text': text} + payload = {"text": text} if alert.custom_body: - payload['attachments'] = [{'fields': [{ - "title": "Description", - "value": alert.custom_body - }]}] + payload["attachments"] = [ + {"fields": [{"title": "Description", "value": alert.custom_body}]} + ] - if options.get('username'): - payload['username'] = options.get('username') - if options.get('icon_url'): - payload['icon_url'] = options.get('icon_url') - if options.get('channel'): - payload['channel'] = options.get('channel') + if options.get("username"): + payload["username"] = options.get("username") + if options.get("icon_url"): + payload["icon_url"] = options.get("icon_url") + if options.get("channel"): + payload["channel"] = options.get("channel") try: - resp = requests.post(options.get('url'), data=json_dumps(payload), timeout=5.0) + resp = requests.post( + options.get("url"), data=json_dumps(payload), timeout=5.0 + ) logging.warning(resp.text) if resp.status_code != 200: - logging.error("Mattermost webhook send ERROR. status_code => {status}".format(status=resp.status_code)) + logging.error( + "Mattermost webhook send ERROR. status_code => {status}".format( + status=resp.status_code + ) + ) except Exception: logging.exception("Mattermost webhook send ERROR.") diff --git a/redash/destinations/pagerduty.py b/redash/destinations/pagerduty.py index 3ed49ad017..3a844fa10d 100644 --- a/redash/destinations/pagerduty.py +++ b/redash/destinations/pagerduty.py @@ -11,8 +11,8 @@ class PagerDuty(BaseDestination): - KEY_STRING = '{alert_id}_{query_id}' - DESCRIPTION_STR = 'Alert: {alert_name}' + KEY_STRING = "{alert_id}_{query_id}" + DESCRIPTION_STR = "Alert: {alert_name}" @classmethod def enabled(cls): @@ -21,55 +21,55 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'integration_key': { - 'type': 'string', - 'title': 'PagerDuty Service Integration Key' + "type": "object", + "properties": { + "integration_key": { + "type": "string", + "title": "PagerDuty Service Integration Key", + }, + "description": { + "type": "string", + "title": "Description for the event, defaults to alert name", }, - 'description': { - 'type': 'string', - 'title': 'Description for the event, defaults to alert name', - } }, - "required": ["integration_key"] + "required": ["integration_key"], } @classmethod def icon(cls): - return 'creative-commons-pd-alt' + return "creative-commons-pd-alt" def notify(self, alert, query, user, new_state, app, host, options): if alert.custom_subject: default_desc = alert.custom_subject - elif options.get('description'): - default_desc = options.get('description') + elif options.get("description"): + default_desc = options.get("description") else: default_desc = self.DESCRIPTION_STR.format(alert_name=alert.name) incident_key = self.KEY_STRING.format(alert_id=alert.id, query_id=query.id) data = { - 'routing_key': options.get('integration_key'), - 'incident_key': incident_key, - 'dedup_key': incident_key, - 'payload': { - 'summary': default_desc, - 'severity': 'error', - 'source': 'redash', - } + "routing_key": options.get("integration_key"), + "incident_key": incident_key, + "dedup_key": incident_key, + "payload": { + "summary": default_desc, + "severity": "error", + "source": "redash", + }, } if alert.custom_body: - data['payload']['custom_details'] = alert.custom_body + data["payload"]["custom_details"] = alert.custom_body - if new_state == 'triggered': - data['event_action'] = 'trigger' + if new_state == "triggered": + data["event_action"] = "trigger" elif new_state == "unknown": - logging.info('Unknown state, doing nothing') + logging.info("Unknown state, doing nothing") return else: - data['event_action'] = 'resolve' + data["event_action"] = "resolve" try: diff --git a/redash/destinations/slack.py b/redash/destinations/slack.py index 4c10c8e73b..edbd6f2c2f 100644 --- a/redash/destinations/slack.py +++ b/redash/destinations/slack.py @@ -9,54 +9,40 @@ class Slack(BaseDestination): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'title': 'Slack Webhook URL' - }, - 'username': { - 'type': 'string', - 'title': 'Username' - }, - 'icon_emoji': { - 'type': 'string', - 'title': 'Icon (Emoji)' - }, - 'icon_url': { - 'type': 'string', - 'title': 'Icon (URL)' - }, - 'channel': { - 'type': 'string', - 'title': 'Channel' - } - } + "type": "object", + "properties": { + "url": {"type": "string", "title": "Slack Webhook URL"}, + "username": {"type": "string", "title": "Username"}, + "icon_emoji": {"type": "string", "title": "Icon (Emoji)"}, + "icon_url": {"type": "string", "title": "Icon (URL)"}, + "channel": {"type": "string", "title": "Channel"}, + }, } @classmethod def icon(cls): - return 'fa-slack' + return "fa-slack" def notify(self, alert, query, user, new_state, app, host, options): # Documentation: https://api.slack.com/docs/attachments fields = [ { "title": "Query", - "value": "{host}/queries/{query_id}".format(host=host, query_id=query.id), - "short": True + "value": "{host}/queries/{query_id}".format( + host=host, query_id=query.id + ), + "short": True, }, { "title": "Alert", - "value": "{host}/alerts/{alert_id}".format(host=host, alert_id=alert.id), - "short": True - } + "value": "{host}/alerts/{alert_id}".format( + host=host, alert_id=alert.id + ), + "short": True, + }, ] if alert.custom_body: - fields.append({ - "title": "Description", - "value": alert.custom_body - }) + fields.append({"title": "Description", "value": alert.custom_body}) if new_state == "triggered": if alert.custom_subject: text = alert.custom_subject @@ -67,18 +53,28 @@ def notify(self, alert, query, user, new_state, app, host, options): text = alert.name + " went back to normal" color = "#27ae60" - payload = {'attachments': [{'text': text, 'color': color, 'fields': fields}]} + payload = {"attachments": [{"text": text, "color": color, "fields": fields}]} - if options.get('username'): payload['username'] = options.get('username') - if options.get('icon_emoji'): payload['icon_emoji'] = options.get('icon_emoji') - if options.get('icon_url'): payload['icon_url'] = options.get('icon_url') - if options.get('channel'): payload['channel'] = options.get('channel') + if options.get("username"): + payload["username"] = options.get("username") + if options.get("icon_emoji"): + payload["icon_emoji"] = options.get("icon_emoji") + if options.get("icon_url"): + payload["icon_url"] = options.get("icon_url") + if options.get("channel"): + payload["channel"] = options.get("channel") try: - resp = requests.post(options.get('url'), data=json_dumps(payload), timeout=5.0) + resp = requests.post( + options.get("url"), data=json_dumps(payload), timeout=5.0 + ) logging.warning(resp.text) if resp.status_code != 200: - logging.error("Slack send ERROR. status_code => {status}".format(status=resp.status_code)) + logging.error( + "Slack send ERROR. status_code => {status}".format( + status=resp.status_code + ) + ) except Exception: logging.exception("Slack send ERROR.") diff --git a/redash/destinations/webhook.py b/redash/destinations/webhook.py index 42144ff3fa..83581e9e37 100644 --- a/redash/destinations/webhook.py +++ b/redash/destinations/webhook.py @@ -13,40 +13,48 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "url": { - "type": "string", - }, - "username": { - "type": "string" - }, - "password": { - "type": "string" - } + "url": {"type": "string"}, + "username": {"type": "string"}, + "password": {"type": "string"}, }, "required": ["url"], - "secret": ["password"] + "secret": ["password"], } @classmethod def icon(cls): - return 'fa-bolt' + return "fa-bolt" def notify(self, alert, query, user, new_state, app, host, options): try: data = { - 'event': 'alert_state_change', - 'alert': serialize_alert(alert, full=False), - 'url_base': host, + "event": "alert_state_change", + "alert": serialize_alert(alert, full=False), + "url_base": host, } - data['alert']['description'] = alert.custom_body - data['alert']['title'] = alert.custom_subject - - headers = {'Content-Type': 'application/json'} - auth = HTTPBasicAuth(options.get('username'), options.get('password')) if options.get('username') else None - resp = requests.post(options.get('url'), data=json_dumps(data), auth=auth, headers=headers, timeout=5.0) + data["alert"]["description"] = alert.custom_body + data["alert"]["title"] = alert.custom_subject + + headers = {"Content-Type": "application/json"} + auth = ( + HTTPBasicAuth(options.get("username"), options.get("password")) + if options.get("username") + else None + ) + resp = requests.post( + options.get("url"), + data=json_dumps(data), + auth=auth, + headers=headers, + timeout=5.0, + ) if resp.status_code != 200: - logging.error("webhook send ERROR. status_code => {status}".format(status=resp.status_code)) + logging.error( + "webhook send ERROR. status_code => {status}".format( + status=resp.status_code + ) + ) except Exception: logging.exception("webhook send ERROR.") diff --git a/redash/handlers/__init__.py b/redash/handlers/__init__.py index 8de5cedaac..a7a05ed0f7 100644 --- a/redash/handlers/__init__.py +++ b/redash/handlers/__init__.py @@ -8,13 +8,13 @@ from redash.security import talisman -@routes.route('/ping', methods=['GET']) +@routes.route("/ping", methods=["GET"]) @talisman(force_https=False) def ping(): - return 'PONG.' + return "PONG." -@routes.route('/status.json') +@routes.route("/status.json") @login_required @require_super_admin def status_api(): @@ -23,6 +23,15 @@ def status_api(): def init_app(app): - from redash.handlers import embed, queries, static, authentication, admin, setup, organization + from redash.handlers import ( + embed, + queries, + static, + authentication, + admin, + setup, + organization, + ) + app.register_blueprint(routes) api.init_app(app) diff --git a/redash/handlers/admin.py b/redash/handlers/admin.py index ba8dd30ec9..f07e38f776 100644 --- a/redash/handlers/admin.py +++ b/redash/handlers/admin.py @@ -11,56 +11,59 @@ from redash.monitor import celery_tasks, rq_status -@routes.route('/api/admin/queries/outdated', methods=['GET']) +@routes.route("/api/admin/queries/outdated", methods=["GET"]) @require_super_admin @login_required def outdated_queries(): - manager_status = redis_connection.hgetall('redash:status') - query_ids = json_loads(manager_status.get('query_ids', '[]')) + manager_status = redis_connection.hgetall("redash:status") + query_ids = json_loads(manager_status.get("query_ids", "[]")) if query_ids: outdated_queries = ( models.Query.query.outerjoin(models.QueryResult) - .filter(models.Query.id.in_(query_ids)) - .order_by(models.Query.created_at.desc()) + .filter(models.Query.id.in_(query_ids)) + .order_by(models.Query.created_at.desc()) ) else: outdated_queries = [] - record_event(current_org, current_user._get_current_object(), { - 'action': 'list', - 'object_type': 'outdated_queries', - }) + record_event( + current_org, + current_user._get_current_object(), + {"action": "list", "object_type": "outdated_queries"}, + ) response = { - 'queries': QuerySerializer(outdated_queries, with_stats=True, with_last_modified_by=False).serialize(), - 'updated_at': manager_status['last_refresh_at'], + "queries": QuerySerializer( + outdated_queries, with_stats=True, with_last_modified_by=False + ).serialize(), + "updated_at": manager_status["last_refresh_at"], } return json_response(response) -@routes.route('/api/admin/queries/tasks', methods=['GET']) +@routes.route("/api/admin/queries/tasks", methods=["GET"]) @require_super_admin @login_required def queries_tasks(): - record_event(current_org, current_user._get_current_object(), { - 'action': 'list', - 'object_type': 'celery_tasks' - }) + record_event( + current_org, + current_user._get_current_object(), + {"action": "list", "object_type": "celery_tasks"}, + ) - response = { - 'tasks': celery_tasks(), - } + response = {"tasks": celery_tasks()} return json_response(response) -@routes.route('/api/admin/queries/rq_status', methods=['GET']) +@routes.route("/api/admin/queries/rq_status", methods=["GET"]) @require_super_admin @login_required def queries_rq_status(): - record_event(current_org, current_user._get_current_object(), { - 'action': 'list', - 'object_type': 'rq_status' - }) + record_event( + current_org, + current_user._get_current_object(), + {"action": "list", "object_type": "rq_status"}, + ) return json_response(rq_status()) diff --git a/redash/handlers/alerts.py b/redash/handlers/alerts.py index 03a0d73f04..7929df9c05 100644 --- a/redash/handlers/alerts.py +++ b/redash/handlers/alerts.py @@ -5,43 +5,48 @@ from redash import models from redash.serializers import serialize_alert -from redash.handlers.base import (BaseResource, get_object_or_404, - require_fields) -from redash.permissions import (require_access, require_admin_or_owner, - require_permission, view_only) +from redash.handlers.base import BaseResource, get_object_or_404, require_fields +from redash.permissions import ( + require_access, + require_admin_or_owner, + require_permission, + view_only, +) from redash.utils import json_dumps class AlertResource(BaseResource): def get(self, alert_id): - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + alert = get_object_or_404( + models.Alert.get_by_id_and_org, alert_id, self.current_org + ) require_access(alert, self.current_user, view_only) - self.record_event({ - 'action': 'view', - 'object_id': alert.id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "view", "object_id": alert.id, "object_type": "alert"} + ) return serialize_alert(alert) def post(self, alert_id): req = request.get_json(True) - params = project(req, ('options', 'name', 'query_id', 'rearm')) - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + params = project(req, ("options", "name", "query_id", "rearm")) + alert = get_object_or_404( + models.Alert.get_by_id_and_org, alert_id, self.current_org + ) require_admin_or_owner(alert.user.id) self.update_model(alert, params) models.db.session.commit() - self.record_event({ - 'action': 'edit', - 'object_id': alert.id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "edit", "object_id": alert.id, "object_type": "alert"} + ) return serialize_alert(alert) def delete(self, alert_id): - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + alert = get_object_or_404( + models.Alert.get_by_id_and_org, alert_id, self.current_org + ) require_admin_or_owner(alert.user_id) models.db.session.delete(alert) models.db.session.commit() @@ -49,68 +54,65 @@ def delete(self, alert_id): class AlertMuteResource(BaseResource): def post(self, alert_id): - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + alert = get_object_or_404( + models.Alert.get_by_id_and_org, alert_id, self.current_org + ) require_admin_or_owner(alert.user.id) - alert.options['muted'] = True + alert.options["muted"] = True models.db.session.commit() - self.record_event({ - 'action': 'mute', - 'object_id': alert.id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "mute", "object_id": alert.id, "object_type": "alert"} + ) def delete(self, alert_id): - alert = get_object_or_404(models.Alert.get_by_id_and_org, alert_id, self.current_org) + alert = get_object_or_404( + models.Alert.get_by_id_and_org, alert_id, self.current_org + ) require_admin_or_owner(alert.user.id) - alert.options['muted'] = False + alert.options["muted"] = False models.db.session.commit() - self.record_event({ - 'action': 'unmute', - 'object_id': alert.id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "unmute", "object_id": alert.id, "object_type": "alert"} + ) class AlertListResource(BaseResource): def post(self): req = request.get_json(True) - require_fields(req, ('options', 'name', 'query_id')) + require_fields(req, ("options", "name", "query_id")) - query = models.Query.get_by_id_and_org(req['query_id'], - self.current_org) + query = models.Query.get_by_id_and_org(req["query_id"], self.current_org) require_access(query, self.current_user, view_only) alert = models.Alert( - name=req['name'], + name=req["name"], query_rel=query, user=self.current_user, - rearm=req.get('rearm'), - options=req['options'], + rearm=req.get("rearm"), + options=req["options"], ) models.db.session.add(alert) models.db.session.flush() models.db.session.commit() - self.record_event({ - 'action': 'create', - 'object_id': alert.id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "create", "object_id": alert.id, "object_type": "alert"} + ) return serialize_alert(alert) - @require_permission('list_alerts') + @require_permission("list_alerts") def get(self): - self.record_event({ - 'action': 'list', - 'object_type': 'alert' - }) - return [serialize_alert(alert) for alert in models.Alert.all(group_ids=self.current_user.group_ids)] + self.record_event({"action": "list", "object_type": "alert"}) + return [ + serialize_alert(alert) + for alert in models.Alert.all(group_ids=self.current_user.group_ids) + ] class AlertSubscriptionListResource(BaseResource): @@ -119,22 +121,26 @@ def post(self, alert_id): alert = models.Alert.get_by_id_and_org(alert_id, self.current_org) require_access(alert, self.current_user, view_only) - kwargs = {'alert': alert, 'user': self.current_user} + kwargs = {"alert": alert, "user": self.current_user} - if 'destination_id' in req: - destination = models.NotificationDestination.get_by_id_and_org(req['destination_id'], self.current_org) - kwargs['destination'] = destination + if "destination_id" in req: + destination = models.NotificationDestination.get_by_id_and_org( + req["destination_id"], self.current_org + ) + kwargs["destination"] = destination subscription = models.AlertSubscription(**kwargs) models.db.session.add(subscription) models.db.session.commit() - self.record_event({ - 'action': 'subscribe', - 'object_id': alert_id, - 'object_type': 'alert', - 'destination': req.get('destination_id') - }) + self.record_event( + { + "action": "subscribe", + "object_id": alert_id, + "object_type": "alert", + "destination": req.get("destination_id"), + } + ) d = subscription.to_dict() return d @@ -155,8 +161,6 @@ def delete(self, alert_id, subscriber_id): models.db.session.delete(subscription) models.db.session.commit() - self.record_event({ - 'action': 'unsubscribe', - 'object_id': alert_id, - 'object_type': 'alert' - }) + self.record_event( + {"action": "unsubscribe", "object_id": alert_id, "object_type": "alert"} + ) diff --git a/redash/handlers/api.py b/redash/handlers/api.py index bca2c527d5..ef479ef3c3 100644 --- a/redash/handlers/api.py +++ b/redash/handlers/api.py @@ -2,56 +2,86 @@ from flask_restful import Api from werkzeug.wrappers import Response -from redash.handlers.alerts import (AlertListResource, - AlertResource, AlertMuteResource, - AlertSubscriptionListResource, - AlertSubscriptionResource) +from redash.handlers.alerts import ( + AlertListResource, + AlertResource, + AlertMuteResource, + AlertSubscriptionListResource, + AlertSubscriptionResource, +) from redash.handlers.base import org_scoped_rule -from redash.handlers.dashboards import (DashboardFavoriteListResource, - DashboardListResource, - DashboardResource, - DashboardShareResource, - DashboardTagsResource, - PublicDashboardResource) -from redash.handlers.data_sources import (DataSourceListResource, - DataSourcePauseResource, - DataSourceResource, - DataSourceSchemaResource, - DataSourceTestResource, - DataSourceTypeListResource) -from redash.handlers.destinations import (DestinationListResource, - DestinationResource, - DestinationTypeListResource) +from redash.handlers.dashboards import ( + DashboardFavoriteListResource, + DashboardListResource, + DashboardResource, + DashboardShareResource, + DashboardTagsResource, + PublicDashboardResource, +) +from redash.handlers.data_sources import ( + DataSourceListResource, + DataSourcePauseResource, + DataSourceResource, + DataSourceSchemaResource, + DataSourceTestResource, + DataSourceTypeListResource, +) +from redash.handlers.destinations import ( + DestinationListResource, + DestinationResource, + DestinationTypeListResource, +) from redash.handlers.events import EventsResource -from redash.handlers.favorites import (DashboardFavoriteResource, - QueryFavoriteResource) -from redash.handlers.groups import (GroupDataSourceListResource, - GroupDataSourceResource, GroupListResource, - GroupMemberListResource, - GroupMemberResource, GroupResource) -from redash.handlers.permissions import (CheckPermissionResource, - ObjectPermissionsListResource) -from redash.handlers.queries import (MyQueriesResource, QueryArchiveResource, - QueryFavoriteListResource, - QueryForkResource, QueryListResource, - QueryRecentResource, QueryRefreshResource, - QueryResource, QuerySearchResource, - QueryTagsResource, - QueryRegenerateApiKeyResource) -from redash.handlers.query_results import (JobResource, - QueryResultDropdownResource, - QueryDropdownsResource, - QueryResultListResource, - QueryResultResource) -from redash.handlers.query_snippets import (QuerySnippetListResource, - QuerySnippetResource) +from redash.handlers.favorites import DashboardFavoriteResource, QueryFavoriteResource +from redash.handlers.groups import ( + GroupDataSourceListResource, + GroupDataSourceResource, + GroupListResource, + GroupMemberListResource, + GroupMemberResource, + GroupResource, +) +from redash.handlers.permissions import ( + CheckPermissionResource, + ObjectPermissionsListResource, +) +from redash.handlers.queries import ( + MyQueriesResource, + QueryArchiveResource, + QueryFavoriteListResource, + QueryForkResource, + QueryListResource, + QueryRecentResource, + QueryRefreshResource, + QueryResource, + QuerySearchResource, + QueryTagsResource, + QueryRegenerateApiKeyResource, +) +from redash.handlers.query_results import ( + JobResource, + QueryResultDropdownResource, + QueryDropdownsResource, + QueryResultListResource, + QueryResultResource, +) +from redash.handlers.query_snippets import ( + QuerySnippetListResource, + QuerySnippetResource, +) from redash.handlers.settings import OrganizationSettings -from redash.handlers.users import (UserDisableResource, UserInviteResource, - UserListResource, - UserRegenerateApiKeyResource, - UserResetPasswordResource, UserResource) -from redash.handlers.visualizations import (VisualizationListResource, - VisualizationResource) +from redash.handlers.users import ( + UserDisableResource, + UserInviteResource, + UserListResource, + UserRegenerateApiKeyResource, + UserResetPasswordResource, + UserResource, +) +from redash.handlers.visualizations import ( + VisualizationListResource, + VisualizationResource, +) from redash.handlers.widgets import WidgetListResource, WidgetResource from redash.utils import json_dumps @@ -65,7 +95,7 @@ def add_org_resource(self, resource, *urls, **kwargs): api = ApiExt() -@api.representation('application/json') +@api.representation("application/json") def json_representation(data, code, headers=None): # Flask-Restful checks only for flask.Response but flask-login uses werkzeug.wrappers.Response if isinstance(data, Response): @@ -75,91 +105,211 @@ def json_representation(data, code, headers=None): return resp -api.add_org_resource(AlertResource, '/api/alerts/', endpoint='alert') -api.add_org_resource(AlertMuteResource, '/api/alerts//mute', endpoint='alert_mute') -api.add_org_resource(AlertSubscriptionListResource, '/api/alerts//subscriptions', endpoint='alert_subscriptions') -api.add_org_resource(AlertSubscriptionResource, '/api/alerts//subscriptions/', endpoint='alert_subscription') -api.add_org_resource(AlertListResource, '/api/alerts', endpoint='alerts') - -api.add_org_resource(DashboardListResource, '/api/dashboards', endpoint='dashboards') -api.add_org_resource(DashboardResource, '/api/dashboards/', endpoint='dashboard') -api.add_org_resource(PublicDashboardResource, '/api/dashboards/public/', endpoint='public_dashboard') -api.add_org_resource(DashboardShareResource, '/api/dashboards//share', endpoint='dashboard_share') - -api.add_org_resource(DataSourceTypeListResource, '/api/data_sources/types', endpoint='data_source_types') -api.add_org_resource(DataSourceListResource, '/api/data_sources', endpoint='data_sources') -api.add_org_resource(DataSourceSchemaResource, '/api/data_sources//schema') -api.add_org_resource(DataSourcePauseResource, '/api/data_sources//pause') -api.add_org_resource(DataSourceTestResource, '/api/data_sources//test') -api.add_org_resource(DataSourceResource, '/api/data_sources/', endpoint='data_source') - -api.add_org_resource(GroupListResource, '/api/groups', endpoint='groups') -api.add_org_resource(GroupResource, '/api/groups/', endpoint='group') -api.add_org_resource(GroupMemberListResource, '/api/groups//members', endpoint='group_members') -api.add_org_resource(GroupMemberResource, '/api/groups//members/', endpoint='group_member') -api.add_org_resource(GroupDataSourceListResource, '/api/groups//data_sources', endpoint='group_data_sources') -api.add_org_resource(GroupDataSourceResource, '/api/groups//data_sources/', endpoint='group_data_source') - -api.add_org_resource(EventsResource, '/api/events', endpoint='events') - -api.add_org_resource(QueryFavoriteListResource, '/api/queries/favorites', endpoint='query_favorites') -api.add_org_resource(QueryFavoriteResource, '/api/queries//favorite', endpoint='query_favorite') -api.add_org_resource(DashboardFavoriteListResource, '/api/dashboards/favorites', endpoint='dashboard_favorites') -api.add_org_resource(DashboardFavoriteResource, '/api/dashboards//favorite', endpoint='dashboard_favorite') - -api.add_org_resource(QueryTagsResource, '/api/queries/tags', endpoint='query_tags') -api.add_org_resource(DashboardTagsResource, '/api/dashboards/tags', endpoint='dashboard_tags') - -api.add_org_resource(QuerySearchResource, '/api/queries/search', endpoint='queries_search') -api.add_org_resource(QueryRecentResource, '/api/queries/recent', endpoint='recent_queries') -api.add_org_resource(QueryArchiveResource, '/api/queries/archive', endpoint='queries_archive') -api.add_org_resource(QueryListResource, '/api/queries', endpoint='queries') -api.add_org_resource(MyQueriesResource, '/api/queries/my', endpoint='my_queries') -api.add_org_resource(QueryRefreshResource, '/api/queries//refresh', endpoint='query_refresh') -api.add_org_resource(QueryResource, '/api/queries/', endpoint='query') -api.add_org_resource(QueryForkResource, '/api/queries//fork', endpoint='query_fork') -api.add_org_resource(QueryRegenerateApiKeyResource, - '/api/queries//regenerate_api_key', - endpoint='query_regenerate_api_key') - -api.add_org_resource(ObjectPermissionsListResource, '/api///acl', endpoint='object_permissions') -api.add_org_resource(CheckPermissionResource, '/api///acl/', endpoint='check_permissions') - -api.add_org_resource(QueryResultListResource, '/api/query_results', endpoint='query_results') -api.add_org_resource(QueryResultDropdownResource, '/api/queries//dropdown', endpoint='query_result_dropdown') -api.add_org_resource(QueryDropdownsResource, '/api/queries//dropdowns/', endpoint='query_result_dropdowns') -api.add_org_resource(QueryResultResource, - '/api/query_results/.', - '/api/query_results/', - '/api/queries//results', - '/api/queries//results.', - '/api/queries//results/.', - endpoint='query_result') -api.add_org_resource(JobResource, - '/api/jobs/', - '/api/queries//jobs/', - endpoint='job') - -api.add_org_resource(UserListResource, '/api/users', endpoint='users') -api.add_org_resource(UserResource, '/api/users/', endpoint='user') -api.add_org_resource(UserInviteResource, '/api/users//invite', endpoint='user_invite') -api.add_org_resource(UserResetPasswordResource, '/api/users//reset_password', endpoint='user_reset_password') -api.add_org_resource(UserRegenerateApiKeyResource, - '/api/users//regenerate_api_key', - endpoint='user_regenerate_api_key') -api.add_org_resource(UserDisableResource, '/api/users//disable', endpoint='user_disable') - -api.add_org_resource(VisualizationListResource, '/api/visualizations', endpoint='visualizations') -api.add_org_resource(VisualizationResource, '/api/visualizations/', endpoint='visualization') - -api.add_org_resource(WidgetListResource, '/api/widgets', endpoint='widgets') -api.add_org_resource(WidgetResource, '/api/widgets/', endpoint='widget') - -api.add_org_resource(DestinationTypeListResource, '/api/destinations/types', endpoint='destination_types') -api.add_org_resource(DestinationResource, '/api/destinations/', endpoint='destination') -api.add_org_resource(DestinationListResource, '/api/destinations', endpoint='destinations') - -api.add_org_resource(QuerySnippetResource, '/api/query_snippets/', endpoint='query_snippet') -api.add_org_resource(QuerySnippetListResource, '/api/query_snippets', endpoint='query_snippets') - -api.add_org_resource(OrganizationSettings, '/api/settings/organization', endpoint='organization_settings') +api.add_org_resource(AlertResource, "/api/alerts/", endpoint="alert") +api.add_org_resource( + AlertMuteResource, "/api/alerts//mute", endpoint="alert_mute" +) +api.add_org_resource( + AlertSubscriptionListResource, + "/api/alerts//subscriptions", + endpoint="alert_subscriptions", +) +api.add_org_resource( + AlertSubscriptionResource, + "/api/alerts//subscriptions/", + endpoint="alert_subscription", +) +api.add_org_resource(AlertListResource, "/api/alerts", endpoint="alerts") + +api.add_org_resource(DashboardListResource, "/api/dashboards", endpoint="dashboards") +api.add_org_resource( + DashboardResource, "/api/dashboards/", endpoint="dashboard" +) +api.add_org_resource( + PublicDashboardResource, + "/api/dashboards/public/", + endpoint="public_dashboard", +) +api.add_org_resource( + DashboardShareResource, + "/api/dashboards//share", + endpoint="dashboard_share", +) + +api.add_org_resource( + DataSourceTypeListResource, "/api/data_sources/types", endpoint="data_source_types" +) +api.add_org_resource( + DataSourceListResource, "/api/data_sources", endpoint="data_sources" +) +api.add_org_resource( + DataSourceSchemaResource, "/api/data_sources//schema" +) +api.add_org_resource( + DataSourcePauseResource, "/api/data_sources//pause" +) +api.add_org_resource(DataSourceTestResource, "/api/data_sources//test") +api.add_org_resource( + DataSourceResource, "/api/data_sources/", endpoint="data_source" +) + +api.add_org_resource(GroupListResource, "/api/groups", endpoint="groups") +api.add_org_resource(GroupResource, "/api/groups/", endpoint="group") +api.add_org_resource( + GroupMemberListResource, "/api/groups//members", endpoint="group_members" +) +api.add_org_resource( + GroupMemberResource, + "/api/groups//members/", + endpoint="group_member", +) +api.add_org_resource( + GroupDataSourceListResource, + "/api/groups//data_sources", + endpoint="group_data_sources", +) +api.add_org_resource( + GroupDataSourceResource, + "/api/groups//data_sources/", + endpoint="group_data_source", +) + +api.add_org_resource(EventsResource, "/api/events", endpoint="events") + +api.add_org_resource( + QueryFavoriteListResource, "/api/queries/favorites", endpoint="query_favorites" +) +api.add_org_resource( + QueryFavoriteResource, "/api/queries//favorite", endpoint="query_favorite" +) +api.add_org_resource( + DashboardFavoriteListResource, + "/api/dashboards/favorites", + endpoint="dashboard_favorites", +) +api.add_org_resource( + DashboardFavoriteResource, + "/api/dashboards//favorite", + endpoint="dashboard_favorite", +) + +api.add_org_resource(QueryTagsResource, "/api/queries/tags", endpoint="query_tags") +api.add_org_resource( + DashboardTagsResource, "/api/dashboards/tags", endpoint="dashboard_tags" +) + +api.add_org_resource( + QuerySearchResource, "/api/queries/search", endpoint="queries_search" +) +api.add_org_resource( + QueryRecentResource, "/api/queries/recent", endpoint="recent_queries" +) +api.add_org_resource( + QueryArchiveResource, "/api/queries/archive", endpoint="queries_archive" +) +api.add_org_resource(QueryListResource, "/api/queries", endpoint="queries") +api.add_org_resource(MyQueriesResource, "/api/queries/my", endpoint="my_queries") +api.add_org_resource( + QueryRefreshResource, "/api/queries//refresh", endpoint="query_refresh" +) +api.add_org_resource(QueryResource, "/api/queries/", endpoint="query") +api.add_org_resource( + QueryForkResource, "/api/queries//fork", endpoint="query_fork" +) +api.add_org_resource( + QueryRegenerateApiKeyResource, + "/api/queries//regenerate_api_key", + endpoint="query_regenerate_api_key", +) + +api.add_org_resource( + ObjectPermissionsListResource, + "/api///acl", + endpoint="object_permissions", +) +api.add_org_resource( + CheckPermissionResource, + "/api///acl/", + endpoint="check_permissions", +) + +api.add_org_resource( + QueryResultListResource, "/api/query_results", endpoint="query_results" +) +api.add_org_resource( + QueryResultDropdownResource, + "/api/queries//dropdown", + endpoint="query_result_dropdown", +) +api.add_org_resource( + QueryDropdownsResource, + "/api/queries//dropdowns/", + endpoint="query_result_dropdowns", +) +api.add_org_resource( + QueryResultResource, + "/api/query_results/.", + "/api/query_results/", + "/api/queries//results", + "/api/queries//results.", + "/api/queries//results/.", + endpoint="query_result", +) +api.add_org_resource( + JobResource, + "/api/jobs/", + "/api/queries//jobs/", + endpoint="job", +) + +api.add_org_resource(UserListResource, "/api/users", endpoint="users") +api.add_org_resource(UserResource, "/api/users/", endpoint="user") +api.add_org_resource( + UserInviteResource, "/api/users//invite", endpoint="user_invite" +) +api.add_org_resource( + UserResetPasswordResource, + "/api/users//reset_password", + endpoint="user_reset_password", +) +api.add_org_resource( + UserRegenerateApiKeyResource, + "/api/users//regenerate_api_key", + endpoint="user_regenerate_api_key", +) +api.add_org_resource( + UserDisableResource, "/api/users//disable", endpoint="user_disable" +) + +api.add_org_resource( + VisualizationListResource, "/api/visualizations", endpoint="visualizations" +) +api.add_org_resource( + VisualizationResource, + "/api/visualizations/", + endpoint="visualization", +) + +api.add_org_resource(WidgetListResource, "/api/widgets", endpoint="widgets") +api.add_org_resource(WidgetResource, "/api/widgets/", endpoint="widget") + +api.add_org_resource( + DestinationTypeListResource, "/api/destinations/types", endpoint="destination_types" +) +api.add_org_resource( + DestinationResource, "/api/destinations/", endpoint="destination" +) +api.add_org_resource( + DestinationListResource, "/api/destinations", endpoint="destinations" +) + +api.add_org_resource( + QuerySnippetResource, "/api/query_snippets/", endpoint="query_snippet" +) +api.add_org_resource( + QuerySnippetListResource, "/api/query_snippets", endpoint="query_snippets" +) + +api.add_org_resource( + OrganizationSettings, "/api/settings/organization", endpoint="organization_settings" +) diff --git a/redash/handlers/authentication.py b/redash/handlers/authentication.py index dd1eea6451..8447ac30f0 100644 --- a/redash/handlers/authentication.py +++ b/redash/handlers/authentication.py @@ -5,11 +5,14 @@ from flask_login import current_user, login_required, login_user, logout_user from redash import __version__, limiter, models, settings from redash.authentication import current_org, get_login_url, get_next_path -from redash.authentication.account import (BadSignature, SignatureExpired, - send_password_reset_email, - send_user_disabled_email, - send_verify_email, - validate_token) +from redash.authentication.account import ( + BadSignature, + SignatureExpired, + send_password_reset_email, + send_user_disabled_email, + send_verify_email, + validate_token, +) from redash.handlers import routes from redash.handlers.base import json_response, org_scoped_rule from redash.version_check import get_latest_version @@ -20,9 +23,11 @@ def get_google_auth_url(next_path): if settings.MULTI_ORG: - google_auth_url = url_for('google_oauth.authorize_org', next=next_path, org_slug=current_org.slug) + google_auth_url = url_for( + "google_oauth.authorize_org", next=next_path, org_slug=current_org.slug + ) else: - google_auth_url = url_for('google_oauth.authorize', next=next_path) + google_auth_url = url_for("google_oauth.authorize", next=next_path) return google_auth_url @@ -32,90 +37,125 @@ def render_token_login_page(template, org_slug, token, invite): org = current_org._get_current_object() user = models.User.get_by_id_and_org(user_id, org) except NoResultFound: - logger.exception("Bad user id in token. Token= , User id= %s, Org=%s", user_id, token, org_slug) - return render_template("error.html", error_message="Invalid invite link. Please ask for a new one."), 400 + logger.exception( + "Bad user id in token. Token= , User id= %s, Org=%s", + user_id, + token, + org_slug, + ) + return ( + render_template( + "error.html", + error_message="Invalid invite link. Please ask for a new one.", + ), + 400, + ) except (SignatureExpired, BadSignature): logger.exception("Failed to verify invite token: %s, org=%s", token, org_slug) - return render_template("error.html", - error_message="Your invite link has expired. Please ask for a new one."), 400 - - if invite and user.details.get('is_invitation_pending') is False: - return render_template("error.html", - error_message=("This invitation has already been accepted. " - "Please try resetting your password instead.")), 400 + return ( + render_template( + "error.html", + error_message="Your invite link has expired. Please ask for a new one.", + ), + 400, + ) + + if invite and user.details.get("is_invitation_pending") is False: + return ( + render_template( + "error.html", + error_message=( + "This invitation has already been accepted. " + "Please try resetting your password instead." + ), + ), + 400, + ) status_code = 200 - if request.method == 'POST': - if 'password' not in request.form: - flash('Bad Request') + if request.method == "POST": + if "password" not in request.form: + flash("Bad Request") status_code = 400 - elif not request.form['password']: - flash('Cannot use empty password.') + elif not request.form["password"]: + flash("Cannot use empty password.") status_code = 400 - elif len(request.form['password']) < 6: - flash('Password length is too short (<6).') + elif len(request.form["password"]) < 6: + flash("Password length is too short (<6).") status_code = 400 else: if invite: user.is_invitation_pending = False - user.hash_password(request.form['password']) + user.hash_password(request.form["password"]) models.db.session.add(user) login_user(user) models.db.session.commit() - return redirect(url_for('redash.index', org_slug=org_slug)) - - google_auth_url = get_google_auth_url(url_for('redash.index', org_slug=org_slug)) - - return render_template(template, - show_google_openid=settings.GOOGLE_OAUTH_ENABLED, - google_auth_url=google_auth_url, - show_saml_login=current_org.get_setting('auth_saml_enabled'), - show_remote_user_login=settings.REMOTE_USER_LOGIN_ENABLED, - show_ldap_login=settings.LDAP_LOGIN_ENABLED, - org_slug=org_slug, - user=user), status_code - - -@routes.route(org_scoped_rule('/invite/'), methods=['GET', 'POST']) + return redirect(url_for("redash.index", org_slug=org_slug)) + + google_auth_url = get_google_auth_url(url_for("redash.index", org_slug=org_slug)) + + return ( + render_template( + template, + show_google_openid=settings.GOOGLE_OAUTH_ENABLED, + google_auth_url=google_auth_url, + show_saml_login=current_org.get_setting("auth_saml_enabled"), + show_remote_user_login=settings.REMOTE_USER_LOGIN_ENABLED, + show_ldap_login=settings.LDAP_LOGIN_ENABLED, + org_slug=org_slug, + user=user, + ), + status_code, + ) + + +@routes.route(org_scoped_rule("/invite/"), methods=["GET", "POST"]) def invite(token, org_slug=None): return render_token_login_page("invite.html", org_slug, token, True) -@routes.route(org_scoped_rule('/reset/'), methods=['GET', 'POST']) +@routes.route(org_scoped_rule("/reset/"), methods=["GET", "POST"]) def reset(token, org_slug=None): return render_token_login_page("reset.html", org_slug, token, False) -@routes.route(org_scoped_rule('/verify/'), methods=['GET']) +@routes.route(org_scoped_rule("/verify/"), methods=["GET"]) def verify(token, org_slug=None): try: user_id = validate_token(token) org = current_org._get_current_object() user = models.User.get_by_id_and_org(user_id, org) except (BadSignature, NoResultFound): - logger.exception("Failed to verify email verification token: %s, org=%s", token, org_slug) - return render_template("error.html", - error_message="Your verification link is invalid. Please ask for a new one."), 400 + logger.exception( + "Failed to verify email verification token: %s, org=%s", token, org_slug + ) + return ( + render_template( + "error.html", + error_message="Your verification link is invalid. Please ask for a new one.", + ), + 400, + ) user.is_email_verified = True models.db.session.add(user) models.db.session.commit() template_context = {"org_slug": org_slug} if settings.MULTI_ORG else {} - next_url = url_for('redash.index', **template_context) + next_url = url_for("redash.index", **template_context) return render_template("verify.html", next_url=next_url) -@routes.route(org_scoped_rule('/forgot'), methods=['GET', 'POST']) +@routes.route(org_scoped_rule("/forgot"), methods=["GET", "POST"]) def forgot_password(org_slug=None): - if not current_org.get_setting('auth_password_login_enabled'): + if not current_org.get_setting("auth_password_login_enabled"): abort(404) submitted = False - if request.method == 'POST' and request.form['email']: + if request.method == "POST" and request.form["email"]: submitted = True - email = request.form['email'] + email = request.form["email"] try: org = current_org._get_current_object() user = models.User.get_by_email_and_org(email, org) @@ -129,38 +169,44 @@ def forgot_password(org_slug=None): return render_template("forgot.html", submitted=submitted) -@routes.route(org_scoped_rule('/verification_email/'), methods=['POST']) +@routes.route(org_scoped_rule("/verification_email/"), methods=["POST"]) def verification_email(org_slug=None): if not current_user.is_email_verified: send_verify_email(current_user, current_org) - return json_response({ - "message": "Please check your email inbox in order to verify your email address." - }) + return json_response( + { + "message": "Please check your email inbox in order to verify your email address." + } + ) -@routes.route(org_scoped_rule('/login'), methods=['GET', 'POST']) +@routes.route(org_scoped_rule("/login"), methods=["GET", "POST"]) @limiter.limit(settings.THROTTLE_LOGIN_PATTERN) def login(org_slug=None): # We intentionally use == as otherwise it won't actually use the proxy. So weird :O # noinspection PyComparisonWithNone if current_org == None and not settings.MULTI_ORG: - return redirect('/setup') + return redirect("/setup") elif current_org == None: - return redirect('/') + return redirect("/") - index_url = url_for('redash.index', org_slug=org_slug) - unsafe_next_path = request.args.get('next', index_url) + index_url = url_for("redash.index", org_slug=org_slug) + unsafe_next_path = request.args.get("next", index_url) next_path = get_next_path(unsafe_next_path) if current_user.is_authenticated: return redirect(next_path) - if request.method == 'POST': + if request.method == "POST": try: org = current_org._get_current_object() - user = models.User.get_by_email_and_org(request.form['email'], org) - if user and not user.is_disabled and user.verify_password(request.form['password']): - remember = ('remember' in request.form) + user = models.User.get_by_email_and_org(request.form["email"], org) + if ( + user + and not user.is_disabled + and user.verify_password(request.form["password"]) + ): + remember = "remember" in request.form login_user(user, remember=remember) return redirect(next_path) else: @@ -170,19 +216,21 @@ def login(org_slug=None): google_auth_url = get_google_auth_url(next_path) - return render_template("login.html", - org_slug=org_slug, - next=next_path, - email=request.form.get('email', ''), - show_google_openid=settings.GOOGLE_OAUTH_ENABLED, - google_auth_url=google_auth_url, - show_password_login=current_org.get_setting('auth_password_login_enabled'), - show_saml_login=current_org.get_setting('auth_saml_enabled'), - show_remote_user_login=settings.REMOTE_USER_LOGIN_ENABLED, - show_ldap_login=settings.LDAP_LOGIN_ENABLED) - - -@routes.route(org_scoped_rule('/logout')) + return render_template( + "login.html", + org_slug=org_slug, + next=next_path, + email=request.form.get("email", ""), + show_google_openid=settings.GOOGLE_OAUTH_ENABLED, + google_auth_url=google_auth_url, + show_password_login=current_org.get_setting("auth_password_login_enabled"), + show_saml_login=current_org.get_setting("auth_saml_enabled"), + show_remote_user_login=settings.REMOTE_USER_LOGIN_ENABLED, + show_ldap_login=settings.LDAP_LOGIN_ENABLED, + ) + + +@routes.route(org_scoped_rule("/logout")) def logout(org_slug=None): logout_user() return redirect(get_login_url(next=None)) @@ -190,64 +238,67 @@ def logout(org_slug=None): def base_href(): if settings.MULTI_ORG: - base_href = url_for('redash.index', _external=True, org_slug=current_org.slug) + base_href = url_for("redash.index", _external=True, org_slug=current_org.slug) else: - base_href = url_for('redash.index', _external=True) + base_href = url_for("redash.index", _external=True) return base_href def date_time_format_config(): - date_format = current_org.get_setting('date_format') + date_format = current_org.get_setting("date_format") date_format_list = set(["DD/MM/YY", "MM/DD/YY", "YYYY-MM-DD", settings.DATE_FORMAT]) - time_format = current_org.get_setting('time_format') + time_format = current_org.get_setting("time_format") time_format_list = set(["HH:mm", "HH:mm:ss", "HH:mm:ss.SSS", settings.TIME_FORMAT]) return { - 'dateFormat': date_format, - 'dateFormatList': list(date_format_list), - 'timeFormatList': list(time_format_list), - 'dateTimeFormat': "{0} {1}".format(date_format, time_format), + "dateFormat": date_format, + "dateFormatList": list(date_format_list), + "timeFormatList": list(time_format_list), + "dateTimeFormat": "{0} {1}".format(date_format, time_format), } def number_format_config(): return { - 'integerFormat': current_org.get_setting('integer_format'), - 'floatFormat': current_org.get_setting('float_format'), + "integerFormat": current_org.get_setting("integer_format"), + "floatFormat": current_org.get_setting("float_format"), } def client_config(): if not current_user.is_api_user() and current_user.is_authenticated: client_config = { - 'newVersionAvailable': bool(get_latest_version()), - 'version': __version__ + "newVersionAvailable": bool(get_latest_version()), + "version": __version__, } else: client_config = {} - - if current_user.has_permission('admin') and current_org.get_setting('beacon_consent') is None: - client_config['showBeaconConsentMessage'] = True + + if ( + current_user.has_permission("admin") + and current_org.get_setting("beacon_consent") is None + ): + client_config["showBeaconConsentMessage"] = True defaults = { - 'allowScriptsInUserInput': settings.ALLOW_SCRIPTS_IN_USER_INPUT, - 'showPermissionsControl': current_org.get_setting("feature_show_permissions_control"), - 'allowCustomJSVisualizations': settings.FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS, - 'autoPublishNamedQueries': settings.FEATURE_AUTO_PUBLISH_NAMED_QUERIES, - 'extendedAlertOptions': settings.FEATURE_EXTENDED_ALERT_OPTIONS, - 'mailSettingsMissing': not settings.email_server_is_configured(), - 'dashboardRefreshIntervals': settings.DASHBOARD_REFRESH_INTERVALS, - 'queryRefreshIntervals': settings.QUERY_REFRESH_INTERVALS, - 'googleLoginEnabled': settings.GOOGLE_OAUTH_ENABLED, - 'pageSize': settings.PAGE_SIZE, - 'pageSizeOptions': settings.PAGE_SIZE_OPTIONS, - 'tableCellMaxJSONSize': settings.TABLE_CELL_MAX_JSON_SIZE, + "allowScriptsInUserInput": settings.ALLOW_SCRIPTS_IN_USER_INPUT, + "showPermissionsControl": current_org.get_setting( + "feature_show_permissions_control" + ), + "allowCustomJSVisualizations": settings.FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS, + "autoPublishNamedQueries": settings.FEATURE_AUTO_PUBLISH_NAMED_QUERIES, + "extendedAlertOptions": settings.FEATURE_EXTENDED_ALERT_OPTIONS, + "mailSettingsMissing": not settings.email_server_is_configured(), + "dashboardRefreshIntervals": settings.DASHBOARD_REFRESH_INTERVALS, + "queryRefreshIntervals": settings.QUERY_REFRESH_INTERVALS, + "googleLoginEnabled": settings.GOOGLE_OAUTH_ENABLED, + "pageSize": settings.PAGE_SIZE, + "pageSizeOptions": settings.PAGE_SIZE_OPTIONS, + "tableCellMaxJSONSize": settings.TABLE_CELL_MAX_JSON_SIZE, } client_config.update(defaults) - client_config.update({ - 'basePath': base_href() - }) + client_config.update({"basePath": base_href()}) client_config.update(date_time_format_config()) client_config.update(number_format_config()) @@ -258,43 +309,41 @@ def messages(): messages = [] if not current_user.is_email_verified: - messages.append('email-not-verified') + messages.append("email-not-verified") if settings.ALLOW_PARAMETERS_IN_EMBEDS: - messages.append('using-deprecated-embed-feature') + messages.append("using-deprecated-embed-feature") return messages -@routes.route('/api/config', methods=['GET']) +@routes.route("/api/config", methods=["GET"]) def config(org_slug=None): - return json_response({ - 'org_slug': current_org.slug, - 'client_config': client_config() - }) + return json_response( + {"org_slug": current_org.slug, "client_config": client_config()} + ) -@routes.route(org_scoped_rule('/api/session'), methods=['GET']) +@routes.route(org_scoped_rule("/api/session"), methods=["GET"]) @login_required def session(org_slug=None): if current_user.is_api_user(): - user = { - 'permissions': [], - 'apiKey': current_user.id - } + user = {"permissions": [], "apiKey": current_user.id} else: user = { - 'profile_image_url': current_user.profile_image_url, - 'id': current_user.id, - 'name': current_user.name, - 'email': current_user.email, - 'groups': current_user.group_ids, - 'permissions': current_user.permissions + "profile_image_url": current_user.profile_image_url, + "id": current_user.id, + "name": current_user.name, + "email": current_user.email, + "groups": current_user.group_ids, + "permissions": current_user.permissions, } - return json_response({ - 'user': user, - 'messages': messages(), - 'org_slug': current_org.slug, - 'client_config': client_config() - }) + return json_response( + { + "user": user, + "messages": messages(), + "org_slug": current_org.slug, + "client_config": client_config(), + } + ) diff --git a/redash/handlers/base.py b/redash/handlers/base.py index e4bc8e66b2..26db713003 100644 --- a/redash/handlers/base.py +++ b/redash/handlers/base.py @@ -15,7 +15,9 @@ from sqlalchemy.dialects import postgresql from sqlalchemy_utils import sort_query -routes = Blueprint('redash', __name__, template_folder=settings.fix_assets_path('templates')) +routes = Blueprint( + "redash", __name__, template_folder=settings.fix_assets_path("templates") +) class BaseResource(Resource): @@ -26,7 +28,7 @@ def __init__(self, *args, **kwargs): self._user = None def dispatch_request(self, *args, **kwargs): - kwargs.pop('org_slug', None) + kwargs.pop("org_slug", None) return super(BaseResource, self).dispatch_request(*args, **kwargs) @@ -49,24 +51,14 @@ def update_model(self, model, updates): def record_event(org, user, options): if user.is_api_user(): - options.update({ - 'api_key': user.name, - 'org_id': org.id - }) + options.update({"api_key": user.name, "org_id": org.id}) else: - options.update({ - 'user_id': user.id, - 'user_name': user.name, - 'org_id': org.id - }) + options.update({"user_id": user.id, "user_name": user.name, "org_id": org.id}) - options.update({ - 'user_agent': request.user_agent.string, - 'ip': request.remote_addr - }) + options.update({"user_agent": request.user_agent.string, "ip": request.remote_addr}) - if 'timestamp' not in options: - options['timestamp'] = int(time.time()) + if "timestamp" not in options: + options["timestamp"] = int(time.time()) record_event_task.delay(options) @@ -91,13 +83,13 @@ def paginate(query_set, page, page_size, serializer, **kwargs): count = query_set.count() if page < 1: - abort(400, message='Page must be positive integer.') + abort(400, message="Page must be positive integer.") if (page - 1) * page_size + 1 > count > 0: - abort(400, message='Page is out of range.') + abort(400, message="Page is out of range.") if page_size > 250 or page_size < 1: - abort(400, message='Page size is out of range (1-250).') + abort(400, message="Page size is out of range (1-250).") results = query_set.paginate(page, page_size) @@ -107,12 +99,7 @@ def paginate(query_set, page, page_size, serializer, **kwargs): else: items = [serializer(result) for result in results.items] - return { - 'count': count, - 'page': page, - 'page_size': page_size, - 'results': items, - } + return {"count": count, "page": page, "page_size": page_size, "results": items} def org_scoped_rule(rule): @@ -123,13 +110,15 @@ def org_scoped_rule(rule): def json_response(response): - return current_app.response_class(json_dumps(response), mimetype='application/json') + return current_app.response_class(json_dumps(response), mimetype="application/json") def filter_by_tags(result_set, column): - if request.args.getlist('tags'): - tags = request.args.getlist('tags') - result_set = result_set.filter(cast(column, postgresql.ARRAY(db.Text)).contains(tags)) + if request.args.getlist("tags"): + tags = request.args.getlist("tags") + result_set = result_set.filter( + cast(column, postgresql.ARRAY(db.Text)).contains(tags) + ) return result_set @@ -139,7 +128,7 @@ def order_results(results, default_order, allowed_orders, fallback=True): "order" request query parameter or the given default order. """ # See if a particular order has been requested - requested_order = request.args.get('order', '').strip() + requested_order = request.args.get("order", "").strip() # and if not (and no fallback is wanted) return results as is if not requested_order and not fallback: diff --git a/redash/handlers/dashboards.py b/redash/handlers/dashboards.py index 954e9da4f7..d5b7488f94 100644 --- a/redash/handlers/dashboards.py +++ b/redash/handlers/dashboards.py @@ -3,12 +3,19 @@ from flask_restful import abort from redash import models, serializers -from redash.handlers.base import (BaseResource, get_object_or_404, paginate, - filter_by_tags, - order_results as _order_results) -from redash.permissions import (can_modify, require_admin_or_owner, - require_object_modify_permission, - require_permission) +from redash.handlers.base import ( + BaseResource, + get_object_or_404, + paginate, + filter_by_tags, + order_results as _order_results, +) +from redash.permissions import ( + can_modify, + require_admin_or_owner, + require_object_modify_permission, + require_permission, +) from redash.security import csp_allows_embeding from redash.serializers import serialize_dashboard from sqlalchemy.orm.exc import StaleDataError @@ -16,21 +23,19 @@ # Ordering map for relationships order_map = { - 'name': 'lowercase_name', - '-name': '-lowercase_name', - 'created_at': 'created_at', - '-created_at': '-created_at', + "name": "lowercase_name", + "-name": "-lowercase_name", + "created_at": "created_at", + "-created_at": "-created_at", } order_results = partial( - _order_results, - default_order='-created_at', - allowed_orders=order_map, + _order_results, default_order="-created_at", allowed_orders=order_map ) class DashboardListResource(BaseResource): - @require_permission('list_dashboards') + @require_permission("list_dashboards") def get(self): """ Lists all accessible dashboards. @@ -43,7 +48,7 @@ def get(self): Responds with an array of :ref:`dashboard ` objects. """ - search_term = request.args.get('q') + search_term = request.args.get("q") if search_term: results = models.Dashboard.search( @@ -54,9 +59,7 @@ def get(self): ) else: results = models.Dashboard.all( - self.current_org, - self.current_user.group_ids, - self.current_user.id, + self.current_org, self.current_user.group_ids, self.current_user.id ) results = filter_by_tags(results, models.Dashboard.tags) @@ -66,8 +69,8 @@ def get(self): # provides an order by search rank ordered_results = order_results(results, fallback=not bool(search_term)) - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) response = paginate( ordered_results, @@ -77,20 +80,15 @@ def get(self): ) if search_term: - self.record_event({ - 'action': 'search', - 'object_type': 'dashboard', - 'term': search_term, - }) + self.record_event( + {"action": "search", "object_type": "dashboard", "term": search_term} + ) else: - self.record_event({ - 'action': 'list', - 'object_type': 'dashboard', - }) + self.record_event({"action": "list", "object_type": "dashboard"}) return response - @require_permission('create_dashboard') + @require_permission("create_dashboard") def post(self): """ Creates a new dashboard. @@ -100,18 +98,20 @@ def post(self): Responds with a :ref:`dashboard `. """ dashboard_properties = request.get_json(force=True) - dashboard = models.Dashboard(name=dashboard_properties['name'], - org=self.current_org, - user=self.current_user, - is_draft=True, - layout='[]') + dashboard = models.Dashboard( + name=dashboard_properties["name"], + org=self.current_org, + user=self.current_user, + is_draft=True, + layout="[]", + ) models.db.session.add(dashboard) models.db.session.commit() return serialize_dashboard(dashboard) class DashboardResource(BaseResource): - @require_permission('list_dashboards') + @require_permission("list_dashboards") def get(self, dashboard_slug=None): """ Retrieves a dashboard. @@ -146,25 +146,32 @@ def get(self, dashboard_slug=None): :>json string widget.created_at: ISO format timestamp for widget creation :>json string widget.updated_at: ISO format timestamp for last widget modification """ - dashboard = get_object_or_404(models.Dashboard.get_by_slug_and_org, dashboard_slug, self.current_org) - response = serialize_dashboard(dashboard, with_widgets=True, user=self.current_user) + dashboard = get_object_or_404( + models.Dashboard.get_by_slug_and_org, dashboard_slug, self.current_org + ) + response = serialize_dashboard( + dashboard, with_widgets=True, user=self.current_user + ) api_key = models.ApiKey.get_by_object(dashboard) if api_key: - response['public_url'] = url_for('redash.public_dashboard', token=api_key.api_key, org_slug=self.current_org.slug, _external=True) - response['api_key'] = api_key.api_key + response["public_url"] = url_for( + "redash.public_dashboard", + token=api_key.api_key, + org_slug=self.current_org.slug, + _external=True, + ) + response["api_key"] = api_key.api_key - response['can_edit'] = can_modify(dashboard, self.current_user) + response["can_edit"] = can_modify(dashboard, self.current_user) - self.record_event({ - 'action': 'view', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - }) + self.record_event( + {"action": "view", "object_id": dashboard.id, "object_type": "dashboard"} + ) return response - @require_permission('edit_dashboard') + @require_permission("edit_dashboard") def post(self, dashboard_slug): """ Modifies a dashboard. @@ -182,16 +189,25 @@ def post(self, dashboard_slug): require_object_modify_permission(dashboard, self.current_user) - updates = project(dashboard_properties, ('name', 'layout', 'version', 'tags', - 'is_draft', 'dashboard_filters_enabled')) + updates = project( + dashboard_properties, + ( + "name", + "layout", + "version", + "tags", + "is_draft", + "dashboard_filters_enabled", + ), + ) # SQLAlchemy handles the case where a concurrent transaction beats us # to the update. But we still have to make sure that we're not starting # out behind. - if 'version' in updates and updates['version'] != dashboard.version: + if "version" in updates and updates["version"] != dashboard.version: abort(409) - updates['changed_by'] = self.current_user + updates["changed_by"] = self.current_user self.update_model(dashboard, updates) models.db.session.add(dashboard) @@ -200,17 +216,17 @@ def post(self, dashboard_slug): except StaleDataError: abort(409) - result = serialize_dashboard(dashboard, with_widgets=True, user=self.current_user) + result = serialize_dashboard( + dashboard, with_widgets=True, user=self.current_user + ) - self.record_event({ - 'action': 'edit', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - }) + self.record_event( + {"action": "edit", "object_id": dashboard.id, "object_type": "dashboard"} + ) return result - @require_permission('edit_dashboard') + @require_permission("edit_dashboard") def delete(self, dashboard_slug): """ Archives a dashboard. @@ -219,18 +235,18 @@ def delete(self, dashboard_slug): Responds with the archived :ref:`dashboard `. """ - dashboard = models.Dashboard.get_by_slug_and_org(dashboard_slug, self.current_org) + dashboard = models.Dashboard.get_by_slug_and_org( + dashboard_slug, self.current_org + ) dashboard.is_archived = True dashboard.record_changes(changed_by=self.current_user) models.db.session.add(dashboard) d = serialize_dashboard(dashboard, with_widgets=True, user=self.current_user) models.db.session.commit() - self.record_event({ - 'action': 'archive', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - }) + self.record_event( + {"action": "archive", "object_id": dashboard.id, "object_type": "dashboard"} + ) return d @@ -269,15 +285,22 @@ def post(self, dashboard_id): models.db.session.flush() models.db.session.commit() - public_url = url_for('redash.public_dashboard', token=api_key.api_key, org_slug=self.current_org.slug, _external=True) + public_url = url_for( + "redash.public_dashboard", + token=api_key.api_key, + org_slug=self.current_org.slug, + _external=True, + ) - self.record_event({ - 'action': 'activate_api_key', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - }) + self.record_event( + { + "action": "activate_api_key", + "object_id": dashboard.id, + "object_type": "dashboard", + } + ) - return {'public_url': public_url, 'api_key': api_key.api_key} + return {"public_url": public_url, "api_key": api_key.api_key} def delete(self, dashboard_id): """ @@ -294,38 +317,39 @@ def delete(self, dashboard_id): models.db.session.add(api_key) models.db.session.commit() - self.record_event({ - 'action': 'deactivate_api_key', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - }) + self.record_event( + { + "action": "deactivate_api_key", + "object_id": dashboard.id, + "object_type": "dashboard", + } + ) class DashboardTagsResource(BaseResource): - @require_permission('list_dashboards') + @require_permission("list_dashboards") def get(self): """ Lists all accessible dashboards. """ tags = models.Dashboard.all_tags(self.current_org, self.current_user) - return { - 'tags': [ - { - 'name': name, - 'count': count, - } - for name, count in tags - ] - } + return {"tags": [{"name": name, "count": count} for name, count in tags]} class DashboardFavoriteListResource(BaseResource): def get(self): - search_term = request.args.get('q') + search_term = request.args.get("q") if search_term: - base_query = models.Dashboard.search(self.current_org, self.current_user.group_ids, self.current_user.id, search_term) - favorites = models.Dashboard.favorites(self.current_user, base_query=base_query) + base_query = models.Dashboard.search( + self.current_org, + self.current_user.group_ids, + self.current_user.id, + search_term, + ) + favorites = models.Dashboard.favorites( + self.current_user, base_query=base_query + ) else: favorites = models.Dashboard.favorites(self.current_user) @@ -336,18 +360,20 @@ def get(self): # provides an order by search rank favorites = order_results(favorites, fallback=not bool(search_term)) - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) response = paginate(favorites, page, page_size, serialize_dashboard) - self.record_event({ - 'action': 'load_favorites', - 'object_type': 'dashboard', - 'params': { - 'q': search_term, - 'tags': request.args.getlist('tags'), - 'page': page + self.record_event( + { + "action": "load_favorites", + "object_type": "dashboard", + "params": { + "q": search_term, + "tags": request.args.getlist("tags"), + "page": page, + }, } - }) + ) return response diff --git a/redash/handlers/data_sources.py b/redash/handlers/data_sources.py index 7ab2ffe43c..4dee409696 100644 --- a/redash/handlers/data_sources.py +++ b/redash/handlers/data_sources.py @@ -8,10 +8,17 @@ from redash import models from redash.handlers.base import BaseResource, get_object_or_404, require_fields -from redash.permissions import (require_access, require_admin, - require_permission, view_only) -from redash.query_runner import (get_configuration_schema_for_query_runner_type, - query_runners, NotSupported) +from redash.permissions import ( + require_access, + require_admin, + require_permission, + view_only, +) +from redash.query_runner import ( + get_configuration_schema_for_query_runner_type, + query_runners, + NotSupported, +) from redash.utils import filter_none from redash.utils.configuration import ConfigurationContainer, ValidationError @@ -19,77 +26,92 @@ class DataSourceTypeListResource(BaseResource): @require_admin def get(self): - available_query_runners = [q for q in query_runners.values() if not q.deprecated] - return [q.to_dict() for q in sorted(available_query_runners, key=lambda q: q.name())] + available_query_runners = [ + q for q in query_runners.values() if not q.deprecated + ] + return [ + q.to_dict() for q in sorted(available_query_runners, key=lambda q: q.name()) + ] class DataSourceResource(BaseResource): @require_admin def get(self, data_source_id): - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) ds = data_source.to_dict(all=True) - self.record_event({ - 'action': 'view', - 'object_id': data_source_id, - 'object_type': 'datasource', - }) + self.record_event( + {"action": "view", "object_id": data_source_id, "object_type": "datasource"} + ) return ds @require_admin def post(self, data_source_id): - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) req = request.get_json(True) - schema = get_configuration_schema_for_query_runner_type(req['type']) + schema = get_configuration_schema_for_query_runner_type(req["type"]) if schema is None: abort(400) try: data_source.options.set_schema(schema) - data_source.options.update(filter_none(req['options'])) + data_source.options.update(filter_none(req["options"])) except ValidationError: abort(400) - data_source.type = req['type'] - data_source.name = req['name'] + data_source.type = req["type"] + data_source.name = req["name"] models.db.session.add(data_source) try: models.db.session.commit() except IntegrityError as e: - if req['name'] in str(e): - abort(400, message="Data source with the name {} already exists.".format(req['name'])) + if req["name"] in str(e): + abort( + 400, + message="Data source with the name {} already exists.".format( + req["name"] + ), + ) abort(400) - self.record_event({ - 'action': 'edit', - 'object_id': data_source.id, - 'object_type': 'datasource', - }) + self.record_event( + {"action": "edit", "object_id": data_source.id, "object_type": "datasource"} + ) return data_source.to_dict(all=True) @require_admin def delete(self, data_source_id): - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) data_source.delete() - self.record_event({ - 'action': 'delete', - 'object_id': data_source_id, - 'object_type': 'datasource', - }) + self.record_event( + { + "action": "delete", + "object_id": data_source_id, + "object_type": "datasource", + } + ) - return make_response('', 204) + return make_response("", 204) class DataSourceListResource(BaseResource): - @require_permission('list_data_sources') + @require_permission("list_data_sources") def get(self): - if self.current_user.has_permission('admin'): + if self.current_user.has_permission("admin"): data_sources = models.DataSource.all(self.current_org) else: - data_sources = models.DataSource.all(self.current_org, group_ids=self.current_user.group_ids) + data_sources = models.DataSource.all( + self.current_org, group_ids=self.current_user.group_ids + ) response = {} for ds in data_sources: @@ -98,74 +120,85 @@ def get(self): try: d = ds.to_dict() - d['view_only'] = all(project(ds.groups, self.current_user.group_ids).values()) + d["view_only"] = all( + project(ds.groups, self.current_user.group_ids).values() + ) response[ds.id] = d except AttributeError: - logging.exception("Error with DataSource#to_dict (data source id: %d)", ds.id) - - self.record_event({ - 'action': 'list', - 'object_id': 'admin/data_sources', - 'object_type': 'datasource', - }) + logging.exception( + "Error with DataSource#to_dict (data source id: %d)", ds.id + ) + + self.record_event( + { + "action": "list", + "object_id": "admin/data_sources", + "object_type": "datasource", + } + ) - return sorted(list(response.values()), key=lambda d: d['name'].lower()) + return sorted(list(response.values()), key=lambda d: d["name"].lower()) @require_admin def post(self): req = request.get_json(True) - require_fields(req, ('options', 'name', 'type')) + require_fields(req, ("options", "name", "type")) - schema = get_configuration_schema_for_query_runner_type(req['type']) + schema = get_configuration_schema_for_query_runner_type(req["type"]) if schema is None: abort(400) - config = ConfigurationContainer(filter_none(req['options']), schema) + config = ConfigurationContainer(filter_none(req["options"]), schema) if not config.is_valid(): abort(400) try: - datasource = models.DataSource.create_with_group(org=self.current_org, - name=req['name'], - type=req['type'], - options=config) + datasource = models.DataSource.create_with_group( + org=self.current_org, name=req["name"], type=req["type"], options=config + ) models.db.session.commit() except IntegrityError as e: - if req['name'] in str(e): - abort(400, message="Data source with the name {} already exists.".format(req['name'])) + if req["name"] in str(e): + abort( + 400, + message="Data source with the name {} already exists.".format( + req["name"] + ), + ) abort(400) - self.record_event({ - 'action': 'create', - 'object_id': datasource.id, - 'object_type': 'datasource' - }) + self.record_event( + { + "action": "create", + "object_id": datasource.id, + "object_type": "datasource", + } + ) return datasource.to_dict(all=True) class DataSourceSchemaResource(BaseResource): def get(self, data_source_id): - data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org) + data_source = get_object_or_404( + models.DataSource.get_by_id_and_org, data_source_id, self.current_org + ) require_access(data_source, self.current_user, view_only) - refresh = request.args.get('refresh') is not None + refresh = request.args.get("refresh") is not None response = {} try: - response['schema'] = data_source.get_schema(refresh) + response["schema"] = data_source.get_schema(refresh) except NotSupported: - response['error'] = { - 'code': 1, - 'message': 'Data source type does not support retrieving schema' + response["error"] = { + "code": 1, + "message": "Data source type does not support retrieving schema", } except Exception: - response['error'] = { - 'code': 2, - 'message': 'Error retrieving schema.' - } + response["error"] = {"code": 2, "message": "Error retrieving schema."} return response @@ -173,39 +206,49 @@ def get(self, data_source_id): class DataSourcePauseResource(BaseResource): @require_admin def post(self, data_source_id): - data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org) + data_source = get_object_or_404( + models.DataSource.get_by_id_and_org, data_source_id, self.current_org + ) data = request.get_json(force=True, silent=True) if data: - reason = data.get('reason') + reason = data.get("reason") else: - reason = request.args.get('reason') + reason = request.args.get("reason") data_source.pause(reason) - self.record_event({ - 'action': 'pause', - 'object_id': data_source.id, - 'object_type': 'datasource' - }) + self.record_event( + { + "action": "pause", + "object_id": data_source.id, + "object_type": "datasource", + } + ) return data_source.to_dict() @require_admin def delete(self, data_source_id): - data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org) + data_source = get_object_or_404( + models.DataSource.get_by_id_and_org, data_source_id, self.current_org + ) data_source.resume() - self.record_event({ - 'action': 'resume', - 'object_id': data_source.id, - 'object_type': 'datasource' - }) + self.record_event( + { + "action": "resume", + "object_id": data_source.id, + "object_type": "datasource", + } + ) return data_source.to_dict() class DataSourceTestResource(BaseResource): @require_admin def post(self, data_source_id): - data_source = get_object_or_404(models.DataSource.get_by_id_and_org, data_source_id, self.current_org) + data_source = get_object_or_404( + models.DataSource.get_by_id_and_org, data_source_id, self.current_org + ) response = {} try: @@ -215,10 +258,12 @@ def post(self, data_source_id): else: response = {"message": "success", "ok": True} - self.record_event({ - 'action': 'test', - 'object_id': data_source_id, - 'object_type': 'datasource', - 'result': response, - }) + self.record_event( + { + "action": "test", + "object_id": data_source_id, + "object_type": "datasource", + "result": response, + } + ) return response diff --git a/redash/handlers/destinations.py b/redash/handlers/destinations.py index 261da2bb39..4f4b122e68 100644 --- a/redash/handlers/destinations.py +++ b/redash/handlers/destinations.py @@ -3,8 +3,10 @@ from sqlalchemy.exc import IntegrityError from redash import models -from redash.destinations import (destinations, - get_configuration_schema_for_destination_type) +from redash.destinations import ( + destinations, + get_configuration_schema_for_destination_type, +) from redash.handlers.base import BaseResource, require_fields from redash.permissions import require_admin from redash.utils.configuration import ConfigurationContainer, ValidationError @@ -20,53 +22,68 @@ def get(self): class DestinationResource(BaseResource): @require_admin def get(self, destination_id): - destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org) + destination = models.NotificationDestination.get_by_id_and_org( + destination_id, self.current_org + ) d = destination.to_dict(all=True) - self.record_event({ - 'action': 'view', - 'object_id': destination_id, - 'object_type': 'destination', - }) + self.record_event( + { + "action": "view", + "object_id": destination_id, + "object_type": "destination", + } + ) return d @require_admin def post(self, destination_id): - destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org) + destination = models.NotificationDestination.get_by_id_and_org( + destination_id, self.current_org + ) req = request.get_json(True) - schema = get_configuration_schema_for_destination_type(req['type']) + schema = get_configuration_schema_for_destination_type(req["type"]) if schema is None: abort(400) try: - destination.type = req['type'] - destination.name = req['name'] + destination.type = req["type"] + destination.name = req["name"] destination.options.set_schema(schema) - destination.options.update(req['options']) + destination.options.update(req["options"]) models.db.session.add(destination) models.db.session.commit() except ValidationError: abort(400) except IntegrityError as e: - if 'name' in str(e): - abort(400, message="Alert Destination with the name {} already exists.".format(req['name'])) + if "name" in str(e): + abort( + 400, + message="Alert Destination with the name {} already exists.".format( + req["name"] + ), + ) abort(500) return destination.to_dict(all=True) @require_admin def delete(self, destination_id): - destination = models.NotificationDestination.get_by_id_and_org(destination_id, self.current_org) + destination = models.NotificationDestination.get_by_id_and_org( + destination_id, self.current_org + ) models.db.session.delete(destination) models.db.session.commit() - self.record_event({ - 'action': 'delete', - 'object_id': destination_id, - 'object_type': 'destination' - }) + self.record_event( + { + "action": "delete", + "object_id": destination_id, + "object_type": "destination", + } + ) - return make_response('', 204) + return make_response("", 204) class DestinationListResource(BaseResource): @@ -81,39 +98,48 @@ def get(self): d = ds.to_dict() response[ds.id] = d - self.record_event({ - 'action': 'list', - 'object_id': 'admin/destinations', - 'object_type': 'destination', - }) + self.record_event( + { + "action": "list", + "object_id": "admin/destinations", + "object_type": "destination", + } + ) return list(response.values()) @require_admin def post(self): req = request.get_json(True) - require_fields(req, ('options', 'name', 'type')) + require_fields(req, ("options", "name", "type")) - schema = get_configuration_schema_for_destination_type(req['type']) + schema = get_configuration_schema_for_destination_type(req["type"]) if schema is None: abort(400) - config = ConfigurationContainer(req['options'], schema) + config = ConfigurationContainer(req["options"], schema) if not config.is_valid(): abort(400) - destination = models.NotificationDestination(org=self.current_org, - name=req['name'], - type=req['type'], - options=config, - user=self.current_user) + destination = models.NotificationDestination( + org=self.current_org, + name=req["name"], + type=req["type"], + options=config, + user=self.current_user, + ) try: models.db.session.add(destination) models.db.session.commit() except IntegrityError as e: - if 'name' in str(e): - abort(400, message="Alert Destination with the name {} already exists.".format(req['name'])) + if "name" in str(e): + abort( + 400, + message="Alert Destination with the name {} already exists.".format( + req["name"] + ), + ) abort(500) return destination.to_dict(all=True) diff --git a/redash/handlers/embed.py b/redash/handlers/embed.py index f677b29a14..e7a5cfb576 100644 --- a/redash/handlers/embed.py +++ b/redash/handlers/embed.py @@ -1,33 +1,37 @@ - - from flask import request from .authentication import current_org from flask_login import current_user, login_required from redash import models from redash.handlers import routes -from redash.handlers.base import (get_object_or_404, org_scoped_rule, - record_event) +from redash.handlers.base import get_object_or_404, org_scoped_rule, record_event from redash.handlers.static import render_index from redash.security import csp_allows_embeding -@routes.route(org_scoped_rule('/embed/query//visualization/'), methods=['GET']) +@routes.route( + org_scoped_rule("/embed/query//visualization/"), + methods=["GET"], +) @login_required @csp_allows_embeding def embed(query_id, visualization_id, org_slug=None): - record_event(current_org, current_user._get_current_object(), { - 'action': 'view', - 'object_id': visualization_id, - 'object_type': 'visualization', - 'query_id': query_id, - 'embed': True, - 'referer': request.headers.get('Referer') - }) + record_event( + current_org, + current_user._get_current_object(), + { + "action": "view", + "object_id": visualization_id, + "object_type": "visualization", + "query_id": query_id, + "embed": True, + "referer": request.headers.get("Referer"), + }, + ) return render_index() -@routes.route(org_scoped_rule('/public/dashboards/'), methods=['GET']) +@routes.route(org_scoped_rule("/public/dashboards/"), methods=["GET"]) @login_required @csp_allows_embeding def public_dashboard(token, org_slug=None): @@ -37,12 +41,16 @@ def public_dashboard(token, org_slug=None): api_key = get_object_or_404(models.ApiKey.get_by_api_key, token) dashboard = api_key.object - record_event(current_org, current_user, { - 'action': 'view', - 'object_id': dashboard.id, - 'object_type': 'dashboard', - 'public': True, - 'headless': 'embed' in request.args, - 'referer': request.headers.get('Referer') - }) + record_event( + current_org, + current_user, + { + "action": "view", + "object_id": dashboard.id, + "object_type": "dashboard", + "public": True, + "headless": "embed" in request.args, + "referer": request.headers.get("Referer"), + }, + ) return render_index() diff --git a/redash/handlers/events.py b/redash/handlers/events.py index 10b09356ee..6ecbe84758 100644 --- a/redash/handlers/events.py +++ b/redash/handlers/events.py @@ -14,44 +14,46 @@ def get_location(ip): with maxminddb.open_database(geolite2.geolite2_database()) as reader: try: match = reader.get(ip) - return match['country']['names']['en'] + return match["country"]["names"]["en"] except Exception: return "Unknown" def event_details(event): details = {} - if event.object_type == 'data_source' and event.action == 'execute_query': - details['query'] = event.additional_properties['query'] - details['data_source'] = event.object_id - elif event.object_type == 'page' and event.action == 'view': - details['page'] = event.object_id + if event.object_type == "data_source" and event.action == "execute_query": + details["query"] = event.additional_properties["query"] + details["data_source"] = event.object_id + elif event.object_type == "page" and event.action == "view": + details["page"] = event.object_id else: - details['object_id'] = event.object_id - details['object_type'] = event.object_type + details["object_id"] = event.object_id + details["object_type"] = event.object_type return details def serialize_event(event): d = { - 'org_id': event.org_id, - 'user_id': event.user_id, - 'action': event.action, - 'object_type': event.object_type, - 'object_id': event.object_id, - 'created_at': event.created_at + "org_id": event.org_id, + "user_id": event.user_id, + "action": event.action, + "object_type": event.object_type, + "object_id": event.object_id, + "created_at": event.created_at, } if event.user_id: - d['user_name'] = event.additional_properties.get('user_name', 'User {}'.format(event.user_id)) + d["user_name"] = event.additional_properties.get( + "user_name", "User {}".format(event.user_id) + ) if not event.user_id: - d['user_name'] = event.additional_properties.get('api_key', 'Unknown') + d["user_name"] = event.additional_properties.get("api_key", "Unknown") - d['browser'] = str(parse_ua(event.additional_properties.get('user_agent', ''))) - d['location'] = get_location(event.additional_properties.get('ip')) - d['details'] = event_details(event) + d["browser"] = str(parse_ua(event.additional_properties.get("user_agent", ""))) + d["location"] = get_location(event.additional_properties.get("ip")) + d["details"] = event_details(event) return d @@ -64,6 +66,6 @@ def post(self): @require_admin def get(self): - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) return paginate(self.current_org.events, page, page_size, serialize_event) diff --git a/redash/handlers/favorites.py b/redash/handlers/favorites.py index fe46ad673b..71ac3a20b8 100644 --- a/redash/handlers/favorites.py +++ b/redash/handlers/favorites.py @@ -2,77 +2,91 @@ from sqlalchemy.exc import IntegrityError from redash import models -from redash.handlers.base import (BaseResource, - get_object_or_404, paginate) +from redash.handlers.base import BaseResource, get_object_or_404, paginate from redash.permissions import require_access, view_only class QueryFavoriteResource(BaseResource): def post(self, query_id): - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query, self.current_user, view_only) - fav = models.Favorite(org_id=self.current_org.id, object=query, user=self.current_user) + fav = models.Favorite( + org_id=self.current_org.id, object=query, user=self.current_user + ) models.db.session.add(fav) try: models.db.session.commit() except IntegrityError as e: - if 'unique_favorite' in str(e): + if "unique_favorite" in str(e): models.db.session.rollback() else: raise e - self.record_event({ - 'action': 'favorite', - 'object_id': query.id, - 'object_type': 'query' - }) + self.record_event( + {"action": "favorite", "object_id": query.id, "object_type": "query"} + ) def delete(self, query_id): - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query, self.current_user, view_only) models.Favorite.query.filter( models.Favorite.object_id == query_id, - models.Favorite.object_type == 'Query', + models.Favorite.object_type == "Query", models.Favorite.user == self.current_user, ).delete() models.db.session.commit() - self.record_event({ - 'action': 'favorite', - 'object_id': query.id, - 'object_type': 'query' - }) + self.record_event( + {"action": "favorite", "object_id": query.id, "object_type": "query"} + ) class DashboardFavoriteResource(BaseResource): def post(self, object_id): - dashboard = get_object_or_404(models.Dashboard.get_by_slug_and_org, object_id, self.current_org) - fav = models.Favorite(org_id=self.current_org.id, object=dashboard, user=self.current_user) + dashboard = get_object_or_404( + models.Dashboard.get_by_slug_and_org, object_id, self.current_org + ) + fav = models.Favorite( + org_id=self.current_org.id, object=dashboard, user=self.current_user + ) models.db.session.add(fav) try: models.db.session.commit() except IntegrityError as e: - if 'unique_favorite' in str(e): + if "unique_favorite" in str(e): models.db.session.rollback() else: raise e - self.record_event({ - 'action': 'favorite', - 'object_id': dashboard.id, - 'object_type': 'dashboard' - }) + self.record_event( + { + "action": "favorite", + "object_id": dashboard.id, + "object_type": "dashboard", + } + ) def delete(self, object_id): - dashboard = get_object_or_404(models.Dashboard.get_by_slug_and_org, object_id, self.current_org) - models.Favorite.query.filter(models.Favorite.object == dashboard, models.Favorite.user == self.current_user).delete() + dashboard = get_object_or_404( + models.Dashboard.get_by_slug_and_org, object_id, self.current_org + ) + models.Favorite.query.filter( + models.Favorite.object == dashboard, + models.Favorite.user == self.current_user, + ).delete() models.db.session.commit() - self.record_event({ - 'action': 'unfavorite', - 'object_id': dashboard.id, - 'object_type': 'dashboard' - }) + self.record_event( + { + "action": "unfavorite", + "object_id": dashboard.id, + "object_type": "dashboard", + } + ) diff --git a/redash/handlers/groups.py b/redash/handlers/groups.py index eddb54a589..40839e0345 100644 --- a/redash/handlers/groups.py +++ b/redash/handlers/groups.py @@ -9,31 +9,28 @@ class GroupListResource(BaseResource): @require_admin def post(self): - name = request.json['name'] + name = request.json["name"] group = models.Group(name=name, org=self.current_org) models.db.session.add(group) models.db.session.commit() - self.record_event({ - 'action': 'create', - 'object_id': group.id, - 'object_type': 'group' - }) + self.record_event( + {"action": "create", "object_id": group.id, "object_type": "group"} + ) return group.to_dict() def get(self): - if self.current_user.has_permission('admin'): + if self.current_user.has_permission("admin"): groups = models.Group.all(self.current_org) else: groups = models.Group.query.filter( - models.Group.id.in_(self.current_user.group_ids)) + models.Group.id.in_(self.current_user.group_ids) + ) - self.record_event({ - 'action': 'list', - 'object_id': 'groups', - 'object_type': 'group', - }) + self.record_event( + {"action": "list", "object_id": "groups", "object_type": "group"} + ) return [g.to_dict() for g in groups] @@ -46,28 +43,27 @@ def post(self, group_id): if group.type == models.Group.BUILTIN_GROUP: abort(400, message="Can't modify built-in groups.") - group.name = request.json['name'] + group.name = request.json["name"] models.db.session.commit() - self.record_event({ - 'action': 'edit', - 'object_id': group.id, - 'object_type': 'group' - }) + self.record_event( + {"action": "edit", "object_id": group.id, "object_type": "group"} + ) return group.to_dict() def get(self, group_id): - if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.group_ids): + if not ( + self.current_user.has_permission("admin") + or int(group_id) in self.current_user.group_ids + ): abort(403) group = models.Group.get_by_id_and_org(group_id, self.current_org) - self.record_event({ - 'action': 'view', - 'object_id': group_id, - 'object_type': 'group', - }) + self.record_event( + {"action": "view", "object_id": group_id, "object_type": "group"} + ) return group.to_dict() @@ -89,23 +85,28 @@ def delete(self, group_id): class GroupMemberListResource(BaseResource): @require_admin def post(self, group_id): - user_id = request.json['user_id'] + user_id = request.json["user_id"] user = models.User.get_by_id_and_org(user_id, self.current_org) group = models.Group.get_by_id_and_org(group_id, self.current_org) user.group_ids.append(group.id) models.db.session.commit() - self.record_event({ - 'action': 'add_member', - 'object_id': group.id, - 'object_type': 'group', - 'member_id': user.id - }) + self.record_event( + { + "action": "add_member", + "object_id": group.id, + "object_type": "group", + "member_id": user.id, + } + ) return user.to_dict() - @require_permission('list_users') + @require_permission("list_users") def get(self, group_id): - if not (self.current_user.has_permission('admin') or int(group_id) in self.current_user.group_ids): + if not ( + self.current_user.has_permission("admin") + or int(group_id) in self.current_user.group_ids + ): abort(403) members = models.Group.members(group_id) @@ -119,54 +120,59 @@ def delete(self, group_id, user_id): user.group_ids.remove(int(group_id)) models.db.session.commit() - self.record_event({ - 'action': 'remove_member', - 'object_id': group_id, - 'object_type': 'group', - 'member_id': user.id - }) + self.record_event( + { + "action": "remove_member", + "object_id": group_id, + "object_type": "group", + "member_id": user.id, + } + ) def serialize_data_source_with_group(data_source, data_source_group): d = data_source.to_dict() - d['view_only'] = data_source_group.view_only + d["view_only"] = data_source_group.view_only return d class GroupDataSourceListResource(BaseResource): @require_admin def post(self, group_id): - data_source_id = request.json['data_source_id'] - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source_id = request.json["data_source_id"] + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) group = models.Group.get_by_id_and_org(group_id, self.current_org) data_source_group = data_source.add_group(group) models.db.session.commit() - self.record_event({ - 'action': 'add_data_source', - 'object_id': group_id, - 'object_type': 'group', - 'member_id': data_source.id - }) + self.record_event( + { + "action": "add_data_source", + "object_id": group_id, + "object_type": "group", + "member_id": data_source.id, + } + ) return serialize_data_source_with_group(data_source, data_source_group) @require_admin def get(self, group_id): - group = get_object_or_404(models.Group.get_by_id_and_org, group_id, - self.current_org) + group = get_object_or_404( + models.Group.get_by_id_and_org, group_id, self.current_org + ) # TOOD: move to models - data_sources = (models.DataSource.query - .join(models.DataSourceGroup) - .filter(models.DataSourceGroup.group == group)) + data_sources = models.DataSource.query.join(models.DataSourceGroup).filter( + models.DataSourceGroup.group == group + ) - self.record_event({ - 'action': 'list', - 'object_id': group_id, - 'object_type': 'group', - }) + self.record_event( + {"action": "list", "object_id": group_id, "object_type": "group"} + ) return [ds.to_dict(with_permissions_for=group) for ds in data_sources] @@ -174,34 +180,42 @@ def get(self, group_id): class GroupDataSourceResource(BaseResource): @require_admin def post(self, group_id, data_source_id): - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) group = models.Group.get_by_id_and_org(group_id, self.current_org) - view_only = request.json['view_only'] + view_only = request.json["view_only"] data_source_group = data_source.update_group_permission(group, view_only) models.db.session.commit() - self.record_event({ - 'action': 'change_data_source_permission', - 'object_id': group_id, - 'object_type': 'group', - 'member_id': data_source.id, - 'view_only': view_only - }) + self.record_event( + { + "action": "change_data_source_permission", + "object_id": group_id, + "object_type": "group", + "member_id": data_source.id, + "view_only": view_only, + } + ) return serialize_data_source_with_group(data_source, data_source_group) @require_admin def delete(self, group_id, data_source_id): - data_source = models.DataSource.get_by_id_and_org(data_source_id, self.current_org) + data_source = models.DataSource.get_by_id_and_org( + data_source_id, self.current_org + ) group = models.Group.get_by_id_and_org(group_id, self.current_org) data_source.remove_group(group) models.db.session.commit() - self.record_event({ - 'action': 'remove_data_source', - 'object_id': group_id, - 'object_type': 'group', - 'member_id': data_source.id - }) + self.record_event( + { + "action": "remove_data_source", + "object_id": group_id, + "object_type": "group", + "member_id": data_source.id, + } + ) diff --git a/redash/handlers/organization.py b/redash/handlers/organization.py index 5c9858e750..f39548f8ce 100644 --- a/redash/handlers/organization.py +++ b/redash/handlers/organization.py @@ -6,15 +6,21 @@ from redash.authentication import current_org -@routes.route(org_scoped_rule('/api/organization/status'), methods=['GET']) +@routes.route(org_scoped_rule("/api/organization/status"), methods=["GET"]) @login_required def organization_status(org_slug=None): counters = { - 'users': models.User.all(current_org).count(), - 'alerts': models.Alert.all(group_ids=current_user.group_ids).count(), - 'data_sources': models.DataSource.all(current_org, group_ids=current_user.group_ids).count(), - 'queries': models.Query.all_queries(current_user.group_ids, current_user.id, include_drafts=True).count(), - 'dashboards': models.Dashboard.query.filter(models.Dashboard.org == current_org, models.Dashboard.is_archived == False).count(), + "users": models.User.all(current_org).count(), + "alerts": models.Alert.all(group_ids=current_user.group_ids).count(), + "data_sources": models.DataSource.all( + current_org, group_ids=current_user.group_ids + ).count(), + "queries": models.Query.all_queries( + current_user.group_ids, current_user.id, include_drafts=True + ).count(), + "dashboards": models.Dashboard.query.filter( + models.Dashboard.org == current_org, models.Dashboard.is_archived == False + ).count(), } return json_response(dict(object_counters=counters)) diff --git a/redash/handlers/permissions.py b/redash/handlers/permissions.py index 142946b67c..94bf111eb6 100644 --- a/redash/handlers/permissions.py +++ b/redash/handlers/permissions.py @@ -8,10 +8,7 @@ from sqlalchemy.orm.exc import NoResultFound -model_to_types = { - 'queries': Query, - 'dashboards': Dashboard -} +model_to_types = {"queries": Query, "dashboards": Dashboard} def get_model_from_type(type): @@ -44,63 +41,66 @@ def post(self, object_type, object_id): req = request.get_json(True) - access_type = req['access_type'] + access_type = req["access_type"] if access_type not in ACCESS_TYPES: - abort(400, message='Unknown access type.') + abort(400, message="Unknown access type.") try: - grantee = User.get_by_id_and_org(req['user_id'], self.current_org) + grantee = User.get_by_id_and_org(req["user_id"], self.current_org) except NoResultFound: - abort(400, message='User not found.') + abort(400, message="User not found.") - permission = AccessPermission.grant(obj, access_type, grantee, self.current_user) + permission = AccessPermission.grant( + obj, access_type, grantee, self.current_user + ) db.session.commit() - self.record_event({ - 'action': 'grant_permission', - 'object_id': object_id, - 'object_type': object_type, - 'grantee': grantee.id, - 'access_type': access_type, - }) + self.record_event( + { + "action": "grant_permission", + "object_id": object_id, + "object_type": object_type, + "grantee": grantee.id, + "access_type": access_type, + } + ) return permission.to_dict() def delete(self, object_type, object_id): model = get_model_from_type(object_type) - obj = get_object_or_404(model.get_by_id_and_org, object_id, - self.current_org) + obj = get_object_or_404(model.get_by_id_and_org, object_id, self.current_org) require_admin_or_owner(obj.user_id) req = request.get_json(True) - grantee_id = req['user_id'] - access_type = req['access_type'] + grantee_id = req["user_id"] + access_type = req["access_type"] - grantee = User.query.get(req['user_id']) + grantee = User.query.get(req["user_id"]) if grantee is None: - abort(400, message='User not found.') + abort(400, message="User not found.") AccessPermission.revoke(obj, grantee, access_type) db.session.commit() - self.record_event({ - 'action': 'revoke_permission', - 'object_id': object_id, - 'object_type': object_type, - 'access_type': access_type, - 'grantee_id': grantee_id - }) + self.record_event( + { + "action": "revoke_permission", + "object_id": object_id, + "object_type": object_type, + "access_type": access_type, + "grantee_id": grantee_id, + } + ) class CheckPermissionResource(BaseResource): def get(self, object_type, object_id, access_type): model = get_model_from_type(object_type) - obj = get_object_or_404(model.get_by_id_and_org, object_id, - self.current_org) + obj = get_object_or_404(model.get_by_id_and_org, object_id, self.current_org) - has_access = AccessPermission.exists(obj, access_type, - self.current_user) + has_access = AccessPermission.exists(obj, access_type, self.current_user) - return {'response': has_access} + return {"response": has_access} diff --git a/redash/handlers/queries.py b/redash/handlers/queries.py index 1e53633aaa..129f7e3e6d 100644 --- a/redash/handlers/queries.py +++ b/redash/handlers/queries.py @@ -7,13 +7,25 @@ from redash import models, settings from redash.authentication.org_resolving import current_org -from redash.handlers.base import (BaseResource, filter_by_tags, get_object_or_404, - org_scoped_rule, paginate, routes, order_results as _order_results) +from redash.handlers.base import ( + BaseResource, + filter_by_tags, + get_object_or_404, + org_scoped_rule, + paginate, + routes, + order_results as _order_results, +) from redash.handlers.query_results import run_query -from redash.permissions import (can_modify, not_view_only, require_access, - require_admin_or_owner, - require_object_modify_permission, - require_permission, view_only) +from redash.permissions import ( + can_modify, + not_view_only, + require_access, + require_admin_or_owner, + require_object_modify_permission, + require_permission, + view_only, +) from redash.utils import collect_parameters_from_request from redash.serializers import QuerySerializer from redash.models.parameterized_query import ParameterizedQuery @@ -21,28 +33,26 @@ # Ordering map for relationships order_map = { - 'name': 'lowercase_name', - '-name': '-lowercase_name', - 'created_at': 'created_at', - '-created_at': '-created_at', - 'schedule': 'schedule', - '-schedule': '-schedule', - 'runtime': 'query_results-runtime', - '-runtime': '-query_results-runtime', - 'executed_at': 'query_results-retrieved_at', - '-executed_at': '-query_results-retrieved_at', - 'created_by': 'users-name', - '-created_by': '-users-name', + "name": "lowercase_name", + "-name": "-lowercase_name", + "created_at": "created_at", + "-created_at": "-created_at", + "schedule": "schedule", + "-schedule": "-schedule", + "runtime": "query_results-runtime", + "-runtime": "-query_results-runtime", + "executed_at": "query_results-retrieved_at", + "-executed_at": "-query_results-retrieved_at", + "created_by": "users-name", + "-created_by": "-users-name", } order_results = partial( - _order_results, - default_order='-created_at', - allowed_orders=order_map, + _order_results, default_order="-created_at", allowed_orders=order_map ) -@routes.route(org_scoped_rule('/api/queries/format'), methods=['POST']) +@routes.route(org_scoped_rule("/api/queries/format"), methods=["POST"]) @login_required def format_sql_query(org_slug=None): """ @@ -54,11 +64,13 @@ def format_sql_query(org_slug=None): arguments = request.get_json(force=True) query = arguments.get("query", "") - return jsonify({'query': sqlparse.format(query, **settings.SQLPARSE_FORMAT_OPTIONS)}) + return jsonify( + {"query": sqlparse.format(query, **settings.SQLPARSE_FORMAT_OPTIONS)} + ) class QuerySearchResource(BaseResource): - @require_permission('view_query') + @require_permission("view_query") def get(self): """ Search query text, names, and descriptions. @@ -68,30 +80,26 @@ def get(self): Responds with a list of :ref:`query ` objects. """ - term = request.args.get('q', '') + term = request.args.get("q", "") if not term: return [] - include_drafts = request.args.get('include_drafts') is not None + include_drafts = request.args.get("include_drafts") is not None - self.record_event({ - 'action': 'search', - 'object_type': 'query', - 'term': term, - }) + self.record_event({"action": "search", "object_type": "query", "term": term}) # this redirects to the new query list API that is aware of search new_location = url_for( - 'queries', + "queries", q=term, org_slug=current_org.slug, - drafts='true' if include_drafts else 'false', + drafts="true" if include_drafts else "false", ) - return {}, 301, {'Location': new_location} + return {}, 301, {"Location": new_location} class QueryRecentResource(BaseResource): - @require_permission('view_query') + @require_permission("view_query") def get(self): """ Retrieve up to 10 queries recently modified by the user. @@ -99,12 +107,17 @@ def get(self): Responds with a list of :ref:`query ` objects. """ - results = models.Query.by_user(self.current_user).order_by(models.Query.updated_at.desc()).limit(10) - return QuerySerializer(results, with_last_modified_by=False, with_user=False).serialize() + results = ( + models.Query.by_user(self.current_user) + .order_by(models.Query.updated_at.desc()) + .limit(10) + ) + return QuerySerializer( + results, with_last_modified_by=False, with_user=False + ).serialize() class BaseQueryListResource(BaseResource): - def get_queries(self, search_term): if search_term: results = models.Query.search( @@ -112,17 +125,15 @@ def get_queries(self, search_term): self.current_user.group_ids, self.current_user.id, include_drafts=True, - multi_byte_search=current_org.get_setting('multi_byte_search_enabled'), + multi_byte_search=current_org.get_setting("multi_byte_search_enabled"), ) else: results = models.Query.all_queries( - self.current_user.group_ids, - self.current_user.id, - include_drafts=True, + self.current_user.group_ids, self.current_user.id, include_drafts=True ) return filter_by_tags(results, models.Query.tags) - @require_permission('view_query') + @require_permission("view_query") def get(self): """ Retrieve a list of queries. @@ -135,7 +146,7 @@ def get(self): Responds with an array of :ref:`query ` objects. """ # See if we want to do full-text search or just regular queries - search_term = request.args.get('q', '') + search_term = request.args.get("q", "") queries = self.get_queries(search_term) @@ -146,8 +157,8 @@ def get(self): # provides an order by search rank ordered_results = order_results(results, fallback=not bool(search_term)) - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) response = paginate( ordered_results, @@ -155,40 +166,40 @@ def get(self): page_size=page_size, serializer=QuerySerializer, with_stats=True, - with_last_modified_by=False + with_last_modified_by=False, ) if search_term: - self.record_event({ - 'action': 'search', - 'object_type': 'query', - 'term': search_term, - }) + self.record_event( + {"action": "search", "object_type": "query", "term": search_term} + ) else: - self.record_event({ - 'action': 'list', - 'object_type': 'query', - }) + self.record_event({"action": "list", "object_type": "query"}) return response def require_access_to_dropdown_queries(user, query_def): - parameters = query_def.get('options', {}).get('parameters', []) - dropdown_query_ids = set([str(p['queryId']) for p in parameters if p['type'] == 'query']) + parameters = query_def.get("options", {}).get("parameters", []) + dropdown_query_ids = set( + [str(p["queryId"]) for p in parameters if p["type"] == "query"] + ) if dropdown_query_ids: groups = models.Query.all_groups_for_query_ids(dropdown_query_ids) if len(groups) < len(dropdown_query_ids): - abort(400, message="You are trying to associate a dropdown query that does not have a matching group. " - "Please verify the dropdown query id you are trying to associate with this query.") + abort( + 400, + message="You are trying to associate a dropdown query that does not have a matching group. " + "Please verify the dropdown query id you are trying to associate with this query.", + ) require_access(dict(groups), user, view_only) class QueryListResource(BaseQueryListResource): - @require_permission('create_query') + @require_permission("create_query") def post(self): """ Create a new query. @@ -223,33 +234,39 @@ def post(self): :>json number runtime: Runtime of last query execution, in seconds (may be null) """ query_def = request.get_json(force=True) - data_source = models.DataSource.get_by_id_and_org(query_def.pop('data_source_id'), self.current_org) + data_source = models.DataSource.get_by_id_and_org( + query_def.pop("data_source_id"), self.current_org + ) require_access(data_source, self.current_user, not_view_only) require_access_to_dropdown_queries(self.current_user, query_def) - for field in ['id', 'created_at', 'api_key', 'visualizations', 'latest_query_data', 'last_modified_by']: + for field in [ + "id", + "created_at", + "api_key", + "visualizations", + "latest_query_data", + "last_modified_by", + ]: query_def.pop(field, None) - query_def['query_text'] = query_def.pop('query') - query_def['user'] = self.current_user - query_def['data_source'] = data_source - query_def['org'] = self.current_org - query_def['is_draft'] = True + query_def["query_text"] = query_def.pop("query") + query_def["user"] = self.current_user + query_def["data_source"] = data_source + query_def["org"] = self.current_org + query_def["is_draft"] = True query = models.Query.create(**query_def) models.db.session.add(query) models.db.session.commit() - self.record_event({ - 'action': 'create', - 'object_id': query.id, - 'object_type': 'query' - }) + self.record_event( + {"action": "create", "object_id": query.id, "object_type": "query"} + ) return QuerySerializer(query, with_visualizations=True).serialize() class QueryArchiveResource(BaseQueryListResource): - def get_queries(self, search_term): if search_term: return models.Query.search( @@ -258,7 +275,7 @@ def get_queries(self, search_term): self.current_user.id, include_drafts=False, include_archived=True, - multi_byte_search=current_org.get_setting('multi_byte_search_enabled'), + multi_byte_search=current_org.get_setting("multi_byte_search_enabled"), ) else: return models.Query.all_queries( @@ -270,7 +287,7 @@ def get_queries(self, search_term): class MyQueriesResource(BaseResource): - @require_permission('view_query') + @require_permission("view_query") def get(self): """ Retrieve a list of queries created by the current user. @@ -282,7 +299,7 @@ def get(self): Responds with an array of :ref:`query ` objects. """ - search_term = request.args.get('q', '') + search_term = request.args.get("q", "") if search_term: results = models.Query.search_by_user(search_term, self.current_user) else: @@ -295,8 +312,8 @@ def get(self): # provides an order by search rank ordered_results = order_results(results, fallback=not bool(search_term)) - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) return paginate( ordered_results, page, @@ -308,7 +325,7 @@ def get(self): class QueryResource(BaseResource): - @require_permission('edit_query') + @require_permission("edit_query") def post(self, query_id): """ Modify a query. @@ -323,27 +340,38 @@ def post(self, query_id): Responds with the updated :ref:`query ` object. """ - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) query_def = request.get_json(force=True) require_object_modify_permission(query, self.current_user) require_access_to_dropdown_queries(self.current_user, query_def) - for field in ['id', 'created_at', 'api_key', 'visualizations', 'latest_query_data', 'user', 'last_modified_by', 'org']: + for field in [ + "id", + "created_at", + "api_key", + "visualizations", + "latest_query_data", + "user", + "last_modified_by", + "org", + ]: query_def.pop(field, None) - if 'query' in query_def: - query_def['query_text'] = query_def.pop('query') + if "query" in query_def: + query_def["query_text"] = query_def.pop("query") - if 'tags' in query_def: - query_def['tags'] = [tag for tag in query_def['tags'] if tag] + if "tags" in query_def: + query_def["tags"] = [tag for tag in query_def["tags"] if tag] - query_def['last_modified_by'] = self.current_user - query_def['changed_by'] = self.current_user + query_def["last_modified_by"] = self.current_user + query_def["changed_by"] = self.current_user # SQLAlchemy handles the case where a concurrent transaction beats us # to the update. But we still have to make sure that we're not starting # out behind. - if 'version' in query_def and query_def['version'] != query.version: + if "version" in query_def and query_def["version"] != query.version: abort(409) try: @@ -354,7 +382,7 @@ def post(self, query_id): return QuerySerializer(query, with_visualizations=True).serialize() - @require_permission('view_query') + @require_permission("view_query") def get(self, query_id): """ Retrieve a query. @@ -363,17 +391,17 @@ def get(self, query_id): Responds with the :ref:`query ` contents. """ - q = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + q = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(q, self.current_user, view_only) result = QuerySerializer(q, with_visualizations=True).serialize() - result['can_edit'] = can_modify(q, self.current_user) + result["can_edit"] = can_modify(q, self.current_user) - self.record_event({ - 'action': 'view', - 'object_id': query_id, - 'object_type': 'query', - }) + self.record_event( + {"action": "view", "object_id": query_id, "object_type": "query"} + ) return result @@ -384,32 +412,38 @@ def delete(self, query_id): :param query_id: ID of query to archive """ - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_admin_or_owner(query.user_id) query.archive(self.current_user) models.db.session.commit() class QueryRegenerateApiKeyResource(BaseResource): - @require_permission('edit_query') + @require_permission("edit_query") def post(self, query_id): - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_admin_or_owner(query.user_id) query.regenerate_api_key() models.db.session.commit() - self.record_event({ - 'action': 'regnerate_api_key', - 'object_id': query_id, - 'object_type': 'query', - }) + self.record_event( + { + "action": "regnerate_api_key", + "object_id": query_id, + "object_type": "query", + } + ) result = QuerySerializer(query).serialize() return result class QueryForkResource(BaseResource): - @require_permission('edit_query') + @require_permission("edit_query") def post(self, query_id): """ Creates a new query, copying the query text from an existing one. @@ -418,16 +452,16 @@ def post(self, query_id): Responds with created :ref:`query ` object. """ - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query.data_source, self.current_user, not_view_only) forked_query = query.fork(self.current_user) models.db.session.commit() - self.record_event({ - 'action': 'fork', - 'object_id': query_id, - 'object_type': 'query', - }) + self.record_event( + {"action": "fork", "object_id": query_id, "object_type": "query"} + ) return QuerySerializer(forked_query, with_visualizations=True).serialize() @@ -447,13 +481,17 @@ def post(self, query_id): if self.current_user.is_api_user(): abort(403, message="Please use a user API key.") - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query, self.current_user, not_view_only) parameter_values = collect_parameters_from_request(request.args) parameterized_query = ParameterizedQuery(query.query_text, org=self.current_org) - return run_query(parameterized_query, parameter_values, query.data_source, query.id) + return run_query( + parameterized_query, parameter_values, query.data_source, query.id + ) class QueryTagsResource(BaseResource): @@ -462,23 +500,20 @@ def get(self): Returns all query tags including those for drafts. """ tags = models.Query.all_tags(self.current_user, include_drafts=True) - return { - 'tags': [ - { - 'name': name, - 'count': count, - } - for name, count in tags - ] - } + return {"tags": [{"name": name, "count": count} for name, count in tags]} class QueryFavoriteListResource(BaseResource): def get(self): - search_term = request.args.get('q') + search_term = request.args.get("q") if search_term: - base_query = models.Query.search(search_term, self.current_user.group_ids, include_drafts=True, limit=None) + base_query = models.Query.search( + search_term, + self.current_user.group_ids, + include_drafts=True, + limit=None, + ) favorites = models.Query.favorites(self.current_user, base_query=base_query) else: favorites = models.Query.favorites(self.current_user) @@ -490,8 +525,8 @@ def get(self): # provides an order by search rank ordered_favorites = order_results(favorites, fallback=not bool(search_term)) - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) response = paginate( ordered_favorites, page, @@ -501,14 +536,16 @@ def get(self): with_last_modified_by=False, ) - self.record_event({ - 'action': 'load_favorites', - 'object_type': 'query', - 'params': { - 'q': search_term, - 'tags': request.args.getlist('tags'), - 'page': page + self.record_event( + { + "action": "load_favorites", + "object_type": "query", + "params": { + "q": search_term, + "tags": request.args.getlist("tags"), + "page": page, + }, } - }) + ) return response diff --git a/redash/handlers/query_results.py b/redash/handlers/query_results.py index 28d458863e..166936f839 100644 --- a/redash/handlers/query_results.py +++ b/redash/handlers/query_results.py @@ -6,34 +6,65 @@ from flask_restful import abort from redash import models, settings from redash.handlers.base import BaseResource, get_object_or_404, record_event -from redash.permissions import (has_access, not_view_only, require_access, - require_permission, view_only) +from redash.permissions import ( + has_access, + not_view_only, + require_access, + require_permission, + view_only, +) from redash.tasks import QueryTask from redash.tasks.queries import enqueue_query -from redash.utils import (collect_parameters_from_request, gen_query_hash, json_dumps, utcnow, to_filename) -from redash.models.parameterized_query import (ParameterizedQuery, InvalidParameterError, - QueryDetachedFromDataSourceError, dropdown_values) -from redash.serializers import serialize_query_result, serialize_query_result_to_csv, serialize_query_result_to_xlsx +from redash.utils import ( + collect_parameters_from_request, + gen_query_hash, + json_dumps, + utcnow, + to_filename, +) +from redash.models.parameterized_query import ( + ParameterizedQuery, + InvalidParameterError, + QueryDetachedFromDataSourceError, + dropdown_values, +) +from redash.serializers import ( + serialize_query_result, + serialize_query_result_to_csv, + serialize_query_result_to_xlsx, +) def error_response(message, http_status=400): - return {'job': {'status': 4, 'error': message}}, http_status + return {"job": {"status": 4, "error": message}}, http_status error_messages = { - 'unsafe_when_shared': error_response('This query contains potentially unsafe parameters and cannot be executed on a shared dashboard or an embedded visualization.', 403), - 'unsafe_on_view_only': error_response('This query contains potentially unsafe parameters and cannot be executed with read-only access to this data source.', 403), - 'no_permission': error_response('You do not have permission to run queries with this data source.', 403), - 'select_data_source': error_response('Please select data source to run this query.', 401) + "unsafe_when_shared": error_response( + "This query contains potentially unsafe parameters and cannot be executed on a shared dashboard or an embedded visualization.", + 403, + ), + "unsafe_on_view_only": error_response( + "This query contains potentially unsafe parameters and cannot be executed with read-only access to this data source.", + 403, + ), + "no_permission": error_response( + "You do not have permission to run queries with this data source.", 403 + ), + "select_data_source": error_response( + "Please select data source to run this query.", 401 + ), } def run_query(query, parameters, data_source, query_id, max_age=0): if data_source.paused: if data_source.pause_reason: - message = '{} is paused ({}). Please try later.'.format(data_source.name, data_source.pause_reason) + message = "{} is paused ({}). Please try later.".format( + data_source.name, data_source.pause_reason + ) else: - message = '{} is paused. Please try later.'.format(data_source.name) + message = "{} is paused. Please try later.".format(data_source.name) return error_response(message) @@ -43,44 +74,62 @@ def run_query(query, parameters, data_source, query_id, max_age=0): abort(400, message=e.message) if query.missing_params: - return error_response('Missing parameter value for: {}'.format(", ".join(query.missing_params))) + return error_response( + "Missing parameter value for: {}".format(", ".join(query.missing_params)) + ) if max_age == 0: query_result = None else: query_result = models.QueryResult.get_latest(data_source, query.text, max_age) - record_event(current_user.org, current_user, { - 'action': 'execute_query', - 'cache': 'hit' if query_result else 'miss', - 'object_id': data_source.id, - 'object_type': 'data_source', - 'query': query.text, - 'query_id': query_id, - 'parameters': parameters - }) + record_event( + current_user.org, + current_user, + { + "action": "execute_query", + "cache": "hit" if query_result else "miss", + "object_id": data_source.id, + "object_type": "data_source", + "query": query.text, + "query_id": query_id, + "parameters": parameters, + }, + ) if query_result: - return {'query_result': serialize_query_result(query_result, current_user.is_api_user())} + return { + "query_result": serialize_query_result( + query_result, current_user.is_api_user() + ) + } else: - job = enqueue_query(query.text, data_source, current_user.id, current_user.is_api_user(), metadata={ - "Username": repr(current_user) if current_user.is_api_user() else current_user.email, - "Query ID": query_id - }) - return {'job': job.to_dict()} + job = enqueue_query( + query.text, + data_source, + current_user.id, + current_user.is_api_user(), + metadata={ + "Username": repr(current_user) + if current_user.is_api_user() + else current_user.email, + "Query ID": query_id, + }, + ) + return {"job": job.to_dict()} def get_download_filename(query_result, query, filetype): retrieved_at = query_result.retrieved_at.strftime("%Y_%m_%d") if query: - filename = to_filename(query.name) if query.name != '' else str(query.id) + filename = to_filename(query.name) if query.name != "" else str(query.id) else: filename = str(query_result.id) return "{}_{}.{}".format(filename, retrieved_at, filetype) class QueryResultListResource(BaseResource): - @require_permission('execute_query') + @require_permission("execute_query") def post(self): """ Execute a query (or retrieve recent results). @@ -96,27 +145,33 @@ def post(self): """ params = request.get_json(force=True) - query = params['query'] - max_age = params.get('max_age', -1) + query = params["query"] + max_age = params.get("max_age", -1) # max_age might have the value of None, in which case calling int(None) will fail if max_age is None: max_age = -1 max_age = int(max_age) - query_id = params.get('query_id', 'adhoc') - parameters = params.get('parameters', collect_parameters_from_request(request.args)) + query_id = params.get("query_id", "adhoc") + parameters = params.get( + "parameters", collect_parameters_from_request(request.args) + ) parameterized_query = ParameterizedQuery(query, org=self.current_org) - data_source_id = params.get('data_source_id') + data_source_id = params.get("data_source_id") if data_source_id: - data_source = models.DataSource.get_by_id_and_org(params.get('data_source_id'), self.current_org) + data_source = models.DataSource.get_by_id_and_org( + params.get("data_source_id"), self.current_org + ) else: - return error_messages['select_data_source'] + return error_messages["select_data_source"] if not has_access(data_source, self.current_user, not_view_only): - return error_messages['no_permission'] + return error_messages["no_permission"] - return run_query(parameterized_query, parameters, data_source, query_id, max_age) + return run_query( + parameterized_query, parameters, data_source, query_id, max_age + ) ONE_YEAR = 60 * 60 * 24 * 365.25 @@ -124,7 +179,9 @@ def post(self): class QueryResultDropdownResource(BaseResource): def get(self, query_id): - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query.data_source, current_user, view_only) try: return dropdown_values(query_id, self.current_org) @@ -134,12 +191,18 @@ def get(self, query_id): class QueryDropdownsResource(BaseResource): def get(self, query_id, dropdown_query_id): - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) require_access(query, current_user, view_only) - related_queries_ids = [p['queryId'] for p in query.parameters if p['type'] == 'query'] + related_queries_ids = [ + p["queryId"] for p in query.parameters if p["type"] == "query" + ] if int(dropdown_query_id) not in related_queries_ids: - dropdown_query = get_object_or_404(models.Query.get_by_id_and_org, dropdown_query_id, self.current_org) + dropdown_query = get_object_or_404( + models.Query.get_by_id_and_org, dropdown_query_id, self.current_org + ) require_access(dropdown_query.data_source, current_user, view_only) return dropdown_values(dropdown_query_id, self.current_org) @@ -148,27 +211,33 @@ def get(self, query_id, dropdown_query_id): class QueryResultResource(BaseResource): @staticmethod def add_cors_headers(headers): - if 'Origin' in request.headers: - origin = request.headers['Origin'] + if "Origin" in request.headers: + origin = request.headers["Origin"] - if set(['*', origin]) & settings.ACCESS_CONTROL_ALLOW_ORIGIN: - headers['Access-Control-Allow-Origin'] = origin - headers['Access-Control-Allow-Credentials'] = str(settings.ACCESS_CONTROL_ALLOW_CREDENTIALS).lower() + if set(["*", origin]) & settings.ACCESS_CONTROL_ALLOW_ORIGIN: + headers["Access-Control-Allow-Origin"] = origin + headers["Access-Control-Allow-Credentials"] = str( + settings.ACCESS_CONTROL_ALLOW_CREDENTIALS + ).lower() - @require_permission('view_query') - def options(self, query_id=None, query_result_id=None, filetype='json'): + @require_permission("view_query") + def options(self, query_id=None, query_result_id=None, filetype="json"): headers = {} self.add_cors_headers(headers) if settings.ACCESS_CONTROL_REQUEST_METHOD: - headers['Access-Control-Request-Method'] = settings.ACCESS_CONTROL_REQUEST_METHOD + headers[ + "Access-Control-Request-Method" + ] = settings.ACCESS_CONTROL_REQUEST_METHOD if settings.ACCESS_CONTROL_ALLOW_HEADERS: - headers['Access-Control-Allow-Headers'] = settings.ACCESS_CONTROL_ALLOW_HEADERS + headers[ + "Access-Control-Allow-Headers" + ] = settings.ACCESS_CONTROL_ALLOW_HEADERS return make_response("", 200, headers) - @require_permission('view_query') + @require_permission("view_query") def post(self, query_id): """ Execute a saved query. @@ -181,31 +250,41 @@ def post(self, query_id): always execute. """ params = request.get_json(force=True, silent=True) or {} - parameter_values = params.get('parameters', {}) + parameter_values = params.get("parameters", {}) - max_age = params.get('max_age', -1) + max_age = params.get("max_age", -1) # max_age might have the value of None, in which case calling int(None) will fail if max_age is None: max_age = -1 max_age = int(max_age) - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) allow_executing_with_view_only_permissions = query.parameterized.is_safe - if has_access(query, self.current_user, allow_executing_with_view_only_permissions): - return run_query(query.parameterized, parameter_values, query.data_source, query_id, max_age) + if has_access( + query, self.current_user, allow_executing_with_view_only_permissions + ): + return run_query( + query.parameterized, + parameter_values, + query.data_source, + query_id, + max_age, + ) else: if not query.parameterized.is_safe: if current_user.is_api_user(): - return error_messages['unsafe_when_shared'] + return error_messages["unsafe_when_shared"] else: - return error_messages['unsafe_on_view_only'] + return error_messages["unsafe_on_view_only"] else: - return error_messages['no_permission'] + return error_messages["no_permission"] - @require_permission('view_query') - def get(self, query_id=None, query_result_id=None, filetype='json'): + @require_permission("view_query") + def get(self, query_id=None, query_result_id=None, filetype="json"): """ Retrieve query results. @@ -228,52 +307,66 @@ def get(self, query_id=None, query_result_id=None, filetype='json'): should_cache = query_result_id is not None parameter_values = collect_parameters_from_request(request.args) - max_age = int(request.args.get('maxAge', 0)) + max_age = int(request.args.get("maxAge", 0)) query_result = None query = None if query_result_id: - query_result = get_object_or_404(models.QueryResult.get_by_id_and_org, query_result_id, self.current_org) + query_result = get_object_or_404( + models.QueryResult.get_by_id_and_org, query_result_id, self.current_org + ) if query_id is not None: - query = get_object_or_404(models.Query.get_by_id_and_org, query_id, self.current_org) - - if query_result is None and query is not None and query.latest_query_data_id is not None: - query_result = get_object_or_404(models.QueryResult.get_by_id_and_org, - query.latest_query_data_id, - self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, query_id, self.current_org + ) - if query is not None and query_result is not None and self.current_user.is_api_user(): + if ( + query_result is None + and query is not None + and query.latest_query_data_id is not None + ): + query_result = get_object_or_404( + models.QueryResult.get_by_id_and_org, + query.latest_query_data_id, + self.current_org, + ) + + if ( + query is not None + and query_result is not None + and self.current_user.is_api_user() + ): if query.query_hash != query_result.query_hash: - abort(404, message='No cached result found for this query.') + abort(404, message="No cached result found for this query.") if query_result: require_access(query_result.data_source, self.current_user, view_only) if isinstance(self.current_user, models.ApiUser): event = { - 'user_id': None, - 'org_id': self.current_org.id, - 'action': 'api_get', - 'api_key': self.current_user.name, - 'file_type': filetype, - 'user_agent': request.user_agent.string, - 'ip': request.remote_addr + "user_id": None, + "org_id": self.current_org.id, + "action": "api_get", + "api_key": self.current_user.name, + "file_type": filetype, + "user_agent": request.user_agent.string, + "ip": request.remote_addr, } if query_id: - event['object_type'] = 'query' - event['object_id'] = query_id + event["object_type"] = "query" + event["object_id"] = query_id else: - event['object_type'] = 'query_result' - event['object_id'] = query_result_id + event["object_type"] = "query_result" + event["object_id"] = query_result_id self.record_event(event) - if filetype == 'json': + if filetype == "json": response = self.make_json_response(query_result) - elif filetype == 'xlsx': + elif filetype == "xlsx": response = self.make_excel_response(query_result) else: response = self.make_csv_response(query_result) @@ -282,33 +375,36 @@ def get(self, query_id=None, query_result_id=None, filetype='json'): self.add_cors_headers(response.headers) if should_cache: - response.headers.add_header('Cache-Control', 'private,max-age=%d' % ONE_YEAR) + response.headers.add_header( + "Cache-Control", "private,max-age=%d" % ONE_YEAR + ) filename = get_download_filename(query_result, query, filetype) response.headers.add_header( - "Content-Disposition", - 'attachment; filename="{}"'.format(filename) + "Content-Disposition", 'attachment; filename="{}"'.format(filename) ) return response else: - abort(404, message='No cached result found for this query.') + abort(404, message="No cached result found for this query.") def make_json_response(self, query_result): - data = json_dumps({'query_result': query_result.to_dict()}) - headers = {'Content-Type': "application/json"} + data = json_dumps({"query_result": query_result.to_dict()}) + headers = {"Content-Type": "application/json"} return make_response(data, 200, headers) @staticmethod def make_csv_response(query_result): - headers = {'Content-Type': "text/csv; charset=UTF-8"} + headers = {"Content-Type": "text/csv; charset=UTF-8"} return make_response(serialize_query_result_to_csv(query_result), 200, headers) @staticmethod def make_excel_response(query_result): - headers = {'Content-Type': "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"} + headers = { + "Content-Type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + } return make_response(serialize_query_result_to_xlsx(query_result), 200, headers) @@ -318,7 +414,7 @@ def get(self, job_id, query_id=None): Retrieve info about a running query job. """ job = QueryTask(job_id=job_id) - return {'job': job.to_dict()} + return {"job": job.to_dict()} def delete(self, job_id): """ diff --git a/redash/handlers/query_snippets.py b/redash/handlers/query_snippets.py index b89f2d6c8b..64808de522 100644 --- a/redash/handlers/query_snippets.py +++ b/redash/handlers/query_snippets.py @@ -3,82 +3,83 @@ from redash import models from redash.permissions import require_admin_or_owner -from redash.handlers.base import (BaseResource, require_fields, - get_object_or_404) +from redash.handlers.base import BaseResource, require_fields, get_object_or_404 class QuerySnippetResource(BaseResource): def get(self, snippet_id): - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, - snippet_id, self.current_org) + snippet = get_object_or_404( + models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org + ) - self.record_event({ - 'action': 'view', - 'object_id': snippet_id, - 'object_type': 'query_snippet', - }) + self.record_event( + {"action": "view", "object_id": snippet_id, "object_type": "query_snippet"} + ) return snippet.to_dict() def post(self, snippet_id): req = request.get_json(True) - params = project(req, ('trigger', 'description', 'snippet')) - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, - snippet_id, self.current_org) + params = project(req, ("trigger", "description", "snippet")) + snippet = get_object_or_404( + models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org + ) require_admin_or_owner(snippet.user.id) self.update_model(snippet, params) models.db.session.commit() - self.record_event({ - 'action': 'edit', - 'object_id': snippet.id, - 'object_type': 'query_snippet' - }) + self.record_event( + {"action": "edit", "object_id": snippet.id, "object_type": "query_snippet"} + ) return snippet.to_dict() def delete(self, snippet_id): - snippet = get_object_or_404(models.QuerySnippet.get_by_id_and_org, - snippet_id, self.current_org) + snippet = get_object_or_404( + models.QuerySnippet.get_by_id_and_org, snippet_id, self.current_org + ) require_admin_or_owner(snippet.user.id) models.db.session.delete(snippet) models.db.session.commit() - self.record_event({ - 'action': 'delete', - 'object_id': snippet.id, - 'object_type': 'query_snippet' - }) + self.record_event( + { + "action": "delete", + "object_id": snippet.id, + "object_type": "query_snippet", + } + ) class QuerySnippetListResource(BaseResource): def post(self): req = request.get_json(True) - require_fields(req, ('trigger', 'description', 'snippet')) + require_fields(req, ("trigger", "description", "snippet")) snippet = models.QuerySnippet( - trigger=req['trigger'], - description=req['description'], - snippet=req['snippet'], + trigger=req["trigger"], + description=req["description"], + snippet=req["snippet"], user=self.current_user, - org=self.current_org + org=self.current_org, ) models.db.session.add(snippet) models.db.session.commit() - self.record_event({ - 'action': 'create', - 'object_id': snippet.id, - 'object_type': 'query_snippet' - }) + self.record_event( + { + "action": "create", + "object_id": snippet.id, + "object_type": "query_snippet", + } + ) return snippet.to_dict() def get(self): - self.record_event({ - 'action': 'list', - 'object_type': 'query_snippet', - }) - return [snippet.to_dict() for snippet in - models.QuerySnippet.all(org=self.current_org)] + self.record_event({"action": "list", "object_type": "query_snippet"}) + return [ + snippet.to_dict() + for snippet in models.QuerySnippet.all(org=self.current_org) + ] diff --git a/redash/handlers/settings.py b/redash/handlers/settings.py index 84e35b5668..d684f42c35 100644 --- a/redash/handlers/settings.py +++ b/redash/handlers/settings.py @@ -7,7 +7,7 @@ def get_settings_with_defaults(defaults, org): - values = org.settings.get('settings', {}) + values = org.settings.get("settings", {}) settings = {} for setting, default_value in defaults.items(): @@ -20,7 +20,7 @@ def get_settings_with_defaults(defaults, org): else: settings[setting] = current_value - settings['auth_google_apps_domains'] = org.google_apps_domains + settings["auth_google_apps_domains"] = org.google_apps_domains return settings @@ -30,39 +30,39 @@ class OrganizationSettings(BaseResource): def get(self): settings = get_settings_with_defaults(org_settings, self.current_org) - return { - "settings": settings - } + return {"settings": settings} @require_admin def post(self): new_values = request.json - if self.current_org.settings.get('settings') is None: - self.current_org.settings['settings'] = {} + if self.current_org.settings.get("settings") is None: + self.current_org.settings["settings"] = {} previous_values = {} for k, v in new_values.items(): - if k == 'auth_google_apps_domains': + if k == "auth_google_apps_domains": previous_values[k] = self.current_org.google_apps_domains self.current_org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = v else: - previous_values[k] = self.current_org.get_setting(k, raise_on_missing=False) + previous_values[k] = self.current_org.get_setting( + k, raise_on_missing=False + ) self.current_org.set_setting(k, v) db.session.add(self.current_org) db.session.commit() - self.record_event({ - 'action': 'edit', - 'object_id': self.current_org.id, - 'object_type': 'settings', - 'new_values': new_values, - 'previous_values': previous_values - }) + self.record_event( + { + "action": "edit", + "object_id": self.current_org.id, + "object_type": "settings", + "new_values": new_values, + "previous_values": previous_values, + } + ) settings = get_settings_with_defaults(org_settings, self.current_org) - return { - "settings": settings - } + return {"settings": settings} diff --git a/redash/handlers/setup.py b/redash/handlers/setup.py index db2f6a1461..caa9be8641 100644 --- a/redash/handlers/setup.py +++ b/redash/handlers/setup.py @@ -11,26 +11,38 @@ class SetupForm(Form): - name = StringField('Name', validators=[validators.InputRequired()]) - email = EmailField('Email Address', validators=[validators.Email()]) - password = PasswordField('Password', validators=[validators.Length(6)]) + name = StringField("Name", validators=[validators.InputRequired()]) + email = EmailField("Email Address", validators=[validators.Email()]) + password = PasswordField("Password", validators=[validators.Length(6)]) org_name = StringField("Organization Name", validators=[validators.InputRequired()]) security_notifications = BooleanField() newsletter = BooleanField() def create_org(org_name, user_name, email, password): - default_org = Organization(name=org_name, slug='default', settings={}) - admin_group = Group(name='admin', permissions=['admin', 'super_admin'], org=default_org, type=Group.BUILTIN_GROUP) - default_group = Group(name='default', permissions=Group.DEFAULT_PERMISSIONS, org=default_org, type=Group.BUILTIN_GROUP) + default_org = Organization(name=org_name, slug="default", settings={}) + admin_group = Group( + name="admin", + permissions=["admin", "super_admin"], + org=default_org, + type=Group.BUILTIN_GROUP, + ) + default_group = Group( + name="default", + permissions=Group.DEFAULT_PERMISSIONS, + org=default_org, + type=Group.BUILTIN_GROUP, + ) db.session.add_all([default_org, admin_group, default_group]) db.session.commit() - user = User(org=default_org, - name=user_name, - email=email, - group_ids=[admin_group.id, default_group.id]) + user = User( + org=default_org, + name=user_name, + email=email, + group_ids=[admin_group.id, default_group.id], + ) user.hash_password(password) db.session.add(user) @@ -39,17 +51,19 @@ def create_org(org_name, user_name, email, password): return default_org, user -@routes.route('/setup', methods=['GET', 'POST']) +@routes.route("/setup", methods=["GET", "POST"]) def setup(): if current_org != None or settings.MULTI_ORG: - return redirect('/') + return redirect("/") form = SetupForm(request.form) form.newsletter.data = True form.security_notifications.data = True - if request.method == 'POST' and form.validate(): - default_org, user = create_org(form.org_name.data, form.name.data, form.email.data, form.password.data) + if request.method == "POST" and form.validate(): + default_org, user = create_org( + form.org_name.data, form.name.data, form.email.data, form.password.data + ) g.org = default_org login_user(user) @@ -58,6 +72,6 @@ def setup(): if form.newsletter.data or form.security_notifications: subscribe.delay(form.data) - return redirect(url_for('redash.index', org_slug=None)) + return redirect(url_for("redash.index", org_slug=None)) - return render_template('setup.html', form=form) + return render_template("setup.html", form=form) diff --git a/redash/handlers/static.py b/redash/handlers/static.py index be2154446b..1a02b66379 100644 --- a/redash/handlers/static.py +++ b/redash/handlers/static.py @@ -12,21 +12,21 @@ def render_index(): if settings.MULTI_ORG: response = render_template("multi_org.html", base_href=base_href()) else: - full_path = safe_join(settings.STATIC_ASSETS_PATH, 'index.html') + full_path = safe_join(settings.STATIC_ASSETS_PATH, "index.html") response = send_file(full_path, **dict(cache_timeout=0, conditional=True)) return response -@routes.route(org_scoped_rule('/dashboard/'), methods=['GET']) +@routes.route(org_scoped_rule("/dashboard/"), methods=["GET"]) @login_required @csp_allows_embeding def dashboard(slug, org_slug=None): return render_index() -@routes.route(org_scoped_rule('/')) -@routes.route(org_scoped_rule('/')) +@routes.route(org_scoped_rule("/")) +@routes.route(org_scoped_rule("/")) @login_required def index(**kwargs): return render_index() diff --git a/redash/handlers/users.py b/redash/handlers/users.py index 68bdc2d188..bc2e98c534 100644 --- a/redash/handlers/users.py +++ b/redash/handlers/users.py @@ -10,31 +10,45 @@ from funcy import partial from redash import models, limiter -from redash.permissions import require_permission, require_admin_or_owner, is_admin_or_owner, \ - require_permission_or_owner, require_admin -from redash.handlers.base import BaseResource, require_fields, get_object_or_404, paginate, order_results as _order_results +from redash.permissions import ( + require_permission, + require_admin_or_owner, + is_admin_or_owner, + require_permission_or_owner, + require_admin, +) +from redash.handlers.base import ( + BaseResource, + require_fields, + get_object_or_404, + paginate, + order_results as _order_results, +) -from redash.authentication.account import invite_link_for_user, send_invite_email, send_password_reset_email, send_verify_email +from redash.authentication.account import ( + invite_link_for_user, + send_invite_email, + send_password_reset_email, + send_verify_email, +) from redash.settings import parse_boolean from redash import settings # Ordering map for relationships order_map = { - 'name': 'name', - '-name': '-name', - 'active_at': 'active_at', - '-active_at': '-active_at', - 'created_at': 'created_at', - '-created_at': '-created_at', - 'groups': 'group_ids', - '-groups': '-group_ids', + "name": "name", + "-name": "-name", + "active_at": "active_at", + "-active_at": "-active_at", + "created_at": "created_at", + "-created_at": "-created_at", + "groups": "group_ids", + "-groups": "-group_ids", } order_results = partial( - _order_results, - default_order='-created_at', - allowed_orders=order_map, + _order_results, default_order="-created_at", allowed_orders=order_map ) @@ -45,14 +59,15 @@ def invite_user(org, inviter, user, send_email=True): if settings.email_server_is_configured() and send_email: send_invite_email(inviter, user, invite_url, org) else: - d['invite_link'] = invite_url + d["invite_link"] = invite_url return d class UserListResource(BaseResource): - decorators = BaseResource.decorators + \ - [limiter.limit('200/day;50/hour', methods=['POST'])] + decorators = BaseResource.decorators + [ + limiter.limit("200/day;50/hour", methods=["POST"]) + ] def get_users(self, disabled, pending, search_term): if disabled: @@ -65,50 +80,52 @@ def get_users(self, disabled, pending, search_term): if search_term: users = models.User.search(users, search_term) - self.record_event({ - 'action': 'search', - 'object_type': 'user', - 'term': search_term, - 'pending': pending, - }) + self.record_event( + { + "action": "search", + "object_type": "user", + "term": search_term, + "pending": pending, + } + ) else: - self.record_event({ - 'action': 'list', - 'object_type': 'user', - 'pending': pending, - }) + self.record_event( + {"action": "list", "object_type": "user", "pending": pending} + ) # order results according to passed order parameter, # special-casing search queries where the database # provides an order by search rank return order_results(users, fallback=not bool(search_term)) - @require_permission('list_users') + @require_permission("list_users") def get(self): - page = request.args.get('page', 1, type=int) - page_size = request.args.get('page_size', 25, type=int) + page = request.args.get("page", 1, type=int) + page_size = request.args.get("page_size", 25, type=int) groups = {group.id: group for group in models.Group.all(self.current_org)} def serialize_user(user): d = user.to_dict() user_groups = [] - for group_id in set(d['groups']): + for group_id in set(d["groups"]): group = groups.get(group_id) if group: - user_groups.append({'id': group.id, 'name': group.name}) + user_groups.append({"id": group.id, "name": group.name}) - d['groups'] = user_groups + d["groups"] = user_groups return d - search_term = request.args.get('q', '') + search_term = request.args.get("q", "") - disabled = request.args.get('disabled', 'false') # get enabled users by default + disabled = request.args.get("disabled", "false") # get enabled users by default disabled = parse_boolean(disabled) - pending = request.args.get('pending', None) # get both active and pending by default + pending = request.args.get( + "pending", None + ) # get both active and pending by default if pending is not None: pending = parse_boolean(pending) @@ -119,37 +136,39 @@ def serialize_user(user): @require_admin def post(self): req = request.get_json(force=True) - require_fields(req, ('name', 'email')) + require_fields(req, ("name", "email")) - if '@' not in req['email']: - abort(400, message='Bad email address.') - name, domain = req['email'].split('@', 1) + if "@" not in req["email"]: + abort(400, message="Bad email address.") + name, domain = req["email"].split("@", 1) - if domain.lower() in blacklist or domain.lower() == 'qq.com': - abort(400, message='Bad email address.') + if domain.lower() in blacklist or domain.lower() == "qq.com": + abort(400, message="Bad email address.") - user = models.User(org=self.current_org, - name=req['name'], - email=req['email'], - is_invitation_pending=True, - group_ids=[self.current_org.default_group.id]) + user = models.User( + org=self.current_org, + name=req["name"], + email=req["email"], + is_invitation_pending=True, + group_ids=[self.current_org.default_group.id], + ) try: models.db.session.add(user) models.db.session.commit() except IntegrityError as e: if "email" in str(e): - abort(400, message='Email already taken.') + abort(400, message="Email already taken.") abort(500) - self.record_event({ - 'action': 'create', - 'object_id': user.id, - 'object_type': 'user' - }) + self.record_event( + {"action": "create", "object_id": user.id, "object_type": "user"} + ) - should_send_invitation = 'no_invite' not in request.args - return invite_user(self.current_org, self.current_user, user, send_email=should_send_invitation) + should_send_invitation = "no_invite" not in request.args + return invite_user( + self.current_org, self.current_user, user, send_email=should_send_invitation + ) class UserInviteResource(BaseResource): @@ -164,47 +183,42 @@ class UserResetPasswordResource(BaseResource): def post(self, user_id): user = models.User.get_by_id_and_org(user_id, self.current_org) if user.is_disabled: - abort(404, message='Not found') + abort(404, message="Not found") reset_link = send_password_reset_email(user) - return { - 'reset_link': reset_link, - } + return {"reset_link": reset_link} class UserRegenerateApiKeyResource(BaseResource): def post(self, user_id): user = models.User.get_by_id_and_org(user_id, self.current_org) if user.is_disabled: - abort(404, message='Not found') + abort(404, message="Not found") if not is_admin_or_owner(user_id): abort(403) user.regenerate_api_key() models.db.session.commit() - self.record_event({ - 'action': 'regnerate_api_key', - 'object_id': user.id, - 'object_type': 'user' - }) + self.record_event( + {"action": "regnerate_api_key", "object_id": user.id, "object_type": "user"} + ) return user.to_dict(with_api_key=True) class UserResource(BaseResource): - decorators = BaseResource.decorators + \ - [limiter.limit('50/hour', methods=['POST'])] + decorators = BaseResource.decorators + [limiter.limit("50/hour", methods=["POST"])] def get(self, user_id): - require_permission_or_owner('list_users', user_id) - user = get_object_or_404(models.User.get_by_id_and_org, user_id, self.current_org) + require_permission_or_owner("list_users", user_id) + user = get_object_or_404( + models.User.get_by_id_and_org, user_id, self.current_org + ) - self.record_event({ - 'action': 'view', - 'object_id': user_id, - 'object_type': 'user', - }) + self.record_event( + {"action": "view", "object_id": user_id, "object_type": "user"} + ) return user.to_dict(with_api_key=is_admin_or_owner(user_id)) @@ -214,39 +228,45 @@ def post(self, user_id): req = request.get_json(True) - params = project(req, ('email', 'name', 'password', 'old_password', 'group_ids')) + params = project( + req, ("email", "name", "password", "old_password", "group_ids") + ) - if 'password' in params and 'old_password' not in params: + if "password" in params and "old_password" not in params: abort(403, message="Must provide current password to update password.") - if 'old_password' in params and not user.verify_password(params['old_password']): + if "old_password" in params and not user.verify_password( + params["old_password"] + ): abort(403, message="Incorrect current password.") - if 'password' in params: - user.hash_password(params.pop('password')) - params.pop('old_password') + if "password" in params: + user.hash_password(params.pop("password")) + params.pop("old_password") - if 'group_ids' in params: - if not self.current_user.has_permission('admin'): + if "group_ids" in params: + if not self.current_user.has_permission("admin"): abort(403, message="Must be admin to change groups membership.") - for group_id in params['group_ids']: + for group_id in params["group_ids"]: try: models.Group.get_by_id_and_org(group_id, self.current_org) except NoResultFound: abort(400, message="Group id {} is invalid.".format(group_id)) - if len(params['group_ids']) == 0: - params.pop('group_ids') + if len(params["group_ids"]) == 0: + params.pop("group_ids") - if 'email' in params: - _, domain = params['email'].split('@', 1) + if "email" in params: + _, domain = params["email"].split("@", 1) - if domain.lower() in blacklist or domain.lower() == 'qq.com': - abort(400, message='Bad email address.') + if domain.lower() in blacklist or domain.lower() == "qq.com": + abort(400, message="Bad email address.") - email_address_changed = 'email' in params and params['email'] != user.email - needs_to_verify_email = email_address_changed and settings.email_server_is_configured() + email_address_changed = "email" in params and params["email"] != user.email + needs_to_verify_email = ( + email_address_changed and settings.email_server_is_configured() + ) if needs_to_verify_email: user.is_email_verified = False @@ -270,12 +290,14 @@ def post(self, user_id): abort(400, message=message) - self.record_event({ - 'action': 'edit', - 'object_id': user.id, - 'object_type': 'user', - 'updated_fields': list(params.keys()) - }) + self.record_event( + { + "action": "edit", + "object_id": user.id, + "object_type": "user", + "updated_fields": list(params.keys()), + } + ) return user.to_dict(with_api_key=is_admin_or_owner(user_id)) @@ -285,11 +307,17 @@ def delete(self, user_id): # admin cannot delete self; current user is an admin (`@require_admin`) # so just check user id if user.id == current_user.id: - abort(403, message="You cannot delete your own account. " - "Please ask another admin to do this for you.") + abort( + 403, + message="You cannot delete your own account. " + "Please ask another admin to do this for you.", + ) elif not user.is_invitation_pending: - abort(403, message="You cannot delete activated users. " - "Please disable the user instead.") + abort( + 403, + message="You cannot delete activated users. " + "Please disable the user instead.", + ) models.db.session.delete(user) models.db.session.commit() @@ -303,8 +331,11 @@ def post(self, user_id): # admin cannot disable self; current user is an admin (`@require_admin`) # so just check user id if user.id == current_user.id: - abort(403, message="You cannot disable your own account. " - "Please ask another admin to do this for you.") + abort( + 403, + message="You cannot disable your own account. " + "Please ask another admin to do this for you.", + ) user.disable() models.db.session.commit() diff --git a/redash/handlers/visualizations.py b/redash/handlers/visualizations.py index 35c9d6610c..1621ea50cd 100644 --- a/redash/handlers/visualizations.py +++ b/redash/handlers/visualizations.py @@ -3,21 +3,22 @@ from redash import models from redash.handlers.base import BaseResource, get_object_or_404 from redash.serializers import serialize_visualization -from redash.permissions import (require_object_modify_permission, - require_permission) +from redash.permissions import require_object_modify_permission, require_permission from redash.utils import json_dumps class VisualizationListResource(BaseResource): - @require_permission('edit_query') + @require_permission("edit_query") def post(self): kwargs = request.get_json(force=True) - query = get_object_or_404(models.Query.get_by_id_and_org, kwargs.pop('query_id'), self.current_org) + query = get_object_or_404( + models.Query.get_by_id_and_org, kwargs.pop("query_id"), self.current_org + ) require_object_modify_permission(query, self.current_user) - kwargs['options'] = json_dumps(kwargs['options']) - kwargs['query_rel'] = query + kwargs["options"] = json_dumps(kwargs["options"]) + kwargs["query_rel"] = query vis = models.Visualization(**kwargs) models.db.session.add(vis) @@ -26,31 +27,37 @@ def post(self): class VisualizationResource(BaseResource): - @require_permission('edit_query') + @require_permission("edit_query") def post(self, visualization_id): - vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) + vis = get_object_or_404( + models.Visualization.get_by_id_and_org, visualization_id, self.current_org + ) require_object_modify_permission(vis.query_rel, self.current_user) kwargs = request.get_json(force=True) - if 'options' in kwargs: - kwargs['options'] = json_dumps(kwargs['options']) + if "options" in kwargs: + kwargs["options"] = json_dumps(kwargs["options"]) - kwargs.pop('id', None) - kwargs.pop('query_id', None) + kwargs.pop("id", None) + kwargs.pop("query_id", None) self.update_model(vis, kwargs) d = serialize_visualization(vis, with_query=False) models.db.session.commit() return d - @require_permission('edit_query') + @require_permission("edit_query") def delete(self, visualization_id): - vis = get_object_or_404(models.Visualization.get_by_id_and_org, visualization_id, self.current_org) + vis = get_object_or_404( + models.Visualization.get_by_id_and_org, visualization_id, self.current_org + ) require_object_modify_permission(vis.query_rel, self.current_user) - self.record_event({ - 'action': 'delete', - 'object_id': visualization_id, - 'object_type': 'Visualization' - }) + self.record_event( + { + "action": "delete", + "object_id": visualization_id, + "object_type": "Visualization", + } + ) models.db.session.delete(vis) models.db.session.commit() diff --git a/redash/handlers/webpack.py b/redash/handlers/webpack.py index fec0d72abb..01a0342549 100644 --- a/redash/handlers/webpack.py +++ b/redash/handlers/webpack.py @@ -2,27 +2,27 @@ import simplejson from flask import url_for -WEBPACK_MANIFEST_PATH = os.path.join(os.path.dirname(__file__), '../../client/dist/', 'asset-manifest.json') +WEBPACK_MANIFEST_PATH = os.path.join( + os.path.dirname(__file__), "../../client/dist/", "asset-manifest.json" +) def configure_webpack(app): - app.extensions['webpack'] = {'assets': None} + app.extensions["webpack"] = {"assets": None} def get_asset(path): - assets = app.extensions['webpack']['assets'] + assets = app.extensions["webpack"]["assets"] # in debug we read in this file each request if assets is None or app.debug: try: with open(WEBPACK_MANIFEST_PATH) as fp: assets = simplejson.load(fp) except IOError: - app.logger.exception('Unable to load webpack manifest') + app.logger.exception("Unable to load webpack manifest") assets = {} - app.extensions['webpack']['assets'] = assets - return url_for('static', filename=assets.get(path, path)) + app.extensions["webpack"]["assets"] = assets + return url_for("static", filename=assets.get(path, path)) @app.context_processor def webpack_assets(): - return { - 'asset_url': get_asset, - } + return {"asset_url": get_asset} diff --git a/redash/handlers/widgets.py b/redash/handlers/widgets.py index 85be9c0cbb..6907943405 100644 --- a/redash/handlers/widgets.py +++ b/redash/handlers/widgets.py @@ -3,14 +3,17 @@ from redash import models from redash.handlers.base import BaseResource from redash.serializers import serialize_widget -from redash.permissions import (require_access, - require_object_modify_permission, - require_permission, view_only) +from redash.permissions import ( + require_access, + require_object_modify_permission, + require_permission, + view_only, +) from redash.utils import json_dumps class WidgetListResource(BaseResource): - @require_permission('edit_dashboard') + @require_permission("edit_dashboard") def post(self): """ Add a widget to a dashboard. @@ -24,20 +27,24 @@ def post(self): :>json object widget: The created widget """ widget_properties = request.get_json(force=True) - dashboard = models.Dashboard.get_by_id_and_org(widget_properties.get('dashboard_id'), self.current_org) + dashboard = models.Dashboard.get_by_id_and_org( + widget_properties.get("dashboard_id"), self.current_org + ) require_object_modify_permission(dashboard, self.current_user) - widget_properties['options'] = json_dumps(widget_properties['options']) - widget_properties.pop('id', None) + widget_properties["options"] = json_dumps(widget_properties["options"]) + widget_properties.pop("id", None) - visualization_id = widget_properties.pop('visualization_id') + visualization_id = widget_properties.pop("visualization_id") if visualization_id: - visualization = models.Visualization.get_by_id_and_org(visualization_id, self.current_org) + visualization = models.Visualization.get_by_id_and_org( + visualization_id, self.current_org + ) require_access(visualization.query_rel, self.current_user, view_only) else: visualization = None - widget_properties['visualization'] = visualization + widget_properties["visualization"] = visualization widget = models.Widget(**widget_properties) models.db.session.add(widget) @@ -48,7 +55,7 @@ def post(self): class WidgetResource(BaseResource): - @require_permission('edit_dashboard') + @require_permission("edit_dashboard") def post(self, widget_id): """ Updates a widget in a dashboard. @@ -61,12 +68,12 @@ def post(self, widget_id): widget = models.Widget.get_by_id_and_org(widget_id, self.current_org) require_object_modify_permission(widget.dashboard, self.current_user) widget_properties = request.get_json(force=True) - widget.text = widget_properties['text'] - widget.options = json_dumps(widget_properties['options']) + widget.text = widget_properties["text"] + widget.options = json_dumps(widget_properties["options"]) models.db.session.commit() return serialize_widget(widget) - @require_permission('edit_dashboard') + @require_permission("edit_dashboard") def delete(self, widget_id): """ Remove a widget from a dashboard. @@ -75,10 +82,8 @@ def delete(self, widget_id): """ widget = models.Widget.get_by_id_and_org(widget_id, self.current_org) require_object_modify_permission(widget.dashboard, self.current_user) - self.record_event({ - 'action': 'delete', - 'object_id': widget_id, - 'object_type': 'widget', - }) + self.record_event( + {"action": "delete", "object_id": widget_id, "object_type": "widget"} + ) models.db.session.delete(widget) models.db.session.commit() diff --git a/redash/metrics/celery.py b/redash/metrics/celery.py index b7e0e2b645..91e80bc03c 100644 --- a/redash/metrics/celery.py +++ b/redash/metrics/celery.py @@ -1,11 +1,10 @@ - - import logging import socket import time from redash import settings from celery.concurrency import asynpool + asynpool.PROC_ALIVE_TIMEOUT = settings.CELERY_INIT_TIMEOUT from celery.signals import task_postrun, task_prerun @@ -34,23 +33,29 @@ def metric_name(name, tags): @task_postrun.connect -def task_postrun_handler(signal, sender, task_id, task, args, kwargs, retval, state, **kw): +def task_postrun_handler( + signal, sender, task_id, task, args, kwargs, retval, state, **kw +): try: run_time = 1000 * (time.time() - tasks_start_time.pop(task_id)) - state = (state or 'unknown').lower() - tags = {'state': state, 'hostname': socket.gethostname()} - if task.name == 'redash.tasks.execute_query': + state = (state or "unknown").lower() + tags = {"state": state, "hostname": socket.gethostname()} + if task.name == "redash.tasks.execute_query": if isinstance(retval, Exception): - tags['state'] = 'exception' - state = 'exception' + tags["state"] = "exception" + state = "exception" - tags['data_source_id'] = args[1] + tags["data_source_id"] = args[1] - normalized_task_name = task.name.replace('redash.tasks.', '').replace('.', '_') + normalized_task_name = task.name.replace("redash.tasks.", "").replace(".", "_") metric = "celery.task_runtime.{}".format(normalized_task_name) - logging.debug("metric=%s", json_dumps({'metric': metric, 'tags': tags, 'value': run_time})) + logging.debug( + "metric=%s", json_dumps({"metric": metric, "tags": tags, "value": run_time}) + ) statsd_client.timing(metric_name(metric, tags), run_time) - statsd_client.incr(metric_name('celery.task.{}.{}'.format(normalized_task_name, state), tags)) + statsd_client.incr( + metric_name("celery.task.{}.{}".format(normalized_task_name, state), tags) + ) except Exception: logging.exception("Exception during task_postrun handler.") diff --git a/redash/metrics/database.py b/redash/metrics/database.py index 6ab1f4cb3b..8b12765fe9 100644 --- a/redash/metrics/database.py +++ b/redash/metrics/database.py @@ -26,21 +26,21 @@ def _table_name_from_select_element(elt): @listens_for(Engine, "before_execute") def before_execute(conn, elt, multiparams, params): - conn.info.setdefault('query_start_time', []).append(time.time()) + conn.info.setdefault("query_start_time", []).append(time.time()) @listens_for(Engine, "after_execute") def after_execute(conn, elt, multiparams, params, result): - duration = 1000 * (time.time() - conn.info['query_start_time'].pop(-1)) + duration = 1000 * (time.time() - conn.info["query_start_time"].pop(-1)) action = elt.__class__.__name__ - if action == 'Select': - name = 'unknown' + if action == "Select": + name = "unknown" try: name = _table_name_from_select_element(elt) except Exception: - logging.exception('Failed finding table name.') - elif action in ['Update', 'Insert', 'Delete']: + logging.exception("Failed finding table name.") + elif action in ["Update", "Insert", "Delete"]: name = elt.table.name else: # create/drop tables, sqlalchemy internal schema queries, etc @@ -48,13 +48,12 @@ def after_execute(conn, elt, multiparams, params, result): action = action.lower() - statsd_client.timing('db.{}.{}'.format(name, action), duration) - metrics_logger.debug("table=%s query=%s duration=%.2f", name, action, - duration) + statsd_client.timing("db.{}.{}".format(name, action), duration) + metrics_logger.debug("table=%s query=%s duration=%.2f", name, action, duration) if has_request_context(): - g.setdefault('queries_count', 0) - g.setdefault('queries_duration', 0) + g.setdefault("queries_count", 0) + g.setdefault("queries_duration", 0) g.queries_count += 1 g.queries_duration += duration diff --git a/redash/metrics/request.py b/redash/metrics/request.py index 02ad1ba493..7f94da4ad9 100644 --- a/redash/metrics/request.py +++ b/redash/metrics/request.py @@ -14,36 +14,42 @@ def record_requets_start_time(): def calculate_metrics(response): - if 'start_time' not in g: + if "start_time" not in g: return response request_duration = (time.time() - g.start_time) * 1000 - queries_duration = g.get('queries_duration', 0.0) - queries_count = g.get('queries_count', 0.0) - endpoint = (request.endpoint or 'unknown').replace('.', '_') - - metrics_logger.info("method=%s path=%s endpoint=%s status=%d content_type=%s content_length=%d duration=%.2f query_count=%d query_duration=%.2f", - request.method, - request.path, - endpoint, - response.status_code, - response.content_type, - response.content_length or -1, - request_duration, - queries_count, - queries_duration) - - statsd_client.timing('requests.{}.{}'.format(endpoint, request.method.lower()), request_duration) + queries_duration = g.get("queries_duration", 0.0) + queries_count = g.get("queries_count", 0.0) + endpoint = (request.endpoint or "unknown").replace(".", "_") + + metrics_logger.info( + "method=%s path=%s endpoint=%s status=%d content_type=%s content_length=%d duration=%.2f query_count=%d query_duration=%.2f", + request.method, + request.path, + endpoint, + response.status_code, + response.content_type, + response.content_length or -1, + request_duration, + queries_count, + queries_duration, + ) + + statsd_client.timing( + "requests.{}.{}".format(endpoint, request.method.lower()), request_duration + ) return response -MockResponse = namedtuple('MockResponse', ['status_code', 'content_type', 'content_length']) +MockResponse = namedtuple( + "MockResponse", ["status_code", "content_type", "content_length"] +) def calculate_metrics_on_exception(error): if error is not None: - calculate_metrics(MockResponse(500, '?', -1)) + calculate_metrics(MockResponse(500, "?", -1)) def init_app(app): diff --git a/redash/models/__init__.py b/redash/models/__init__.py index 4e7ca7bdf8..e6d909d7cf 100644 --- a/redash/models/__init__.py +++ b/redash/models/__init__.py @@ -19,12 +19,25 @@ from sqlalchemy_utils.types.encrypted.encrypted_type import FernetEngine from redash import redis_connection, utils, settings -from redash.destinations import (get_configuration_schema_for_destination_type, - get_destination) +from redash.destinations import ( + get_configuration_schema_for_destination_type, + get_destination, +) from redash.metrics import database # noqa: F401 -from redash.query_runner import (get_configuration_schema_for_query_runner_type, - get_query_runner, TYPE_BOOLEAN, TYPE_DATE, TYPE_DATETIME) -from redash.utils import generate_token, json_dumps, json_loads, mustache_render, base_url +from redash.query_runner import ( + get_configuration_schema_for_query_runner_type, + get_query_runner, + TYPE_BOOLEAN, + TYPE_DATE, + TYPE_DATETIME, +) +from redash.utils import ( + generate_token, + json_dumps, + json_loads, + mustache_render, + base_url, +) from redash.utils.configuration import ConfigurationContainer from redash.models.parameterized_query import ParameterizedQuery @@ -32,14 +45,20 @@ from .changes import ChangeTrackingMixin, Change # noqa from .mixins import BelongsToOrgMixin, TimestampMixin from .organizations import Organization -from .types import EncryptedConfiguration, Configuration, MutableDict, MutableList, PseudoJSON -from .users import (AccessPermission, AnonymousUser, ApiUser, Group, User) # noqa +from .types import ( + EncryptedConfiguration, + Configuration, + MutableDict, + MutableList, + PseudoJSON, +) +from .users import AccessPermission, AnonymousUser, ApiUser, Group, User # noqa logger = logging.getLogger(__name__) class ScheduledQueriesExecutions(object): - KEY_NAME = 'sq:executed_at' + KEY_NAME = "sq:executed_at" def __init__(self): self.executions = {} @@ -48,9 +67,7 @@ def refresh(self): self.executions = redis_connection.hgetall(self.KEY_NAME) def update(self, query_id): - redis_connection.hmset(self.KEY_NAME, { - query_id: time.time() - }) + redis_connection.hmset(self.KEY_NAME, {query_id: time.time()}) def get(self, query_id): timestamp = self.executions.get(str(query_id)) @@ -63,23 +80,31 @@ def get(self, query_id): scheduled_queries_executions = ScheduledQueriesExecutions() -@generic_repr('id', 'name', 'type', 'org_id', 'created_at') +@generic_repr("id", "name", "type", "org_id", "created_at") class DataSource(BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship(Organization, backref="data_sources") name = Column(db.String(255)) type = Column(db.String(255)) - options = Column('encrypted_options', ConfigurationContainer.as_mutable(EncryptedConfiguration(db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine))) + options = Column( + "encrypted_options", + ConfigurationContainer.as_mutable( + EncryptedConfiguration( + db.Text, settings.DATASOURCE_SECRET_KEY, FernetEngine + ) + ), + ) queue_name = Column(db.String(255), default="queries") scheduled_queue_name = Column(db.String(255), default="scheduled_queries") created_at = Column(db.DateTime(True), default=db.func.now()) - data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source", - cascade="all") - __tablename__ = 'data_sources' - __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) + data_source_groups = db.relationship( + "DataSourceGroup", back_populates="data_source", cascade="all" + ) + __tablename__ = "data_sources" + __table_args__ = (db.Index("data_sources_org_id_name", "org_id", "name"),) def __eq__(self, other): return self.id == other.id @@ -89,26 +114,31 @@ def __hash__(self): def to_dict(self, all=False, with_permissions_for=None): d = { - 'id': self.id, - 'name': self.name, - 'type': self.type, - 'syntax': self.query_runner.syntax, - 'paused': self.paused, - 'pause_reason': self.pause_reason + "id": self.id, + "name": self.name, + "type": self.type, + "syntax": self.query_runner.syntax, + "paused": self.paused, + "pause_reason": self.pause_reason, } if all: schema = get_configuration_schema_for_query_runner_type(self.type) self.options.set_schema(schema) - d['options'] = self.options.to_dict(mask_secrets=True) - d['queue_name'] = self.queue_name - d['scheduled_queue_name'] = self.scheduled_queue_name - d['groups'] = self.groups + d["options"] = self.options.to_dict(mask_secrets=True) + d["queue_name"] = self.queue_name + d["scheduled_queue_name"] = self.scheduled_queue_name + d["groups"] = self.groups if with_permissions_for is not None: - d['view_only'] = db.session.query(DataSourceGroup.view_only).filter( - DataSourceGroup.group == with_permissions_for, - DataSourceGroup.data_source == self).one()[0] + d["view_only"] = ( + db.session.query(DataSourceGroup.view_only) + .filter( + DataSourceGroup.group == with_permissions_for, + DataSourceGroup.data_source == self, + ) + .one()[0] + ) return d @@ -119,8 +149,8 @@ def __str__(self): def create_with_group(cls, *args, **kwargs): data_source = cls(*args, **kwargs) data_source_group = DataSourceGroup( - data_source=data_source, - group=data_source.org.default_group) + data_source=data_source, group=data_source.org.default_group + ) db.session.add_all([data_source, data_source_group]) return data_source @@ -130,7 +160,8 @@ def all(cls, org, group_ids=None): if group_ids: data_sources = data_sources.join(DataSourceGroup).filter( - DataSourceGroup.group_id.in_(group_ids)) + DataSourceGroup.group_id.in_(group_ids) + ) return data_sources.distinct() @@ -139,7 +170,9 @@ def get_by_id(cls, _id): return cls.query.filter(cls.id == _id).one() def delete(self): - Query.query.filter(Query.data_source == self).update(dict(data_source_id=None, latest_query_data_id=None)) + Query.query.filter(Query.data_source == self).update( + dict(data_source_id=None, latest_query_data_id=None) + ) QueryResult.query.filter(QueryResult.data_source == self).delete() res = db.session.delete(self) db.session.commit() @@ -155,7 +188,9 @@ def get_schema(self, refresh=False): if cache is None: query_runner = self.query_runner - schema = sorted(query_runner.get_schema(get_stats=refresh), key=lambda t: t['name']) + schema = sorted( + query_runner.get_schema(get_stats=refresh), key=lambda t: t["name"] + ) redis_connection.set(self._schema_key, json_dumps(schema)) else: @@ -169,7 +204,7 @@ def _schema_key(self): @property def _pause_key(self): - return 'ds:{}:pause'.format(self.id) + return "ds:{}:pause".format(self.id) @property def paused(self): @@ -180,7 +215,7 @@ def pause_reason(self): return redis_connection.get(self._pause_key) def pause(self, reason=None): - redis_connection.set(self._pause_key, reason or '') + redis_connection.set(self._pause_key, reason or "") def resume(self): redis_connection.delete(self._pause_key) @@ -192,15 +227,14 @@ def add_group(self, group, view_only=False): def remove_group(self, group): DataSourceGroup.query.filter( - DataSourceGroup.group == group, - DataSourceGroup.data_source == self + DataSourceGroup.group == group, DataSourceGroup.data_source == self ).delete() db.session.commit() def update_group_permission(self, group, view_only): dsg = DataSourceGroup.query.filter( - DataSourceGroup.group == group, - DataSourceGroup.data_source == self).one() + DataSourceGroup.group == group, DataSourceGroup.data_source == self + ).one() dsg.view_only = view_only db.session.add(dsg) return dsg @@ -216,13 +250,11 @@ def get_by_name(cls, name): # XXX examine call sites to see if a regular SQLA collection would work better @property def groups(self): - groups = DataSourceGroup.query.filter( - DataSourceGroup.data_source == self - ) + groups = DataSourceGroup.query.filter(DataSourceGroup.data_source == self) return dict([(group.group_id, group.view_only) for group in groups]) -@generic_repr('id', 'data_source_id', 'group_id', 'view_only') +@generic_repr("id", "data_source_id", "group_id", "view_only") class DataSourceGroup(db.Model): # XXX drop id, use datasource/group as PK id = Column(db.Integer, primary_key=True) @@ -235,7 +267,8 @@ class DataSourceGroup(db.Model): __tablename__ = "data_source_groups" -DESERIALIZED_DATA_ATTR = '_deserialized_data' +DESERIALIZED_DATA_ATTR = "_deserialized_data" + class DBPersistence(object): @property @@ -255,35 +288,38 @@ def data(self, data): self._data = data -QueryResultPersistence = settings.dynamic_settings.QueryResultPersistence or DBPersistence +QueryResultPersistence = ( + settings.dynamic_settings.QueryResultPersistence or DBPersistence +) + -@generic_repr('id', 'org_id', 'data_source_id', 'query_hash', 'runtime', 'retrieved_at') +@generic_repr("id", "org_id", "data_source_id", "query_hash", "runtime", "retrieved_at") class QueryResult(db.Model, QueryResultPersistence, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship(Organization) data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) - data_source = db.relationship(DataSource, backref=backref('query_results')) + data_source = db.relationship(DataSource, backref=backref("query_results")) query_hash = Column(db.String(32), index=True) - query_text = Column('query', db.Text) - _data = Column('data', db.Text) + query_text = Column("query", db.Text) + _data = Column("data", db.Text) runtime = Column(postgresql.DOUBLE_PRECISION) retrieved_at = Column(db.DateTime(True)) - __tablename__ = 'query_results' + __tablename__ = "query_results" def __str__(self): return "%d | %s | %s" % (self.id, self.query_hash, self.retrieved_at) def to_dict(self): return { - 'id': self.id, - 'query_hash': self.query_hash, - 'query': self.query_text, - 'data': self.data, - 'data_source_id': self.data_source_id, - 'runtime': self.runtime, - 'retrieved_at': self.retrieved_at + "id": self.id, + "query_hash": self.query_hash, + "query": self.query_text, + "data": self.data, + "data_source_id": self.data_source_id, + "runtime": self.runtime, + "retrieved_at": self.retrieved_at, } @classmethod @@ -291,11 +327,9 @@ def unused(cls, days=7): age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) return ( cls.query.filter( - Query.id.is_(None), - cls.retrieved_at < age_threshold - ) - .outerjoin(Query) - ).options(load_only('id')) + Query.id.is_(None), cls.retrieved_at < age_threshold + ).outerjoin(Query) + ).options(load_only("id")) @classmethod def get_latest(cls, data_source, query, max_age=0): @@ -303,32 +337,34 @@ def get_latest(cls, data_source, query, max_age=0): if max_age == -1: query = cls.query.filter( - cls.query_hash == query_hash, - cls.data_source == data_source + cls.query_hash == query_hash, cls.data_source == data_source ) else: query = cls.query.filter( cls.query_hash == query_hash, cls.data_source == data_source, ( - db.func.timezone('utc', cls.retrieved_at) + - datetime.timedelta(seconds=max_age) >= - db.func.timezone('utc', db.func.now()) - ) + db.func.timezone("utc", cls.retrieved_at) + + datetime.timedelta(seconds=max_age) + >= db.func.timezone("utc", db.func.now()) + ), ) return query.order_by(cls.retrieved_at.desc()).first() @classmethod - def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): - query_result = cls(org_id=org, - query_hash=query_hash, - query_text=query, - runtime=run_time, - data_source=data_source, - retrieved_at=retrieved_at, - data=data) - + def store_result( + cls, org, data_source, query_hash, query, data, run_time, retrieved_at + ): + query_result = cls( + org_id=org, + query_hash=query_hash, + query_text=query, + runtime=run_time, + data_source=data_source, + retrieved_at=retrieved_at, + data=data, + ) db.session.add(query_result) logging.info("Inserted query (%s) data; id=%s", query_hash, query_result.id) @@ -340,53 +376,79 @@ def groups(self): return self.data_source.groups -def should_schedule_next(previous_iteration, now, interval, time=None, day_of_week=None, failures=0): +def should_schedule_next( + previous_iteration, now, interval, time=None, day_of_week=None, failures=0 +): # if time exists then interval > 23 hours (82800s) # if day_of_week exists then interval > 6 days (518400s) - if (time is None): + if time is None: ttl = int(interval) next_iteration = previous_iteration + datetime.timedelta(seconds=ttl) else: - hour, minute = time.split(':') + hour, minute = time.split(":") hour, minute = int(hour), int(minute) # The following logic is needed for cases like the following: # - The query scheduled to run at 23:59. # - The scheduler wakes up at 00:01. # - Using naive implementation of comparing timestamps, it will skip the execution. - normalized_previous_iteration = previous_iteration.replace(hour=hour, minute=minute) + normalized_previous_iteration = previous_iteration.replace( + hour=hour, minute=minute + ) if normalized_previous_iteration > previous_iteration: - previous_iteration = normalized_previous_iteration - datetime.timedelta(days=1) + previous_iteration = normalized_previous_iteration - datetime.timedelta( + days=1 + ) days_delay = int(interval) / 60 / 60 / 24 days_to_add = 0 - if (day_of_week is not None): - days_to_add = list(calendar.day_name).index(day_of_week) - normalized_previous_iteration.weekday() + if day_of_week is not None: + days_to_add = ( + list(calendar.day_name).index(day_of_week) + - normalized_previous_iteration.weekday() + ) - next_iteration = (previous_iteration + datetime.timedelta(days=days_delay) + - datetime.timedelta(days=days_to_add)).replace(hour=hour, minute=minute) + next_iteration = ( + previous_iteration + + datetime.timedelta(days=days_delay) + + datetime.timedelta(days=days_to_add) + ).replace(hour=hour, minute=minute) if failures: try: - next_iteration += datetime.timedelta(minutes=2**failures) + next_iteration += datetime.timedelta(minutes=2 ** failures) except OverflowError: return False return now > next_iteration @gfk_type -@generic_repr('id', 'name', 'query_hash', 'version', 'user_id', 'org_id', - 'data_source_id', 'query_hash', 'last_modified_by_id', - 'is_archived', 'is_draft', 'schedule', 'schedule_failures') +@generic_repr( + "id", + "name", + "query_hash", + "version", + "user_id", + "org_id", + "data_source_id", + "query_hash", + "last_modified_by_id", + "is_archived", + "is_draft", + "schedule", + "schedule_failures", +) class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) version = Column(db.Integer, default=1) - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship(Organization, backref="queries") data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id"), nullable=True) - data_source = db.relationship(DataSource, backref='queries') - latest_query_data_id = Column(db.Integer, db.ForeignKey("query_results.id"), nullable=True) + data_source = db.relationship(DataSource, backref="queries") + latest_query_data_id = Column( + db.Integer, db.ForeignKey("query_results.id"), nullable=True + ) latest_query_data = db.relationship(QueryResult) name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) @@ -395,29 +457,33 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): api_key = Column(db.String(40), default=lambda: generate_token(40)) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, foreign_keys=[user_id]) - last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True) - last_modified_by = db.relationship(User, backref="modified_queries", - foreign_keys=[last_modified_by_id]) + last_modified_by_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) + last_modified_by = db.relationship( + User, backref="modified_queries", foreign_keys=[last_modified_by_id] + ) is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) schedule = Column(MutableDict.as_mutable(PseudoJSON), nullable=True) schedule_failures = Column(db.Integer, default=0) visualizations = db.relationship("Visualization", cascade="all, delete-orphan") options = Column(MutableDict.as_mutable(PseudoJSON), default={}) - search_vector = Column(TSVectorType('id', 'name', 'description', 'query', - weights={'name': 'A', - 'id': 'B', - 'description': 'C', - 'query': 'D'}), - nullable=True) - tags = Column('tags', MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True) + search_vector = Column( + TSVectorType( + "id", + "name", + "description", + "query", + weights={"name": "A", "id": "B", "description": "C", "query": "D"}, + ), + nullable=True, + ) + tags = Column( + "tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True + ) query_class = SearchBaseQuery - __tablename__ = 'queries' - __mapper_args__ = { - "version_id_col": version, - 'version_id_generator': False - } + __tablename__ = "queries" + __mapper_args__ = {"version_id_col": version, "version_id_generator": False} def __str__(self): return text_type(self.id) @@ -443,56 +509,48 @@ def regenerate_api_key(self): @classmethod def create(cls, **kwargs): query = cls(**kwargs) - db.session.add(Visualization(query_rel=query, - name="Table", - description='', - type="TABLE", - options="{}")) + db.session.add( + Visualization( + query_rel=query, + name="Table", + description="", + type="TABLE", + options="{}", + ) + ) return query @classmethod - def all_queries(cls, group_ids, user_id=None, include_drafts=False, include_archived=False): + def all_queries( + cls, group_ids, user_id=None, include_drafts=False, include_archived=False + ): query_ids = ( - db.session - .query(distinct(cls.id)) + db.session.query(distinct(cls.id)) .join( - DataSourceGroup, - Query.data_source_id == DataSourceGroup.data_source_id + DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id ) .filter(Query.is_archived.is_(include_archived)) .filter(DataSourceGroup.group_id.in_(group_ids)) ) queries = ( - cls - .query - .options( + cls.query.options( joinedload(Query.user), - joinedload( - Query.latest_query_data - ).load_only( - 'runtime', - 'retrieved_at', - ) + joinedload(Query.latest_query_data).load_only( + "runtime", "retrieved_at" + ), ) .filter(cls.id.in_(query_ids)) # Adding outer joins to be able to order by relationship .outerjoin(User, User.id == Query.user_id) - .outerjoin( - QueryResult, - QueryResult.id == Query.latest_query_data_id - ) + .outerjoin(QueryResult, QueryResult.id == Query.latest_query_data_id) .options( - contains_eager(Query.user), - contains_eager(Query.latest_query_data), + contains_eager(Query.user), contains_eager(Query.latest_query_data) ) ) if not include_drafts: queries = queries.filter( - or_( - Query.is_draft.is_(False), - Query.user_id == user_id - ) + or_(Query.is_draft.is_(False), Query.user_id == user_id) ) return queries @@ -500,30 +558,26 @@ def all_queries(cls, group_ids, user_id=None, include_drafts=False, include_arch def favorites(cls, user, base_query=None): if base_query is None: base_query = cls.all_queries(user.group_ids, user.id, include_drafts=True) - return base_query.join(( - Favorite, - and_( - Favorite.object_type == 'Query', - Favorite.object_id == Query.id + return base_query.join( + ( + Favorite, + and_(Favorite.object_type == "Query", Favorite.object_id == Query.id), ) - )).filter(Favorite.user_id == user.id) + ).filter(Favorite.user_id == user.id) @classmethod def all_tags(cls, user, include_drafts=False): queries = cls.all_queries( - group_ids=user.group_ids, - user_id=user.id, - include_drafts=include_drafts, + group_ids=user.group_ids, user_id=user.id, include_drafts=include_drafts ) - tag_column = func.unnest(cls.tags).label('tag') - usage_count = func.count(1).label('usage_count') + tag_column = func.unnest(cls.tags).label("tag") + usage_count = func.count(1).label("usage_count") query = ( - db.session - .query(tag_column, usage_count) + db.session.query(tag_column, usage_count) .group_by(tag_column) - .filter(Query.id.in_(queries.options(load_only('id')))) + .filter(Query.id.in_(queries.options(load_only("id")))) .order_by(usage_count.desc()) ) return query @@ -539,20 +593,23 @@ def by_api_key(cls, api_key): @classmethod def past_scheduled_queries(cls): now = utils.utcnow() - queries = ( - Query.query - .filter(Query.schedule.isnot(None)) - .order_by(Query.id) - ) - return [query for query in queries if query.schedule["until"] is not None and pytz.utc.localize( - datetime.datetime.strptime(query.schedule['until'], '%Y-%m-%d') - ) <= now] + queries = Query.query.filter(Query.schedule.isnot(None)).order_by(Query.id) + return [ + query + for query in queries + if query.schedule["until"] is not None + and pytz.utc.localize( + datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d") + ) + <= now + ] @classmethod def outdated_queries(cls): queries = ( - Query.query - .options(joinedload(Query.latest_query_data).load_only('retrieved_at')) + Query.query.options( + joinedload(Query.latest_query_data).load_only("retrieved_at") + ) .filter(Query.schedule.isnot(None)) .order_by(Query.id) ) @@ -562,11 +619,13 @@ def outdated_queries(cls): scheduled_queries_executions.refresh() for query in queries: - if query.schedule['interval'] is None: + if query.schedule["interval"] is None: continue - if query.schedule['until'] is not None: - schedule_until = pytz.utc.localize(datetime.datetime.strptime(query.schedule['until'], '%Y-%m-%d')) + if query.schedule["until"] is not None: + schedule_until = pytz.utc.localize( + datetime.datetime.strptime(query.schedule["until"], "%Y-%m-%d") + ) if schedule_until <= now: continue @@ -578,16 +637,30 @@ def outdated_queries(cls): retrieved_at = scheduled_queries_executions.get(query.id) or retrieved_at - if should_schedule_next(retrieved_at, now, query.schedule['interval'], query.schedule['time'], - query.schedule['day_of_week'], query.schedule_failures): + if should_schedule_next( + retrieved_at, + now, + query.schedule["interval"], + query.schedule["time"], + query.schedule["day_of_week"], + query.schedule_failures, + ): key = "{}:{}".format(query.query_hash, query.data_source_id) outdated_queries[key] = query return list(outdated_queries.values()) @classmethod - def search(cls, term, group_ids, user_id=None, include_drafts=False, - limit=None, include_archived=False, multi_byte_search=False): + def search( + cls, + term, + group_ids, + user_id=None, + include_drafts=False, + limit=None, + include_archived=False, + multi_byte_search=False, + ): all_queries = cls.all_queries( group_ids, user_id=user_id, @@ -597,13 +670,14 @@ def search(cls, term, group_ids, user_id=None, include_drafts=False, if multi_byte_search: # Since tsvector doesn't work well with CJK languages, use `ilike` too - pattern = '%{}%'.format(term) - return all_queries.filter( - or_( - cls.name.ilike(pattern), - cls.description.ilike(pattern) + pattern = "%{}%".format(term) + return ( + all_queries.filter( + or_(cls.name.ilike(pattern), cls.description.ilike(pattern)) ) - ).order_by(Query.id).limit(limit) + .order_by(Query.id) + .limit(limit) + ) # sort the result using the weight as defined in the search vector column return all_queries.search(term, sort=True).limit(limit) @@ -614,20 +688,25 @@ def search_by_user(cls, term, user, limit=None): @classmethod def recent(cls, group_ids, user_id=None, limit=20): - query = (cls.query - .filter(Event.created_at > (db.func.current_date() - 7)) - .join(Event, Query.id == Event.object_id.cast(db.Integer)) - .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) - .filter( - Event.action.in_(['edit', 'execute', 'edit_name', - 'edit_description', 'view_source']), - Event.object_id != None, - Event.object_type == 'query', - DataSourceGroup.group_id.in_(group_ids), - or_(Query.is_draft == False, Query.user_id == user_id), - Query.is_archived == False) - .group_by(Event.object_id, Query.id) - .order_by(db.desc(db.func.count(0)))) + query = ( + cls.query.filter(Event.created_at > (db.func.current_date() - 7)) + .join(Event, Query.id == Event.object_id.cast(db.Integer)) + .join( + DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id + ) + .filter( + Event.action.in_( + ["edit", "execute", "edit_name", "edit_description", "view_source"] + ), + Event.object_id != None, + Event.object_type == "query", + DataSourceGroup.group_id.in_(group_ids), + or_(Query.is_draft == False, Query.user_id == user_id), + Query.is_archived == False, + ) + .group_by(Event.object_id, Query.id) + .order_by(db.desc(db.func.count(0))) + ) if user_id: query = query.filter(Event.user_id == user_id) @@ -647,14 +726,14 @@ def all_groups_for_query_ids(cls, query_ids): JOIN data_source_groups ON queries.data_source_id = data_source_groups.data_source_id WHERE queries.id in :ids""" - return db.session.execute(query, {'ids': tuple(query_ids)}).fetchall() + return db.session.execute(query, {"ids": tuple(query_ids)}).fetchall() @classmethod def update_latest_result(cls, query_result): # TODO: Investigate how big an impact this select-before-update makes. queries = Query.query.filter( Query.query_hash == query_result.query_hash, - Query.data_source == query_result.data_source + Query.data_source == query_result.data_source, ) for q in queries: @@ -664,23 +743,37 @@ def update_latest_result(cls, query_result): db.session.add(q) query_ids = [q.id for q in queries] - logging.info("Updated %s queries with result (%s).", len(query_ids), query_result.query_hash) + logging.info( + "Updated %s queries with result (%s).", + len(query_ids), + query_result.query_hash, + ) return query_ids - def fork(self, user): - forked_list = ['org', 'data_source', 'latest_query_data', 'description', - 'query_text', 'query_hash', 'options'] + forked_list = [ + "org", + "data_source", + "latest_query_data", + "description", + "query_text", + "query_hash", + "options", + ] kwargs = {a: getattr(self, a) for a in forked_list} # Query.create will add default TABLE visualization, so use constructor to create bare copy of query - forked_query = Query(name='Copy of (#{}) {}'.format(self.id, self.name), user=user, **kwargs) + forked_query = Query( + name="Copy of (#{}) {}".format(self.id, self.name), user=user, **kwargs + ) for v in sorted(self.visualizations, key=lambda v: v.id): forked_v = v.copy() - forked_v['query_rel'] = forked_query - fv = Visualization(**forked_v) # it will magically add it to `forked_query.visualizations` + forked_v["query_rel"] = forked_query + fv = Visualization( + **forked_v + ) # it will magically add it to `forked_query.visualizations` db.session.add(fv) db.session.add(forked_query) @@ -730,22 +823,22 @@ def dashboard_api_keys(self): AND active=true AND visualizations.query_id = :id""" - api_keys = db.session.execute(query, {'id': self.id}).fetchall() + api_keys = db.session.execute(query, {"id": self.id}).fetchall() return [api_key[0] for api_key in api_keys] -@listens_for(Query.query_text, 'set') +@listens_for(Query.query_text, "set") def gen_query_hash(target, val, oldval, initiator): target.query_hash = utils.gen_query_hash(val) target.schedule_failures = 0 -@listens_for(Query.user_id, 'set') +@listens_for(Query.user_id, "set") def query_last_modified_by(target, val, oldval, initiator): target.last_modified_by_id = val -@generic_repr('id', 'object_type', 'object_id', 'user_id', 'org_id') +@generic_repr("id", "object_type", "object_id", "user_id", "org_id") class Favorite(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -755,7 +848,7 @@ class Favorite(TimestampMixin, db.Model): object = generic_relationship(object_type, object_id) user_id = Column(db.Integer, db.ForeignKey("users.id")) - user = db.relationship(User, backref='favorites') + user = db.relationship(User, backref="favorites") __tablename__ = "favorites" __table_args__ = ( @@ -773,21 +866,27 @@ def are_favorites(cls, user, objects): return [] object_type = text_type(objects[0].__class__.__name__) - return [fav.object_id for fav in cls.query.filter(cls.object_id.in_([o.id for o in objects]), cls.object_type == object_type, cls.user_id == user)] + return [ + fav.object_id + for fav in cls.query.filter( + cls.object_id.in_([o.id for o in objects]), + cls.object_type == object_type, + cls.user_id == user, + ) + ] OPERATORS = { - '>': lambda v, t: v > t, - '>=': lambda v, t: v >= t, - '<': lambda v, t: v < t, - '<=': lambda v, t: v <= t, - '==': lambda v, t: v == t, - '!=': lambda v, t: v != t, - + ">": lambda v, t: v > t, + ">=": lambda v, t: v >= t, + "<": lambda v, t: v < t, + "<=": lambda v, t: v <= t, + "==": lambda v, t: v == t, + "!=": lambda v, t: v != t, # backward compatibility - 'greater than': lambda v, t: v > t, - 'less than': lambda v, t: v < t, - 'equals': lambda v, t: v == t, + "greater than": lambda v, t: v > t, + "less than": lambda v, t: v < t, + "equals": lambda v, t: v == t, } @@ -809,42 +908,39 @@ def next_state(op, value, threshold): new_state = Alert.TRIGGERED_STATE else: new_state = Alert.OK_STATE - + return new_state -@generic_repr('id', 'name', 'query_id', 'user_id', 'state', 'last_triggered_at', 'rearm') +@generic_repr( + "id", "name", "query_id", "user_id", "state", "last_triggered_at", "rearm" +) class Alert(TimestampMixin, BelongsToOrgMixin, db.Model): - UNKNOWN_STATE = 'unknown' - OK_STATE = 'ok' - TRIGGERED_STATE = 'triggered' + UNKNOWN_STATE = "unknown" + OK_STATE = "ok" + TRIGGERED_STATE = "triggered" id = Column(db.Integer, primary_key=True) name = Column(db.String(255)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) - query_rel = db.relationship(Query, backref=backref('alerts', cascade="all")) + query_rel = db.relationship(Query, backref=backref("alerts", cascade="all")) user_id = Column(db.Integer, db.ForeignKey("users.id")) - user = db.relationship(User, backref='alerts') + user = db.relationship(User, backref="alerts") options = Column(MutableDict.as_mutable(PseudoJSON)) state = Column(db.String(255), default=UNKNOWN_STATE) subscriptions = db.relationship("AlertSubscription", cascade="all, delete-orphan") last_triggered_at = Column(db.DateTime(True), nullable=True) rearm = Column(db.Integer, nullable=True) - __tablename__ = 'alerts' + __tablename__ = "alerts" @classmethod def all(cls, group_ids): return ( - cls.query - .options( - joinedload(Alert.user), - joinedload(Alert.query_rel), - ) + cls.query.options(joinedload(Alert.user), joinedload(Alert.query_rel)) .join(Query) .join( - DataSourceGroup, - DataSourceGroup.data_source_id == Query.data_source_id + DataSourceGroup, DataSourceGroup.data_source_id == Query.data_source_id ) .filter(DataSourceGroup.group_id.in_(group_ids)) ) @@ -856,11 +952,11 @@ def get_by_id_and_org(cls, object_id, org): def evaluate(self): data = self.query_rel.latest_query_data.data - if data['rows'] and self.options['column'] in data['rows'][0]: - op = OPERATORS.get(self.options['op'], lambda v, t: False) + if data["rows"] and self.options["column"] in data["rows"][0]: + op = OPERATORS.get(self.options["op"], lambda v, t: False) - value = data['rows'][0][self.options['column']] - threshold = self.options['value'] + value = data["rows"][0][self.options["column"]] + threshold = self.options["value"] new_state = next_state(op, value, threshold) else: @@ -869,43 +965,47 @@ def evaluate(self): return new_state def subscribers(self): - return User.query.join(AlertSubscription).filter(AlertSubscription.alert == self) + return User.query.join(AlertSubscription).filter( + AlertSubscription.alert == self + ) def render_template(self, template): if template is None: - return '' + return "" data = self.query_rel.latest_query_data.data host = base_url(self.query_rel.org) - col_name = self.options['column'] - if data['rows'] and col_name in data['rows'][0]: - result_value = data['rows'][0][col_name] + col_name = self.options["column"] + if data["rows"] and col_name in data["rows"][0]: + result_value = data["rows"][0][col_name] else: result_value = None context = { - 'ALERT_NAME': self.name, - 'ALERT_URL': '{host}/alerts/{alert_id}'.format(host=host, alert_id=self.id), - 'ALERT_STATUS': self.state.upper(), - 'ALERT_CONDITION': self.options['op'], - 'ALERT_THRESHOLD': self.options['value'], - 'QUERY_NAME': self.query_rel.name, - 'QUERY_URL': '{host}/queries/{query_id}'.format(host=host, query_id=self.query_rel.id), - 'QUERY_RESULT_VALUE': result_value, - 'QUERY_RESULT_ROWS': data['rows'], - 'QUERY_RESULT_COLS': data['columns'], + "ALERT_NAME": self.name, + "ALERT_URL": "{host}/alerts/{alert_id}".format(host=host, alert_id=self.id), + "ALERT_STATUS": self.state.upper(), + "ALERT_CONDITION": self.options["op"], + "ALERT_THRESHOLD": self.options["value"], + "QUERY_NAME": self.query_rel.name, + "QUERY_URL": "{host}/queries/{query_id}".format( + host=host, query_id=self.query_rel.id + ), + "QUERY_RESULT_VALUE": result_value, + "QUERY_RESULT_ROWS": data["rows"], + "QUERY_RESULT_COLS": data["columns"], } return mustache_render(template, context) @property def custom_body(self): - template = self.options.get('custom_body', self.options.get('template')) + template = self.options.get("custom_body", self.options.get("template")) return self.render_template(template) @property def custom_subject(self): - template = self.options.get('custom_subject') + template = self.options.get("custom_subject") return self.render_template(template) @property @@ -914,20 +1014,22 @@ def groups(self): @property def muted(self): - return self.options.get('muted', False) + return self.options.get("muted", False) def generate_slug(ctx): - slug = utils.slugify(ctx.current_parameters['name']) + slug = utils.slugify(ctx.current_parameters["name"]) tries = 1 while Dashboard.query.filter(Dashboard.slug == slug).first() is not None: - slug = utils.slugify(ctx.current_parameters['name']) + "_" + str(tries) + slug = utils.slugify(ctx.current_parameters["name"]) + "_" + str(tries) tries += 1 return slug @gfk_type -@generic_repr('id', 'name', 'slug', 'user_id', 'org_id', 'version', 'is_archived', 'is_draft') +@generic_repr( + "id", "name", "slug", "user_id", "org_id", "version", "is_archived", "is_draft" +) class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) version = Column(db.Integer) @@ -942,13 +1044,13 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) is_draft = Column(db.Boolean, default=True, index=True) - widgets = db.relationship('Widget', backref='dashboard', lazy='dynamic') - tags = Column('tags', MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True) + widgets = db.relationship("Widget", backref="dashboard", lazy="dynamic") + tags = Column( + "tags", MutableList.as_mutable(postgresql.ARRAY(db.Unicode)), nullable=True + ) - __tablename__ = 'dashboards' - __mapper_args__ = { - "version_id_col": version - } + __tablename__ = "dashboards" + __mapper_args__ = {"version_id_col": version} def __str__(self): return "%s=%s" % (self.id, self.name) @@ -956,43 +1058,51 @@ def __str__(self): @classmethod def all(cls, org, group_ids, user_id): query = ( - Dashboard.query - .options( - subqueryload(Dashboard.user).load_only('_profile_image_url', 'name'), + Dashboard.query.options( + subqueryload(Dashboard.user).load_only("_profile_image_url", "name") ) .outerjoin(Widget) .outerjoin(Visualization) .outerjoin(Query) - .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .outerjoin( + DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id + ) .filter( Dashboard.is_archived == False, - (DataSourceGroup.group_id.in_(group_ids) | - (Dashboard.user_id == user_id) | - ((Widget.dashboard != None) & (Widget.visualization == None))), - Dashboard.org == org) - .distinct()) + ( + DataSourceGroup.group_id.in_(group_ids) + | (Dashboard.user_id == user_id) + | ((Widget.dashboard != None) & (Widget.visualization == None)) + ), + Dashboard.org == org, + ) + .distinct() + ) - query = query.filter(or_(Dashboard.user_id == user_id, Dashboard.is_draft == False)) + query = query.filter( + or_(Dashboard.user_id == user_id, Dashboard.is_draft == False) + ) return query @classmethod def search(cls, org, groups_ids, user_id, search_term): # TODO: switch to FTS - return cls.all(org, groups_ids, user_id).filter(cls.name.ilike('%{}%'.format(search_term))) + return cls.all(org, groups_ids, user_id).filter( + cls.name.ilike("%{}%".format(search_term)) + ) @classmethod def all_tags(cls, org, user): dashboards = cls.all(org, user.group_ids, user.id) - tag_column = func.unnest(cls.tags).label('tag') - usage_count = func.count(1).label('usage_count') + tag_column = func.unnest(cls.tags).label("tag") + usage_count = func.count(1).label("usage_count") query = ( - db.session - .query(tag_column, usage_count) + db.session.query(tag_column, usage_count) .group_by(tag_column) - .filter(Dashboard.id.in_(dashboards.options(load_only('id')))) + .filter(Dashboard.id.in_(dashboards.options(load_only("id")))) .order_by(usage_count.desc()) ) return query @@ -1005,9 +1115,9 @@ def favorites(cls, user, base_query=None): ( Favorite, and_( - Favorite.object_type == 'Dashboard', - Favorite.object_id == Dashboard.id - ) + Favorite.object_type == "Dashboard", + Favorite.object_id == Dashboard.id, + ), ) ).filter(Favorite.user_id == user.id) @@ -1026,18 +1136,18 @@ def lowercase_name(cls): return func.lower(cls.name) -@generic_repr('id', 'name', 'type', 'query_id') +@generic_repr("id", "name", "type", "query_id") class Visualization(TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) type = Column(db.String(100)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) # query_rel and not query, because db.Model already has query defined. - query_rel = db.relationship(Query, back_populates='visualizations') + query_rel = db.relationship(Query, back_populates="visualizations") name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) options = Column(db.Text) - __tablename__ = 'visualizations' + __tablename__ = "visualizations" def __str__(self): return "%s %s" % (self.id, self.type) @@ -1048,24 +1158,28 @@ def get_by_id_and_org(cls, object_id, org): def copy(self): return { - 'type': self.type, - 'name': self.name, - 'description': self.description, - 'options': self.options + "type": self.type, + "name": self.name, + "description": self.description, + "options": self.options, } -@generic_repr('id', 'visualization_id', 'dashboard_id') +@generic_repr("id", "visualization_id", "dashboard_id") class Widget(TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) - visualization_id = Column(db.Integer, db.ForeignKey('visualizations.id'), nullable=True) - visualization = db.relationship(Visualization, backref=backref('widgets', cascade='delete')) + visualization_id = Column( + db.Integer, db.ForeignKey("visualizations.id"), nullable=True + ) + visualization = db.relationship( + Visualization, backref=backref("widgets", cascade="delete") + ) text = Column(db.Text, nullable=True) width = Column(db.Integer) options = Column(db.Text) dashboard_id = Column(db.Integer, db.ForeignKey("dashboards.id"), index=True) - __tablename__ = 'widgets' + __tablename__ = "widgets" def __str__(self): return "%s" % self.id @@ -1075,7 +1189,9 @@ def get_by_id_and_org(cls, object_id, org): return super(Widget, cls).get_by_id_and_org(object_id, org, Dashboard) -@generic_repr('id', 'object_type', 'object_id', 'action', 'user_id', 'org_id', 'created_at') +@generic_repr( + "id", "object_type", "object_id", "action", "user_id", "org_id", "created_at" +) class Event(db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1085,44 +1201,56 @@ class Event(db.Model): action = Column(db.String(255)) object_type = Column(db.String(255)) object_id = Column(db.String(255), nullable=True) - additional_properties = Column(MutableDict.as_mutable(PseudoJSON), nullable=True, default={}) + additional_properties = Column( + MutableDict.as_mutable(PseudoJSON), nullable=True, default={} + ) created_at = Column(db.DateTime(True), default=db.func.now()) - __tablename__ = 'events' + __tablename__ = "events" def __str__(self): - return "%s,%s,%s,%s" % (self.user_id, self.action, self.object_type, self.object_id) + return "%s,%s,%s,%s" % ( + self.user_id, + self.action, + self.object_type, + self.object_id, + ) def to_dict(self): return { - 'org_id': self.org_id, - 'user_id': self.user_id, - 'action': self.action, - 'object_type': self.object_type, - 'object_id': self.object_id, - 'additional_properties': self.additional_properties, - 'created_at': self.created_at.isoformat() + "org_id": self.org_id, + "user_id": self.user_id, + "action": self.action, + "object_type": self.object_type, + "object_id": self.object_id, + "additional_properties": self.additional_properties, + "created_at": self.created_at.isoformat(), } @classmethod def record(cls, event): - org_id = event.pop('org_id') - user_id = event.pop('user_id', None) - action = event.pop('action') - object_type = event.pop('object_type') - object_id = event.pop('object_id', None) - - created_at = datetime.datetime.utcfromtimestamp(event.pop('timestamp')) - - event = cls(org_id=org_id, user_id=user_id, action=action, - object_type=object_type, object_id=object_id, - additional_properties=event, - created_at=created_at) + org_id = event.pop("org_id") + user_id = event.pop("user_id", None) + action = event.pop("action") + object_type = event.pop("object_type") + object_id = event.pop("object_id", None) + + created_at = datetime.datetime.utcfromtimestamp(event.pop("timestamp")) + + event = cls( + org_id=org_id, + user_id=user_id, + action=action, + object_type=object_type, + object_id=object_id, + additional_properties=event, + created_at=created_at, + ) db.session.add(event) return event -@generic_repr('id', 'created_by_id', 'org_id', 'active') +@generic_repr("id", "created_by_id", "org_id", "active") class ApiKey(TimestampMixin, GFKBase, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1133,9 +1261,9 @@ class ApiKey(TimestampMixin, GFKBase, db.Model): created_by_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) created_by = db.relationship(User) - __tablename__ = 'api_keys' + __tablename__ = "api_keys" __table_args__ = ( - db.Index('api_keys_object_type_object_id', 'object_type', 'object_id'), + db.Index("api_keys_object_type_object_id", "object_type", "object_id"), ) @classmethod @@ -1147,7 +1275,7 @@ def get_by_object(cls, object): return cls.query.filter( cls.object_type == object.__class__.__tablename__, cls.object_id == object.id, - cls.active == True + cls.active == True, ).first() @classmethod @@ -1157,7 +1285,7 @@ def create_for_object(cls, object, user): return k -@generic_repr('id', 'name', 'type', 'user_id', 'org_id', 'created_at') +@generic_repr("id", "name", "type", "user_id", "org_id", "created_at") class NotificationDestination(BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1169,10 +1297,10 @@ class NotificationDestination(BelongsToOrgMixin, db.Model): options = Column(ConfigurationContainer.as_mutable(Configuration)) created_at = Column(db.DateTime(True), default=db.func.now()) - __tablename__ = 'notification_destinations' + __tablename__ = "notification_destinations" __table_args__ = ( db.Index( - 'notification_destinations_org_id_name', 'org_id', 'name', unique=True + "notification_destinations_org_id_name", "org_id", "name", unique=True ), ) @@ -1181,16 +1309,16 @@ def __str__(self): def to_dict(self, all=False): d = { - 'id': self.id, - 'name': self.name, - 'type': self.type, - 'icon': self.destination.icon() + "id": self.id, + "name": self.name, + "type": self.type, + "icon": self.destination.icon(), } if all: schema = get_configuration_schema_for_destination_type(self.type) self.options.set_schema(schema) - d['options'] = self.options.to_dict(mask_secrets=True) + d["options"] = self.options.to_dict(mask_secrets=True) return d @@ -1200,67 +1328,69 @@ def destination(self): @classmethod def all(cls, org): - notification_destinations = cls.query.filter(cls.org == org).order_by(cls.id.asc()) + notification_destinations = cls.query.filter(cls.org == org).order_by( + cls.id.asc() + ) return notification_destinations def notify(self, alert, query, user, new_state, app, host): schema = get_configuration_schema_for_destination_type(self.type) self.options.set_schema(schema) - return self.destination.notify(alert, query, user, new_state, - app, host, self.options) + return self.destination.notify( + alert, query, user, new_state, app, host, self.options + ) -@generic_repr('id', 'user_id', 'destination_id', 'alert_id') +@generic_repr("id", "user_id", "destination_id", "alert_id") class AlertSubscription(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User) - destination_id = Column(db.Integer, - db.ForeignKey("notification_destinations.id"), - nullable=True) + destination_id = Column( + db.Integer, db.ForeignKey("notification_destinations.id"), nullable=True + ) destination = db.relationship(NotificationDestination) alert_id = Column(db.Integer, db.ForeignKey("alerts.id")) alert = db.relationship(Alert, back_populates="subscriptions") - __tablename__ = 'alert_subscriptions' + __tablename__ = "alert_subscriptions" __table_args__ = ( db.Index( - 'alert_subscriptions_destination_id_alert_id', - 'destination_id', 'alert_id', unique=True + "alert_subscriptions_destination_id_alert_id", + "destination_id", + "alert_id", + unique=True, ), ) def to_dict(self): - d = { - 'id': self.id, - 'user': self.user.to_dict(), - 'alert_id': self.alert_id - } + d = {"id": self.id, "user": self.user.to_dict(), "alert_id": self.alert_id} if self.destination: - d['destination'] = self.destination.to_dict() + d["destination"] = self.destination.to_dict() return d @classmethod def all(cls, alert_id): - return AlertSubscription.query.join(User).filter(AlertSubscription.alert_id == alert_id) + return AlertSubscription.query.join(User).filter( + AlertSubscription.alert_id == alert_id + ) def notify(self, alert, query, user, new_state, app, host): if self.destination: - return self.destination.notify(alert, query, user, new_state, - app, host) + return self.destination.notify(alert, query, user, new_state, app, host) else: # User email subscription, so create an email destination object - config = {'addresses': self.user.email} - schema = get_configuration_schema_for_destination_type('email') + config = {"addresses": self.user.email} + schema = get_configuration_schema_for_destination_type("email") options = ConfigurationContainer(config, schema) - destination = get_destination('email', options) + destination = get_destination("email", options) return destination.notify(alert, query, user, new_state, app, host, options) -@generic_repr('id', 'trigger', 'user_id', 'org_id') +@generic_repr("id", "trigger", "user_id", "org_id") class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1271,7 +1401,7 @@ class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): user = db.relationship(User, backref="query_snippets") snippet = Column(db.Text) - __tablename__ = 'query_snippets' + __tablename__ = "query_snippets" @classmethod def all(cls, org): @@ -1279,22 +1409,32 @@ def all(cls, org): def to_dict(self): d = { - 'id': self.id, - 'trigger': self.trigger, - 'description': self.description, - 'snippet': self.snippet, - 'user': self.user.to_dict(), - 'updated_at': self.updated_at, - 'created_at': self.created_at + "id": self.id, + "trigger": self.trigger, + "description": self.description, + "snippet": self.snippet, + "user": self.user.to_dict(), + "updated_at": self.updated_at, + "created_at": self.created_at, } return d def init_db(): - default_org = Organization(name="Default", slug='default', settings={}) - admin_group = Group(name='admin', permissions=['admin', 'super_admin'], org=default_org, type=Group.BUILTIN_GROUP) - default_group = Group(name='default', permissions=Group.DEFAULT_PERMISSIONS, org=default_org, type=Group.BUILTIN_GROUP) + default_org = Organization(name="Default", slug="default", settings={}) + admin_group = Group( + name="admin", + permissions=["admin", "super_admin"], + org=default_org, + type=Group.BUILTIN_GROUP, + ) + default_group = Group( + name="default", + permissions=Group.DEFAULT_PERMISSIONS, + org=default_org, + type=Group.BUILTIN_GROUP, + ) db.session.add_all([default_org, admin_group, default_group]) # XXX remove after fixing User.group_ids diff --git a/redash/models/base.py b/redash/models/base.py index 9664825987..a33ba23c4a 100644 --- a/redash/models/base.py +++ b/redash/models/base.py @@ -17,14 +17,12 @@ def apply_driver_hacks(self, app, info, options): def apply_pool_defaults(self, app, options): super(RedashSQLAlchemy, self).apply_pool_defaults(app, options) if settings.SQLALCHEMY_DISABLE_POOL: - options['poolclass'] = NullPool + options["poolclass"] = NullPool # Remove options NullPool does not support: - options.pop('max_overflow', None) + options.pop("max_overflow", None) -db = RedashSQLAlchemy(session_options={ - 'expire_on_commit': False -}) +db = RedashSQLAlchemy(session_options={"expire_on_commit": False}) # Make sure the SQLAlchemy mappers are all properly configured first. # This is required by SQLAlchemy-Searchable as it adds DDL listeners # on the configuration phase of models. @@ -32,7 +30,7 @@ def apply_pool_defaults(self, app, options): # listen to a few database events to set up functions, trigger updates # and indexes for the full text search -make_searchable(options={'regconfig': 'pg_catalog.simple'}) +make_searchable(options={"regconfig": "pg_catalog.simple"}) class SearchBaseQuery(BaseQuery, SearchQueryMixin): @@ -63,6 +61,7 @@ class GFKBase(object): """ Compatibility with 'generic foreign key' approach Peewee used. """ + object_type = Column(db.String(255)) object_id = Column(db.Integer) @@ -75,8 +74,11 @@ def object(self): return self._object else: object_class = _gfk_types[self.object_type] - self._object = session.query(object_class).filter( - object_class.id == self.object_id).first() + self._object = ( + session.query(object_class) + .filter(object_class.id == self.object_id) + .first() + ) return self._object @object.setter diff --git a/redash/models/changes.py b/redash/models/changes.py index 1529c4b323..a670703514 100644 --- a/redash/models/changes.py +++ b/redash/models/changes.py @@ -5,48 +5,49 @@ from .types import PseudoJSON -@generic_repr('id', 'object_type', 'object_id', 'created_at') +@generic_repr("id", "object_type", "object_id", "created_at") class Change(GFKBase, db.Model): id = Column(db.Integer, primary_key=True) # 'object' defined in GFKBase object_version = Column(db.Integer, default=0) user_id = Column(db.Integer, db.ForeignKey("users.id")) - user = db.relationship("User", backref='changes') + user = db.relationship("User", backref="changes") change = Column(PseudoJSON) created_at = Column(db.DateTime(True), default=db.func.now()) - __tablename__ = 'changes' + __tablename__ = "changes" def to_dict(self, full=True): d = { - 'id': self.id, - 'object_id': self.object_id, - 'object_type': self.object_type, - 'change_type': self.change_type, - 'object_version': self.object_version, - 'change': self.change, - 'created_at': self.created_at + "id": self.id, + "object_id": self.object_id, + "object_type": self.object_type, + "change_type": self.change_type, + "object_version": self.object_version, + "change": self.change, + "created_at": self.created_at, } if full: - d['user'] = self.user.to_dict() + d["user"] = self.user.to_dict() else: - d['user_id'] = self.user_id + d["user_id"] = self.user_id return d @classmethod def last_change(cls, obj): - return cls.query.filter( - cls.object_id == obj.id, - cls.object_type == obj.__class__.__tablename__ - ).order_by( - cls.object_version.desc() - ).first() + return ( + cls.query.filter( + cls.object_id == obj.id, cls.object_type == obj.__class__.__tablename__ + ) + .order_by(cls.object_version.desc()) + .first() + ) class ChangeTrackingMixin(object): - skipped_fields = ('id', 'created_at', 'updated_at', 'version') + skipped_fields = ("id", "created_at", "updated_at", "version") _clean_values = None def __init__(self, *a, **kw): @@ -54,7 +55,7 @@ def __init__(self, *a, **kw): self.record_changes(self.user) def prep_cleanvalues(self): - self.__dict__['_clean_values'] = {} + self.__dict__["_clean_values"] = {} for attr in inspect(self.__class__).column_attrs: col, = attr.columns # 'query' is col name but not attr name @@ -77,10 +78,16 @@ def record_changes(self, changed_by): for attr in inspect(self.__class__).column_attrs: col, = attr.columns if attr.key not in self.skipped_fields: - changes[col.name] = {'previous': self._clean_values[col.name], - 'current': getattr(self, attr.key)} - - db.session.add(Change(object=self, - object_version=self.version, - user=changed_by, - change=changes)) + changes[col.name] = { + "previous": self._clean_values[col.name], + "current": getattr(self, attr.key), + } + + db.session.add( + Change( + object=self, + object_version=self.version, + user=changed_by, + change=changes, + ) + ) diff --git a/redash/models/mixins.py b/redash/models/mixins.py index a19f4717aa..9116fe46de 100644 --- a/redash/models/mixins.py +++ b/redash/models/mixins.py @@ -8,10 +8,10 @@ class TimestampMixin(object): created_at = Column(db.DateTime(True), default=db.func.now(), nullable=False) -@listens_for(TimestampMixin, 'before_update', propagate=True) +@listens_for(TimestampMixin, "before_update", propagate=True) def timestamp_before_update(mapper, connection, target): # Check if we really want to update the updated_at value - if hasattr(target, 'skip_updated_at'): + if hasattr(target, "skip_updated_at"): return target.updated_at = db.func.now() diff --git a/redash/models/organizations.py b/redash/models/organizations.py index 88799d767c..66c07007d4 100644 --- a/redash/models/organizations.py +++ b/redash/models/organizations.py @@ -9,9 +9,9 @@ from .users import User, Group -@generic_repr('id', 'name', 'slug') +@generic_repr("id", "name", "slug") class Organization(TimestampMixin, db.Model): - SETTING_GOOGLE_APPS_DOMAINS = 'google_apps_domains' + SETTING_GOOGLE_APPS_DOMAINS = "google_apps_domains" SETTING_IS_PUBLIC = "is_public" id = Column(db.Integer, primary_key=True) @@ -19,12 +19,12 @@ class Organization(TimestampMixin, db.Model): slug = Column(db.String(255), unique=True) settings = Column(MutableDict.as_mutable(PseudoJSON)) groups = db.relationship("Group", lazy="dynamic") - events = db.relationship("Event", lazy="dynamic", order_by="desc(Event.created_at)",) + events = db.relationship("Event", lazy="dynamic", order_by="desc(Event.created_at)") - __tablename__ = 'organizations' + __tablename__ = "organizations" def __str__(self): - return '%s (%s)' % (self.name, self.id) + return "%s (%s)" % (self.name, self.id) @classmethod def get_by_slug(cls, slug): @@ -36,7 +36,9 @@ def get_by_id(cls, _id): @property def default_group(self): - return self.groups.filter(Group.name == 'default', Group.type == Group.BUILTIN_GROUP).first() + return self.groups.filter( + Group.name == "default", Group.type == Group.BUILTIN_GROUP + ).first() @property def google_apps_domains(self): @@ -48,25 +50,25 @@ def is_public(self): @property def is_disabled(self): - return self.settings.get('is_disabled', False) + return self.settings.get("is_disabled", False) def disable(self): - self.settings['is_disabled'] = True + self.settings["is_disabled"] = True def enable(self): - self.settings['is_disabled'] = False + self.settings["is_disabled"] = False def set_setting(self, key, value): if key not in org_settings: raise KeyError(key) - self.settings.setdefault('settings', {}) - self.settings['settings'][key] = value - flag_modified(self, 'settings') + self.settings.setdefault("settings", {}) + self.settings["settings"][key] = value + flag_modified(self, "settings") def get_setting(self, key, raise_on_missing=True): - if key in self.settings.get('settings', {}): - return self.settings['settings'][key] + if key in self.settings.get("settings", {}): + return self.settings["settings"][key] if key in org_settings: return org_settings[key] @@ -78,7 +80,9 @@ def get_setting(self, key, raise_on_missing=True): @property def admin_group(self): - return self.groups.filter(Group.name == 'admin', Group.type == Group.BUILTIN_GROUP).first() + return self.groups.filter( + Group.name == "admin", Group.type == Group.BUILTIN_GROUP + ).first() def has_user(self, email): return self.users.filter(User.email == email).count() == 1 diff --git a/redash/models/parameterized_query.py b/redash/models/parameterized_query.py index 81ddde18c6..85225e0414 100644 --- a/redash/models/parameterized_query.py +++ b/redash/models/parameterized_query.py @@ -23,7 +23,9 @@ def _load_result(query_id, org): query = models.Query.get_by_id_and_org(query_id, org) if query.data_source: - query_result = models.QueryResult.get_by_id_and_org(query.latest_query_data_id, org) + query_result = models.QueryResult.get_by_id_and_org( + query.latest_query_data_id, org + ) return query_result.data else: raise QueryDetachedFromDataSourceError(query_id) @@ -40,12 +42,16 @@ def join_parameter_list_values(parameters, schema): updated_parameters = {} for (key, value) in parameters.items(): if isinstance(value, list): - definition = next((definition for definition in schema if definition["name"] == key), {}) - multi_values_options = definition.get('multiValuesOptions', {}) - separator = str(multi_values_options.get('separator', ',')) - prefix = str(multi_values_options.get('prefix', '')) - suffix = str(multi_values_options.get('suffix', '')) - updated_parameters[key] = separator.join([prefix + v + suffix for v in value]) + definition = next( + (definition for definition in schema if definition["name"] == key), {} + ) + multi_values_options = definition.get("multiValuesOptions", {}) + separator = str(multi_values_options.get("separator", ",")) + prefix = str(multi_values_options.get("prefix", "")) + suffix = str(multi_values_options.get("suffix", "")) + updated_parameters[key] = separator.join( + [prefix + v + suffix for v in value] + ) else: updated_parameters[key] = value return updated_parameters @@ -74,7 +80,7 @@ def _parameter_names(parameter_values): for key, value in parameter_values.items(): if isinstance(value, dict): for inner_key in value.keys(): - names.append('{}.{}'.format(key, inner_key)) + names.append("{}.{}".format(key, inner_key)) else: names.append(key) @@ -122,12 +128,16 @@ def __init__(self, template, schema=None, org=None): self.parameters = {} def apply(self, parameters): - invalid_parameter_names = [key for (key, value) in parameters.items() if not self._valid(key, value)] + invalid_parameter_names = [ + key for (key, value) in parameters.items() if not self._valid(key, value) + ] if invalid_parameter_names: raise InvalidParameterError(invalid_parameter_names) else: self.parameters.update(parameters) - self.query = mustache_render(self.template, join_parameter_list_values(parameters, self.schema)) + self.query = mustache_render( + self.template, join_parameter_list_values(parameters, self.schema) + ) return self @@ -135,27 +145,32 @@ def _valid(self, name, value): if not self.schema: return True - definition = next((definition for definition in self.schema if definition["name"] == name), None) + definition = next( + (definition for definition in self.schema if definition["name"] == name), + None, + ) if not definition: return False - enum_options = definition.get('enumOptions') - query_id = definition.get('queryId') - allow_multiple_values = isinstance(definition.get('multiValuesOptions'), dict) + enum_options = definition.get("enumOptions") + query_id = definition.get("queryId") + allow_multiple_values = isinstance(definition.get("multiValuesOptions"), dict) if isinstance(enum_options, string_types): - enum_options = enum_options.split('\n') + enum_options = enum_options.split("\n") validators = { "text": lambda value: isinstance(value, string_types), "number": _is_number, - "enum": lambda value: _is_value_within_options(value, - enum_options, - allow_multiple_values), - "query": lambda value: _is_value_within_options(value, - [v["value"] for v in dropdown_values(query_id, self.org)], - allow_multiple_values), + "enum": lambda value: _is_value_within_options( + value, enum_options, allow_multiple_values + ), + "query": lambda value: _is_value_within_options( + value, + [v["value"] for v in dropdown_values(query_id, self.org)], + allow_multiple_values, + ), "date": _is_date, "datetime-local": _is_date, "datetime-with-seconds": _is_date, @@ -186,7 +201,9 @@ def text(self): class InvalidParameterError(Exception): def __init__(self, parameters): parameter_names = ", ".join(parameters) - message = "The following parameter values are incompatible with their definitions: {}".format(parameter_names) + message = "The following parameter values are incompatible with their definitions: {}".format( + parameter_names + ) super(InvalidParameterError, self).__init__(message) @@ -194,4 +211,5 @@ class QueryDetachedFromDataSourceError(Exception): def __init__(self, query_id): self.query_id = query_id super(QueryDetachedFromDataSourceError, self).__init__( - "This query is detached from any data source. Please select a different query.") + "This query is detached from any data source. Please select a different query." + ) diff --git a/redash/models/types.py b/redash/models/types.py index 60a5916b60..b38a3ad108 100644 --- a/redash/models/types.py +++ b/redash/models/types.py @@ -22,10 +22,14 @@ def process_result_value(self, value, dialect): class EncryptedConfiguration(EncryptedType): def process_bind_param(self, value, dialect): - return super(EncryptedConfiguration, self).process_bind_param(value.to_json(), dialect) + return super(EncryptedConfiguration, self).process_bind_param( + value.to_json(), dialect + ) def process_result_value(self, value, dialect): - return ConfigurationContainer.from_json(super(EncryptedConfiguration, self).process_result_value(value, dialect)) + return ConfigurationContainer.from_json( + super(EncryptedConfiguration, self).process_result_value(value, dialect) + ) # XXX replace PseudoJSON and MutableDict with real JSON field diff --git a/redash/models/users.py b/redash/models/users.py index 724bfeafcf..6be4882915 100644 --- a/redash/models/users.py +++ b/redash/models/users.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) -LAST_ACTIVE_KEY = 'users:last_active_at' +LAST_ACTIVE_KEY = "users:last_active_at" def sync_last_active_at(): @@ -68,46 +68,57 @@ def has_permission(self, permission): return self.has_permissions((permission,)) def has_permissions(self, permissions): - has_permissions = reduce(lambda a, b: a and b, - [permission in self.permissions for permission in permissions], - True) + has_permissions = reduce( + lambda a, b: a and b, + [permission in self.permissions for permission in permissions], + True, + ) return has_permissions -@generic_repr('id', 'name', 'email') -class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): +@generic_repr("id", "name", "email") +class User( + TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin +): id = Column(db.Integer, primary_key=True) - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship("Organization", backref=db.backref("users", lazy="dynamic")) name = Column(db.String(320)) email = Column(EmailType) - _profile_image_url = Column('profile_image_url', db.String(320), nullable=True) + _profile_image_url = Column("profile_image_url", db.String(320), nullable=True) password_hash = Column(db.String(128), nullable=True) - group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True) - api_key = Column(db.String(40), - default=lambda: generate_token(40), - unique=True) + group_ids = Column( + "groups", MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True + ) + api_key = Column(db.String(40), default=lambda: generate_token(40), unique=True) disabled_at = Column(db.DateTime(True), default=None, nullable=True) - details = Column(MutableDict.as_mutable(postgresql.JSON), nullable=True, - server_default='{}', default={}) - active_at = json_cast_property(db.DateTime(True), 'details', 'active_at', - default=None) - is_invitation_pending = json_cast_property(db.Boolean(True), 'details', 'is_invitation_pending', default=False) - is_email_verified = json_cast_property(db.Boolean(True), 'details', 'is_email_verified', default=True) - - __tablename__ = 'users' - __table_args__ = ( - db.Index('users_org_id_email', 'org_id', 'email', unique=True), + details = Column( + MutableDict.as_mutable(postgresql.JSON), + nullable=True, + server_default="{}", + default={}, + ) + active_at = json_cast_property( + db.DateTime(True), "details", "active_at", default=None + ) + is_invitation_pending = json_cast_property( + db.Boolean(True), "details", "is_invitation_pending", default=False + ) + is_email_verified = json_cast_property( + db.Boolean(True), "details", "is_email_verified", default=True ) + __tablename__ = "users" + __table_args__ = (db.Index("users_org_id_email", "org_id", "email", unique=True),) + def __str__(self): - return '%s (%s)' % (self.name, self.email) + return "%s (%s)" % (self.name, self.email) def __init__(self, *args, **kwargs): - if kwargs.get('email') is not None: - kwargs['email'] = kwargs['email'].lower() + if kwargs.get("email") is not None: + kwargs["email"] = kwargs["email"].lower() super(User, self).__init__(*args, **kwargs) @property @@ -126,32 +137,32 @@ def regenerate_api_key(self): def to_dict(self, with_api_key=False): profile_image_url = self.profile_image_url if self.is_disabled: - assets = app.extensions['webpack']['assets'] or {} - path = 'images/avatar.svg' - profile_image_url = url_for('static', filename=assets.get(path, path)) + assets = app.extensions["webpack"]["assets"] or {} + path = "images/avatar.svg" + profile_image_url = url_for("static", filename=assets.get(path, path)) d = { - 'id': self.id, - 'name': self.name, - 'email': self.email, - 'profile_image_url': profile_image_url, - 'groups': self.group_ids, - 'updated_at': self.updated_at, - 'created_at': self.created_at, - 'disabled_at': self.disabled_at, - 'is_disabled': self.is_disabled, - 'active_at': self.active_at, - 'is_invitation_pending': self.is_invitation_pending, - 'is_email_verified': self.is_email_verified, + "id": self.id, + "name": self.name, + "email": self.email, + "profile_image_url": profile_image_url, + "groups": self.group_ids, + "updated_at": self.updated_at, + "created_at": self.created_at, + "disabled_at": self.disabled_at, + "is_disabled": self.is_disabled, + "active_at": self.active_at, + "is_invitation_pending": self.is_invitation_pending, + "is_email_verified": self.is_email_verified, } if self.password_hash is None: - d['auth_type'] = 'external' + d["auth_type"] = "external" else: - d['auth_type'] = 'password' + d["auth_type"] = "password" if with_api_key: - d['api_key'] = self.api_key + d["api_key"] = self.api_key return d @@ -169,8 +180,14 @@ def profile_image_url(self): @property def permissions(self): # TODO: this should be cached. - return list(itertools.chain(*[g.permissions for g in - Group.query.filter(Group.id.in_(self.group_ids))])) + return list( + itertools.chain( + *[ + g.permissions + for g in Group.query.filter(Group.id.in_(self.group_ids)) + ] + ) + ) @classmethod def get_by_org(cls, org): @@ -198,7 +215,7 @@ def all_disabled(cls, org): @classmethod def search(cls, base_query, term): - term = '%{}%'.format(term) + term = "%{}%".format(term) search_filter = or_(cls.name.ilike(term), cls.email.like(term)) return base_query.filter(search_filter) @@ -208,7 +225,9 @@ def pending(cls, base_query, pending): if pending: return base_query.filter(cls.is_invitation_pending.is_(True)) else: - return base_query.filter(cls.is_invitation_pending.isnot(True)) # check for both `false`/`null` + return base_query.filter( + cls.is_invitation_pending.isnot(True) + ) # check for both `false`/`null` @classmethod def find_by_email(cls, email): @@ -237,38 +256,49 @@ def get_id(self): return "{0}-{1}".format(self.id, identity) -@generic_repr('id', 'name', 'type', 'org_id') +@generic_repr("id", "name", "type", "org_id") class Group(db.Model, BelongsToOrgMixin): - DEFAULT_PERMISSIONS = ['create_dashboard', 'create_query', 'edit_dashboard', 'edit_query', - 'view_query', 'view_source', 'execute_query', 'list_users', 'schedule_query', - 'list_dashboards', 'list_alerts', 'list_data_sources'] - - BUILTIN_GROUP = 'builtin' - REGULAR_GROUP = 'regular' + DEFAULT_PERMISSIONS = [ + "create_dashboard", + "create_query", + "edit_dashboard", + "edit_query", + "view_query", + "view_source", + "execute_query", + "list_users", + "schedule_query", + "list_dashboards", + "list_alerts", + "list_data_sources", + ] + + BUILTIN_GROUP = "builtin" + REGULAR_GROUP = "regular" id = Column(db.Integer, primary_key=True) - data_sources = db.relationship("DataSourceGroup", back_populates="group", - cascade="all") - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + data_sources = db.relationship( + "DataSourceGroup", back_populates="group", cascade="all" + ) + org_id = Column(db.Integer, db.ForeignKey("organizations.id")) org = db.relationship("Organization", back_populates="groups") type = Column(db.String(255), default=REGULAR_GROUP) name = Column(db.String(100)) - permissions = Column(postgresql.ARRAY(db.String(255)), - default=DEFAULT_PERMISSIONS) + permissions = Column(postgresql.ARRAY(db.String(255)), default=DEFAULT_PERMISSIONS) created_at = Column(db.DateTime(True), default=db.func.now()) - __tablename__ = 'groups' + __tablename__ = "groups" def __str__(self): return text_type(self.id) def to_dict(self): return { - 'id': self.id, - 'name': self.name, - 'permissions': self.permissions, - 'type': self.type, - 'created_at': self.created_at + "id": self.id, + "name": self.name, + "permissions": self.permissions, + "type": self.type, + "created_at": self.created_at, } @classmethod @@ -285,32 +315,38 @@ def find_by_name(cls, org, group_names): return list(result) -@generic_repr('id', 'object_type', 'object_id', 'access_type', 'grantor_id', 'grantee_id') +@generic_repr( + "id", "object_type", "object_id", "access_type", "grantor_id", "grantee_id" +) class AccessPermission(GFKBase, db.Model): id = Column(db.Integer, primary_key=True) # 'object' defined in GFKBase access_type = Column(db.String(255)) grantor_id = Column(db.Integer, db.ForeignKey("users.id")) - grantor = db.relationship(User, backref='grantor', foreign_keys=[grantor_id]) + grantor = db.relationship(User, backref="grantor", foreign_keys=[grantor_id]) grantee_id = Column(db.Integer, db.ForeignKey("users.id")) - grantee = db.relationship(User, backref='grantee', foreign_keys=[grantee_id]) + grantee = db.relationship(User, backref="grantee", foreign_keys=[grantee_id]) - __tablename__ = 'access_permissions' + __tablename__ = "access_permissions" @classmethod def grant(cls, obj, access_type, grantee, grantor): - grant = cls.query.filter(cls.object_type == obj.__tablename__, - cls.object_id == obj.id, - cls.access_type == access_type, - cls.grantee == grantee, - cls.grantor == grantor).one_or_none() + grant = cls.query.filter( + cls.object_type == obj.__tablename__, + cls.object_id == obj.id, + cls.access_type == access_type, + cls.grantee == grantee, + cls.grantor == grantor, + ).one_or_none() if not grant: - grant = cls(object_type=obj.__tablename__, - object_id=obj.id, - access_type=access_type, - grantee=grantee, - grantor=grantor) + grant = cls( + object_type=obj.__tablename__, + object_id=obj.id, + access_type=access_type, + grantee=grantee, + grantor=grantor, + ) db.session.add(grant) return grant @@ -330,8 +366,9 @@ def exists(cls, obj, access_type, grantee): @classmethod def _query(cls, obj, access_type=None, grantee=None, grantor=None): - q = cls.query.filter(cls.object_id == obj.id, - cls.object_type == obj.__tablename__) + q = cls.query.filter( + cls.object_id == obj.id, cls.object_type == obj.__tablename__ + ) if access_type: q = q.filter(AccessPermission.access_type == access_type) @@ -346,12 +383,12 @@ def _query(cls, obj, access_type=None, grantee=None, grantor=None): def to_dict(self): d = { - 'id': self.id, - 'object_id': self.object_id, - 'object_type': self.object_type, - 'access_type': self.access_type, - 'grantor': self.grantor_id, - 'grantee': self.grantee_id + "id": self.id, + "object_id": self.object_id, + "object_type": self.object_type, + "access_type": self.access_type, + "grantor": self.grantor_id, + "grantee": self.grantee_id, } return d @@ -392,7 +429,7 @@ def org_id(self): @property def permissions(self): - return ['view_query'] + return ["view_query"] def has_access(self, obj, access_type): return False diff --git a/redash/monitor.py b/redash/monitor.py index cd9471dff0..789903d724 100644 --- a/redash/monitor.py +++ b/redash/monitor.py @@ -12,17 +12,20 @@ def get_redis_status(): info = redis_connection.info() - return {'redis_used_memory': info['used_memory'], 'redis_used_memory_human': info['used_memory_human']} + return { + "redis_used_memory": info["used_memory"], + "redis_used_memory_human": info["used_memory_human"], + } def get_object_counts(): status = {} - status['queries_count'] = Query.query.count() + status["queries_count"] = Query.query.count() if settings.FEATURE_SHOW_QUERY_RESULTS_COUNT: - status['query_results_count'] = QueryResult.query.count() - status['unused_query_results_count'] = QueryResult.unused().count() - status['dashboards_count'] = Dashboard.query.count() - status['widgets_count'] = Widget.query.count() + status["query_results_count"] = QueryResult.query.count() + status["unused_query_results_count"] = QueryResult.unused().count() + status["dashboards_count"] = Dashboard.query.count() + status["widgets_count"] = Widget.query.count() return status @@ -31,19 +34,30 @@ def get_celery_queues(): scheduled_queue_names = db.session.query(DataSource.scheduled_queue_name).distinct() query = db.session.execute(union_all(queue_names, scheduled_queue_names)) - return ['celery'] + [row[0] for row in query] + return ["celery"] + [row[0] for row in query] def get_queues_status(): - return {**{queue: {'size': redis_connection.llen(queue)} for queue in get_celery_queues()}, - **{queue.name: {'size': len(queue)} for queue in Queue.all(connection=rq_redis_connection)}} + return { + **{ + queue: {"size": redis_connection.llen(queue)} + for queue in get_celery_queues() + }, + **{ + queue.name: {"size": len(queue)} + for queue in Queue.all(connection=rq_redis_connection) + }, + } def get_db_sizes(): database_metrics = [] queries = [ - ['Query Results Size', "select pg_total_relation_size('query_results') as size from (select 1) as a"], - ['Redash DB Size', "select pg_database_size('postgres') as size"] + [ + "Query Results Size", + "select pg_total_relation_size('query_results') as size from (select 1) as a", + ], + ["Redash DB Size", "select pg_database_size('postgres') as size"], ] for query_name, query in queries: result = db.session.execute(query).first() @@ -53,16 +67,13 @@ def get_db_sizes(): def get_status(): - status = { - 'version': __version__, - 'workers': [] - } + status = {"version": __version__, "workers": []} status.update(get_redis_status()) status.update(get_object_counts()) - status['manager'] = redis_connection.hgetall('redash:status') - status['manager']['queues'] = get_queues_status() - status['database_metrics'] = {} - status['database_metrics']['metrics'] = get_db_sizes() + status["manager"] = redis_connection.hgetall("redash:status") + status["manager"]["queues"] = get_queues_status() + status["database_metrics"] = {} + status["database_metrics"]["metrics"] = get_db_sizes() return status @@ -72,20 +83,20 @@ def get_waiting_in_queue(queue_name): for raw in redis_connection.lrange(queue_name, 0, -1): job = json_loads(raw) try: - args = json_loads(job['headers']['argsrepr']) - if args.get('query_id') == 'adhoc': - args['query_id'] = None + args = json_loads(job["headers"]["argsrepr"]) + if args.get("query_id") == "adhoc": + args["query_id"] = None except ValueError: args = {} job_row = { - 'state': 'waiting_in_queue', - 'task_name': job['headers']['task'], - 'worker': None, - 'worker_pid': None, - 'start_time': None, - 'task_id': job['headers']['id'], - 'queue': job['properties']['delivery_info']['routing_key'] + "state": "waiting_in_queue", + "task_name": job["headers"]["task"], + "worker": None, + "worker_pid": None, + "start_time": None, + "task_id": job["headers"]["id"], + "queue": job["properties"]["delivery_info"]["routing_key"], } job_row.update(args) @@ -99,23 +110,23 @@ def parse_tasks(task_lists, state): for task in itertools.chain(*task_lists.values()): task_row = { - 'state': state, - 'task_name': task['name'], - 'worker': task['hostname'], - 'queue': task['delivery_info']['routing_key'], - 'task_id': task['id'], - 'worker_pid': task['worker_pid'], - 'start_time': task['time_start'], + "state": state, + "task_name": task["name"], + "worker": task["hostname"], + "queue": task["delivery_info"]["routing_key"], + "task_id": task["id"], + "worker_pid": task["worker_pid"], + "start_time": task["time_start"], } - if task['name'] == 'redash.tasks.execute_query': + if task["name"] == "redash.tasks.execute_query": try: - args = json_loads(task['args']) + args = json_loads(task["args"]) except ValueError: args = {} - if args.get('query_id') == 'adhoc': - args['query_id'] = None + if args.get("query_id") == "adhoc": + args["query_id"] = None task_row.update(args) @@ -125,8 +136,8 @@ def parse_tasks(task_lists, state): def celery_tasks(): - tasks = parse_tasks(celery.control.inspect().active(), 'active') - tasks += parse_tasks(celery.control.inspect().reserved(), 'reserved') + tasks = parse_tasks(celery.control.inspect().active(), "active") + tasks += parse_tasks(celery.control.inspect().reserved(), "reserved") for queue_name in get_celery_queues(): tasks += get_waiting_in_queue(queue_name) @@ -135,46 +146,52 @@ def celery_tasks(): def fetch_jobs(queue, job_ids): - return [{ - 'id': job.id, - 'name': job.func_name, - 'queue': queue.name, - 'enqueued_at': job.enqueued_at, - 'started_at': job.started_at - } for job in Job.fetch_many(job_ids, connection=rq_redis_connection) if job is not None] + return [ + { + "id": job.id, + "name": job.func_name, + "queue": queue.name, + "enqueued_at": job.enqueued_at, + "started_at": job.started_at, + } + for job in Job.fetch_many(job_ids, connection=rq_redis_connection) + if job is not None + ] def rq_queues(): return { q.name: { - 'name': q.name, - 'started': fetch_jobs(q, StartedJobRegistry(queue=q).get_job_ids()), - 'queued': len(q.job_ids) - } for q in Queue.all(connection=rq_redis_connection)} + "name": q.name, + "started": fetch_jobs(q, StartedJobRegistry(queue=q).get_job_ids()), + "queued": len(q.job_ids), + } + for q in Queue.all(connection=rq_redis_connection) + } def describe_job(job): - return '{} ({})'.format(job.id, job.func_name.split(".").pop()) if job else None + return "{} ({})".format(job.id, job.func_name.split(".").pop()) if job else None def rq_workers(): - return [{ - 'name': w.name, - 'hostname': w.hostname, - 'pid': w.pid, - 'queues': ", ".join([q.name for q in w.queues]), - 'state': w.state, - 'last_heartbeat': w.last_heartbeat, - 'birth_date': w.birth_date, - 'current_job': describe_job(w.get_current_job()), - 'successful_jobs': w.successful_job_count, - 'failed_jobs': w.failed_job_count, - 'total_working_time': w.total_working_time - } for w in Worker.all(connection=rq_redis_connection)] + return [ + { + "name": w.name, + "hostname": w.hostname, + "pid": w.pid, + "queues": ", ".join([q.name for q in w.queues]), + "state": w.state, + "last_heartbeat": w.last_heartbeat, + "birth_date": w.birth_date, + "current_job": describe_job(w.get_current_job()), + "successful_jobs": w.successful_job_count, + "failed_jobs": w.failed_job_count, + "total_working_time": w.total_working_time, + } + for w in Worker.all(connection=rq_redis_connection) + ] def rq_status(): - return { - 'queues': rq_queues(), - 'workers': rq_workers() - } + return {"queues": rq_queues(), "workers": rq_workers()} diff --git a/redash/permissions.py b/redash/permissions.py index d928d918c9..eefc5132d1 100644 --- a/redash/permissions.py +++ b/redash/permissions.py @@ -7,15 +7,15 @@ view_only = True not_view_only = False -ACCESS_TYPE_VIEW = 'view' -ACCESS_TYPE_MODIFY = 'modify' -ACCESS_TYPE_DELETE = 'delete' +ACCESS_TYPE_VIEW = "view" +ACCESS_TYPE_MODIFY = "modify" +ACCESS_TYPE_DELETE = "delete" ACCESS_TYPES = (ACCESS_TYPE_VIEW, ACCESS_TYPE_MODIFY, ACCESS_TYPE_DELETE) def has_access(obj, user, need_view_only): - if hasattr(obj, 'api_key') and user.is_api_user(): + if hasattr(obj, "api_key") and user.is_api_user(): return has_access_to_object(obj, user.id, need_view_only) else: return has_access_to_groups(obj, user, need_view_only) @@ -24,7 +24,7 @@ def has_access(obj, user, need_view_only): def has_access_to_object(obj, api_key, need_view_only): if obj.api_key == api_key: return need_view_only - elif hasattr(obj, 'dashboard_api_keys'): + elif hasattr(obj, "dashboard_api_keys"): # check if api_key belongs to a dashboard containing this query return api_key in obj.dashboard_api_keys and need_view_only else: @@ -32,9 +32,9 @@ def has_access_to_object(obj, api_key, need_view_only): def has_access_to_groups(obj, user, need_view_only): - groups = obj.groups if hasattr(obj, 'groups') else obj + groups = obj.groups if hasattr(obj, "groups") else obj - if 'admin' in user.permissions: + if "admin" in user.permissions: return True matching_groups = set(groups.keys()).intersection(user.group_ids) @@ -76,19 +76,21 @@ def require_permission(permission): def require_admin(fn): - return require_permission('admin')(fn) + return require_permission("admin")(fn) def require_super_admin(fn): - return require_permission('super_admin')(fn) + return require_permission("super_admin")(fn) def has_permission_or_owner(permission, object_owner_id): - return int(object_owner_id) == current_user.id or current_user.has_permission(permission) + return int(object_owner_id) == current_user.id or current_user.has_permission( + permission + ) def is_admin_or_owner(object_owner_id): - return has_permission_or_owner('admin', object_owner_id) + return has_permission_or_owner("admin", object_owner_id) def require_permission_or_owner(permission, object_owner_id): diff --git a/redash/query_runner/__init__.py b/redash/query_runner/__init__.py index a3d08145d0..061b325144 100644 --- a/redash/query_runner/__init__.py +++ b/redash/query_runner/__init__.py @@ -11,39 +11,34 @@ logger = logging.getLogger(__name__) __all__ = [ - 'BaseQueryRunner', - 'BaseHTTPQueryRunner', - 'InterruptException', - 'BaseSQLQueryRunner', - 'TYPE_DATETIME', - 'TYPE_BOOLEAN', - 'TYPE_INTEGER', - 'TYPE_STRING', - 'TYPE_DATE', - 'TYPE_FLOAT', - 'SUPPORTED_COLUMN_TYPES', - 'register', - 'get_query_runner', - 'import_query_runners', - 'guess_type' + "BaseQueryRunner", + "BaseHTTPQueryRunner", + "InterruptException", + "BaseSQLQueryRunner", + "TYPE_DATETIME", + "TYPE_BOOLEAN", + "TYPE_INTEGER", + "TYPE_STRING", + "TYPE_DATE", + "TYPE_FLOAT", + "SUPPORTED_COLUMN_TYPES", + "register", + "get_query_runner", + "import_query_runners", + "guess_type", ] # Valid types of columns returned in results: -TYPE_INTEGER = 'integer' -TYPE_FLOAT = 'float' -TYPE_BOOLEAN = 'boolean' -TYPE_STRING = 'string' -TYPE_DATETIME = 'datetime' -TYPE_DATE = 'date' - -SUPPORTED_COLUMN_TYPES = set([ - TYPE_INTEGER, - TYPE_FLOAT, - TYPE_BOOLEAN, - TYPE_STRING, - TYPE_DATETIME, - TYPE_DATE -]) +TYPE_INTEGER = "integer" +TYPE_FLOAT = "float" +TYPE_BOOLEAN = "boolean" +TYPE_STRING = "string" +TYPE_DATETIME = "datetime" +TYPE_DATE = "date" + +SUPPORTED_COLUMN_TYPES = set( + [TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN, TYPE_STRING, TYPE_DATETIME, TYPE_DATE] +) class InterruptException(Exception): @@ -60,7 +55,7 @@ class BaseQueryRunner(object): noop_query = None def __init__(self, configuration): - self.syntax = 'sql' + self.syntax = "sql" self.configuration = configuration @classmethod @@ -110,9 +105,9 @@ def fetch_columns(self, columns): duplicates_counter += 1 column_names.append(column_name) - new_columns.append({'name': column_name, - 'friendly_name': column_name, - 'type': col[1]}) + new_columns.append( + {"name": column_name, "friendly_name": column_name, "type": col[1]} + ) return new_columns @@ -124,19 +119,18 @@ def _run_query_internal(self, query): if error is not None: raise Exception("Failed running query [%s]." % query) - return json_loads(results)['rows'] + return json_loads(results)["rows"] @classmethod def to_dict(cls): return { - 'name': cls.name(), - 'type': cls.type(), - 'configuration_schema': cls.configuration_schema() + "name": cls.name(), + "type": cls.type(), + "configuration_schema": cls.configuration_schema(), } class BaseSQLQueryRunner(BaseQueryRunner): - def get_schema(self, get_stats=False): schema_dict = {} self._get_tables(schema_dict) @@ -150,8 +144,8 @@ def _get_tables(self, schema_dict): def _get_tables_stats(self, tables_dict): for t in tables_dict.keys(): if type(tables_dict[t]) == dict: - res = self._run_query_internal('select count(*) as cnt from %s' % t) - tables_dict[t]['size'] = res[0]['cnt'] + res = self._run_query_internal("select count(*) as cnt from %s" % t) + tables_dict[t]["size"] = res[0]["cnt"] class BaseHTTPQueryRunner(BaseQueryRunner): @@ -159,45 +153,36 @@ class BaseHTTPQueryRunner(BaseQueryRunner): response_error = "Endpoint returned unexpected status code" requires_authentication = False requires_url = True - url_title = 'URL base path' - username_title = 'HTTP Basic Auth Username' - password_title = 'HTTP Basic Auth Password' + url_title = "URL base path" + username_title = "HTTP Basic Auth Username" + password_title = "HTTP Basic Auth Password" @classmethod def configuration_schema(cls): schema = { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'title': cls.url_title, - }, - 'username': { - 'type': 'string', - 'title': cls.username_title, - }, - 'password': { - 'type': 'string', - 'title': cls.password_title, - }, + "type": "object", + "properties": { + "url": {"type": "string", "title": cls.url_title}, + "username": {"type": "string", "title": cls.username_title}, + "password": {"type": "string", "title": cls.password_title}, }, - 'secret': ['password'], - 'order': ['url', 'username', 'password'] + "secret": ["password"], + "order": ["url", "username", "password"], } if cls.requires_url or cls.requires_authentication: - schema['required'] = [] + schema["required"] = [] if cls.requires_url: - schema['required'] += ['url'] + schema["required"] += ["url"] if cls.requires_authentication: - schema['required'] += ['username', 'password'] + schema["required"] += ["username", "password"] return schema def get_auth(self): - username = self.configuration.get('username') - password = self.configuration.get('password') + username = self.configuration.get("username") + password = self.configuration.get("password") if username and password: return (username, password) if self.requires_authentication: @@ -205,7 +190,7 @@ def get_auth(self): else: return None - def get_response(self, url, auth=None, http_method='get', **kwargs): + def get_response(self, url, auth=None, http_method="get", **kwargs): # Get authentication values if not given if auth is None: auth = self.get_auth() @@ -223,19 +208,12 @@ def get_response(self, url, auth=None, http_method='get', **kwargs): # Any other responses (e.g. 2xx and 3xx): if response.status_code != 200: - error = '{} ({}).'.format( - self.response_error, - response.status_code, - ) + error = "{} ({}).".format(self.response_error, response.status_code) except requests.HTTPError as exc: logger.exception(exc) - error = ( - "Failed to execute query. " - "Return Code: {} Reason: {}".format( - response.status_code, - response.text - ) + error = "Failed to execute query. " "Return Code: {} Reason: {}".format( + response.status_code, response.text ) except requests.RequestException as exc: # Catch all other requests exceptions and return the error. @@ -252,11 +230,18 @@ def get_response(self, url, auth=None, http_method='get', **kwargs): def register(query_runner_class): global query_runners if query_runner_class.enabled(): - logger.debug("Registering %s (%s) query runner.", query_runner_class.name(), query_runner_class.type()) + logger.debug( + "Registering %s (%s) query runner.", + query_runner_class.name(), + query_runner_class.type(), + ) query_runners[query_runner_class.type()] = query_runner_class else: - logger.debug("%s query runner enabled but not supported, not registering. Either disable or install missing " - "dependencies.", query_runner_class.name()) + logger.debug( + "%s query runner enabled but not supported, not registering. Either disable or install missing " + "dependencies.", + query_runner_class.name(), + ) def get_query_runner(query_runner_type, configuration): @@ -292,7 +277,7 @@ def guess_type(value): def guess_type_from_string(string_value): - if string_value == '' or string_value is None: + if string_value == "" or string_value is None: return TYPE_STRING try: @@ -307,7 +292,7 @@ def guess_type_from_string(string_value): except (ValueError, OverflowError): pass - if text_type(string_value).lower() in ('true', 'false'): + if text_type(string_value).lower() in ("true", "false"): return TYPE_BOOLEAN try: diff --git a/redash/query_runner/amazon_elasticsearch.py b/redash/query_runner/amazon_elasticsearch.py index 7d465de863..cf81969874 100644 --- a/redash/query_runner/amazon_elasticsearch.py +++ b/redash/query_runner/amazon_elasticsearch.py @@ -4,6 +4,7 @@ try: from requests_aws_sign import AWSV4Sign from botocore import session, credentials + enabled = True except ImportError: enabled = False @@ -25,45 +26,42 @@ def type(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'server': { - 'type': 'string', - 'title': 'Endpoint' - }, - 'region': { - 'type': 'string', - }, - 'access_key': { - 'type': 'string', - 'title': 'Access Key' - }, - 'secret_key': { - 'type': 'string', - 'title': 'Secret Key' + "type": "object", + "properties": { + "server": {"type": "string", "title": "Endpoint"}, + "region": {"type": "string"}, + "access_key": {"type": "string", "title": "Access Key"}, + "secret_key": {"type": "string", "title": "Secret Key"}, + "use_aws_iam_profile": { + "type": "boolean", + "title": "Use AWS IAM Profile", }, - 'use_aws_iam_profile': { - 'type': 'boolean', - 'title': 'Use AWS IAM Profile' - } }, "secret": ["secret_key"], - "order": ["server", "region", "access_key", "secret_key", "use_aws_iam_profile"], - "required": ['server', 'region'] + "order": [ + "server", + "region", + "access_key", + "secret_key", + "use_aws_iam_profile", + ], + "required": ["server", "region"], } def __init__(self, configuration): super(AmazonElasticsearchService, self).__init__(configuration) - region = configuration['region'] + region = configuration["region"] cred = None - if configuration.get('use_aws_iam_profile', False): + if configuration.get("use_aws_iam_profile", False): cred = credentials.get_credentials(session.Session()) else: - cred = credentials.Credentials(access_key=configuration.get('access_key', ''), - secret_key=configuration.get('secret_key', '')) + cred = credentials.Credentials( + access_key=configuration.get("access_key", ""), + secret_key=configuration.get("secret_key", ""), + ) - self.auth = AWSV4Sign(cred, region, 'es') + self.auth = AWSV4Sign(cred, region, "es") register(AmazonElasticsearchService) diff --git a/redash/query_runner/athena.py b/redash/query_runner/athena.py index 52f32d3f33..65b2c08c10 100644 --- a/redash/query_runner/athena.py +++ b/redash/query_runner/athena.py @@ -6,34 +6,39 @@ from redash.utils import json_dumps, json_loads logger = logging.getLogger(__name__) -ANNOTATE_QUERY = parse_boolean(os.environ.get('ATHENA_ANNOTATE_QUERY', 'true')) -SHOW_EXTRA_SETTINGS = parse_boolean(os.environ.get('ATHENA_SHOW_EXTRA_SETTINGS', 'true')) -ASSUME_ROLE = parse_boolean(os.environ.get('ATHENA_ASSUME_ROLE', 'false')) -OPTIONAL_CREDENTIALS = parse_boolean(os.environ.get('ATHENA_OPTIONAL_CREDENTIALS', 'true')) +ANNOTATE_QUERY = parse_boolean(os.environ.get("ATHENA_ANNOTATE_QUERY", "true")) +SHOW_EXTRA_SETTINGS = parse_boolean( + os.environ.get("ATHENA_SHOW_EXTRA_SETTINGS", "true") +) +ASSUME_ROLE = parse_boolean(os.environ.get("ATHENA_ASSUME_ROLE", "false")) +OPTIONAL_CREDENTIALS = parse_boolean( + os.environ.get("ATHENA_OPTIONAL_CREDENTIALS", "true") +) try: import pyathena import boto3 + enabled = True except ImportError: enabled = False _TYPE_MAPPINGS = { - 'boolean': TYPE_BOOLEAN, - 'tinyint': TYPE_INTEGER, - 'smallint': TYPE_INTEGER, - 'integer': TYPE_INTEGER, - 'bigint': TYPE_INTEGER, - 'double': TYPE_FLOAT, - 'varchar': TYPE_STRING, - 'timestamp': TYPE_DATETIME, - 'date': TYPE_DATE, - 'varbinary': TYPE_STRING, - 'array': TYPE_STRING, - 'map': TYPE_STRING, - 'row': TYPE_STRING, - 'decimal': TYPE_FLOAT, + "boolean": TYPE_BOOLEAN, + "tinyint": TYPE_INTEGER, + "smallint": TYPE_INTEGER, + "integer": TYPE_INTEGER, + "bigint": TYPE_INTEGER, + "double": TYPE_FLOAT, + "varchar": TYPE_STRING, + "timestamp": TYPE_DATETIME, + "date": TYPE_DATE, + "varbinary": TYPE_STRING, + "array": TYPE_STRING, + "map": TYPE_STRING, + "row": TYPE_STRING, + "decimal": TYPE_FLOAT, } @@ -43,7 +48,7 @@ def format(self, operation, parameters=None): class Athena(BaseQueryRunner): - noop_query = 'SELECT 1' + noop_query = "SELECT 1" @classmethod def name(cls): @@ -52,82 +57,68 @@ def name(cls): @classmethod def configuration_schema(cls): schema = { - 'type': 'object', - 'properties': { - 'region': { - 'type': 'string', - 'title': 'AWS Region' - }, - 'aws_access_key': { - 'type': 'string', - 'title': 'AWS Access Key' - }, - 'aws_secret_key': { - 'type': 'string', - 'title': 'AWS Secret Key' - }, - 's3_staging_dir': { - 'type': 'string', - 'title': 'S3 Staging (Query Results) Bucket Path' + "type": "object", + "properties": { + "region": {"type": "string", "title": "AWS Region"}, + "aws_access_key": {"type": "string", "title": "AWS Access Key"}, + "aws_secret_key": {"type": "string", "title": "AWS Secret Key"}, + "s3_staging_dir": { + "type": "string", + "title": "S3 Staging (Query Results) Bucket Path", }, - 'schema': { - 'type': 'string', - 'title': 'Schema Name', - 'default': 'default' + "schema": { + "type": "string", + "title": "Schema Name", + "default": "default", }, - 'glue': { - 'type': 'boolean', - 'title': 'Use Glue Data Catalog', - }, - 'work_group': { - 'type': 'string', - 'title': 'Athena Work Group', - 'default': 'primary' + "glue": {"type": "boolean", "title": "Use Glue Data Catalog"}, + "work_group": { + "type": "string", + "title": "Athena Work Group", + "default": "primary", }, }, - 'required': ['region', 's3_staging_dir'], - 'extra_options': ['glue'], - 'order': ['region', 's3_staging_dir', 'schema', 'work_group'], - 'secret': ['aws_secret_key'] + "required": ["region", "s3_staging_dir"], + "extra_options": ["glue"], + "order": ["region", "s3_staging_dir", "schema", "work_group"], + "secret": ["aws_secret_key"], } if SHOW_EXTRA_SETTINGS: - schema['properties'].update({ - 'encryption_option': { - 'type': 'string', - 'title': 'Encryption Option', - }, - 'kms_key': { - 'type': 'string', - 'title': 'KMS Key', - }, - }) - schema['extra_options'].append('encryption_option') - schema['extra_options'].append('kms_key') + schema["properties"].update( + { + "encryption_option": { + "type": "string", + "title": "Encryption Option", + }, + "kms_key": {"type": "string", "title": "KMS Key"}, + } + ) + schema["extra_options"].append("encryption_option") + schema["extra_options"].append("kms_key") if ASSUME_ROLE: - del schema['properties']['aws_access_key'] - del schema['properties']['aws_secret_key'] - schema['secret'] = [] - - schema['order'].insert(1, 'iam_role') - schema['order'].insert(2, 'external_id') - schema['properties'].update({ - 'iam_role': { - 'type': 'string', - 'title': 'IAM role to assume', - }, - 'external_id': { - 'type': 'string', - 'title': 'External ID to be used while STS assume role', - }, - }) + del schema["properties"]["aws_access_key"] + del schema["properties"]["aws_secret_key"] + schema["secret"] = [] + + schema["order"].insert(1, "iam_role") + schema["order"].insert(2, "external_id") + schema["properties"].update( + { + "iam_role": {"type": "string", "title": "IAM role to assume"}, + "external_id": { + "type": "string", + "title": "External ID to be used while STS assume role", + }, + } + ) else: - schema['order'].insert(1, 'aws_access_key') - schema['order'].insert(2, 'aws_secret_key') + schema["order"].insert(1, "aws_access_key") + schema["order"].insert(2, "aws_secret_key") if not OPTIONAL_CREDENTIALS and not ASSUME_ROLE: - schema['required'] += ['aws_access_key', 'aws_secret_key'] + schema["required"] += ["aws_access_key", "aws_secret_key"] return schema @@ -146,47 +137,50 @@ def type(cls): def _get_iam_credentials(self, user=None): if ASSUME_ROLE: - role_session_name = 'redash' if user is None else user.email - sts = boto3.client('sts') + role_session_name = "redash" if user is None else user.email + sts = boto3.client("sts") creds = sts.assume_role( - RoleArn=self.configuration.get('iam_role'), + RoleArn=self.configuration.get("iam_role"), RoleSessionName=role_session_name, - ExternalId=self.configuration.get('external_id') - ) + ExternalId=self.configuration.get("external_id"), + ) return { - 'aws_access_key_id': creds['Credentials']['AccessKeyId'], - 'aws_secret_access_key': creds['Credentials']['SecretAccessKey'], - 'aws_session_token': creds['Credentials']['SessionToken'], - 'region_name': self.configuration['region'] + "aws_access_key_id": creds["Credentials"]["AccessKeyId"], + "aws_secret_access_key": creds["Credentials"]["SecretAccessKey"], + "aws_session_token": creds["Credentials"]["SessionToken"], + "region_name": self.configuration["region"], } else: return { - 'aws_access_key_id': self.configuration.get('aws_access_key', None), - 'aws_secret_access_key': self.configuration.get('aws_secret_key', None), - 'region_name': self.configuration['region'] + "aws_access_key_id": self.configuration.get("aws_access_key", None), + "aws_secret_access_key": self.configuration.get("aws_secret_key", None), + "region_name": self.configuration["region"], } def __get_schema_from_glue(self): - client = boto3.client('glue', **self._get_iam_credentials()) + client = boto3.client("glue", **self._get_iam_credentials()) schema = {} - database_paginator = client.get_paginator('get_databases') - table_paginator = client.get_paginator('get_tables') + database_paginator = client.get_paginator("get_databases") + table_paginator = client.get_paginator("get_tables") for databases in database_paginator.paginate(): - for database in databases['DatabaseList']: - iterator = table_paginator.paginate(DatabaseName=database['Name']) - for table in iterator.search('TableList[]'): - table_name = '%s.%s' % (database['Name'], table['Name']) + for database in databases["DatabaseList"]: + iterator = table_paginator.paginate(DatabaseName=database["Name"]) + for table in iterator.search("TableList[]"): + table_name = "%s.%s" % (database["Name"], table["Name"]) if table_name not in schema: - column = [columns['Name'] for columns in table['StorageDescriptor']['Columns']] - schema[table_name] = {'name': table_name, 'columns': column} - for partition in table.get('PartitionKeys', []): - schema[table_name]['columns'].append(partition['Name']) + column = [ + columns["Name"] + for columns in table["StorageDescriptor"]["Columns"] + ] + schema[table_name] = {"name": table_name, "columns": column} + for partition in table.get("PartitionKeys", []): + schema[table_name]["columns"].append(partition["Name"]) return list(schema.values()) def get_schema(self, get_stats=False): - if self.configuration.get('glue', False): + if self.configuration.get("glue", False): return self.__get_schema_from_glue() schema = {} @@ -201,29 +195,35 @@ def get_schema(self, get_stats=False): raise Exception("Failed getting schema.") results = json_loads(results) - for row in results['rows']: - table_name = '{0}.{1}'.format(row['table_schema'], row['table_name']) + for row in results["rows"]: + table_name = "{0}.{1}".format(row["table_schema"], row["table_name"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name] = {"name": table_name, "columns": []} + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) def run_query(self, query, user): cursor = pyathena.connect( - s3_staging_dir=self.configuration['s3_staging_dir'], - schema_name=self.configuration.get('schema', 'default'), - encryption_option=self.configuration.get('encryption_option', None), - kms_key=self.configuration.get('kms_key', None), - work_group=self.configuration.get('work_group', 'primary'), + s3_staging_dir=self.configuration["s3_staging_dir"], + schema_name=self.configuration.get("schema", "default"), + encryption_option=self.configuration.get("encryption_option", None), + kms_key=self.configuration.get("kms_key", None), + work_group=self.configuration.get("work_group", "primary"), formatter=SimpleFormatter(), - **self._get_iam_credentials(user=user)).cursor() + **self._get_iam_credentials(user=user) + ).cursor() try: cursor.execute(query) - column_tuples = [(i[0], _TYPE_MAPPINGS.get(i[1], None)) for i in cursor.description] + column_tuples = [ + (i[0], _TYPE_MAPPINGS.get(i[1], None)) for i in cursor.description + ] columns = self.fetch_columns(column_tuples) - rows = [dict(zip(([c['name'] for c in columns]), r)) for i, r in enumerate(cursor.fetchall())] + rows = [ + dict(zip(([c["name"] for c in columns]), r)) + for i, r in enumerate(cursor.fetchall()) + ] qbytes = None athena_query_id = None try: @@ -235,12 +235,12 @@ def run_query(self, query, user): except AttributeError as e: logger.debug("Athena Upstream can't get query_id: %s", e) data = { - 'columns': columns, - 'rows': rows, - 'metadata': { - 'data_scanned': qbytes, - 'athena_query_id': athena_query_id - } + "columns": columns, + "rows": rows, + "metadata": { + "data_scanned": qbytes, + "athena_query_id": athena_query_id, + }, } json_data = json_dumps(data, ignore_nan=True) error = None diff --git a/redash/query_runner/axibase_tsd.py b/redash/query_runner/axibase_tsd.py index 24aa5f3321..7d82d3fc37 100644 --- a/redash/query_runner/axibase_tsd.py +++ b/redash/query_runner/axibase_tsd.py @@ -13,24 +13,22 @@ import atsd_client from atsd_client.exceptions import SQLException from atsd_client.services import SQLService, MetricsService + enabled = True except ImportError: enabled = False types_map = { - 'long': TYPE_INTEGER, - - 'bigint': TYPE_INTEGER, - 'integer': TYPE_INTEGER, - 'smallint': TYPE_INTEGER, - - 'float': TYPE_FLOAT, - 'double': TYPE_FLOAT, - 'decimal': TYPE_FLOAT, - - 'string': TYPE_STRING, - 'date': TYPE_DATE, - 'xsd:dateTimeStamp': TYPE_DATETIME + "long": TYPE_INTEGER, + "bigint": TYPE_INTEGER, + "integer": TYPE_INTEGER, + "smallint": TYPE_INTEGER, + "float": TYPE_FLOAT, + "double": TYPE_FLOAT, + "decimal": TYPE_FLOAT, + "string": TYPE_STRING, + "date": TYPE_DATE, + "xsd:dateTimeStamp": TYPE_DATETIME, } @@ -41,7 +39,7 @@ def resolve_redash_type(type_in_atsd): :return: redash type constant """ if isinstance(type_in_atsd, dict): - type_in_redash = types_map.get(type_in_atsd['base']) + type_in_redash = types_map.get(type_in_atsd["base"]) else: type_in_redash = types_map.get(type_in_atsd) return type_in_redash @@ -53,22 +51,26 @@ def generate_rows_and_columns(csv_response): :param csv_response: `str` :return: prepared rows and columns """ - meta, data = csv_response.split('\n', 1) + meta, data = csv_response.split("\n", 1) meta = meta[1:] - meta_with_padding = meta + '=' * (4 - len(meta) % 4) - meta_decoded = meta_with_padding.decode('base64') + meta_with_padding = meta + "=" * (4 - len(meta) % 4) + meta_decoded = meta_with_padding.decode("base64") meta_json = json_loads(meta_decoded) - meta_columns = meta_json['tableSchema']['columns'] + meta_columns = meta_json["tableSchema"]["columns"] reader = csv.reader(data.splitlines()) next(reader) - columns = [{'friendly_name': i['titles'], - 'type': resolve_redash_type(i['datatype']), - 'name': i['name']} - for i in meta_columns] - column_names = [c['name'] for c in columns] + columns = [ + { + "friendly_name": i["titles"], + "type": resolve_redash_type(i["datatype"]), + "name": i["name"], + } + for i in meta_columns + ] + column_names = [c["name"] for c in columns] rows = [dict(zip(column_names, row)) for row in reader] return columns, rows @@ -87,80 +89,66 @@ def name(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'protocol': { - 'type': 'string', - 'title': 'Protocol', - 'default': 'http' - }, - 'hostname': { - 'type': 'string', - 'title': 'Host', - 'default': 'axibase_tsd_hostname' - }, - 'port': { - 'type': 'number', - 'title': 'Port', - 'default': 8088 - }, - 'username': { - 'type': 'string' - }, - 'password': { - 'type': 'string', - 'title': 'Password' - }, - 'timeout': { - 'type': 'number', - 'default': 600, - 'title': 'Connection Timeout' + "type": "object", + "properties": { + "protocol": {"type": "string", "title": "Protocol", "default": "http"}, + "hostname": { + "type": "string", + "title": "Host", + "default": "axibase_tsd_hostname", }, - 'min_insert_date': { - 'type': 'string', - 'title': 'Metric Minimum Insert Date' + "port": {"type": "number", "title": "Port", "default": 8088}, + "username": {"type": "string"}, + "password": {"type": "string", "title": "Password"}, + "timeout": { + "type": "number", + "default": 600, + "title": "Connection Timeout", }, - 'expression': { - 'type': 'string', - 'title': 'Metric Filter' + "min_insert_date": { + "type": "string", + "title": "Metric Minimum Insert Date", }, - 'limit': { - 'type': 'number', - 'default': 5000, - 'title': 'Metric Limit' + "expression": {"type": "string", "title": "Metric Filter"}, + "limit": {"type": "number", "default": 5000, "title": "Metric Limit"}, + "trust_certificate": { + "type": "boolean", + "title": "Trust SSL Certificate", }, - 'trust_certificate': { - 'type': 'boolean', - 'title': 'Trust SSL Certificate' - } }, - 'required': ['username', 'password', 'hostname', 'protocol', 'port'], - 'secret': ['password'] + "required": ["username", "password", "hostname", "protocol", "port"], + "secret": ["password"], } def __init__(self, configuration): super(AxibaseTSD, self).__init__(configuration) - self.url = '{0}://{1}:{2}'.format(self.configuration.get('protocol', 'http'), - self.configuration.get('hostname', 'localhost'), - self.configuration.get('port', 8088)) + self.url = "{0}://{1}:{2}".format( + self.configuration.get("protocol", "http"), + self.configuration.get("hostname", "localhost"), + self.configuration.get("port", 8088), + ) def run_query(self, query, user): - connection = atsd_client.connect_url(self.url, - self.configuration.get('username'), - self.configuration.get('password'), - verify=self.configuration.get('trust_certificate', False), - timeout=self.configuration.get('timeout', 600)) + connection = atsd_client.connect_url( + self.url, + self.configuration.get("username"), + self.configuration.get("password"), + verify=self.configuration.get("trust_certificate", False), + timeout=self.configuration.get("timeout", 600), + ) sql = SQLService(connection) query_id = str(uuid.uuid4()) try: logger.debug("SQL running query: %s", query) - data = sql.query_with_params(query, {'outputFormat': 'csv', 'metadataFormat': 'EMBED', - 'queryId': query_id}) + data = sql.query_with_params( + query, + {"outputFormat": "csv", "metadataFormat": "EMBED", "queryId": query_id}, + ) columns, rows = generate_rows_and_columns(data) - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None @@ -175,23 +163,38 @@ def run_query(self, query, user): return json_data, error def get_schema(self, get_stats=False): - connection = atsd_client.connect_url(self.url, - self.configuration.get('username'), - self.configuration.get('password'), - verify=self.configuration.get('trust_certificate', False), - timeout=self.configuration.get('timeout', 600)) + connection = atsd_client.connect_url( + self.url, + self.configuration.get("username"), + self.configuration.get("password"), + verify=self.configuration.get("trust_certificate", False), + timeout=self.configuration.get("timeout", 600), + ) metrics = MetricsService(connection) - ml = metrics.list(expression=self.configuration.get('expression', None), - minInsertDate=self.configuration.get('min_insert_date', None), - limit=self.configuration.get('limit', 5000)) - metrics_list = [i.name.encode('utf-8') for i in ml] - metrics_list.append('atsd_series') + ml = metrics.list( + expression=self.configuration.get("expression", None), + minInsertDate=self.configuration.get("min_insert_date", None), + limit=self.configuration.get("limit", 5000), + ) + metrics_list = [i.name.encode("utf-8") for i in ml] + metrics_list.append("atsd_series") schema = {} - default_columns = ['entity', 'datetime', 'time', 'metric', 'value', 'text', - 'tags', 'entity.tags', 'metric.tags'] + default_columns = [ + "entity", + "datetime", + "time", + "metric", + "value", + "text", + "tags", + "entity.tags", + "metric.tags", + ] for table_name in metrics_list: - schema[table_name] = {'name': "'{}'".format(table_name), - 'columns': default_columns} + schema[table_name] = { + "name": "'{}'".format(table_name), + "columns": default_columns, + } values = list(schema.values()) return values diff --git a/redash/query_runner/azure_kusto.py b/redash/query_runner/azure_kusto.py index fa8c4a85f3..26d229f776 100644 --- a/redash/query_runner/azure_kusto.py +++ b/redash/query_runner/azure_kusto.py @@ -1,27 +1,35 @@ from redash.query_runner import BaseQueryRunner, register -from redash.query_runner import TYPE_STRING, TYPE_DATE, TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN +from redash.query_runner import ( + TYPE_STRING, + TYPE_DATE, + TYPE_DATETIME, + TYPE_INTEGER, + TYPE_FLOAT, + TYPE_BOOLEAN, +) from redash.utils import json_dumps, json_loads try: from azure.kusto.data.request import KustoClient, KustoConnectionStringBuilder from azure.kusto.data.exceptions import KustoServiceError + enabled = True except ImportError: enabled = False TYPES_MAP = { - 'boolean': TYPE_BOOLEAN, - 'datetime': TYPE_DATETIME, - 'date': TYPE_DATE, - 'dynamic': TYPE_STRING, - 'guid': TYPE_STRING, - 'int': TYPE_INTEGER, - 'long': TYPE_INTEGER, - 'real': TYPE_FLOAT, - 'string': TYPE_STRING, - 'timespan': TYPE_STRING, - 'decimal': TYPE_FLOAT + "boolean": TYPE_BOOLEAN, + "datetime": TYPE_DATETIME, + "date": TYPE_DATE, + "dynamic": TYPE_STRING, + "guid": TYPE_STRING, + "int": TYPE_INTEGER, + "long": TYPE_INTEGER, + "real": TYPE_FLOAT, + "string": TYPE_STRING, + "timespan": TYPE_STRING, + "decimal": TYPE_FLOAT, } @@ -31,41 +39,37 @@ class AzureKusto(BaseQueryRunner): def __init__(self, configuration): super(AzureKusto, self).__init__(configuration) - self.syntax = 'custom' + self.syntax = "custom" @classmethod def configuration_schema(cls): return { "type": "object", "properties": { - "cluster": { - "type": "string" - }, - "azure_ad_client_id": { - "type": "string", - "title": "Azure AD Client ID" - }, + "cluster": {"type": "string"}, + "azure_ad_client_id": {"type": "string", "title": "Azure AD Client ID"}, "azure_ad_client_secret": { "type": "string", - "title": "Azure AD Client Secret" - }, - "azure_ad_tenant_id": { - "type": "string", - "title": "Azure AD Tenant Id" + "title": "Azure AD Client Secret", }, - "database": { - "type": "string" - } + "azure_ad_tenant_id": {"type": "string", "title": "Azure AD Tenant Id"}, + "database": {"type": "string"}, }, "required": [ - "cluster", "azure_ad_client_id", "azure_ad_client_secret", - "azure_ad_tenant_id", "database" + "cluster", + "azure_ad_client_id", + "azure_ad_client_secret", + "azure_ad_tenant_id", + "database", ], "order": [ - "cluster", "azure_ad_client_id", "azure_ad_client_secret", - "azure_ad_tenant_id", "database" + "cluster", + "azure_ad_client_id", + "azure_ad_client_secret", + "azure_ad_tenant_id", + "database", ], - "secret": ["azure_ad_client_secret"] + "secret": ["azure_ad_client_secret"], } @classmethod @@ -83,14 +87,15 @@ def name(cls): def run_query(self, query, user): kcsb = KustoConnectionStringBuilder.with_aad_application_key_authentication( - connection_string=self.configuration['cluster'], - aad_app_id=self.configuration['azure_ad_client_id'], - app_key=self.configuration['azure_ad_client_secret'], - authority_id=self.configuration['azure_ad_tenant_id']) + connection_string=self.configuration["cluster"], + aad_app_id=self.configuration["azure_ad_client_id"], + app_key=self.configuration["azure_ad_client_secret"], + authority_id=self.configuration["azure_ad_tenant_id"], + ) client = KustoClient(kcsb) - db = self.configuration['database'] + db = self.configuration["database"] try: response = client.execute(db, query) @@ -100,24 +105,26 @@ def run_query(self, query, user): columns = [] rows = [] for c in result_cols: - columns.append({ - 'name': c.column_name, - 'friendly_name': c.column_name, - 'type': TYPES_MAP.get(c.column_type, None) - }) + columns.append( + { + "name": c.column_name, + "friendly_name": c.column_name, + "type": TYPES_MAP.get(c.column_type, None), + } + ) # rows must be [{'column1': value, 'column2': value}] for row in result_rows: rows.append(row.to_dict()) error = None - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) except KustoServiceError as err: json_data = None try: - error = err.args[1][0]['error']['@message'] + error = err.args[1][0]["error"]["@message"] except (IndexError, KeyError): error = err.args[1] except KeyboardInterrupt: @@ -136,19 +143,21 @@ def get_schema(self, get_stats=False): results = json_loads(results) - schema_as_json = json_loads(results['rows'][0]['DatabaseSchema']) - tables_list = schema_as_json['Databases'][self.configuration['database']]['Tables'].values() + schema_as_json = json_loads(results["rows"][0]["DatabaseSchema"]) + tables_list = schema_as_json["Databases"][self.configuration["database"]][ + "Tables" + ].values() schema = {} for table in tables_list: - table_name = table['Name'] + table_name = table["Name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - for column in table['OrderedColumns']: - schema[table_name]['columns'].append(column['Name']) + for column in table["OrderedColumns"]: + schema[table_name]["columns"].append(column["Name"]) return list(schema.values()) diff --git a/redash/query_runner/big_query.py b/redash/query_runner/big_query.py index 03defe1fcd..710b00e2eb 100644 --- a/redash/query_runner/big_query.py +++ b/redash/query_runner/big_query.py @@ -24,24 +24,24 @@ enabled = False types_map = { - 'INTEGER': TYPE_INTEGER, - 'FLOAT': TYPE_FLOAT, - 'BOOLEAN': TYPE_BOOLEAN, - 'STRING': TYPE_STRING, - 'TIMESTAMP': TYPE_DATETIME, + "INTEGER": TYPE_INTEGER, + "FLOAT": TYPE_FLOAT, + "BOOLEAN": TYPE_BOOLEAN, + "STRING": TYPE_STRING, + "TIMESTAMP": TYPE_DATETIME, } def transform_cell(field_type, cell_value): if cell_value is None: return None - if field_type == 'INTEGER': + if field_type == "INTEGER": return int(cell_value) - elif field_type == 'FLOAT': + elif field_type == "FLOAT": return float(cell_value) - elif field_type == 'BOOLEAN': + elif field_type == "BOOLEAN": return cell_value.lower() == "true" - elif field_type == 'TIMESTAMP': + elif field_type == "TIMESTAMP": return datetime.datetime.fromtimestamp(float(cell_value)) return cell_value @@ -51,10 +51,12 @@ def transform_row(row, fields): for column_index, cell in enumerate(row["f"]): field = fields[column_index] - if field.get('mode') == 'REPEATED': - cell_value = [transform_cell(field['type'], item['v']) for item in cell['v']] + if field.get("mode") == "REPEATED": + cell_value = [ + transform_cell(field["type"], item["v"]) for item in cell["v"] + ] else: - cell_value = transform_cell(field['type'], cell['v']) + cell_value = transform_cell(field["type"], cell["v"]) row_data[field["name"]] = cell_value @@ -70,12 +72,11 @@ def _load_key(filename): def _get_query_results(jobs, project_id, location, job_id, start_index): - query_reply = jobs.getQueryResults(projectId=project_id, - location=location, - jobId=job_id, - startIndex=start_index).execute() - logging.debug('query_reply %s', query_reply) - if not query_reply['jobComplete']: + query_reply = jobs.getQueryResults( + projectId=project_id, location=location, jobId=job_id, startIndex=start_index + ).execute() + logging.debug("query_reply %s", query_reply) + if not query_reply["jobComplete"]: time.sleep(10) return _get_query_results(jobs, project_id, location, job_id, start_index) @@ -93,54 +94,51 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'projectId': { - 'type': 'string', - 'title': 'Project ID' - }, - 'jsonKeyFile': { - "type": "string", - 'title': 'JSON Key File' - }, - 'totalMBytesProcessedLimit': { + "type": "object", + "properties": { + "projectId": {"type": "string", "title": "Project ID"}, + "jsonKeyFile": {"type": "string", "title": "JSON Key File"}, + "totalMBytesProcessedLimit": { "type": "number", - 'title': 'Scanned Data Limit (MB)' + "title": "Scanned Data Limit (MB)", }, - 'userDefinedFunctionResourceUri': { + "userDefinedFunctionResourceUri": { "type": "string", - 'title': 'UDF Source URIs (i.e. gs://bucket/date_utils.js, gs://bucket/string_utils.js )' + "title": "UDF Source URIs (i.e. gs://bucket/date_utils.js, gs://bucket/string_utils.js )", }, - 'useStandardSql': { + "useStandardSql": { "type": "boolean", - 'title': "Use Standard SQL", + "title": "Use Standard SQL", "default": True, }, - 'location': { - "type": "string", - "title": "Processing Location", - }, - 'loadSchema': { - "type": "boolean", - "title": "Load Schema" - }, - 'maximumBillingTier': { + "location": {"type": "string", "title": "Processing Location"}, + "loadSchema": {"type": "boolean", "title": "Load Schema"}, + "maximumBillingTier": { "type": "number", - "title": "Maximum Billing Tier" - } + "title": "Maximum Billing Tier", + }, }, - 'required': ['jsonKeyFile', 'projectId'], - "order": ['projectId', 'jsonKeyFile', 'loadSchema', 'useStandardSql', 'location', 'totalMBytesProcessedLimit', 'maximumBillingTier', 'userDefinedFunctionResourceUri'], - 'secret': ['jsonKeyFile'] + "required": ["jsonKeyFile", "projectId"], + "order": [ + "projectId", + "jsonKeyFile", + "loadSchema", + "useStandardSql", + "location", + "totalMBytesProcessedLimit", + "maximumBillingTier", + "userDefinedFunctionResourceUri", + ], + "secret": ["jsonKeyFile"], } def _get_bigquery_service(self): scope = [ "https://www.googleapis.com/auth/bigquery", - "https://www.googleapis.com/auth/drive" + "https://www.googleapis.com/auth/drive", ] - key = json_loads(b64decode(self.configuration['jsonKeyFile'])) + key = json_loads(b64decode(self.configuration["jsonKeyFile"])) creds = ServiceAccountCredentials.from_json_keyfile_dict(key, scope) http = httplib2.Http(timeout=settings.BIGQUERY_HTTP_TIMEOUT) @@ -155,43 +153,38 @@ def _get_location(self): return self.configuration.get("location") def _get_total_bytes_processed(self, jobs, query): - job_data = { - "query": query, - "dryRun": True, - } + job_data = {"query": query, "dryRun": True} if self._get_location(): - job_data['location'] = self._get_location() + job_data["location"] = self._get_location() - if self.configuration.get('useStandardSql', False): - job_data['useLegacySql'] = False + if self.configuration.get("useStandardSql", False): + job_data["useLegacySql"] = False response = jobs.query(projectId=self._get_project_id(), body=job_data).execute() return int(response["totalBytesProcessed"]) def _get_job_data(self, query): - job_data = { - "configuration": { - "query": { - "query": query, - } - } - } + job_data = {"configuration": {"query": {"query": query}}} if self._get_location(): - job_data['jobReference'] = { - 'location': self._get_location() - } + job_data["jobReference"] = {"location": self._get_location()} - if self.configuration.get('useStandardSql', False): - job_data['configuration']['query']['useLegacySql'] = False + if self.configuration.get("useStandardSql", False): + job_data["configuration"]["query"]["useLegacySql"] = False - if self.configuration.get('userDefinedFunctionResourceUri'): - resource_uris = self.configuration["userDefinedFunctionResourceUri"].split(',') - job_data["configuration"]["query"]["userDefinedFunctionResources"] = [{"resourceUri": resource_uri} for resource_uri in resource_uris] + if self.configuration.get("userDefinedFunctionResourceUri"): + resource_uris = self.configuration["userDefinedFunctionResourceUri"].split( + "," + ) + job_data["configuration"]["query"]["userDefinedFunctionResources"] = [ + {"resourceUri": resource_uri} for resource_uri in resource_uris + ] if "maximumBillingTier" in self.configuration: - job_data["configuration"]["query"]["maximumBillingTier"] = self.configuration["maximumBillingTier"] + job_data["configuration"]["query"][ + "maximumBillingTier" + ] = self.configuration["maximumBillingTier"] return job_data @@ -200,90 +193,113 @@ def _get_query_result(self, jobs, query): job_data = self._get_job_data(query) insert_response = jobs.insert(projectId=project_id, body=job_data).execute() current_row = 0 - query_reply = _get_query_results(jobs, project_id=project_id, location=self._get_location(), - job_id=insert_response['jobReference']['jobId'], start_index=current_row) + query_reply = _get_query_results( + jobs, + project_id=project_id, + location=self._get_location(), + job_id=insert_response["jobReference"]["jobId"], + start_index=current_row, + ) logger.debug("bigquery replied: %s", query_reply) rows = [] - while ("rows" in query_reply) and current_row < query_reply['totalRows']: + while ("rows" in query_reply) and current_row < query_reply["totalRows"]: for row in query_reply["rows"]: rows.append(transform_row(row, query_reply["schema"]["fields"])) - current_row += len(query_reply['rows']) + current_row += len(query_reply["rows"]) query_result_request = { - 'projectId': project_id, - 'jobId': query_reply['jobReference']['jobId'], - 'startIndex': current_row + "projectId": project_id, + "jobId": query_reply["jobReference"]["jobId"], + "startIndex": current_row, } if self._get_location(): - query_result_request['location'] = self._get_location() + query_result_request["location"] = self._get_location() query_reply = jobs.getQueryResults(**query_result_request).execute() - columns = [{ - 'name': f["name"], - 'friendly_name': f["name"], - 'type': "string" if f.get('mode') == "REPEATED" - else types_map.get(f['type'], "string") - } for f in query_reply["schema"]["fields"]] + columns = [ + { + "name": f["name"], + "friendly_name": f["name"], + "type": "string" + if f.get("mode") == "REPEATED" + else types_map.get(f["type"], "string"), + } + for f in query_reply["schema"]["fields"] + ] data = { "columns": columns, "rows": rows, - 'metadata': {'data_scanned': int(query_reply['totalBytesProcessed'])} + "metadata": {"data_scanned": int(query_reply["totalBytesProcessed"])}, } return data def _get_columns_schema(self, table_data): columns = [] - for column in table_data.get('schema', {}).get('fields', []): + for column in table_data.get("schema", {}).get("fields", []): columns.extend(self._get_columns_schema_column(column)) project_id = self._get_project_id() - table_name = table_data['id'].replace("%s:" % project_id, "") - return {'name': table_name, 'columns': columns} + table_name = table_data["id"].replace("%s:" % project_id, "") + return {"name": table_name, "columns": columns} def _get_columns_schema_column(self, column): columns = [] - if column['type'] == 'RECORD': - for field in column['fields']: - columns.append("{}.{}".format(column['name'], field['name'])) + if column["type"] == "RECORD": + for field in column["fields"]: + columns.append("{}.{}".format(column["name"], field["name"])) else: - columns.append(column['name']) + columns.append(column["name"]) return columns def get_schema(self, get_stats=False): - if not self.configuration.get('loadSchema', False): + if not self.configuration.get("loadSchema", False): return [] service = self._get_bigquery_service() project_id = self._get_project_id() datasets = service.datasets().list(projectId=project_id).execute() schema = [] - for dataset in datasets.get('datasets', []): - dataset_id = dataset['datasetReference']['datasetId'] - tables = service.tables().list(projectId=project_id, datasetId=dataset_id).execute() + for dataset in datasets.get("datasets", []): + dataset_id = dataset["datasetReference"]["datasetId"] + tables = ( + service.tables() + .list(projectId=project_id, datasetId=dataset_id) + .execute() + ) while True: - for table in tables.get('tables', []): - table_data = service.tables().get(projectId=project_id, - datasetId=dataset_id, - tableId=table['tableReference']['tableId']).execute() + for table in tables.get("tables", []): + table_data = ( + service.tables() + .get( + projectId=project_id, + datasetId=dataset_id, + tableId=table["tableReference"]["tableId"], + ) + .execute() + ) table_schema = self._get_columns_schema(table_data) schema.append(table_schema) - next_token = tables.get('nextPageToken', None) + next_token = tables.get("nextPageToken", None) if next_token is None: break - tables = service.tables().list(projectId=project_id, - datasetId=dataset_id, - pageToken=next_token).execute() + tables = ( + service.tables() + .list( + projectId=project_id, datasetId=dataset_id, pageToken=next_token + ) + .execute() + ) return schema @@ -296,9 +312,15 @@ def run_query(self, query, user): try: if "totalMBytesProcessedLimit" in self.configuration: limitMB = self.configuration["totalMBytesProcessedLimit"] - processedMB = self._get_total_bytes_processed(jobs, query) / 1000.0 / 1000.0 + processedMB = ( + self._get_total_bytes_processed(jobs, query) / 1000.0 / 1000.0 + ) if limitMB < processedMB: - return None, "Larger than %d MBytes will be processed (%f MBytes)" % (limitMB, processedMB) + return ( + None, + "Larger than %d MBytes will be processed (%f MBytes)" + % (limitMB, processedMB), + ) data = self._get_query_result(jobs, query) error = None @@ -307,7 +329,7 @@ def run_query(self, query, user): except apiclient.errors.HttpError as e: json_data = None if e.resp.status == 400: - error = json_loads(e.content)['error']['message'] + error = json_loads(e.content)["error"]["message"] else: error = e.content except KeyboardInterrupt: diff --git a/redash/query_runner/big_query_gce.py b/redash/query_runner/big_query_gce.py index 2fb7d9db05..bc7a38d91d 100644 --- a/redash/query_runner/big_query_gce.py +++ b/redash/query_runner/big_query_gce.py @@ -25,7 +25,7 @@ def enabled(cls): try: # check if we're on a GCE instance - requests.get('http://metadata.google.internal') + requests.get("http://metadata.google.internal") except requests.exceptions.ConnectionError: return False @@ -34,38 +34,40 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'totalMBytesProcessedLimit': { + "type": "object", + "properties": { + "totalMBytesProcessedLimit": { "type": "number", - 'title': 'Total MByte Processed Limit' + "title": "Total MByte Processed Limit", }, - 'userDefinedFunctionResourceUri': { + "userDefinedFunctionResourceUri": { "type": "string", - 'title': 'UDF Source URIs (i.e. gs://bucket/date_utils.js, gs://bucket/string_utils.js )' + "title": "UDF Source URIs (i.e. gs://bucket/date_utils.js, gs://bucket/string_utils.js )", }, - 'useStandardSql': { + "useStandardSql": { "type": "boolean", - 'title': "Use Standard SQL", + "title": "Use Standard SQL", "default": True, }, - 'location': { + "location": { "type": "string", "title": "Processing Location", "default": "US", }, - 'loadSchema': { - "type": "boolean", - "title": "Load Schema" - } - } + "loadSchema": {"type": "boolean", "title": "Load Schema"}, + }, } def _get_project_id(self): - return requests.get('http://metadata/computeMetadata/v1/project/project-id', headers={'Metadata-Flavor': 'Google'}).content + return requests.get( + "http://metadata/computeMetadata/v1/project/project-id", + headers={"Metadata-Flavor": "Google"}, + ).content def _get_bigquery_service(self): - credentials = gce.AppAssertionCredentials(scope='https://www.googleapis.com/auth/bigquery') + credentials = gce.AppAssertionCredentials( + scope="https://www.googleapis.com/auth/bigquery" + ) http = httplib2.Http() http = credentials.authorize(http) diff --git a/redash/query_runner/cass.py b/redash/query_runner/cass.py index e56aad703a..4725101b1e 100644 --- a/redash/query_runner/cass.py +++ b/redash/query_runner/cass.py @@ -9,6 +9,7 @@ from cassandra.cluster import Cluster from cassandra.auth import PlainTextAuthProvider from cassandra.util import sortedset + enabled = True except ImportError: enabled = False @@ -31,39 +32,21 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'host': { - 'type': 'string', - }, - 'port': { - 'type': 'number', - 'default': 9042, - }, - 'keyspace': { - 'type': 'string', - 'title': 'Keyspace name' - }, - 'username': { - 'type': 'string', - 'title': 'Username' + "type": "object", + "properties": { + "host": {"type": "string"}, + "port": {"type": "number", "default": 9042}, + "keyspace": {"type": "string", "title": "Keyspace name"}, + "username": {"type": "string", "title": "Username"}, + "password": {"type": "string", "title": "Password"}, + "protocol": { + "type": "number", + "title": "Protocol Version", + "default": 3, }, - 'password': { - 'type': 'string', - 'title': 'Password' - }, - 'protocol': { - 'type': 'number', - 'title': 'Protocol Version', - 'default': 3 - }, - 'timeout': { - 'type': 'number', - 'title': 'Timeout', - 'default': 10 - } + "timeout": {"type": "number", "title": "Timeout", "default": 10}, }, - 'required': ['keyspace', 'host'] + "required": ["keyspace", "host"], } @classmethod @@ -76,61 +59,73 @@ def get_schema(self, get_stats=False): """ results, error = self.run_query(query, None) results = json_loads(results) - release_version = results['rows'][0]['release_version'] + release_version = results["rows"][0]["release_version"] query = """ SELECT table_name, column_name FROM system_schema.columns WHERE keyspace_name ='{}'; - """.format(self.configuration['keyspace']) + """.format( + self.configuration["keyspace"] + ) - if release_version.startswith('2'): - query = """ + if release_version.startswith("2"): + query = """ SELECT columnfamily_name AS table_name, column_name FROM system.schema_columns WHERE keyspace_name ='{}'; - """.format(self.configuration['keyspace']) + """.format( + self.configuration["keyspace"] + ) results, error = self.run_query(query, None) results = json_loads(results) schema = {} - for row in results['rows']: - table_name = row['table_name'] - column_name = row['column_name'] + for row in results["rows"]: + table_name = row["table_name"] + column_name = row["column_name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} - schema[table_name]['columns'].append(column_name) + schema[table_name] = {"name": table_name, "columns": []} + schema[table_name]["columns"].append(column_name) return list(schema.values()) def run_query(self, query, user): connection = None try: - if self.configuration.get('username', '') and self.configuration.get('password', ''): - auth_provider = PlainTextAuthProvider(username='{}'.format(self.configuration.get('username', '')), - password='{}'.format(self.configuration.get('password', ''))) - connection = Cluster([self.configuration.get('host', '')], - auth_provider=auth_provider, - port=self.configuration.get('port', ''), - protocol_version=self.configuration.get('protocol', 3)) + if self.configuration.get("username", "") and self.configuration.get( + "password", "" + ): + auth_provider = PlainTextAuthProvider( + username="{}".format(self.configuration.get("username", "")), + password="{}".format(self.configuration.get("password", "")), + ) + connection = Cluster( + [self.configuration.get("host", "")], + auth_provider=auth_provider, + port=self.configuration.get("port", ""), + protocol_version=self.configuration.get("protocol", 3), + ) else: - connection = Cluster([self.configuration.get('host', '')], - port=self.configuration.get('port', ''), - protocol_version=self.configuration.get('protocol', 3)) + connection = Cluster( + [self.configuration.get("host", "")], + port=self.configuration.get("port", ""), + protocol_version=self.configuration.get("protocol", 3), + ) session = connection.connect() - session.set_keyspace(self.configuration['keyspace']) - session.default_timeout = self.configuration.get('timeout', 10) + session.set_keyspace(self.configuration["keyspace"]) + session.default_timeout = self.configuration.get("timeout", 10) logger.debug("Cassandra running query: %s", query) result = session.execute(query) column_names = result.column_names - columns = self.fetch_columns([(c, 'string') for c in column_names]) + columns = self.fetch_columns([(c, "string") for c in column_names]) rows = [dict(zip(column_names, row)) for row in result] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data, cls=CassandraJSONEncoder) error = None @@ -142,7 +137,6 @@ def run_query(self, query, user): class ScyllaDB(Cassandra): - @classmethod def type(cls): return "scylla" diff --git a/redash/query_runner/clickhouse.py b/redash/query_runner/clickhouse.py index 1217a9ec18..ae905589c1 100644 --- a/redash/query_runner/clickhouse.py +++ b/redash/query_runner/clickhouse.py @@ -17,30 +17,19 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "url": { - "type": "string", - "default": "http://127.0.0.1:8123" - }, - "user": { - "type": "string", - "default": "default" - }, - "password": { - "type": "string" - }, - "dbname": { - "type": "string", - "title": "Database Name" - }, + "url": {"type": "string", "default": "http://127.0.0.1:8123"}, + "user": {"type": "string", "default": "default"}, + "password": {"type": "string"}, + "dbname": {"type": "string", "title": "Database Name"}, "timeout": { "type": "number", "title": "Request Timeout", - "default": 30 - } + "default": 30, + }, }, "required": ["dbname"], "extra_options": ["timeout"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -57,29 +46,29 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - table_name = '{}.{}'.format(row['database'], row['table']) + for row in results["rows"]: + table_name = "{}.{}".format(row["database"], row["table"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['name']) + schema[table_name]["columns"].append(row["name"]) return list(schema.values()) def _send_query(self, data, stream=False): - url = self.configuration.get('url', "http://127.0.0.1:8123") + url = self.configuration.get("url", "http://127.0.0.1:8123") try: r = requests.post( url, data=data.encode("utf-8"), stream=stream, - timeout=self.configuration.get('timeout', 30), + timeout=self.configuration.get("timeout", 30), params={ - 'user': self.configuration.get('user', "default"), - 'password': self.configuration.get('password', ""), - 'database': self.configuration['dbname'] - } + "user": self.configuration.get("user", "default"), + "password": self.configuration.get("password", ""), + "database": self.configuration["dbname"], + }, ) if r.status_code != 200: raise Exception(r.text) @@ -87,7 +76,9 @@ def _send_query(self, data, stream=False): return r.json() except requests.RequestException as e: if e.response: - details = "({}, Status Code: {})".format(e.__class__.__name__, e.response.status_code) + details = "({}, Status Code: {})".format( + e.__class__.__name__, e.response.status_code + ) else: details = "({})".format(e.__class__.__name__) raise Exception("Connection error to: {} {}.".format(url, details)) @@ -95,39 +86,43 @@ def _send_query(self, data, stream=False): @staticmethod def _define_column_type(column): c = column.lower() - f = re.search(r'^nullable\((.*)\)$', c) + f = re.search(r"^nullable\((.*)\)$", c) if f is not None: c = f.group(1) - if c.startswith('int') or c.startswith('uint'): + if c.startswith("int") or c.startswith("uint"): return TYPE_INTEGER - elif c.startswith('float'): + elif c.startswith("float"): return TYPE_FLOAT - elif c == 'datetime': + elif c == "datetime": return TYPE_DATETIME - elif c == 'date': + elif c == "date": return TYPE_DATE else: return TYPE_STRING def _clickhouse_query(self, query): - query += '\nFORMAT JSON' + query += "\nFORMAT JSON" result = self._send_query(query) columns = [] columns_int64 = [] # db converts value to string if its type equals UInt64 columns_totals = {} - for r in result['meta']: - column_name = r['name'] - column_type = self._define_column_type(r['type']) + for r in result["meta"]: + column_name = r["name"] + column_type = self._define_column_type(r["type"]) - if r['type'] in ('Int64', 'UInt64', 'Nullable(Int64)', 'Nullable(UInt64)'): + if r["type"] in ("Int64", "UInt64", "Nullable(Int64)", "Nullable(UInt64)"): columns_int64.append(column_name) else: - columns_totals[column_name] = 'Total' if column_type == TYPE_STRING else None + columns_totals[column_name] = ( + "Total" if column_type == TYPE_STRING else None + ) - columns.append({'name': column_name, 'friendly_name': column_name, 'type': column_type}) + columns.append( + {"name": column_name, "friendly_name": column_name, "type": column_type} + ) - rows = result['data'] + rows = result["data"] for row in rows: for column in columns_int64: try: @@ -135,13 +130,13 @@ def _clickhouse_query(self, query): except TypeError: row[column] = None - if 'totals' in result: - totals = result['totals'] + if "totals" in result: + totals = result["totals"] for column, value in columns_totals.items(): totals[column] = value rows.append(totals) - return {'columns': columns, 'rows': rows} + return {"columns": columns, "rows": rows} def run_query(self, query, user): logger.debug("Clickhouse is about to execute query: %s", query) diff --git a/redash/query_runner/couchbase.py b/redash/query_runner/couchbase.py index 1753bc0dd5..4e1de63fd4 100644 --- a/redash/query_runner/couchbase.py +++ b/redash/query_runner/couchbase.py @@ -13,7 +13,7 @@ import requests import httplib2 except ImportError as e: - logger.error('Failed to import: ' + str(e)) + logger.error("Failed to import: " + str(e)) TYPES_MAP = { @@ -23,7 +23,7 @@ float: TYPE_FLOAT, bool: TYPE_BOOLEAN, datetime.datetime: TYPE_DATETIME, - datetime.datetime: TYPE_STRING + datetime.datetime: TYPE_STRING, } @@ -43,23 +43,29 @@ def parse_results(results): for key in row: if isinstance(row[key], dict): for inner_key in row[key]: - column_name = '{}.{}'.format(key, inner_key) + column_name = "{}.{}".format(key, inner_key) if _get_column_by_name(columns, column_name) is None: - columns.append({ - "name": column_name, - "friendly_name": column_name, - "type": TYPES_MAP.get(type(row[key][inner_key]), TYPE_STRING) - }) + columns.append( + { + "name": column_name, + "friendly_name": column_name, + "type": TYPES_MAP.get( + type(row[key][inner_key]), TYPE_STRING + ), + } + ) parsed_row[column_name] = row[key][inner_key] else: if _get_column_by_name(columns, key) is None: - columns.append({ - "name": key, - "friendly_name": key, - "type": TYPES_MAP.get(type(row[key]), TYPE_STRING) - }) + columns.append( + { + "name": key, + "friendly_name": key, + "type": TYPES_MAP.get(type(row[key]), TYPE_STRING), + } + ) parsed_row[key] = row[key] @@ -69,35 +75,26 @@ def parse_results(results): class Couchbase(BaseQueryRunner): should_annotate_query = False - noop_query = 'Select 1' + noop_query = "Select 1" @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'protocol': { - 'type': 'string', - 'default': 'http' - }, - 'host': { - 'type': 'string', - }, - 'port': { - 'type': 'string', - 'title': 'Port (Defaults: 8095 - Analytics, 8093 - N1QL)', - 'default': '8095' - }, - 'user': { - 'type': 'string', - }, - 'password': { - 'type': 'string', + "type": "object", + "properties": { + "protocol": {"type": "string", "default": "http"}, + "host": {"type": "string"}, + "port": { + "type": "string", + "title": "Port (Defaults: 8095 - Analytics, 8093 - N1QL)", + "default": "8095", }, + "user": {"type": "string"}, + "password": {"type": "string"}, }, - 'required': ['host', 'user', 'password'], - 'order': ['protocol', 'host', 'port', 'user', 'password'], - 'secret': ['password'] + "required": ["host", "user", "password"], + "order": ["protocol", "host", "port", "user", "password"], + "secret": ["password"], } def __init__(self, configuration): @@ -108,17 +105,15 @@ def enabled(cls): return True def test_connection(self): - result = self.call_service(self.noop_query, '') + result = self.call_service(self.noop_query, "") def get_buckets(self, query, name_param): - defaultColumns = [ - 'meta().id' - ] - result = self.call_service(query, "").json()['results'] + defaultColumns = ["meta().id"] + result = self.call_service(query, "").json()["results"] schema = {} for row in result: table_name = row.get(name_param) - schema[table_name] = {'name': table_name, 'columns': defaultColumns} + schema[table_name] = {"name": table_name, "columns": defaultColumns} return list(schema.values()) @@ -127,7 +122,9 @@ def get_schema(self, get_stats=False): try: # Try fetch from Analytics return self.get_buckets( - "SELECT ds.GroupName as name FROM Metadata.`Dataset` ds where ds.DataverseName <> 'Metadata'", "name") + "SELECT ds.GroupName as name FROM Metadata.`Dataset` ds where ds.DataverseName <> 'Metadata'", + "name", + ) except Exception: # Try fetch from N1QL return self.get_buckets("select name from system:keyspaces", "name") @@ -139,7 +136,7 @@ def call_service(self, query, user): protocol = self.configuration.get("protocol", "http") host = self.configuration.get("host") port = self.configuration.get("port", 8095) - params = {'statement': query} + params = {"statement": query} url = "%s://%s:%s/query/service" % (protocol, host, port) @@ -147,7 +144,7 @@ def call_service(self, query, user): r.raise_for_status() return r except requests.exceptions.HTTPError as err: - if (err.response.status_code == 401): + if err.response.status_code == 401: raise Exception("Wrong username/password") raise Exception("Couchbase connection error") @@ -155,11 +152,8 @@ def run_query(self, query, user): try: result = self.call_service(query, user) - rows, columns = parse_results(result.json()['results']) - data = { - "columns": columns, - "rows": rows - } + rows, columns = parse_results(result.json()["results"]) + data = {"columns": columns, "rows": rows} return json_dumps(data), None except KeyboardInterrupt: diff --git a/redash/query_runner/databricks.py b/redash/query_runner/databricks.py index 04deb22e35..4756581c20 100644 --- a/redash/query_runner/databricks.py +++ b/redash/query_runner/databricks.py @@ -5,6 +5,7 @@ try: from pyhive import hive from thrift.transport import THttpClient + enabled = True except ImportError: enabled = False @@ -24,41 +25,31 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string" - }, - "database": { - "type": "string" - }, - "http_path": { - "type": "string", - "title": "HTTP Path" - }, - "http_password": { - "type": "string", - "title": "Access Token" - }, + "host": {"type": "string"}, + "database": {"type": "string"}, + "http_path": {"type": "string", "title": "HTTP Path"}, + "http_password": {"type": "string", "title": "Access Token"}, }, "order": ["host", "http_path", "http_password", "database"], "secret": ["http_password"], - "required": ["host", "database", "http_path", "http_password"] + "required": ["host", "database", "http_path", "http_password"], } def _get_connection(self): - host = self.configuration['host'] + host = self.configuration["host"] # if path is set but is missing initial slash, append it - path = self.configuration.get('http_path', '') - if path and path[0] != '/': - path = '/' + path + path = self.configuration.get("http_path", "") + if path and path[0] != "/": + path = "/" + path http_uri = "https://{}{}".format(host, path) transport = THttpClient.THttpClient(http_uri) - password = self.configuration.get('http_password', '') - auth = base64.b64encode('token:' + password) - transport.setCustomHeaders({'Authorization': 'Basic ' + auth}) + password = self.configuration.get("http_password", "") + auth = base64.b64encode("token:" + password) + transport.setCustomHeaders({"Authorization": "Basic " + auth}) connection = hive.connect(thrift_transport=transport) return connection @@ -70,14 +61,32 @@ def _get_tables(self, schema): schemas = self._run_query_internal(schemas_query) - for schema_name in [a for a in [str(a['databaseName']) for a in schemas] if len(a) > 0]: - for table_name in [a for a in [str(a['tableName']) for a in self._run_query_internal(tables_query % schema_name)] if len(a) > 0]: - columns = [a for a in [str(a['col_name']) for a in self._run_query_internal(columns_query % (schema_name, table_name))] if len(a) > 0] - - if schema_name != 'default': - table_name = '{}.{}'.format(schema_name, table_name) - - schema[table_name] = {'name': table_name, 'columns': columns} + for schema_name in [ + a for a in [str(a["databaseName"]) for a in schemas] if len(a) > 0 + ]: + for table_name in [ + a + for a in [ + str(a["tableName"]) + for a in self._run_query_internal(tables_query % schema_name) + ] + if len(a) > 0 + ]: + columns = [ + a + for a in [ + str(a["col_name"]) + for a in self._run_query_internal( + columns_query % (schema_name, table_name) + ) + ] + if len(a) > 0 + ] + + if schema_name != "default": + table_name = "{}.{}".format(schema_name, table_name) + + schema[table_name] = {"name": table_name, "columns": columns} return list(schema.values()) diff --git a/redash/query_runner/db2.py b/redash/query_runner/db2.py index 8f2257a1e3..9aa4b93506 100644 --- a/redash/query_runner/db2.py +++ b/redash/query_runner/db2.py @@ -21,7 +21,7 @@ ibm_db_dbi.BINARY: TYPE_STRING, ibm_db_dbi.XML: TYPE_STRING, ibm_db_dbi.TEXT: TYPE_STRING, - ibm_db_dbi.STRING: TYPE_STRING + ibm_db_dbi.STRING: TYPE_STRING, } enabled = True @@ -37,28 +37,15 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "host": { - "type": "string", - "default": "127.0.0.1" - }, - "port": { - "type": "number", - "default": 50000 - }, - "dbname": { - "type": "string", - "title": "Database Name" - } + "user": {"type": "string"}, + "password": {"type": "string"}, + "host": {"type": "string", "default": "127.0.0.1"}, + "port": {"type": "number", "default": 50000}, + "dbname": {"type": "string", "title": "Database Name"}, }, - "order": ['host', 'port', 'user', 'password', 'dbname'], + "order": ["host", "port", "user", "password", "dbname"], "required": ["dbname"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -82,16 +69,16 @@ def _get_definitions(self, schema, query): results = json_loads(results) - for row in results['rows']: - if row['TABLE_SCHEMA'] != 'public': - table_name = '{}.{}'.format(row['TABLE_SCHEMA'], row['TABLE_NAME']) + for row in results["rows"]: + if row["TABLE_SCHEMA"] != "public": + table_name = "{}.{}".format(row["TABLE_SCHEMA"], row["TABLE_NAME"]) else: - table_name = row['TABLE_NAME'] + table_name = row["TABLE_NAME"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['COLUMN_NAME']) + schema[table_name]["columns"].append(row["COLUMN_NAME"]) def _get_tables(self, schema): query = """ @@ -109,7 +96,12 @@ def _get_tables(self, schema): def _get_connection(self): self.connection_string = "DATABASE={};HOSTNAME={};PORT={};PROTOCOL=TCPIP;UID={};PWD={};".format( - self.configuration["dbname"], self.configuration["host"], self.configuration["port"], self.configuration["user"], self.configuration["password"]) + self.configuration["dbname"], + self.configuration["host"], + self.configuration["port"], + self.configuration["user"], + self.configuration["password"], + ) connection = ibm_db_dbi.connect(self.connection_string, "", "") return connection @@ -122,14 +114,19 @@ def run_query(self, query, user): cursor.execute(query) if cursor.description is not None: - columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in cursor] - - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [(i[0], types_map.get(i[1], None)) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) + for row in cursor + ] + + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) else: - error = 'Query completed but it returned no data.' + error = "Query completed but it returned no data." json_data = None except (select.error, OSError) as e: error = "Query interrupted. Please retry." diff --git a/redash/query_runner/dgraph.py b/redash/query_runner/dgraph.py index 3bf68c82d2..f48f8d91f3 100644 --- a/redash/query_runner/dgraph.py +++ b/redash/query_runner/dgraph.py @@ -2,6 +2,7 @@ try: import pydgraph + enabled = True except ImportError: enabled = False @@ -15,13 +16,13 @@ def reduce_item(reduced_item, key, value): # Reduction Condition 1 if type(value) is list: for i, sub_item in enumerate(value): - reduce_item(reduced_item, '{}.{}'.format(key, i), sub_item) + reduce_item(reduced_item, "{}.{}".format(key, i), sub_item) # Reduction Condition 2 elif type(value) is dict: sub_keys = value.keys() for sub_key in sub_keys: - reduce_item(reduced_item, '{}.{}'.format(key, sub_key), value[sub_key]) + reduce_item(reduced_item, "{}.{}".format(key, sub_key), value[sub_key]) # Base Condition else: @@ -42,19 +43,13 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "servers": { - "type": "string" - } + "user": {"type": "string"}, + "password": {"type": "string"}, + "servers": {"type": "string"}, }, "order": ["servers", "user", "password"], "required": ["servers"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -66,7 +61,7 @@ def enabled(cls): return enabled def run_dgraph_query_raw(self, query): - servers = self.configuration.get('servers') + servers = self.configuration.get("servers") client_stub = pydgraph.DgraphClientStub(servers) client = pydgraph.DgraphClient(client_stub) @@ -111,10 +106,12 @@ def run_query(self, query, user): header = list(set(header)) - columns = [{'name': c, 'friendly_name': c, 'type': 'string'} for c in header] + columns = [ + {"name": c, "friendly_name": c, "type": "string"} for c in header + ] # finally, assemble both the columns and data - data = {'columns': columns, 'rows': processed_data} + data = {"columns": columns, "rows": processed_data} json_data = json_dumps(data) except Exception as e: @@ -132,11 +129,11 @@ def get_schema(self, get_stats=False): schema = {} - for row in results['schema']: - table_name = row['predicate'] + for row in results["schema"]: + table_name = row["predicate"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} return list(schema.values()) diff --git a/redash/query_runner/drill.py b/redash/query_runner/drill.py index 4c19fdfbc4..d3b754374d 100644 --- a/redash/query_runner/drill.py +++ b/redash/query_runner/drill.py @@ -7,9 +7,13 @@ from six import text_type from redash.query_runner import ( - BaseHTTPQueryRunner, register, - TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN, - guess_type + BaseHTTPQueryRunner, + register, + TYPE_DATETIME, + TYPE_INTEGER, + TYPE_FLOAT, + TYPE_BOOLEAN, + guess_type, ) from redash.utils import json_dumps, json_loads @@ -18,8 +22,8 @@ # Convert Drill string value to actual type def convert_type(string_value, actual_type): - if string_value is None or string_value == '': - return '' + if string_value is None or string_value == "": + return "" if actual_type == TYPE_INTEGER: return int(string_value) @@ -28,7 +32,7 @@ def convert_type(string_value, actual_type): return float(string_value) if actual_type == TYPE_BOOLEAN: - return text_type(string_value).lower() == 'true' + return text_type(string_value).lower() == "true" if actual_type == TYPE_DATETIME: return parser.parse(string_value) @@ -38,41 +42,43 @@ def convert_type(string_value, actual_type): # Parse Drill API response and translate it to accepted format def parse_response(data): - cols = data['columns'] - rows = data['rows'] + cols = data["columns"] + rows = data["rows"] if len(cols) == 0: - return {'columns': [], 'rows': []} + return {"columns": [], "rows": []} first_row = rows[0] columns = [] types = {} for c in cols: - columns.append({'name': c, 'type': guess_type(first_row[c]), 'friendly_name': c}) + columns.append( + {"name": c, "type": guess_type(first_row[c]), "friendly_name": c} + ) for col in columns: - types[col['name']] = col['type'] + types[col["name"]] = col["type"] for row in rows: for key, value in row.items(): row[key] = convert_type(value, types[key]) - return {'columns': columns, 'rows': rows} + return {"columns": columns, "rows": rows} class Drill(BaseHTTPQueryRunner): - noop_query = 'select version from sys.version' + noop_query = "select version from sys.version" response_error = "Drill API returned unexpected status code" requires_authentication = False requires_url = True - url_title = 'Drill URL' - username_title = 'Username' - password_title = 'Password' + url_title = "Drill URL" + username_title = "Username" + password_title = "Password" @classmethod def name(cls): - return 'Apache Drill' + return "Apache Drill" @classmethod def configuration_schema(cls): @@ -80,20 +86,22 @@ def configuration_schema(cls): # Since Drill itself can act as aggregator of various datasources, # it can contain quite a lot of schemas in `INFORMATION_SCHEMA` # We added this to improve user experience and let users focus only on desired schemas. - schema['properties']['allowed_schemas'] = { - 'type': 'string', - 'title': 'List of schemas to use in schema browser (comma separated)' + schema["properties"]["allowed_schemas"] = { + "type": "string", + "title": "List of schemas to use in schema browser (comma separated)", } - schema['order'] += ['allowed_schemas'] + schema["order"] += ["allowed_schemas"] return schema def run_query(self, query, user): - drill_url = os.path.join(self.configuration['url'], 'query.json') + drill_url = os.path.join(self.configuration["url"], "query.json") try: - payload = {'queryType': 'SQL', 'query': query} + payload = {"queryType": "SQL", "query": query} - response, error = self.get_response(drill_url, http_method='post', json=payload) + response, error = self.get_response( + drill_url, http_method="post", json=payload + ) if error is not None: return None, error @@ -101,7 +109,7 @@ def run_query(self, query, user): return json_dumps(results), None except KeyboardInterrupt: - return None, 'Query cancelled by user.' + return None, "Query cancelled by user." def get_schema(self, get_stats=False): @@ -118,9 +126,16 @@ def get_schema(self, get_stats=False): and TABLE_SCHEMA not like '%.INFORMATION_SCHEMA' """ - allowed_schemas = self.configuration.get('allowed_schemas') + allowed_schemas = self.configuration.get("allowed_schemas") if allowed_schemas: - query += "and TABLE_SCHEMA in ({})".format(', '.join(["'{}'".format(re.sub('[^a-zA-Z0-9_.`]', '', allowed_schema)) for allowed_schema in allowed_schemas.split(',')])) + query += "and TABLE_SCHEMA in ({})".format( + ", ".join( + [ + "'{}'".format(re.sub("[^a-zA-Z0-9_.`]", "", allowed_schema)) + for allowed_schema in allowed_schemas.split(",") + ] + ) + ) results, error = self.run_query(query, None) @@ -131,13 +146,13 @@ def get_schema(self, get_stats=False): schema = {} - for row in results['rows']: - table_name = '{}.{}'.format(row['TABLE_SCHEMA'], row['TABLE_NAME']) + for row in results["rows"]: + table_name = "{}.{}".format(row["TABLE_SCHEMA"], row["TABLE_NAME"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['COLUMN_NAME']) + schema[table_name]["columns"].append(row["COLUMN_NAME"]) return list(schema.values()) diff --git a/redash/query_runner/druid.py b/redash/query_runner/druid.py index bec254de32..0790d5e1e6 100644 --- a/redash/query_runner/druid.py +++ b/redash/query_runner/druid.py @@ -1,5 +1,6 @@ try: from pydruid.db import connect + enabled = True except ImportError: enabled = False @@ -8,11 +9,7 @@ from redash.query_runner import TYPE_STRING, TYPE_INTEGER, TYPE_BOOLEAN from redash.utils import json_dumps, json_loads -TYPES_MAP = { - 1: TYPE_STRING, - 2: TYPE_INTEGER, - 3: TYPE_BOOLEAN, -} +TYPES_MAP = {1: TYPE_STRING, 2: TYPE_INTEGER, 3: TYPE_BOOLEAN} class Druid(BaseQueryRunner): @@ -23,28 +20,15 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string", - "default": "localhost" - }, - "port": { - "type": "number", - "default": 8082 - }, - "scheme": { - "type": "string", - "default": "http" - }, - "user": { - "type": "string" - }, - "password": { - "type": "string" - } + "host": {"type": "string", "default": "localhost"}, + "port": {"type": "number", "default": 8082}, + "scheme": {"type": "string", "default": "http"}, + "user": {"type": "string"}, + "password": {"type": "string"}, }, - "order": ['scheme', 'host', 'port', 'user', 'password'], - "required": ['host'], - "secret": ['password'] + "order": ["scheme", "host", "port", "user", "password"], + "required": ["host"], + "secret": ["password"], } @classmethod @@ -52,21 +36,27 @@ def enabled(cls): return enabled def run_query(self, query, user): - connection = connect(host=self.configuration['host'], - port=self.configuration['port'], - path='/druid/v2/sql/', - scheme=(self.configuration.get('scheme') or 'http'), - user=(self.configuration.get('user') or None), - password=(self.configuration.get('password') or None)) + connection = connect( + host=self.configuration["host"], + port=self.configuration["port"], + path="/druid/v2/sql/", + scheme=(self.configuration.get("scheme") or "http"), + user=(self.configuration.get("user") or None), + password=(self.configuration.get("password") or None), + ) cursor = connection.cursor() try: cursor.execute(query) - columns = self.fetch_columns([(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in cursor] - - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) for row in cursor + ] + + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) print(json_data) @@ -92,13 +82,13 @@ def get_schema(self, get_stats=False): schema = {} results = json_loads(results) - for row in results['rows']: - table_name = '{}.{}'.format(row['TABLE_SCHEMA'], row['TABLE_NAME']) + for row in results["rows"]: + table_name = "{}.{}".format(row["TABLE_SCHEMA"], row["TABLE_NAME"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['COLUMN_NAME']) + schema[table_name]["columns"].append(row["COLUMN_NAME"]) return list(schema.values()) diff --git a/redash/query_runner/dynamodb_sql.py b/redash/query_runner/dynamodb_sql.py index f5fc7f0e3d..bc9cc8056a 100644 --- a/redash/query_runner/dynamodb_sql.py +++ b/redash/query_runner/dynamodb_sql.py @@ -10,25 +10,26 @@ from dql import Engine, FragmentEngine from dynamo3 import DynamoDBError from pyparsing import ParseException + enabled = True except ImportError as e: enabled = False types_map = { - 'UNICODE': TYPE_INTEGER, - 'TINYINT': TYPE_INTEGER, - 'SMALLINT': TYPE_INTEGER, - 'INT': TYPE_INTEGER, - 'DOUBLE': TYPE_FLOAT, - 'DECIMAL': TYPE_FLOAT, - 'FLOAT': TYPE_FLOAT, - 'REAL': TYPE_FLOAT, - 'BOOLEAN': TYPE_BOOLEAN, - 'TIMESTAMP': TYPE_DATETIME, - 'DATE': TYPE_DATETIME, - 'CHAR': TYPE_STRING, - 'STRING': TYPE_STRING, - 'VARCHAR': TYPE_STRING + "UNICODE": TYPE_INTEGER, + "TINYINT": TYPE_INTEGER, + "SMALLINT": TYPE_INTEGER, + "INT": TYPE_INTEGER, + "DOUBLE": TYPE_FLOAT, + "DECIMAL": TYPE_FLOAT, + "FLOAT": TYPE_FLOAT, + "REAL": TYPE_FLOAT, + "BOOLEAN": TYPE_BOOLEAN, + "TIMESTAMP": TYPE_DATETIME, + "DATE": TYPE_DATETIME, + "CHAR": TYPE_STRING, + "STRING": TYPE_STRING, + "VARCHAR": TYPE_STRING, } @@ -40,19 +41,12 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "region": { - "type": "string", - "default": "us-east-1" - }, - "access_key": { - "type": "string", - }, - "secret_key": { - "type": "string", - } + "region": {"type": "string", "default": "us-east-1"}, + "access_key": {"type": "string"}, + "secret_key": {"type": "string"}, }, "required": ["access_key", "secret_key"], - "secret": ["secret_key"] + "secret": ["secret_key"], } def test_connection(self): @@ -71,11 +65,11 @@ def _connect(self): engine = FragmentEngine() config = self.configuration.to_dict() - if not config.get('region'): - config['region'] = 'us-east-1' + if not config.get("region"): + config["region"] = "us-east-1" - if config.get('host') == '': - config['host'] = None + if config.get("host") == "": + config["host"] = None engine.connect(**config) @@ -90,8 +84,10 @@ def _get_tables(self, schema): for table_name in tables: try: table = engine.describe(table_name, True) - schema[table.name] = {'name': table.name, - 'columns': list(table.attrs.keys())} + schema[table.name] = { + "name": table.name, + "columns": list(table.attrs.keys()), + } except DynamoDBError: pass @@ -100,8 +96,8 @@ def run_query(self, query, user): try: engine = self._connect() - if not query.endswith(';'): - query = query + ';' + if not query.endswith(";"): + query = query + ";" result = engine.execute(query) @@ -120,19 +116,22 @@ def run_query(self, query, user): for item in result: if not columns: for k, v in item.items(): - columns.append({ - 'name': k, - 'friendly_name': k, - 'type': types_map.get(str(type(v)).upper(), None) - }) + columns.append( + { + "name": k, + "friendly_name": k, + "type": types_map.get(str(type(v)).upper(), None), + } + ) rows.append(item) - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None except ParseException as e: error = "Error parsing query at line {} (column {}):\n{}".format( - e.lineno, e.column, e.line) + e.lineno, e.column, e.line + ) json_data = None except (SyntaxError, RuntimeError) as e: error = e.message diff --git a/redash/query_runner/elasticsearch.py b/redash/query_runner/elasticsearch.py index a5f08c8e28..04ec193049 100644 --- a/redash/query_runner/elasticsearch.py +++ b/redash/query_runner/elasticsearch.py @@ -31,17 +31,14 @@ # "geo_point" TODO: Need to split to 2 fields somehow } -ELASTICSEARCH_BUILTIN_FIELDS_MAPPING = { - "_id": "Id", - "_score": "Score" -} +ELASTICSEARCH_BUILTIN_FIELDS_MAPPING = {"_id": "Id", "_score": "Score"} PYTHON_TYPES_MAPPING = { str: TYPE_STRING, text_type: TYPE_STRING, bool: TYPE_BOOLEAN, int: TYPE_INTEGER, - float: TYPE_FLOAT + float: TYPE_FLOAT, } @@ -52,24 +49,18 @@ class BaseElasticSearch(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'server': { - 'type': 'string', - 'title': 'Base URL' - }, - 'basic_auth_user': { - 'type': 'string', - 'title': 'Basic Auth User' + "type": "object", + "properties": { + "server": {"type": "string", "title": "Base URL"}, + "basic_auth_user": {"type": "string", "title": "Basic Auth User"}, + "basic_auth_password": { + "type": "string", + "title": "Basic Auth Password", }, - 'basic_auth_password': { - 'type': 'string', - 'title': 'Basic Auth Password' - } }, - "order": ['server', 'basic_auth_user', 'basic_auth_password'], + "order": ["server", "basic_auth_user", "basic_auth_password"], "secret": ["basic_auth_password"], - "required": ["server"] + "required": ["server"], } @classmethod @@ -112,7 +103,9 @@ def _get_mappings(self, url): mappings = r.json() except requests.HTTPError as e: logger.exception(e) - error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text) + error = "Failed to execute query. Return Code: {0} Reason: {1}".format( + r.status_code, r.text + ) mappings = None except requests.exceptions.RequestException as e: logger.exception(e) @@ -133,29 +126,33 @@ def _get_query_mappings(self, url): if "properties" not in index_mappings["mappings"][m]: continue for property_name in index_mappings["mappings"][m]["properties"]: - property_data = index_mappings["mappings"][m]["properties"][property_name] + property_data = index_mappings["mappings"][m]["properties"][ + property_name + ] if property_name not in mappings: property_type = property_data.get("type", None) if property_type: if property_type in ELASTICSEARCH_TYPES_MAPPING: - mappings[property_name] = ELASTICSEARCH_TYPES_MAPPING[property_type] + mappings[property_name] = ELASTICSEARCH_TYPES_MAPPING[ + property_type + ] else: mappings[property_name] = TYPE_STRING - #raise Exception("Unknown property type: {0}".format(property_type)) + # raise Exception("Unknown property type: {0}".format(property_type)) return mappings, error def get_schema(self, *args, **kwargs): def parse_doc(doc, path=None): - '''Recursively parse a doc type dictionary - ''' + """Recursively parse a doc type dictionary + """ path = path or [] result = [] - for field, description in doc['properties'].items(): - if 'properties' in description: + for field, description in doc["properties"].items(): + if "properties" in description: result.extend(parse_doc(description, path + [field])) else: - result.append('.'.join(path + [field])) + result.append(".".join(path + [field])) return result schema = {} @@ -168,22 +165,29 @@ def parse_doc(doc, path=None): # in a hierarchical format for name, index in mappings.items(): columns = [] - schema[name] = {'name': name} - for doc, items in index['mappings'].items(): + schema[name] = {"name": name} + for doc, items in index["mappings"].items(): columns.extend(parse_doc(items)) # remove duplicates # sort alphabetically - schema[name]['columns'] = sorted(set(columns)) + schema[name]["columns"] = sorted(set(columns)) return list(schema.values()) - def _parse_results(self, mappings, result_fields, raw_result, result_columns, result_rows): - def add_column_if_needed(mappings, column_name, friendly_name, result_columns, result_columns_index): + def _parse_results( + self, mappings, result_fields, raw_result, result_columns, result_rows + ): + def add_column_if_needed( + mappings, column_name, friendly_name, result_columns, result_columns_index + ): if friendly_name not in result_columns_index: - result_columns.append({ - "name": friendly_name, - "friendly_name": friendly_name, - "type": mappings.get(column_name, "string")}) + result_columns.append( + { + "name": friendly_name, + "friendly_name": friendly_name, + "type": mappings.get(column_name, "string"), + } + ) result_columns_index[friendly_name] = result_columns[-1] def get_row(rows, row): @@ -197,37 +201,77 @@ def collect_value(mappings, row, key, value, type): return mappings[key] = type - add_column_if_needed(mappings, key, key, result_columns, result_columns_index) + add_column_if_needed( + mappings, key, key, result_columns, result_columns_index + ) row[key] = value - def collect_aggregations(mappings, rows, parent_key, data, row, result_columns, result_columns_index): + def collect_aggregations( + mappings, rows, parent_key, data, row, result_columns, result_columns_index + ): if isinstance(data, dict): for key, value in data.items(): - val = collect_aggregations(mappings, rows, parent_key if key == 'buckets' else key, value, row, result_columns, result_columns_index) + val = collect_aggregations( + mappings, + rows, + parent_key if key == "buckets" else key, + value, + row, + result_columns, + result_columns_index, + ) if val: row = get_row(rows, row) - collect_value(mappings, row, key, val, 'long') + collect_value(mappings, row, key, val, "long") - for data_key in ['value', 'doc_count']: + for data_key in ["value", "doc_count"]: if data_key not in data: continue - if 'key' in data and len(list(data.keys())) == 2: - key_is_string = 'key_as_string' in data - collect_value(mappings, row, data['key'] if not key_is_string else data['key_as_string'], data[data_key], 'long' if not key_is_string else 'string') + if "key" in data and len(list(data.keys())) == 2: + key_is_string = "key_as_string" in data + collect_value( + mappings, + row, + data["key"] if not key_is_string else data["key_as_string"], + data[data_key], + "long" if not key_is_string else "string", + ) else: return data[data_key] elif isinstance(data, list): for value in data: result_row = get_row(rows, row) - collect_aggregations(mappings, rows, parent_key, value, result_row, result_columns, result_columns_index) - if 'doc_count' in value: - collect_value(mappings, result_row, 'doc_count', value['doc_count'], 'integer') - if 'key' in value: - if 'key_as_string' in value: - collect_value(mappings, result_row, parent_key, value['key_as_string'], 'string') + collect_aggregations( + mappings, + rows, + parent_key, + value, + result_row, + result_columns, + result_columns_index, + ) + if "doc_count" in value: + collect_value( + mappings, + result_row, + "doc_count", + value["doc_count"], + "integer", + ) + if "key" in value: + if "key_as_string" in value: + collect_value( + mappings, + result_row, + parent_key, + value["key_as_string"], + "string", + ) else: - collect_value(mappings, result_row, parent_key, value['key'], 'string') + collect_value( + mappings, result_row, parent_key, value["key"], "string" + ) return None @@ -238,26 +282,38 @@ def collect_aggregations(mappings, rows, parent_key, data, row, result_columns, for r in result_fields: result_fields_index[r] = None - if 'error' in raw_result: - error = raw_result['error'] + if "error" in raw_result: + error = raw_result["error"] if len(error) > 10240: - error = error[:10240] + '... continues' + error = error[:10240] + "... continues" raise Exception(error) - elif 'aggregations' in raw_result: + elif "aggregations" in raw_result: if result_fields: for field in result_fields: - add_column_if_needed(mappings, field, field, result_columns, result_columns_index) + add_column_if_needed( + mappings, field, field, result_columns, result_columns_index + ) for key, data in raw_result["aggregations"].items(): - collect_aggregations(mappings, result_rows, key, data, None, result_columns, result_columns_index) + collect_aggregations( + mappings, + result_rows, + key, + data, + None, + result_columns, + result_columns_index, + ) logger.debug("result_rows %s", str(result_rows)) logger.debug("result_columns %s", str(result_columns)) - elif 'hits' in raw_result and 'hits' in raw_result['hits']: + elif "hits" in raw_result and "hits" in raw_result["hits"]: if result_fields: for field in result_fields: - add_column_if_needed(mappings, field, field, result_columns, result_columns_index) + add_column_if_needed( + mappings, field, field, result_columns, result_columns_index + ) for h in raw_result["hits"]["hits"]: row = {} @@ -267,22 +323,36 @@ def collect_aggregations(mappings, rows, parent_key, data, row, result_columns, if result_fields and column not in result_fields_index: continue - add_column_if_needed(mappings, column, column, result_columns, result_columns_index) + add_column_if_needed( + mappings, column, column, result_columns, result_columns_index + ) value = h[column_name][column] - row[column] = value[0] if isinstance(value, list) and len(value) == 1 else value + row[column] = ( + value[0] + if isinstance(value, list) and len(value) == 1 + else value + ) result_rows.append(row) else: - raise Exception("Redash failed to parse the results it got from Elasticsearch.") + raise Exception( + "Redash failed to parse the results it got from Elasticsearch." + ) def test_connection(self): try: - r = requests.get("{0}/_cluster/health".format(self.server_url), auth=self.auth) + r = requests.get( + "{0}/_cluster/health".format(self.server_url), auth=self.auth + ) r.raise_for_status() except requests.HTTPError as e: logger.exception(e) - raise Exception("Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text)) + raise Exception( + "Failed to execute query. Return Code: {0} Reason: {1}".format( + r.status_code, r.text + ) + ) except requests.exceptions.RequestException as e: logger.exception(e) raise Exception("Connection refused") @@ -293,14 +363,18 @@ class Kibana(BaseElasticSearch): def enabled(cls): return True - def _execute_simple_query(self, url, auth, _from, mappings, result_fields, result_columns, result_rows): + def _execute_simple_query( + self, url, auth, _from, mappings, result_fields, result_columns, result_rows + ): url += "&from={0}".format(_from) r = requests.get(url, auth=self.auth) r.raise_for_status() raw_result = r.json() - self._parse_results(mappings, result_fields, raw_result, result_columns, result_rows) + self._parse_results( + mappings, result_fields, raw_result, result_columns, result_rows + ) total = raw_result["hits"]["total"] result_size = len(raw_result["hits"]["hits"]) @@ -347,7 +421,15 @@ def run_query(self, query, user): _from = 0 while True: query_size = size if limit >= (_from + size) else (limit - _from) - total = self._execute_simple_query(url + "&size={0}".format(query_size), self.auth, _from, mappings, result_fields, result_columns, result_rows) + total = self._execute_simple_query( + url + "&size={0}".format(query_size), + self.auth, + _from, + mappings, + result_fields, + result_columns, + result_rows, + ) _from += size if _from >= limit: break @@ -355,16 +437,15 @@ def run_query(self, query, user): # TODO: Handle complete ElasticSearch queries (JSON based sent over HTTP POST) raise Exception("Advanced queries are not supported") - json_data = json_dumps({ - "columns": result_columns, - "rows": result_rows - }) + json_data = json_dumps({"columns": result_columns, "rows": result_rows}) except KeyboardInterrupt: error = "Query cancelled by user." json_data = None except requests.HTTPError as e: logger.exception(e) - error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text) + error = "Failed to execute query. Return Code: {0} Reason: {1}".format( + r.status_code, r.text + ) json_data = None except requests.exceptions.RequestException as e: logger.exception(e) @@ -381,7 +462,7 @@ def enabled(cls): @classmethod def name(cls): - return 'Elasticsearch' + return "Elasticsearch" def run_query(self, query, user): try: @@ -412,19 +493,20 @@ def run_query(self, query, user): result_columns = [] result_rows = [] - self._parse_results(mappings, result_fields, r.json(), result_columns, result_rows) + self._parse_results( + mappings, result_fields, r.json(), result_columns, result_rows + ) - json_data = json_dumps({ - "columns": result_columns, - "rows": result_rows - }) + json_data = json_dumps({"columns": result_columns, "rows": result_rows}) except KeyboardInterrupt: logger.exception(e) error = "Query cancelled by user." json_data = None except requests.HTTPError as e: logger.exception(e) - error = "Failed to execute query. Return Code: {0} Reason: {1}".format(r.status_code, r.text) + error = "Failed to execute query. Return Code: {0} Reason: {1}".format( + r.status_code, r.text + ) json_data = None except requests.exceptions.RequestException as e: logger.exception(e) diff --git a/redash/query_runner/exasol.py b/redash/query_runner/exasol.py index a4dcd9dcae..ed045207ae 100644 --- a/redash/query_runner/exasol.py +++ b/redash/query_runner/exasol.py @@ -3,82 +3,86 @@ from redash.query_runner import * from redash.utils import json_dumps + def _exasol_type_mapper(val, data_type): if val is None: return None - elif data_type['type'] == 'DECIMAL': - if data_type['scale'] == 0 and data_type['precision'] < 16: + elif data_type["type"] == "DECIMAL": + if data_type["scale"] == 0 and data_type["precision"] < 16: return int(val) - elif data_type['scale'] == 0 and data_type['precision'] >= 16: + elif data_type["scale"] == 0 and data_type["precision"] >= 16: return val else: return float(val) - elif data_type['type'] == 'DATE': + elif data_type["type"] == "DATE": return datetime.date(int(val[0:4]), int(val[5:7]), int(val[8:10])) - elif data_type['type'] == 'TIMESTAMP': - return datetime.datetime(int(val[0:4]), int(val[5:7]), int(val[8:10]), # year, month, day - int(val[11:13]), int(val[14:16]), int(val[17:19]), # hour, minute, second - int(val[20:26].ljust(6, '0')) if len(val) > 20 else 0) # microseconds (if available) + elif data_type["type"] == "TIMESTAMP": + return datetime.datetime( + int(val[0:4]), + int(val[5:7]), + int(val[8:10]), # year, month, day + int(val[11:13]), + int(val[14:16]), + int(val[17:19]), # hour, minute, second + int(val[20:26].ljust(6, "0")) if len(val) > 20 else 0, + ) # microseconds (if available) else: return val + def _type_mapper(data_type): - if data_type['type'] == 'DECIMAL': - if data_type['scale'] == 0 and data_type['precision'] < 16: + if data_type["type"] == "DECIMAL": + if data_type["scale"] == 0 and data_type["precision"] < 16: return TYPE_INTEGER - elif data_type['scale'] == 0 and data_type['precision'] >= 16: + elif data_type["scale"] == 0 and data_type["precision"] >= 16: return TYPE_STRING else: return TYPE_FLOAT - elif data_type['type'] == 'DATE': + elif data_type["type"] == "DATE": return TYPE_DATE - elif data_type['type'] == 'TIMESTAMP': + elif data_type["type"] == "TIMESTAMP": return TYPE_DATETIME else: return TYPE_STRING + try: import pyexasol + enabled = True except ImportError: enabled = False class Exasol(BaseQueryRunner): - noop_query = 'SELECT 1 FROM DUAL' + noop_query = "SELECT 1 FROM DUAL" @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'user': { - 'type': 'string' - }, - 'password': { - 'type': 'string' - }, - 'host': { - 'type': 'string' - }, - 'port': { - 'type': 'number', - 'default': 8563 - }, + "type": "object", + "properties": { + "user": {"type": "string"}, + "password": {"type": "string"}, + "host": {"type": "string"}, + "port": {"type": "number", "default": 8563}, }, - 'required': ['host', 'port', 'user', 'password'], - 'order': ['host', 'port', 'user', 'password'], - 'secret': ['password'] + "required": ["host", "port", "user", "password"], + "order": ["host", "port", "user", "password"], + "secret": ["password"], } def _get_connection(self): - exahost = "%s:%s" % (self.configuration.get('host', None), self.configuration.get('port', 8563)) + exahost = "%s:%s" % ( + self.configuration.get("host", None), + self.configuration.get("port", 8563), + ) return pyexasol.connect( dsn=exahost, - user=self.configuration.get('user', None) , - password=self.configuration.get('password', None), + user=self.configuration.get("user", None), + password=self.configuration.get("password", None), compression=True, - json_lib='rapidjson', + json_lib="rapidjson", fetch_mapper=_exasol_type_mapper, ) @@ -88,21 +92,22 @@ def run_query(self, query, user): error = None try: statement = connection.execute(query) - columns = [{'name': n, 'friendly_name': n,'type': _type_mapper(t)} for (n, t) in statement.columns().items()] + columns = [ + {"name": n, "friendly_name": n, "type": _type_mapper(t)} + for (n, t) in statement.columns().items() + ] cnames = statement.column_names() - + rows = [dict(zip(cnames, row)) for row in statement] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) finally: if statement is not None: statement.close() - - connection.close() - - return json_data, error + connection.close() + return json_data, error def get_schema(self, get_stats=False): query = """ @@ -120,16 +125,19 @@ def get_schema(self, get_stats=False): result = {} for (schema, table_name, column) in statement: - table_name_with_schema = '%s.%s' % (schema,table_name) + table_name_with_schema = "%s.%s" % (schema, table_name) if table_name_with_schema not in result: - result[table_name_with_schema] = {'name': table_name_with_schema, 'columns': []} + result[table_name_with_schema] = { + "name": table_name_with_schema, + "columns": [], + } - result[table_name_with_schema]['columns'].append(column) + result[table_name_with_schema]["columns"].append(column) finally: if statement is not None: statement.close() - + connection.close() return result.values() @@ -139,4 +147,4 @@ def enabled(cls): return enabled -register(Exasol) \ No newline at end of file +register(Exasol) diff --git a/redash/query_runner/google_analytics.py b/redash/query_runner/google_analytics.py index ed519e6c1c..f32f180627 100644 --- a/redash/query_runner/google_analytics.py +++ b/redash/query_runner/google_analytics.py @@ -13,6 +13,7 @@ from apiclient.discovery import build from apiclient.errors import HttpError import httplib2 + enabled = True except ImportError as e: enabled = False @@ -23,56 +24,64 @@ INTEGER=TYPE_INTEGER, FLOAT=TYPE_FLOAT, DATE=TYPE_DATE, - DATETIME=TYPE_DATETIME + DATETIME=TYPE_DATETIME, ) def parse_ga_response(response): columns = [] - for h in response['columnHeaders']: - if h['name'] in ('ga:date', 'mcf:conversionDate'): - h['dataType'] = 'DATE' - elif h['name'] == 'ga:dateHour': - h['dataType'] = 'DATETIME' - columns.append({ - 'name': h['name'], - 'friendly_name': h['name'].split(':', 1)[1], - 'type': types_conv.get(h['dataType'], 'string') - }) + for h in response["columnHeaders"]: + if h["name"] in ("ga:date", "mcf:conversionDate"): + h["dataType"] = "DATE" + elif h["name"] == "ga:dateHour": + h["dataType"] = "DATETIME" + columns.append( + { + "name": h["name"], + "friendly_name": h["name"].split(":", 1)[1], + "type": types_conv.get(h["dataType"], "string"), + } + ) rows = [] - for r in response.get('rows', []): + for r in response.get("rows", []): d = {} for c, value in enumerate(r): - column_name = response['columnHeaders'][c]['name'] - column_type = [col for col in columns if col['name'] == column_name][0]['type'] + column_name = response["columnHeaders"][c]["name"] + column_type = [col for col in columns if col["name"] == column_name][0][ + "type" + ] # mcf results come a bit different than ga results: if isinstance(value, dict): - if 'primitiveValue' in value: - value = value['primitiveValue'] - elif 'conversionPathValue' in value: + if "primitiveValue" in value: + value = value["primitiveValue"] + elif "conversionPathValue" in value: steps = [] - for step in value['conversionPathValue']: - steps.append('{}:{}'.format(step['interactionType'], step['nodeValue'])) - value = ', '.join(steps) + for step in value["conversionPathValue"]: + steps.append( + "{}:{}".format(step["interactionType"], step["nodeValue"]) + ) + value = ", ".join(steps) else: raise Exception("Results format not supported") if column_type == TYPE_DATE: - value = datetime.strptime(value, '%Y%m%d') + value = datetime.strptime(value, "%Y%m%d") elif column_type == TYPE_DATETIME: if len(value) == 10: - value = datetime.strptime(value, '%Y%m%d%H') + value = datetime.strptime(value, "%Y%m%d%H") elif len(value) == 12: - value = datetime.strptime(value, '%Y%m%d%H%M') + value = datetime.strptime(value, "%Y%m%d%H%M") else: - raise Exception("Unknown date/time format in results: '{}'".format(value)) + raise Exception( + "Unknown date/time format in results: '{}'".format(value) + ) d[column_name] = value rows.append(d) - return {'columns': columns, 'rows': rows} + return {"columns": columns, "rows": rows} class GoogleAnalytics(BaseSQLQueryRunner): @@ -93,40 +102,50 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'jsonKeyFile': { - "type": "string", - 'title': 'JSON Key File' - } - }, - 'required': ['jsonKeyFile'], - 'secret': ['jsonKeyFile'] + "type": "object", + "properties": {"jsonKeyFile": {"type": "string", "title": "JSON Key File"}}, + "required": ["jsonKeyFile"], + "secret": ["jsonKeyFile"], } def __init__(self, configuration): super(GoogleAnalytics, self).__init__(configuration) - self.syntax = 'json' + self.syntax = "json" def _get_analytics_service(self): - scope = ['https://www.googleapis.com/auth/analytics.readonly'] - key = json_loads(b64decode(self.configuration['jsonKeyFile'])) + scope = ["https://www.googleapis.com/auth/analytics.readonly"] + key = json_loads(b64decode(self.configuration["jsonKeyFile"])) creds = ServiceAccountCredentials.from_json_keyfile_dict(key, scope) - return build('analytics', 'v3', http=creds.authorize(httplib2.Http())) + return build("analytics", "v3", http=creds.authorize(httplib2.Http())) def _get_tables(self, schema): - accounts = self._get_analytics_service().management().accounts().list().execute().get('items') + accounts = ( + self._get_analytics_service() + .management() + .accounts() + .list() + .execute() + .get("items") + ) if accounts is None: raise Exception("Failed getting accounts.") else: for account in accounts: - schema[account['name']] = {'name': account['name'], 'columns': []} - properties = self._get_analytics_service().management().webproperties().list( - accountId=account['id']).execute().get('items', []) + schema[account["name"]] = {"name": account["name"], "columns": []} + properties = ( + self._get_analytics_service() + .management() + .webproperties() + .list(accountId=account["id"]) + .execute() + .get("items", []) + ) for property_ in properties: - if 'defaultProfileId' in property_ and 'name' in property_: - schema[account['name']]['columns'].append( - '{0} (ga:{1})'.format(property_['name'], property_['defaultProfileId']) + if "defaultProfileId" in property_ and "name" in property_: + schema[account["name"]]["columns"].append( + "{0} (ga:{1})".format( + property_["name"], property_["defaultProfileId"] + ) ) return list(schema.values()) @@ -146,17 +165,19 @@ def run_query(self, query, user): except: params = parse_qs(urlparse(query).query, keep_blank_values=True) for key in params.keys(): - params[key] = ','.join(params[key]) - if '-' in key: - params[key.replace('-', '_')] = params.pop(key) + params[key] = ",".join(params[key]) + if "-" in key: + params[key.replace("-", "_")] = params.pop(key) - if 'mcf:' in params['metrics'] and 'ga:' in params['metrics']: + if "mcf:" in params["metrics"] and "ga:" in params["metrics"]: raise Exception("Can't mix mcf: and ga: metrics.") - if 'mcf:' in params.get('dimensions', '') and 'ga:' in params.get('dimensions', ''): + if "mcf:" in params.get("dimensions", "") and "ga:" in params.get( + "dimensions", "" + ): raise Exception("Can't mix mcf: and ga: dimensions.") - if 'mcf:' in params['metrics']: + if "mcf:" in params["metrics"]: api = self._get_analytics_service().data().mcf() else: api = self._get_analytics_service().data().ga() @@ -172,7 +193,7 @@ def run_query(self, query, user): error = e._get_reason() json_data = None else: - error = 'Wrong query format.' + error = "Wrong query format." json_data = None return json_data, error diff --git a/redash/query_runner/google_spreadsheets.py b/redash/query_runner/google_spreadsheets.py index 7584153665..6126ea3210 100644 --- a/redash/query_runner/google_spreadsheets.py +++ b/redash/query_runner/google_spreadsheets.py @@ -32,18 +32,16 @@ def _get_columns_and_column_names(row): for i, column_name in enumerate(row): if not column_name: - column_name = 'column_{}'.format(xl_col_to_name(i)) + column_name = "column_{}".format(xl_col_to_name(i)) if column_name in column_names: column_name = "{}{}".format(column_name, duplicate_counter) duplicate_counter += 1 column_names.append(column_name) - columns.append({ - 'name': column_name, - 'friendly_name': column_name, - 'type': TYPE_STRING - }) + columns.append( + {"name": column_name, "friendly_name": column_name, "type": TYPE_STRING} + ) return columns, column_names @@ -53,10 +51,10 @@ def _value_eval_list(row_values, col_types): raw_values = zip(col_types, row_values) for typ, rval in raw_values: try: - if rval is None or rval == '': + if rval is None or rval == "": val = None elif typ == TYPE_BOOLEAN: - val = True if str(rval).lower() == 'true' else False + val = True if str(rval).lower() == "true" else False elif typ == TYPE_DATETIME: val = parser.parse(rval) elif typ == TYPE_FLOAT: @@ -77,31 +75,38 @@ def _value_eval_list(row_values, col_types): class WorksheetNotFoundError(Exception): def __init__(self, worksheet_num, worksheet_count): - message = "Worksheet number {} not found. Spreadsheet has {} worksheets. Note that the worksheet count is zero based.".format(worksheet_num, worksheet_count) + message = "Worksheet number {} not found. Spreadsheet has {} worksheets. Note that the worksheet count is zero based.".format( + worksheet_num, worksheet_count + ) super(WorksheetNotFoundError, self).__init__(message) def parse_query(query): values = query.split("|") key = values[0] # key of the spreadsheet - worksheet_num = 0 if len(values) != 2 else int(values[1]) # if spreadsheet contains more than one worksheet - this is the number of it + worksheet_num = ( + 0 if len(values) != 2 else int(values[1]) + ) # if spreadsheet contains more than one worksheet - this is the number of it return key, worksheet_num def parse_worksheet(worksheet): if not worksheet: - return {'columns': [], 'rows': []} + return {"columns": [], "rows": []} columns, column_names = _get_columns_and_column_names(worksheet[HEADER_INDEX]) if len(worksheet) > 1: for j, value in enumerate(worksheet[HEADER_INDEX + 1]): - columns[j]['type'] = guess_type(value) + columns[j]["type"] = guess_type(value) - column_types = [c['type'] for c in columns] - rows = [dict(zip(column_names, _value_eval_list(row, column_types))) for row in worksheet[HEADER_INDEX + 1:]] - data = {'columns': columns, 'rows': rows} + column_types = [c["type"] for c in columns] + rows = [ + dict(zip(column_names, _value_eval_list(row, column_types))) + for row in worksheet[HEADER_INDEX + 1 :] + ] + data = {"columns": columns, "rows": rows} return data @@ -118,14 +123,14 @@ def parse_spreadsheet(spreadsheet, worksheet_num): def is_url_key(key): - return key.startswith('https://') + return key.startswith("https://") def parse_api_error(error): error_data = error.response.json() - if 'error' in error_data and 'message' in error_data['error']: - message = error_data['error']['message'] + if "error" in error_data and "message" in error_data["error"]: + message = error_data["error"]["message"] else: message = error.message @@ -134,7 +139,7 @@ def parse_api_error(error): class TimeoutSession(Session): def request(self, *args, **kwargs): - kwargs.setdefault('timeout', 300) + kwargs.setdefault("timeout", 300) return super(TimeoutSession, self).request(*args, **kwargs) @@ -143,7 +148,7 @@ class GoogleSpreadsheet(BaseQueryRunner): def __init__(self, configuration): super(GoogleSpreadsheet, self).__init__(configuration) - self.syntax = 'custom' + self.syntax = "custom" @classmethod def name(cls): @@ -160,23 +165,16 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'jsonKeyFile': { - "type": "string", - 'title': 'JSON Key File' - } - }, - 'required': ['jsonKeyFile'], - 'secret': ['jsonKeyFile'] + "type": "object", + "properties": {"jsonKeyFile": {"type": "string", "title": "JSON Key File"}}, + "required": ["jsonKeyFile"], + "secret": ["jsonKeyFile"], } def _get_spreadsheet_service(self): - scope = [ - 'https://spreadsheets.google.com/feeds', - ] + scope = ["https://spreadsheets.google.com/feeds"] - key = json_loads(b64decode(self.configuration['jsonKeyFile'])) + key = json_loads(b64decode(self.configuration["jsonKeyFile"])) creds = ServiceAccountCredentials.from_json_keyfile_dict(key, scope) timeout_session = Session() @@ -187,7 +185,7 @@ def _get_spreadsheet_service(self): def test_connection(self): service = self._get_spreadsheet_service() - test_spreadsheet_key = '1S0mld7LMbUad8LYlo13Os9f7eNjw57MqVC0YiCd1Jis' + test_spreadsheet_key = "1S0mld7LMbUad8LYlo13Os9f7eNjw57MqVC0YiCd1Jis" try: service.open_by_key(test_spreadsheet_key).worksheets() except APIError as e: @@ -210,7 +208,12 @@ def run_query(self, query, user): return json_dumps(data), None except gspread.SpreadsheetNotFound: - return None, "Spreadsheet ({}) not found. Make sure you used correct id.".format(key) + return ( + None, + "Spreadsheet ({}) not found. Make sure you used correct id.".format( + key + ), + ) except APIError as e: return None, parse_api_error(e) diff --git a/redash/query_runner/graphite.py b/redash/query_runner/graphite.py index 711584c70d..a04acf6c19 100644 --- a/redash/query_runner/graphite.py +++ b/redash/query_runner/graphite.py @@ -10,18 +10,26 @@ def _transform_result(response): - columns = ({'name': 'Time::x', 'type': TYPE_DATETIME}, - {'name': 'value::y', 'type': TYPE_FLOAT}, - {'name': 'name::series', 'type': TYPE_STRING}) + columns = ( + {"name": "Time::x", "type": TYPE_DATETIME}, + {"name": "value::y", "type": TYPE_FLOAT}, + {"name": "name::series", "type": TYPE_STRING}, + ) rows = [] for series in response.json(): - for values in series['datapoints']: + for values in series["datapoints"]: timestamp = datetime.datetime.fromtimestamp(int(values[1])) - rows.append({'Time::x': timestamp, 'name::series': series['target'], 'value::y': values[0]}) + rows.append( + { + "Time::x": timestamp, + "name::series": series["target"], + "value::y": values[0], + } + ) - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} return json_dumps(data) @@ -31,29 +39,20 @@ class Graphite(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string' - }, - 'username': { - 'type': 'string' - }, - 'password': { - 'type': 'string' - }, - 'verify': { - 'type': 'boolean', - 'title': 'Verify SSL certificate' - } + "type": "object", + "properties": { + "url": {"type": "string"}, + "username": {"type": "string"}, + "password": {"type": "string"}, + "verify": {"type": "boolean", "title": "Verify SSL certificate"}, }, - 'required': ['url'], - 'secret': ['password'] + "required": ["url"], + "secret": ["password"], } def __init__(self, configuration): super(Graphite, self).__init__(configuration) - self.syntax = 'custom' + self.syntax = "custom" if "username" in self.configuration and self.configuration["username"]: self.auth = (self.configuration["username"], self.configuration["password"]) @@ -61,12 +60,20 @@ def __init__(self, configuration): self.auth = None self.verify = self.configuration.get("verify", True) - self.base_url = "%s/render?format=json&" % self.configuration['url'] + self.base_url = "%s/render?format=json&" % self.configuration["url"] def test_connection(self): - r = requests.get("{}/render".format(self.configuration['url']), auth=self.auth, verify=self.verify) + r = requests.get( + "{}/render".format(self.configuration["url"]), + auth=self.auth, + verify=self.verify, + ) if r.status_code != 200: - raise Exception("Got invalid response from Graphite (http status code: {0}).".format(r.status_code)) + raise Exception( + "Got invalid response from Graphite (http status code: {0}).".format( + r.status_code + ) + ) def run_query(self, query, user): url = "%s%s" % (self.base_url, "&".join(query.split("\n"))) diff --git a/redash/query_runner/hive_ds.py b/redash/query_runner/hive_ds.py index 59555f6589..c4900ba60a 100644 --- a/redash/query_runner/hive_ds.py +++ b/redash/query_runner/hive_ds.py @@ -11,6 +11,7 @@ from pyhive import hive from pyhive.exc import DatabaseError from thrift.transport import THttpClient + enabled = True except ImportError: enabled = False @@ -19,20 +20,20 @@ COLUMN_TYPE = 1 types_map = { - 'BIGINT_TYPE': TYPE_INTEGER, - 'TINYINT_TYPE': TYPE_INTEGER, - 'SMALLINT_TYPE': TYPE_INTEGER, - 'INT_TYPE': TYPE_INTEGER, - 'DOUBLE_TYPE': TYPE_FLOAT, - 'DECIMAL_TYPE': TYPE_FLOAT, - 'FLOAT_TYPE': TYPE_FLOAT, - 'REAL_TYPE': TYPE_FLOAT, - 'BOOLEAN_TYPE': TYPE_BOOLEAN, - 'TIMESTAMP_TYPE': TYPE_DATETIME, - 'DATE_TYPE': TYPE_DATETIME, - 'CHAR_TYPE': TYPE_STRING, - 'STRING_TYPE': TYPE_STRING, - 'VARCHAR_TYPE': TYPE_STRING + "BIGINT_TYPE": TYPE_INTEGER, + "TINYINT_TYPE": TYPE_INTEGER, + "SMALLINT_TYPE": TYPE_INTEGER, + "INT_TYPE": TYPE_INTEGER, + "DOUBLE_TYPE": TYPE_FLOAT, + "DECIMAL_TYPE": TYPE_FLOAT, + "FLOAT_TYPE": TYPE_FLOAT, + "REAL_TYPE": TYPE_FLOAT, + "BOOLEAN_TYPE": TYPE_BOOLEAN, + "TIMESTAMP_TYPE": TYPE_DATETIME, + "DATE_TYPE": TYPE_DATETIME, + "CHAR_TYPE": TYPE_STRING, + "STRING_TYPE": TYPE_STRING, + "VARCHAR_TYPE": TYPE_STRING, } @@ -45,21 +46,13 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, - "database": { - "type": "string" - }, - "username": { - "type": "string" - }, + "host": {"type": "string"}, + "port": {"type": "number"}, + "database": {"type": "string"}, + "username": {"type": "string"}, }, "order": ["host", "port", "database", "username"], - "required": ["host"] + "required": ["host"], } @classmethod @@ -77,24 +70,46 @@ def _get_tables(self, schema): columns_query = "show columns in %s.%s" - for schema_name in [a for a in [str(a['database_name']) for a in self._run_query_internal(schemas_query)] if len(a) > 0]: - for table_name in [a for a in [str(a['tab_name']) for a in self._run_query_internal(tables_query % schema_name)] if len(a) > 0]: - columns = [a for a in [str(a['field']) for a in self._run_query_internal(columns_query % (schema_name, table_name))] if len(a) > 0] - - if schema_name != 'default': - table_name = '{}.{}'.format(schema_name, table_name) - - schema[table_name] = {'name': table_name, 'columns': columns} + for schema_name in [ + a + for a in [ + str(a["database_name"]) for a in self._run_query_internal(schemas_query) + ] + if len(a) > 0 + ]: + for table_name in [ + a + for a in [ + str(a["tab_name"]) + for a in self._run_query_internal(tables_query % schema_name) + ] + if len(a) > 0 + ]: + columns = [ + a + for a in [ + str(a["field"]) + for a in self._run_query_internal( + columns_query % (schema_name, table_name) + ) + ] + if len(a) > 0 + ] + + if schema_name != "default": + table_name = "{}.{}".format(schema_name, table_name) + + schema[table_name] = {"name": table_name, "columns": columns} return list(schema.values()) def _get_connection(self): - host = self.configuration['host'] + host = self.configuration["host"] connection = hive.connect( host=host, - port=self.configuration.get('port', None), - database=self.configuration.get('database', 'default'), - username=self.configuration.get('username', None), + port=self.configuration.get("port", None), + database=self.configuration.get("database", "default"), + username=self.configuration.get("username", None), ) return connection @@ -114,15 +129,17 @@ def run_query(self, query, user): column_name = column[COLUMN_NAME] column_names.append(column_name) - columns.append({ - 'name': column_name, - 'friendly_name': column_name, - 'type': types_map.get(column[COLUMN_TYPE], None) - }) + columns.append( + { + "name": column_name, + "friendly_name": column_name, + "type": types_map.get(column[COLUMN_TYPE], None), + } + ) rows = [dict(zip(column_names, row)) for row in cursor] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None except KeyboardInterrupt: @@ -150,58 +167,52 @@ def name(cls): @classmethod def type(cls): - return 'hive_http' + return "hive_http" @classmethod def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, - "database": { - "type": "string" - }, - "username": { - "type": "string" - }, + "host": {"type": "string"}, + "port": {"type": "number"}, + "database": {"type": "string"}, + "username": {"type": "string"}, "http_scheme": { "type": "string", "title": "HTTP Scheme (http or https)", - "default": "https" - }, - "http_path": { - "type": "string", - "title": "HTTP Path" - }, - "http_password": { - "type": "string", - "title": "Password" + "default": "https", }, + "http_path": {"type": "string", "title": "HTTP Path"}, + "http_password": {"type": "string", "title": "Password"}, }, - "order": ["host", "port", "http_path", "username", "http_password", "database", "http_scheme"], + "order": [ + "host", + "port", + "http_path", + "username", + "http_password", + "database", + "http_scheme", + ], "secret": ["http_password"], - "required": ["host", "http_path"] + "required": ["host", "http_path"], } def _get_connection(self): - host = self.configuration['host'] + host = self.configuration["host"] - scheme = self.configuration.get('http_scheme', 'https') + scheme = self.configuration.get("http_scheme", "https") # if path is set but is missing initial slash, append it - path = self.configuration.get('http_path', '') - if path and path[0] != '/': - path = '/' + path + path = self.configuration.get("http_path", "") + if path and path[0] != "/": + path = "/" + path # if port is set prepend colon - port = self.configuration.get('port', '') + port = self.configuration.get("port", "") if port: - port = ':' + str(port) + port = ":" + str(port) http_uri = "{}://{}{}{}".format(scheme, host, port, path) @@ -209,11 +220,11 @@ def _get_connection(self): transport = THttpClient.THttpClient(http_uri) # if username or password is set, add Authorization header - username = self.configuration.get('username', '') - password = self.configuration.get('http_password', '') + username = self.configuration.get("username", "") + password = self.configuration.get("http_password", "") if username or password: - auth = base64.b64encode(username + ':' + password) - transport.setCustomHeaders({'Authorization': 'Basic ' + auth}) + auth = base64.b64encode(username + ":" + password) + transport.setCustomHeaders({"Authorization": "Basic " + auth}) # create connection connection = hive.connect(thrift_transport=transport) diff --git a/redash/query_runner/impala_ds.py b/redash/query_runner/impala_ds.py index 7f64164f73..c5ea952a87 100644 --- a/redash/query_runner/impala_ds.py +++ b/redash/query_runner/impala_ds.py @@ -8,6 +8,7 @@ try: from impala.dbapi import connect from impala.error import DatabaseError, RPCError + enabled = True except ImportError as e: enabled = False @@ -16,19 +17,19 @@ COLUMN_TYPE = 1 types_map = { - 'BIGINT': TYPE_INTEGER, - 'TINYINT': TYPE_INTEGER, - 'SMALLINT': TYPE_INTEGER, - 'INT': TYPE_INTEGER, - 'DOUBLE': TYPE_FLOAT, - 'DECIMAL': TYPE_FLOAT, - 'FLOAT': TYPE_FLOAT, - 'REAL': TYPE_FLOAT, - 'BOOLEAN': TYPE_BOOLEAN, - 'TIMESTAMP': TYPE_DATETIME, - 'CHAR': TYPE_STRING, - 'STRING': TYPE_STRING, - 'VARCHAR': TYPE_STRING + "BIGINT": TYPE_INTEGER, + "TINYINT": TYPE_INTEGER, + "SMALLINT": TYPE_INTEGER, + "INT": TYPE_INTEGER, + "DOUBLE": TYPE_FLOAT, + "DECIMAL": TYPE_FLOAT, + "FLOAT": TYPE_FLOAT, + "REAL": TYPE_FLOAT, + "BOOLEAN": TYPE_BOOLEAN, + "TIMESTAMP": TYPE_DATETIME, + "CHAR": TYPE_STRING, + "STRING": TYPE_STRING, + "VARCHAR": TYPE_STRING, } @@ -40,38 +41,24 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, + "host": {"type": "string"}, + "port": {"type": "number"}, "protocol": { "type": "string", "extendedEnum": [ {"value": "beeswax", "name": "Beeswax"}, - {"value": "hiveserver2", "name": "Hive Server 2"} + {"value": "hiveserver2", "name": "Hive Server 2"}, ], - "title": "Protocol" - }, - "database": { - "type": "string" - }, - "use_ldap": { - "type": "boolean" + "title": "Protocol", }, - "ldap_user": { - "type": "string" - }, - "ldap_password": { - "type": "string" - }, - "timeout": { - "type": "number" - } + "database": {"type": "string"}, + "use_ldap": {"type": "boolean"}, + "ldap_user": {"type": "string"}, + "ldap_password": {"type": "string"}, + "timeout": {"type": "number"}, }, "required": ["host"], - "secret": ["ldap_password"] + "secret": ["ldap_password"], } @classmethod @@ -83,14 +70,24 @@ def _get_tables(self, schema_dict): tables_query = "show tables in %s;" columns_query = "show column stats %s.%s;" - for schema_name in [str(a['name']) for a in self._run_query_internal(schemas_query)]: - for table_name in [str(a['name']) for a in self._run_query_internal(tables_query % schema_name)]: - columns = [str(a['Column']) for a in self._run_query_internal(columns_query % (schema_name, table_name))] - - if schema_name != 'default': - table_name = '{}.{}'.format(schema_name, table_name) - - schema_dict[table_name] = {'name': table_name, 'columns': columns} + for schema_name in [ + str(a["name"]) for a in self._run_query_internal(schemas_query) + ]: + for table_name in [ + str(a["name"]) + for a in self._run_query_internal(tables_query % schema_name) + ]: + columns = [ + str(a["Column"]) + for a in self._run_query_internal( + columns_query % (schema_name, table_name) + ) + ] + + if schema_name != "default": + table_name = "{}.{}".format(schema_name, table_name) + + schema_dict[table_name] = {"name": table_name, "columns": columns} return list(schema_dict.values()) @@ -111,15 +108,17 @@ def run_query(self, query, user): column_name = column[COLUMN_NAME] column_names.append(column_name) - columns.append({ - 'name': column_name, - 'friendly_name': column_name, - 'type': types_map.get(column[COLUMN_TYPE], None) - }) + columns.append( + { + "name": column_name, + "friendly_name": column_name, + "type": types_map.get(column[COLUMN_TYPE], None), + } + ) rows = [dict(zip(column_names, row)) for row in cursor] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None cursor.close() diff --git a/redash/query_runner/influx_db.py b/redash/query_runner/influx_db.py index bec53c8d27..c9cfe7567f 100644 --- a/redash/query_runner/influx_db.py +++ b/redash/query_runner/influx_db.py @@ -7,6 +7,7 @@ try: from influxdb import InfluxDBClusterClient + enabled = True except ImportError: @@ -18,33 +19,32 @@ def _transform_result(results): result_rows = [] for result in results: - for series in result.raw.get('series', []): - for column in series['columns']: + for series in result.raw.get("series", []): + for column in series["columns"]: if column not in result_columns: result_columns.append(column) - tags = series.get('tags', {}) + tags = series.get("tags", {}) for key in tags.keys(): if key not in result_columns: result_columns.append(key) for result in results: - for series in result.raw.get('series', []): - for point in series['values']: + for series in result.raw.get("series", []): + for point in series["values"]: result_row = {} for column in result_columns: - tags = series.get('tags', {}) + tags = series.get("tags", {}) if column in tags: result_row[column] = tags[column] - elif column in series['columns']: - index = series['columns'].index(column) + elif column in series["columns"]: + index = series["columns"].index(column) value = point[index] result_row[column] = value result_rows.append(result_row) - return json_dumps({ - "columns": [{'name': c} for c in result_columns], - "rows": result_rows - }) + return json_dumps( + {"columns": [{"name": c} for c in result_columns], "rows": result_rows} + ) class InfluxDB(BaseQueryRunner): @@ -54,13 +54,9 @@ class InfluxDB(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string' - } - }, - 'required': ['url'] + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], } @classmethod @@ -72,9 +68,9 @@ def type(cls): return "influxdb" def run_query(self, query, user): - client = InfluxDBClusterClient.from_DSN(self.configuration['url']) + client = InfluxDBClusterClient.from_DSN(self.configuration["url"]) - logger.debug("influxdb url: %s", self.configuration['url']) + logger.debug("influxdb url: %s", self.configuration["url"]) logger.debug("influxdb got query: %s", query) try: diff --git a/redash/query_runner/jql.py b/redash/query_runner/jql.py index 7b10387353..c696534695 100644 --- a/redash/query_runner/jql.py +++ b/redash/query_runner/jql.py @@ -19,10 +19,14 @@ def add_row(self, row): def add_column(self, column, column_type=TYPE_STRING): if column not in self.columns: - self.columns[column] = {'name': column, 'type': column_type, 'friendly_name': column} + self.columns[column] = { + "name": column, + "type": column_type, + "friendly_name": column, + } def to_json(self): - return json_dumps({'rows': self.rows, 'columns': list(self.columns.values())}) + return json_dumps({"rows": self.rows, "columns": list(self.columns.values())}) def merge(self, set): self.rows = self.rows + set.rows @@ -30,9 +34,9 @@ def merge(self, set): def parse_issue(issue, field_mapping): result = OrderedDict() - result['key'] = issue['key'] + result["key"] = issue["key"] - for k, v in issue['fields'].items():# + for k, v in issue["fields"].items(): # output_name = field_mapping.get_output_field_name(k) member_names = field_mapping.get_dict_members(k) @@ -41,20 +45,22 @@ def parse_issue(issue, field_mapping): # if field mapping with dict member mappings defined get value of each member for member_name in member_names: if member_name in v: - result[field_mapping.get_dict_output_field_name(k, member_name)] = v[member_name] + result[ + field_mapping.get_dict_output_field_name(k, member_name) + ] = v[member_name] else: # these special mapping rules are kept for backwards compatibility - if 'key' in v: - result['{}_key'.format(output_name)] = v['key'] - if 'name' in v: - result['{}_name'.format(output_name)] = v['name'] + if "key" in v: + result["{}_key".format(output_name)] = v["key"] + if "name" in v: + result["{}_name".format(output_name)] = v["name"] if k in v: result[output_name] = v[k] - if 'watchCount' in v: - result[output_name] = v['watchCount'] + if "watchCount" in v: + result[output_name] = v["watchCount"] elif isinstance(v, list): if len(member_names) > 0: @@ -66,7 +72,9 @@ def parse_issue(issue, field_mapping): if member_name in listItem: listValues.append(listItem[member_name]) if len(listValues) > 0: - result[field_mapping.get_dict_output_field_name(k, member_name)] = ','.join(listValues) + result[ + field_mapping.get_dict_output_field_name(k, member_name) + ] = ",".join(listValues) else: # otherwise support list values only for non-dict items @@ -75,7 +83,7 @@ def parse_issue(issue, field_mapping): if not isinstance(listItem, dict): listValues.append(listItem) if len(listValues) > 0: - result[output_name] = ','.join(listValues) + result[output_name] = ",".join(listValues) else: result[output_name] = v @@ -86,7 +94,7 @@ def parse_issue(issue, field_mapping): def parse_issues(data, field_mapping): results = ResultSet() - for issue in data['issues']: + for issue in data["issues"]: results.add_row(parse_issue(issue, field_mapping)) return results @@ -94,12 +102,11 @@ def parse_issues(data, field_mapping): def parse_count(data): results = ResultSet() - results.add_row({'count': data['total']}) + results.add_row({"count": data["total"]}) return results class FieldMapping: - def __init__(cls, query_field_mapping): cls.mapping = [] for k, v in query_field_mapping.items(): @@ -107,34 +114,36 @@ def __init__(cls, query_field_mapping): member_name = None # check for member name contained in field name - member_parser = re.search('(\w+)\.(\w+)', k) - if (member_parser): + member_parser = re.search("(\w+)\.(\w+)", k) + if member_parser: field_name = member_parser.group(1) member_name = member_parser.group(2) - cls.mapping.append({ - 'field_name': field_name, - 'member_name': member_name, - 'output_field_name': v - }) + cls.mapping.append( + { + "field_name": field_name, + "member_name": member_name, + "output_field_name": v, + } + ) def get_output_field_name(cls, field_name): for item in cls.mapping: - if item['field_name'] == field_name and not item['member_name']: - return item['output_field_name'] + if item["field_name"] == field_name and not item["member_name"]: + return item["output_field_name"] return field_name def get_dict_members(cls, field_name): member_names = [] for item in cls.mapping: - if item['field_name'] == field_name and item['member_name']: - member_names.append(item['member_name']) + if item["field_name"] == field_name and item["member_name"]: + member_names.append(item["member_name"]) return member_names def get_dict_output_field_name(cls, field_name, member_name): for item in cls.mapping: - if item['field_name'] == field_name and item['member_name'] == member_name: - return item['output_field_name'] + if item["field_name"] == field_name and item["member_name"] == member_name: + return item["output_field_name"] return None @@ -142,9 +151,9 @@ class JiraJQL(BaseHTTPQueryRunner): noop_query = '{"queryType": "count"}' response_error = "JIRA returned unexpected status code" requires_authentication = True - url_title = 'JIRA URL' - username_title = 'Username' - password_title = 'API Token' + url_title = "JIRA URL" + username_title = "Username" + password_title = "API Token" @classmethod def name(cls): @@ -152,21 +161,21 @@ def name(cls): def __init__(self, configuration): super(JiraJQL, self).__init__(configuration) - self.syntax = 'json' + self.syntax = "json" def run_query(self, query, user): - jql_url = '{}/rest/api/2/search'.format(self.configuration["url"]) + jql_url = "{}/rest/api/2/search".format(self.configuration["url"]) try: query = json_loads(query) - query_type = query.pop('queryType', 'select') - field_mapping = FieldMapping(query.pop('fieldMapping', {})) + query_type = query.pop("queryType", "select") + field_mapping = FieldMapping(query.pop("fieldMapping", {})) - if query_type == 'count': - query['maxResults'] = 1 - query['fields'] = '' + if query_type == "count": + query["maxResults"] = 1 + query["fields"] = "" else: - query['maxResults'] = query.get('maxResults', 1000) + query["maxResults"] = query.get("maxResults", 1000) response, error = self.get_response(jql_url, params=query) if error is not None: @@ -174,20 +183,20 @@ def run_query(self, query, user): data = response.json() - if query_type == 'count': + if query_type == "count": results = parse_count(data) else: results = parse_issues(data, field_mapping) - index = data['startAt'] + data['maxResults'] + index = data["startAt"] + data["maxResults"] - while data['total'] > index: - query['startAt'] = index + while data["total"] > index: + query["startAt"] = index response, error = self.get_response(jql_url, params=query) if error is not None: return None, error data = response.json() - index = data['startAt'] + data['maxResults'] + index = data["startAt"] + data["maxResults"] addl_results = parse_issues(data, field_mapping) results.merge(addl_results) diff --git a/redash/query_runner/json_ds.py b/redash/query_runner/json_ds.py index 7467a52288..2a9cba3508 100644 --- a/redash/query_runner/json_ds.py +++ b/redash/query_runner/json_ds.py @@ -7,9 +7,15 @@ from funcy import compact, project from six import text_type from redash.utils import json_dumps -from redash.query_runner import (BaseHTTPQueryRunner, register, - TYPE_BOOLEAN, TYPE_DATETIME, TYPE_FLOAT, - TYPE_INTEGER, TYPE_STRING) +from redash.query_runner import ( + BaseHTTPQueryRunner, + register, + TYPE_BOOLEAN, + TYPE_DATETIME, + TYPE_FLOAT, + TYPE_INTEGER, + TYPE_STRING, +) class QueryParseError(Exception): @@ -60,26 +66,23 @@ def _get_type(value): def add_column(columns, column_name, column_type): if _get_column_by_name(columns, column_name) is None: - columns.append({ - "name": column_name, - "friendly_name": column_name, - "type": column_type - }) + columns.append( + {"name": column_name, "friendly_name": column_name, "type": column_type} + ) def _apply_path_search(response, path): if path is None: return response - path_parts = path.split('.') + path_parts = path.split(".") path_parts.reverse() while len(path_parts) > 0: current_path = path_parts.pop() if current_path in response: response = response[current_path] else: - raise Exception( - "Couldn't find path {} in response.".format(path)) + raise Exception("Couldn't find path {} in response.".format(path)) return response @@ -88,15 +91,14 @@ def _normalize_json(data, path): data = _apply_path_search(data, path) if isinstance(data, dict): - data = [data, ] + data = [data] return data def _sort_columns_with_fields(columns, fields): if fields: - columns = compact( - [_get_column_by_name(columns, field) for field in fields]) + columns = compact([_get_column_by_name(columns, field) for field in fields]) return columns @@ -114,7 +116,7 @@ def parse_json(data, path, fields): for key in row: if isinstance(row[key], dict): for inner_key in row[key]: - column_name = '{}.{}'.format(key, inner_key) + column_name = "{}.{}".format(key, inner_key) if fields and key not in fields and column_name not in fields: continue @@ -133,7 +135,7 @@ def parse_json(data, path, fields): columns = _sort_columns_with_fields(columns, fields) - return {'rows': rows, 'columns': columns} + return {"rows": rows, "columns": columns} class JSON(BaseHTTPQueryRunner): @@ -142,24 +144,18 @@ class JSON(BaseHTTPQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'username': { - 'type': 'string', - 'title': cls.username_title, - }, - 'password': { - 'type': 'string', - 'title': cls.password_title, - }, + "type": "object", + "properties": { + "username": {"type": "string", "title": cls.username_title}, + "password": {"type": "string", "title": cls.password_title}, }, - 'secret': ['password'], - 'order': ['username', 'password'] + "secret": ["password"], + "order": ["username", "password"], } def __init__(self, configuration): super(JSON, self).__init__(configuration) - self.syntax = 'yaml' + self.syntax = "yaml" def test_connection(self): pass @@ -170,37 +166,42 @@ def run_query(self, query, user): if not isinstance(query, dict): raise QueryParseError( - "Query should be a YAML object describing the URL to query.") + "Query should be a YAML object describing the URL to query." + ) - if 'url' not in query: + if "url" not in query: raise QueryParseError("Query must include 'url' option.") - if is_private_address(query['url']): + if is_private_address(query["url"]): raise Exception("Can't query private addresses.") - method = query.get('method', 'get') + method = query.get("method", "get") request_options = project( - query, ('params', 'headers', 'data', 'auth', 'json',)) - - fields = query.get('fields') - path = query.get('path') - - if isinstance(request_options.get('auth', None), list): - request_options['auth'] = tuple(request_options['auth']) - elif self.configuration.get('username') or self.configuration.get('password'): - request_options['auth'] = (self.configuration.get( - 'username'), self.configuration.get('password')) - - if method not in ('get', 'post'): + query, ("params", "headers", "data", "auth", "json") + ) + + fields = query.get("fields") + path = query.get("path") + + if isinstance(request_options.get("auth", None), list): + request_options["auth"] = tuple(request_options["auth"]) + elif self.configuration.get("username") or self.configuration.get( + "password" + ): + request_options["auth"] = ( + self.configuration.get("username"), + self.configuration.get("password"), + ) + + if method not in ("get", "post"): raise QueryParseError("Only GET or POST methods are allowed.") if fields and not isinstance(fields, list): raise QueryParseError("'fields' needs to be a list.") response, error = self.get_response( - query['url'], - http_method=method, - **request_options) + query["url"], http_method=method, **request_options + ) if error is not None: return None, error @@ -210,7 +211,7 @@ def run_query(self, query, user): if data: return data, None else: - return None, "Got empty response from '{}'.".format(query['url']) + return None, "Got empty response from '{}'.".format(query["url"]) except KeyboardInterrupt: return None, "Query cancelled by user." diff --git a/redash/query_runner/kylin.py b/redash/query_runner/kylin.py index c08954d3a1..cfc02c671f 100644 --- a/redash/query_runner/kylin.py +++ b/redash/query_runner/kylin.py @@ -10,40 +10,40 @@ logger = logging.getLogger(__name__) types_map = { - 'tinyint': TYPE_INTEGER, - 'smallint': TYPE_INTEGER, - 'integer': TYPE_INTEGER, - 'bigint': TYPE_INTEGER, - 'int4': TYPE_INTEGER, - 'long8': TYPE_INTEGER, - 'int': TYPE_INTEGER, - 'short': TYPE_INTEGER, - 'long': TYPE_INTEGER, - 'byte': TYPE_INTEGER, - 'hllc10': TYPE_INTEGER, - 'hllc12': TYPE_INTEGER, - 'hllc14': TYPE_INTEGER, - 'hllc15': TYPE_INTEGER, - 'hllc16': TYPE_INTEGER, - 'hllc(10)': TYPE_INTEGER, - 'hllc(12)': TYPE_INTEGER, - 'hllc(14)': TYPE_INTEGER, - 'hllc(15)': TYPE_INTEGER, - 'hllc(16)': TYPE_INTEGER, - 'float': TYPE_FLOAT, - 'double': TYPE_FLOAT, - 'decimal': TYPE_FLOAT, - 'real': TYPE_FLOAT, - 'numeric': TYPE_FLOAT, - 'boolean': TYPE_BOOLEAN, - 'bool': TYPE_BOOLEAN, - 'date': TYPE_DATE, - 'datetime': TYPE_DATETIME, - 'timestamp': TYPE_DATETIME, - 'time': TYPE_DATETIME, - 'varchar': TYPE_STRING, - 'char': TYPE_STRING, - 'string': TYPE_STRING, + "tinyint": TYPE_INTEGER, + "smallint": TYPE_INTEGER, + "integer": TYPE_INTEGER, + "bigint": TYPE_INTEGER, + "int4": TYPE_INTEGER, + "long8": TYPE_INTEGER, + "int": TYPE_INTEGER, + "short": TYPE_INTEGER, + "long": TYPE_INTEGER, + "byte": TYPE_INTEGER, + "hllc10": TYPE_INTEGER, + "hllc12": TYPE_INTEGER, + "hllc14": TYPE_INTEGER, + "hllc15": TYPE_INTEGER, + "hllc16": TYPE_INTEGER, + "hllc(10)": TYPE_INTEGER, + "hllc(12)": TYPE_INTEGER, + "hllc(14)": TYPE_INTEGER, + "hllc(15)": TYPE_INTEGER, + "hllc(16)": TYPE_INTEGER, + "float": TYPE_FLOAT, + "double": TYPE_FLOAT, + "decimal": TYPE_FLOAT, + "real": TYPE_FLOAT, + "numeric": TYPE_FLOAT, + "boolean": TYPE_BOOLEAN, + "bool": TYPE_BOOLEAN, + "date": TYPE_DATE, + "datetime": TYPE_DATETIME, + "timestamp": TYPE_DATETIME, + "time": TYPE_DATETIME, + "varchar": TYPE_STRING, + "char": TYPE_STRING, + "string": TYPE_STRING, } @@ -53,34 +53,25 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string", - "title": "Kylin Username", - }, - "password": { - "type": "string", - "title": "Kylin Password", - }, + "user": {"type": "string", "title": "Kylin Username"}, + "password": {"type": "string", "title": "Kylin Password"}, "url": { "type": "string", "title": "Kylin API URL", "default": "http://kylin.example.com/kylin/", }, - "project": { - "type": "string", - "title": "Kylin Project", - }, + "project": {"type": "string", "title": "Kylin Project"}, }, - "order": ['url', 'project', 'user', 'password'], + "order": ["url", "project", "user", "password"], "required": ["url", "project", "user", "password"], - "secret": ["password"] + "secret": ["password"], } def run_query(self, query, user): - url = self.configuration['url'] - kylinuser = self.configuration['user'] - kylinpass = self.configuration['password'] - kylinproject = self.configuration['project'] + url = self.configuration["url"] + kylinuser = self.configuration["user"] + kylinpass = self.configuration["password"] + kylinproject = self.configuration["project"] resp = requests.post( os.path.join(url, "api/query"), @@ -90,24 +81,24 @@ def run_query(self, query, user): "offset": settings.KYLIN_OFFSET, "limit": settings.KYLIN_LIMIT, "acceptPartial": settings.KYLIN_ACCEPT_PARTIAL, - "project": kylinproject - } + "project": kylinproject, + }, ) if not resp.ok: return {}, resp.text or str(resp.reason) data = resp.json() - columns = self.get_columns(data['columnMetas']) - rows = self.get_rows(columns, data['results']) + columns = self.get_columns(data["columnMetas"]) + rows = self.get_rows(columns, data["results"]) - return json_dumps({'columns': columns, 'rows': rows}), None + return json_dumps({"columns": columns, "rows": rows}), None def get_schema(self, get_stats=False): - url = self.configuration['url'] - kylinuser = self.configuration['user'] - kylinpass = self.configuration['password'] - kylinproject = self.configuration['project'] + url = self.configuration["url"] + kylinuser = self.configuration["user"] + kylinpass = self.configuration["password"] + kylinproject = self.configuration["project"] resp = requests.get( os.path.join(url, "api/tables_and_columns"), @@ -121,24 +112,28 @@ def get_schema(self, get_stats=False): return [self.get_table_schema(table) for table in data] def test_connection(self): - url = self.configuration['url'] + url = self.configuration["url"] requests.get(url).raise_for_status() def get_columns(self, colmetas): - return self.fetch_columns([ - (meta['name'], types_map.get(meta['columnTypeName'].lower(), TYPE_STRING)) - for meta in colmetas - ]) + return self.fetch_columns( + [ + ( + meta["name"], + types_map.get(meta["columnTypeName"].lower(), TYPE_STRING), + ) + for meta in colmetas + ] + ) def get_rows(self, columns, results): return [ - dict(zip((column['name'] for column in columns), row)) - for row in results + dict(zip((column["name"] for column in columns), row)) for row in results ] def get_table_schema(self, table): - name = table['table_NAME'] - columns = [col['column_NAME'].lower() for col in table['columns']] + name = table["table_NAME"] + columns = [col["column_NAME"].lower() for col in table["columns"]] return {"name": name, "columns": columns} diff --git a/redash/query_runner/mapd.py b/redash/query_runner/mapd.py index d4e6eaef0d..45f77cc273 100644 --- a/redash/query_runner/mapd.py +++ b/redash/query_runner/mapd.py @@ -1,13 +1,19 @@ - - try: import pymapd + enabled = True except ImportError: enabled = False from redash.query_runner import BaseSQLQueryRunner, register -from redash.query_runner import TYPE_STRING, TYPE_DATE, TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN +from redash.query_runner import ( + TYPE_STRING, + TYPE_DATE, + TYPE_DATETIME, + TYPE_INTEGER, + TYPE_FLOAT, + TYPE_BOOLEAN, +) from redash.utils import json_dumps TYPES_MAP = { @@ -23,42 +29,25 @@ 9: TYPE_DATE, 10: TYPE_BOOLEAN, 11: TYPE_DATE, - 12: TYPE_DATE + 12: TYPE_DATE, } class Mapd(BaseSQLQueryRunner): - @classmethod def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string", - "default": "localhost" - }, - "port": { - "type": "number", - "default": 9091 - }, - "user": { - "type": "string", - "default": "mapd", - "title": "username" - }, - "password": { - "type": "string", - "default": "HyperInteractive" - }, - "database": { - "type": "string", - "default": "mapd" - } + "host": {"type": "string", "default": "localhost"}, + "port": {"type": "number", "default": 9091}, + "user": {"type": "string", "default": "mapd", "title": "username"}, + "password": {"type": "string", "default": "HyperInteractive"}, + "database": {"type": "string", "default": "mapd"}, }, "order": ["user", "password", "host", "port", "database"], "required": ["host", "port", "user", "password", "database"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -67,12 +56,12 @@ def enabled(cls): def connect_database(self): connection = pymapd.connect( - user=self.configuration['user'], - password=self.configuration['password'], - host=self.configuration['host'], - port=self.configuration['port'], - dbname=self.configuration['database'] - ) + user=self.configuration["user"], + password=self.configuration["password"], + host=self.configuration["host"], + port=self.configuration["port"], + dbname=self.configuration["database"], + ) return connection def run_query(self, query, user): @@ -81,9 +70,13 @@ def run_query(self, query, user): try: cursor.execute(query) - columns = self.fetch_columns([(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in cursor] - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [(i[0], TYPES_MAP.get(i[1], None)) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) for row in cursor + ] + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) finally: @@ -96,9 +89,9 @@ def _get_tables(self, schema): connection = self.connect_database() try: for table_name in connection.get_tables(): - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} for row_column in connection.get_table_details(table_name): - schema[table_name]['columns'].append(row_column[0]) + schema[table_name]["columns"].append(row_column[0]) finally: connection.close diff --git a/redash/query_runner/memsql_ds.py b/redash/query_runner/memsql_ds.py index 4525920e03..56443bac96 100644 --- a/redash/query_runner/memsql_ds.py +++ b/redash/query_runner/memsql_ds.py @@ -8,6 +8,7 @@ try: from memsql.common import database + enabled = True except ImportError: enabled = False @@ -16,51 +17,42 @@ COLUMN_TYPE = 1 types_map = { - 'BIGINT': TYPE_INTEGER, - 'TINYINT': TYPE_INTEGER, - 'SMALLINT': TYPE_INTEGER, - 'MEDIUMINT': TYPE_INTEGER, - 'INT': TYPE_INTEGER, - 'DOUBLE': TYPE_FLOAT, - 'DECIMAL': TYPE_FLOAT, - 'FLOAT': TYPE_FLOAT, - 'REAL': TYPE_FLOAT, - 'BOOL': TYPE_BOOLEAN, - 'BOOLEAN': TYPE_BOOLEAN, - 'TIMESTAMP': TYPE_DATETIME, - 'DATETIME': TYPE_DATETIME, - 'DATE': TYPE_DATETIME, - 'JSON': TYPE_STRING, - 'CHAR': TYPE_STRING, - 'VARCHAR': TYPE_STRING + "BIGINT": TYPE_INTEGER, + "TINYINT": TYPE_INTEGER, + "SMALLINT": TYPE_INTEGER, + "MEDIUMINT": TYPE_INTEGER, + "INT": TYPE_INTEGER, + "DOUBLE": TYPE_FLOAT, + "DECIMAL": TYPE_FLOAT, + "FLOAT": TYPE_FLOAT, + "REAL": TYPE_FLOAT, + "BOOL": TYPE_BOOLEAN, + "BOOLEAN": TYPE_BOOLEAN, + "TIMESTAMP": TYPE_DATETIME, + "DATETIME": TYPE_DATETIME, + "DATE": TYPE_DATETIME, + "JSON": TYPE_STRING, + "CHAR": TYPE_STRING, + "VARCHAR": TYPE_STRING, } class MemSQL(BaseSQLQueryRunner): should_annotate_query = False - noop_query = 'SELECT 1' + noop_query = "SELECT 1" @classmethod def configuration_schema(cls): return { "type": "object", "properties": { - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, - "user": { - "type": "string" - }, - "password": { - "type": "string" - } - + "host": {"type": "string"}, + "port": {"type": "number"}, + "user": {"type": "string"}, + "password": {"type": "string"}, }, "required": ["host", "port"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -78,13 +70,32 @@ def _get_tables(self, schema): columns_query = "show columns in %s" - for schema_name in [a for a in [str(a['Database']) for a in self._run_query_internal(schemas_query)] if len(a) > 0]: - for table_name in [a for a in [str(a['Tables_in_%s' % schema_name]) for a in self._run_query_internal( - tables_query % schema_name)] if len(a) > 0]: - table_name = '.'.join((schema_name, table_name)) - columns = [a for a in [str(a['Field']) for a in self._run_query_internal(columns_query % table_name)] if len(a) > 0] - - schema[table_name] = {'name': table_name, 'columns': columns} + for schema_name in [ + a + for a in [ + str(a["Database"]) for a in self._run_query_internal(schemas_query) + ] + if len(a) > 0 + ]: + for table_name in [ + a + for a in [ + str(a["Tables_in_%s" % schema_name]) + for a in self._run_query_internal(tables_query % schema_name) + ] + if len(a) > 0 + ]: + table_name = ".".join((schema_name, table_name)) + columns = [ + a + for a in [ + str(a["Field"]) + for a in self._run_query_internal(columns_query % table_name) + ] + if len(a) > 0 + ] + + schema[table_name] = {"name": table_name, "columns": columns} return list(schema.values()) def run_query(self, query, user): @@ -117,13 +128,11 @@ def run_query(self, query, user): if column_names: for column in column_names: - columns.append({ - 'name': column, - 'friendly_name': column, - 'type': TYPE_STRING - }) + columns.append( + {"name": column, "friendly_name": column, "type": TYPE_STRING} + ) - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None except KeyboardInterrupt: diff --git a/redash/query_runner/mongodb.py b/redash/query_runner/mongodb.py index 558942197d..6da92620e6 100644 --- a/redash/query_runner/mongodb.py +++ b/redash/query_runner/mongodb.py @@ -17,6 +17,7 @@ from bson.decimal128 import Decimal128 from bson.son import SON from bson.json_util import object_hook as bson_object_hook + enabled = True except ImportError: @@ -44,14 +45,14 @@ def default(self, o): return super(MongoDBJSONEncoder, self).default(o) -date_regex = re.compile("ISODate\(\"(.*)\"\)", re.IGNORECASE) +date_regex = re.compile('ISODate\("(.*)"\)', re.IGNORECASE) def parse_oids(oids): if not isinstance(oids, list): raise Exception("$oids takes an array as input.") - return [bson_object_hook({'$oid': oid}) for oid in oids] + return [bson_object_hook({"$oid": oid}) for oid in oids] def datetime_parser(dct): @@ -61,11 +62,11 @@ def datetime_parser(dct): if len(m) > 0: dct[k] = parse(m[0], yearfirst=True) - if '$humanTime' in dct: - return parse_human_time(dct['$humanTime']) + if "$humanTime" in dct: + return parse_human_time(dct["$humanTime"]) - if '$oids' in dct: - return parse_oids(dct['$oids']) + if "$oids" in dct: + return parse_oids(dct["$oids"]) return bson_object_hook(dct) @@ -93,23 +94,29 @@ def parse_results(results): for key in row: if isinstance(row[key], dict): for inner_key in row[key]: - column_name = '{}.{}'.format(key, inner_key) + column_name = "{}.{}".format(key, inner_key) if _get_column_by_name(columns, column_name) is None: - columns.append({ - "name": column_name, - "friendly_name": column_name, - "type": TYPES_MAP.get(type(row[key][inner_key]), TYPE_STRING) - }) + columns.append( + { + "name": column_name, + "friendly_name": column_name, + "type": TYPES_MAP.get( + type(row[key][inner_key]), TYPE_STRING + ), + } + ) parsed_row[column_name] = row[key][inner_key] else: if _get_column_by_name(columns, key) is None: - columns.append({ - "name": key, - "friendly_name": key, - "type": TYPES_MAP.get(type(row[key]), TYPE_STRING) - }) + columns.append( + { + "name": key, + "friendly_name": key, + "type": TYPES_MAP.get(type(row[key]), TYPE_STRING), + } + ) parsed_row[key] = row[key] @@ -124,22 +131,13 @@ class MongoDB(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'connectionString': { - 'type': 'string', - 'title': 'Connection String' - }, - 'dbName': { - 'type': 'string', - 'title': "Database Name" - }, - 'replicaSetName': { - 'type': 'string', - 'title': 'Replica Set Name' - }, + "type": "object", + "properties": { + "connectionString": {"type": "string", "title": "Connection String"}, + "dbName": {"type": "string", "title": "Database Name"}, + "replicaSetName": {"type": "string", "title": "Replica Set Name"}, }, - 'required': ['connectionString', 'dbName'] + "required": ["connectionString", "dbName"], } @classmethod @@ -149,16 +147,23 @@ def enabled(cls): def __init__(self, configuration): super(MongoDB, self).__init__(configuration) - self.syntax = 'json' + self.syntax = "json" self.db_name = self.configuration["dbName"] - self.is_replica_set = True if "replicaSetName" in self.configuration and self.configuration["replicaSetName"] else False + self.is_replica_set = ( + True + if "replicaSetName" in self.configuration + and self.configuration["replicaSetName"] + else False + ) def _get_db(self): if self.is_replica_set: - db_connection = pymongo.MongoClient(self.configuration["connectionString"], - replicaSet=self.configuration["replicaSetName"]) + db_connection = pymongo.MongoClient( + self.configuration["connectionString"], + replicaSet=self.configuration["replicaSetName"], + ) else: db_connection = pymongo.MongoClient(self.configuration["connectionString"]) @@ -171,11 +176,11 @@ def test_connection(self): def _merge_property_names(self, columns, document): for property in document: - if property not in columns: - columns.append(property) + if property not in columns: + columns.append(property) def _is_collection_a_view(self, db, collection_name): - if 'viewOn' in db[collection_name].options(): + if "viewOn" in db[collection_name].options(): return True else: return False @@ -210,18 +215,22 @@ def get_schema(self, get_stats=False): schema = {} db = self._get_db() for collection_name in db.collection_names(): - if collection_name.startswith('system.'): + if collection_name.startswith("system."): continue columns = self._get_collection_fields(db, collection_name) schema[collection_name] = { - "name": collection_name, "columns": sorted(columns)} + "name": collection_name, + "columns": sorted(columns), + } return list(schema.values()) def run_query(self, query, user): db = self._get_db() - logger.debug("mongodb connection string: %s", self.configuration['connectionString']) + logger.debug( + "mongodb connection string: %s", self.configuration["connectionString"] + ) logger.debug("mongodb got query: %s", query) try: @@ -283,7 +292,7 @@ def run_query(self, query, user): cursor = cursor.count() elif aggregate: - allow_disk_use = query_data.get('allowDiskUse', False) + allow_disk_use = query_data.get("allowDiskUse", False) r = db[collection].aggregate(aggregate, allowDiskUse=allow_disk_use) # Backwards compatibility with older pymongo versions. @@ -297,11 +306,9 @@ def run_query(self, query, user): cursor = r if "count" in query_data: - columns.append({ - "name": "count", - "friendly_name": "count", - "type": TYPE_INTEGER - }) + columns.append( + {"name": "count", "friendly_name": "count", "type": TYPE_INTEGER} + ) rows.append({"count": cursor}) else: @@ -316,14 +323,11 @@ def run_query(self, query, user): columns = ordered_columns - if query_data.get('sortColumns'): - reverse = query_data['sortColumns'] == 'desc' - columns = sorted(columns, key=lambda col: col['name'], reverse=reverse) + if query_data.get("sortColumns"): + reverse = query_data["sortColumns"] == "desc" + columns = sorted(columns, key=lambda col: col["name"], reverse=reverse) - data = { - "columns": columns, - "rows": rows - } + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data, cls=MongoDBJSONEncoder) diff --git a/redash/query_runner/mssql.py b/redash/query_runner/mssql.py index 541c736747..0f62fdb892 100644 --- a/redash/query_runner/mssql.py +++ b/redash/query_runner/mssql.py @@ -9,6 +9,7 @@ try: import pymssql + enabled = True except ImportError: enabled = False @@ -34,37 +35,24 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "server": { - "type": "string", - "default": "127.0.0.1" - }, - "port": { - "type": "number", - "default": 1433 - }, + "user": {"type": "string"}, + "password": {"type": "string"}, + "server": {"type": "string", "default": "127.0.0.1"}, + "port": {"type": "number", "default": 1433}, "tds_version": { "type": "string", "default": "7.0", - "title": "TDS Version" + "title": "TDS Version", }, "charset": { "type": "string", "default": "UTF-8", - "title": "Character Set" + "title": "Character Set", }, - "db": { - "type": "string", - "title": "Database Name" - } + "db": {"type": "string", "title": "Database Name"}, }, "required": ["db"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -96,16 +84,16 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - if row['table_schema'] != self.configuration['db']: - table_name = '{}.{}'.format(row['table_schema'], row['table_name']) + for row in results["rows"]: + if row["table_schema"] != self.configuration["db"]: + table_name = "{}.{}".format(row["table_schema"], row["table_name"]) else: - table_name = row['table_name'] + table_name = row["table_name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) @@ -113,18 +101,25 @@ def run_query(self, query, user): connection = None try: - server = self.configuration.get('server', '') - user = self.configuration.get('user', '') - password = self.configuration.get('password', '') - db = self.configuration['db'] - port = self.configuration.get('port', 1433) - tds_version = self.configuration.get('tds_version', '7.0') - charset = self.configuration.get('charset', 'UTF-8') + server = self.configuration.get("server", "") + user = self.configuration.get("user", "") + password = self.configuration.get("password", "") + db = self.configuration["db"] + port = self.configuration.get("port", 1433) + tds_version = self.configuration.get("tds_version", "7.0") + charset = self.configuration.get("charset", "UTF-8") if port != 1433: - server = server + ':' + str(port) + server = server + ":" + str(port) - connection = pymssql.connect(server=server, user=user, password=password, database=db, tds_version=tds_version, charset=charset) + connection = pymssql.connect( + server=server, + user=user, + password=password, + database=db, + tds_version=tds_version, + charset=charset, + ) if isinstance(query, str): query = query.encode(charset) @@ -136,10 +131,15 @@ def run_query(self, query, user): data = cursor.fetchall() if cursor.description is not None: - columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in data] - - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [(i[0], types_map.get(i[1], None)) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) + for row in data + ] + + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None else: diff --git a/redash/query_runner/mssql_odbc.py b/redash/query_runner/mssql_odbc.py index 48c18cb930..6c574f0287 100644 --- a/redash/query_runner/mssql_odbc.py +++ b/redash/query_runner/mssql_odbc.py @@ -10,6 +10,7 @@ try: import pyodbc + enabled = True except ImportError: enabled = False @@ -24,37 +25,24 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "server": { - "type": "string", - "default": "127.0.0.1" - }, - "port": { - "type": "number", - "default": 1433 - }, + "user": {"type": "string"}, + "password": {"type": "string"}, + "server": {"type": "string", "default": "127.0.0.1"}, + "port": {"type": "number", "default": 1433}, "charset": { "type": "string", "default": "UTF-8", - "title": "Character Set" - }, - "db": { - "type": "string", - "title": "Database Name" + "title": "Character Set", }, + "db": {"type": "string", "title": "Database Name"}, "driver": { "type": "string", "title": "Driver Identifier", - "default": "{ODBC Driver 13 for SQL Server}" - } + "default": "{ODBC Driver 13 for SQL Server}", + }, }, "required": ["db"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -86,16 +74,16 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - if row['table_schema'] != self.configuration['db']: - table_name = '{}.{}'.format(row['table_schema'], row['table_name']) + for row in results["rows"]: + if row["table_schema"] != self.configuration["db"]: + table_name = "{}.{}".format(row["table_schema"], row["table_name"]) else: - table_name = row['table_name'] + table_name = row["table_name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) @@ -103,21 +91,20 @@ def run_query(self, query, user): connection = None try: - server = self.configuration.get('server', '') - user = self.configuration.get('user', '') - password = self.configuration.get('password', '') - db = self.configuration['db'] - port = self.configuration.get('port', 1433) - charset = self.configuration.get('charset', 'UTF-8') - driver = self.configuration.get('driver', '{ODBC Driver 13 for SQL Server}') - - connection_string_fmt = 'DRIVER={};PORT={};SERVER={};DATABASE={};UID={};PWD={}' - connection_string = connection_string_fmt.format(driver, - port, - server, - db, - user, - password) + server = self.configuration.get("server", "") + user = self.configuration.get("user", "") + password = self.configuration.get("password", "") + db = self.configuration["db"] + port = self.configuration.get("port", 1433) + charset = self.configuration.get("charset", "UTF-8") + driver = self.configuration.get("driver", "{ODBC Driver 13 for SQL Server}") + + connection_string_fmt = ( + "DRIVER={};PORT={};SERVER={};DATABASE={};UID={};PWD={}" + ) + connection_string = connection_string_fmt.format( + driver, port, server, db, user, password + ) connection = pyodbc.connect(connection_string) cursor = connection.cursor() logger.debug("SQLServerODBC running query: %s", query) @@ -125,10 +112,15 @@ def run_query(self, query, user): data = cursor.fetchall() if cursor.description is not None: - columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in data] - - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [(i[0], types_map.get(i[1], None)) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) + for row in data + ] + + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None else: diff --git a/redash/query_runner/mysql.py b/redash/query_runner/mysql.py index a6f5d5c480..05552eef7f 100644 --- a/redash/query_runner/mysql.py +++ b/redash/query_runner/mysql.py @@ -2,12 +2,22 @@ import os import threading -from redash.query_runner import TYPE_FLOAT, TYPE_INTEGER, TYPE_DATETIME, TYPE_STRING, TYPE_DATE, BaseSQLQueryRunner, InterruptException, register +from redash.query_runner import ( + TYPE_FLOAT, + TYPE_INTEGER, + TYPE_DATETIME, + TYPE_STRING, + TYPE_DATE, + BaseSQLQueryRunner, + InterruptException, + register, +) from redash.settings import parse_boolean from redash.utils import json_dumps, json_loads try: import MySQLdb + enabled = True except ImportError: enabled = False @@ -44,57 +54,41 @@ class Mysql(BaseSQLQueryRunner): @classmethod def configuration_schema(cls): show_ssl_settings = parse_boolean( - os.environ.get('MYSQL_SHOW_SSL_SETTINGS', 'true')) + os.environ.get("MYSQL_SHOW_SSL_SETTINGS", "true") + ) schema = { - 'type': 'object', - 'properties': { - 'host': { - 'type': 'string', - 'default': '127.0.0.1' - }, - 'user': { - 'type': 'string' - }, - 'passwd': { - 'type': 'string', - 'title': 'Password' - }, - 'db': { - 'type': 'string', - 'title': 'Database name' - }, - 'port': { - 'type': 'number', - 'default': 3306, - } + "type": "object", + "properties": { + "host": {"type": "string", "default": "127.0.0.1"}, + "user": {"type": "string"}, + "passwd": {"type": "string", "title": "Password"}, + "db": {"type": "string", "title": "Database name"}, + "port": {"type": "number", "default": 3306}, }, - "order": ['host', 'port', 'user', 'passwd', 'db'], - 'required': ['db'], - 'secret': ['passwd'] + "order": ["host", "port", "user", "passwd", "db"], + "required": ["db"], + "secret": ["passwd"], } if show_ssl_settings: - schema['properties'].update({ - 'use_ssl': { - 'type': 'boolean', - 'title': 'Use SSL' - }, - 'ssl_cacert': { - 'type': - 'string', - 'title': - 'Path to CA certificate file to verify peer against (SSL)' - }, - 'ssl_cert': { - 'type': 'string', - 'title': 'Path to client certificate file (SSL)' - }, - 'ssl_key': { - 'type': 'string', - 'title': 'Path to private key file (SSL)' + schema["properties"].update( + { + "use_ssl": {"type": "boolean", "title": "Use SSL"}, + "ssl_cacert": { + "type": "string", + "title": "Path to CA certificate file to verify peer against (SSL)", + }, + "ssl_cert": { + "type": "string", + "title": "Path to client certificate file (SSL)", + }, + "ssl_key": { + "type": "string", + "title": "Path to private key file (SSL)", + }, } - }) + ) return schema @@ -107,19 +101,21 @@ def enabled(cls): return enabled def _connection(self): - params = dict(host=self.configuration.get('host', ''), - user=self.configuration.get('user', ''), - passwd=self.configuration.get('passwd', ''), - db=self.configuration['db'], - port=self.configuration.get('port', 3306), - charset='utf8', - use_unicode=True, - connect_timeout=60) + params = dict( + host=self.configuration.get("host", ""), + user=self.configuration.get("user", ""), + passwd=self.configuration.get("passwd", ""), + db=self.configuration["db"], + port=self.configuration.get("port", 3306), + charset="utf8", + use_unicode=True, + connect_timeout=60, + ) ssl_options = self._get_ssl_parameters() if ssl_options: - params['ssl'] = ssl_options + params["ssl"] = ssl_options connection = MySQLdb.connect(**params) @@ -141,17 +137,16 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - if row['table_schema'] != self.configuration['db']: - table_name = '{}.{}'.format(row['table_schema'], - row['table_name']) + for row in results["rows"]: + if row["table_schema"] != self.configuration["db"]: + table_name = "{}.{}".format(row["table_schema"], row["table_name"]) else: - table_name = row['table_name'] + table_name = row["table_name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) @@ -163,8 +158,9 @@ def run_query(self, query, user): try: connection = self._connection() thread_id = connection.thread_id() - t = threading.Thread(target=self._run_query, - args=(query, user, connection, r, ev)) + t = threading.Thread( + target=self._run_query, args=(query, user, connection, r, ev) + ) t.start() while not ev.wait(1): pass @@ -194,14 +190,15 @@ def _run_query(self, query, user, connection, r, ev): # TODO - very similar to pg.py if desc is not None: - columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) - for i in desc]) + columns = self.fetch_columns( + [(i[0], types_map.get(i[1], None)) for i in desc] + ) rows = [ - dict(zip((column['name'] for column in columns), row)) + dict(zip((column["name"] for column in columns), row)) for row in data ] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} r.json_data = json_dumps(data) r.error = None else: @@ -220,17 +217,13 @@ def _run_query(self, query, user, connection, r, ev): connection.close() def _get_ssl_parameters(self): - if not self.configuration.get('use_ssl'): + if not self.configuration.get("use_ssl"): return None ssl_params = {} - if self.configuration.get('use_ssl'): - config_map = { - "ssl_cacert": "ca", - "ssl_cert": "cert", - "ssl_key": "key", - } + if self.configuration.get("use_ssl"): + config_map = {"ssl_cacert": "ca", "ssl_cert": "cert", "ssl_key": "key"} for key, cfg in config_map.items(): val = self.configuration.get(key) if val: @@ -267,46 +260,31 @@ def name(cls): @classmethod def type(cls): - return 'rds_mysql' + return "rds_mysql" @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'host': { - 'type': 'string', - }, - 'user': { - 'type': 'string' - }, - 'passwd': { - 'type': 'string', - 'title': 'Password' - }, - 'db': { - 'type': 'string', - 'title': 'Database name' - }, - 'port': { - 'type': 'number', - 'default': 3306, - }, - 'use_ssl': { - 'type': 'boolean', - 'title': 'Use SSL' - } + "type": "object", + "properties": { + "host": {"type": "string"}, + "user": {"type": "string"}, + "passwd": {"type": "string", "title": "Password"}, + "db": {"type": "string", "title": "Database name"}, + "port": {"type": "number", "default": 3306}, + "use_ssl": {"type": "boolean", "title": "Use SSL"}, }, - "order": ['host', 'port', 'user', 'passwd', 'db'], - 'required': ['db', 'user', 'passwd', 'host'], - 'secret': ['passwd'] + "order": ["host", "port", "user", "passwd", "db"], + "required": ["db", "user", "passwd", "host"], + "secret": ["passwd"], } def _get_ssl_parameters(self): - if self.configuration.get('use_ssl'): - ca_path = os.path.join(os.path.dirname(__file__), - './files/rds-combined-ca-bundle.pem') - return {'ca': ca_path} + if self.configuration.get("use_ssl"): + ca_path = os.path.join( + os.path.dirname(__file__), "./files/rds-combined-ca-bundle.pem" + ) + return {"ca": ca_path} return None diff --git a/redash/query_runner/oracle.py b/redash/query_runner/oracle.py index 11a382c35b..01daf7feeb 100644 --- a/redash/query_runner/oracle.py +++ b/redash/query_runner/oracle.py @@ -48,25 +48,14 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, - "servicename": { - "type": "string", - "title": "DSN Service Name" - } + "user": {"type": "string"}, + "password": {"type": "string"}, + "host": {"type": "string"}, + "port": {"type": "number"}, + "servicename": {"type": "string", "title": "DSN Service Name"}, }, "required": ["servicename", "user", "password", "host", "port"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -79,9 +68,12 @@ def __init__(self, configuration): dsn = cx_Oracle.makedsn( self.configuration["host"], self.configuration["port"], - service_name=self.configuration["servicename"]) + service_name=self.configuration["servicename"], + ) - self.connection_string = "{}/{}@{}".format(self.configuration["user"], self.configuration["password"], dsn) + self.connection_string = "{}/{}@{}".format( + self.configuration["user"], self.configuration["password"], dsn + ) def _get_tables(self, schema): query = """ @@ -100,16 +92,16 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - if row['OWNER'] != None: - table_name = '{}.{}'.format(row['OWNER'], row['TABLE_NAME']) + for row in results["rows"]: + if row["OWNER"] != None: + table_name = "{}.{}".format(row["OWNER"], row["TABLE_NAME"]) else: - table_name = row['TABLE_NAME'] + table_name = row["TABLE_NAME"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['COLUMN_NAME']) + schema[table_name]["columns"].append(row["COLUMN_NAME"]) return list(schema.values()) @@ -130,7 +122,12 @@ def output_handler(cls, cursor, name, default_type, length, precision, scale): if default_type == cx_Oracle.NUMBER: if scale <= 0: - return cursor.var(cx_Oracle.STRING, 255, outconverter=Oracle._convert_number, arraysize=cursor.arraysize) + return cursor.var( + cx_Oracle.STRING, + 255, + outconverter=Oracle._convert_number, + arraysize=cursor.arraysize, + ) def run_query(self, query, user): connection = cx_Oracle.connect(self.connection_string) @@ -142,15 +139,23 @@ def run_query(self, query, user): cursor.execute(query) rows_count = cursor.rowcount if cursor.description is not None: - columns = self.fetch_columns([(i[0], Oracle.get_col_type(i[1], i[5])) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in cursor] - data = {'columns': columns, 'rows': rows} + columns = self.fetch_columns( + [ + (i[0], Oracle.get_col_type(i[1], i[5])) + for i in cursor.description + ] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) + for row in cursor + ] + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) else: - columns = [{'name': 'Row(s) Affected', 'type': 'TYPE_INTEGER'}] - rows = [{'Row(s) Affected': rows_count}] - data = {'columns': columns, 'rows': rows} + columns = [{"name": "Row(s) Affected", "type": "TYPE_INTEGER"}] + rows = [{"Row(s) Affected": rows_count}] + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) connection.commit() except cx_Oracle.DatabaseError as err: diff --git a/redash/query_runner/pg.py b/redash/query_runner/pg.py index 99787730bb..3dfb538735 100644 --- a/redash/query_runner/pg.py +++ b/redash/query_runner/pg.py @@ -25,7 +25,7 @@ 1015: TYPE_STRING, 1008: TYPE_STRING, 1009: TYPE_STRING, - 2951: TYPE_STRING + 2951: TYPE_STRING, } @@ -34,15 +34,11 @@ def default(self, o): if isinstance(o, Range): # From: https://github.com/psycopg/psycopg2/pull/779 if o._bounds is None: - return '' + return "" - items = [ - o._bounds[0], - str(o._lower), ', ', - str(o._upper), o._bounds[1] - ] + items = [o._bounds[0], str(o._lower), ", ", str(o._upper), o._bounds[1]] - return ''.join(items) + return "".join(items) return super(PostgreSQLJSONEncoder, self).default(o) @@ -64,10 +60,10 @@ def _wait(conn, timeout=None): def full_table_name(schema, name): - if '.' in name: + if "." in name: name = u'"{}"'.format(name) - return u'{}.{}'.format(schema, name) + return u"{}.{}".format(schema, name) def build_schema(query_result, schema): @@ -78,21 +74,26 @@ def build_schema(query_result, schema): # (while this feels unlikely, this actually happened) # In this case if we omit the schema name for the public table, we will have # a conflict. - table_names = set(map(lambda r: full_table_name(r['table_schema'], r['table_name']), query_result['rows'])) + table_names = set( + map( + lambda r: full_table_name(r["table_schema"], r["table_name"]), + query_result["rows"], + ) + ) - for row in query_result['rows']: - if row['table_schema'] != 'public': - table_name = full_table_name(row['table_schema'], row['table_name']) + for row in query_result["rows"]: + if row["table_schema"] != "public": + table_name = full_table_name(row["table_schema"], row["table_name"]) else: - if row['table_name'] in table_names: - table_name = full_table_name(row['table_schema'], row['table_name']) + if row["table_name"] in table_names: + table_name = full_table_name(row["table_schema"], row["table_name"]) else: - table_name = row['table_name'] + table_name = row["table_name"] if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) class PostgreSQL(BaseSQLQueryRunner): @@ -103,33 +104,16 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "host": { - "type": "string", - "default": "127.0.0.1" - }, - "port": { - "type": "number", - "default": 5432 - }, - "dbname": { - "type": "string", - "title": "Database Name" - }, - "sslmode": { - "type": "string", - "title": "SSL Mode", - "default": "prefer" - } + "user": {"type": "string"}, + "password": {"type": "string"}, + "host": {"type": "string", "default": "127.0.0.1"}, + "port": {"type": "number", "default": 5432}, + "dbname": {"type": "string", "title": "Database Name"}, + "sslmode": {"type": "string", "title": "SSL Mode", "default": "prefer"}, }, - "order": ['host', 'port', 'user', 'password'], + "order": ["host", "port", "user", "password"], "required": ["dbname"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -147,7 +131,7 @@ def _get_definitions(self, schema, query): build_schema(results, schema) def _get_tables(self, schema): - ''' + """ relkind constants per https://www.postgresql.org/docs/10/static/catalog-pg-class.html r = regular table v = view @@ -159,7 +143,7 @@ def _get_tables(self, schema): S = sequence t = TOAST table c = composite type - ''' + """ query = """ SELECT s.nspname as table_schema, @@ -190,13 +174,14 @@ def _get_tables(self, schema): def _get_connection(self): connection = psycopg2.connect( - user=self.configuration.get('user'), - password=self.configuration.get('password'), - host=self.configuration.get('host'), - port=self.configuration.get('port'), - dbname=self.configuration.get('dbname'), - sslmode=self.configuration.get('sslmode'), - async_=True) + user=self.configuration.get("user"), + password=self.configuration.get("password"), + host=self.configuration.get("host"), + port=self.configuration.get("port"), + dbname=self.configuration.get("dbname"), + sslmode=self.configuration.get("sslmode"), + async_=True, + ) return connection @@ -211,20 +196,19 @@ def run_query(self, query, user): _wait(connection) if cursor.description is not None: - columns = self.fetch_columns([(i[0], types_map.get(i[1], None)) - for i in cursor.description]) + columns = self.fetch_columns( + [(i[0], types_map.get(i[1], None)) for i in cursor.description] + ) rows = [ - dict(zip((column['name'] for column in columns), row)) + dict(zip((column["name"] for column in columns), row)) for row in cursor ] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} error = None - json_data = json_dumps(data, - ignore_nan=True, - cls=PostgreSQLJSONEncoder) + json_data = json_dumps(data, ignore_nan=True, cls=PostgreSQLJSONEncoder) else: - error = 'Query completed but it returned no data.' + error = "Query completed but it returned no data." json_data = None except (select.error, OSError) as e: error = "Query interrupted. Please retry." @@ -248,18 +232,20 @@ def type(cls): return "redshift" def _get_connection(self): - sslrootcert_path = os.path.join(os.path.dirname(__file__), - './files/redshift-ca-bundle.crt') + sslrootcert_path = os.path.join( + os.path.dirname(__file__), "./files/redshift-ca-bundle.crt" + ) connection = psycopg2.connect( - user=self.configuration.get('user'), - password=self.configuration.get('password'), - host=self.configuration.get('host'), - port=self.configuration.get('port'), - dbname=self.configuration.get('dbname'), - sslmode=self.configuration.get('sslmode', 'prefer'), + user=self.configuration.get("user"), + password=self.configuration.get("password"), + host=self.configuration.get("host"), + port=self.configuration.get("port"), + dbname=self.configuration.get("dbname"), + sslmode=self.configuration.get("sslmode", "prefer"), sslrootcert=sslrootcert_path, - async_=True) + async_=True, + ) return connection @@ -268,54 +254,48 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "host": { - "type": "string" - }, - "port": { - "type": "number" - }, - "dbname": { - "type": "string", - "title": "Database Name" - }, - "sslmode": { - "type": "string", - "title": "SSL Mode", - "default": "prefer" - }, + "user": {"type": "string"}, + "password": {"type": "string"}, + "host": {"type": "string"}, + "port": {"type": "number"}, + "dbname": {"type": "string", "title": "Database Name"}, + "sslmode": {"type": "string", "title": "SSL Mode", "default": "prefer"}, "adhoc_query_group": { "type": "string", "title": "Query Group for Adhoc Queries", - "default": "default" + "default": "default", }, "scheduled_query_group": { "type": "string", "title": "Query Group for Scheduled Queries", - "default": "default" + "default": "default", }, }, - "order": ['host', 'port', 'user', 'password', 'dbname', 'sslmode', 'adhoc_query_group', 'scheduled_query_group'], + "order": [ + "host", + "port", + "user", + "password", + "dbname", + "sslmode", + "adhoc_query_group", + "scheduled_query_group", + ], "required": ["dbname", "user", "password", "host", "port"], - "secret": ["password"] + "secret": ["password"], } def annotate_query(self, query, metadata): annotated = super(Redshift, self).annotate_query(query, metadata) - if metadata.get('Scheduled', False): - query_group = self.configuration.get('scheduled_query_group') + if metadata.get("Scheduled", False): + query_group = self.configuration.get("scheduled_query_group") else: - query_group = self.configuration.get('adhoc_query_group') + query_group = self.configuration.get("adhoc_query_group") if query_group: - set_query_group = 'set query_group to {};'.format(query_group) - annotated = '{}\n{}'.format(set_query_group, annotated) + set_query_group = "set query_group to {};".format(query_group) + annotated = "{}\n{}".format(set_query_group, annotated) return annotated diff --git a/redash/query_runner/phoenix.py b/redash/query_runner/phoenix.py index 22fc7c996c..daf33835a3 100644 --- a/redash/query_runner/phoenix.py +++ b/redash/query_runner/phoenix.py @@ -2,57 +2,55 @@ from redash.utils import json_dumps, json_loads import logging + logger = logging.getLogger(__name__) try: import phoenixdb from phoenixdb.errors import * + enabled = True except ImportError: enabled = False TYPES_MAPPING = { - 'VARCHAR': TYPE_STRING, - 'CHAR': TYPE_STRING, - 'BINARY': TYPE_STRING, - 'VARBINARY': TYPE_STRING, - 'BOOLEAN': TYPE_BOOLEAN, - 'TIME': TYPE_DATETIME, - 'DATE': TYPE_DATETIME, - 'TIMESTAMP': TYPE_DATETIME, - 'UNSIGNED_TIME': TYPE_DATETIME, - 'UNSIGNED_DATE': TYPE_DATETIME, - 'UNSIGNED_TIMESTAMP': TYPE_DATETIME, - 'INTEGER': TYPE_INTEGER, - 'UNSIGNED_INT': TYPE_INTEGER, - 'BIGINT': TYPE_INTEGER, - 'UNSIGNED_LONG': TYPE_INTEGER, - 'TINYINT': TYPE_INTEGER, - 'UNSIGNED_TINYINT': TYPE_INTEGER, - 'SMALLINT': TYPE_INTEGER, - 'UNSIGNED_SMALLINT': TYPE_INTEGER, - 'FLOAT': TYPE_FLOAT, - 'UNSIGNED_FLOAT': TYPE_FLOAT, - 'DOUBLE': TYPE_FLOAT, - 'UNSIGNED_DOUBLE': TYPE_FLOAT, - 'DECIMAL': TYPE_FLOAT + "VARCHAR": TYPE_STRING, + "CHAR": TYPE_STRING, + "BINARY": TYPE_STRING, + "VARBINARY": TYPE_STRING, + "BOOLEAN": TYPE_BOOLEAN, + "TIME": TYPE_DATETIME, + "DATE": TYPE_DATETIME, + "TIMESTAMP": TYPE_DATETIME, + "UNSIGNED_TIME": TYPE_DATETIME, + "UNSIGNED_DATE": TYPE_DATETIME, + "UNSIGNED_TIMESTAMP": TYPE_DATETIME, + "INTEGER": TYPE_INTEGER, + "UNSIGNED_INT": TYPE_INTEGER, + "BIGINT": TYPE_INTEGER, + "UNSIGNED_LONG": TYPE_INTEGER, + "TINYINT": TYPE_INTEGER, + "UNSIGNED_TINYINT": TYPE_INTEGER, + "SMALLINT": TYPE_INTEGER, + "UNSIGNED_SMALLINT": TYPE_INTEGER, + "FLOAT": TYPE_FLOAT, + "UNSIGNED_FLOAT": TYPE_FLOAT, + "DOUBLE": TYPE_FLOAT, + "UNSIGNED_DOUBLE": TYPE_FLOAT, + "DECIMAL": TYPE_FLOAT, } class Phoenix(BaseQueryRunner): - noop_query = 'select 1' + noop_query = "select 1" @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string' - } - }, - 'required': ['url'] + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], } @classmethod @@ -78,35 +76,42 @@ def get_schema(self, get_stats=False): results = json_loads(results) - for row in results['rows']: - table_name = '{}.{}'.format(row['TABLE_SCHEM'], row['TABLE_NAME']) + for row in results["rows"]: + table_name = "{}.{}".format(row["TABLE_SCHEM"], row["TABLE_NAME"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['COLUMN_NAME']) + schema[table_name]["columns"].append(row["COLUMN_NAME"]) return list(schema.values()) def run_query(self, query, user): connection = phoenixdb.connect( - url=self.configuration.get('url', ''), - autocommit=True) + url=self.configuration.get("url", ""), autocommit=True + ) cursor = connection.cursor() try: cursor.execute(query) - column_tuples = [(i[0], TYPES_MAPPING.get(i[1], None)) for i in cursor.description] + column_tuples = [ + (i[0], TYPES_MAPPING.get(i[1], None)) for i in cursor.description + ] columns = self.fetch_columns(column_tuples) - rows = [dict(zip(([column['name'] for column in columns]), r)) for i, r in enumerate(cursor.fetchall())] - data = {'columns': columns, 'rows': rows} + rows = [ + dict(zip(([column["name"] for column in columns]), r)) + for i, r in enumerate(cursor.fetchall()) + ] + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None cursor.close() except Error as e: json_data = None - error = 'code: {}, sql state:{}, message: {}'.format(e.code, e.sqlstate, e.message) + error = "code: {}, sql state:{}, message: {}".format( + e.code, e.sqlstate, e.message + ) except (KeyboardInterrupt, InterruptException) as e: error = "Query cancelled by user." json_data = None diff --git a/redash/query_runner/presto.py b/redash/query_runner/presto.py index 56369903ef..f6accd3c90 100644 --- a/redash/query_runner/presto.py +++ b/redash/query_runner/presto.py @@ -3,12 +3,14 @@ from redash.utils import json_dumps, json_loads import logging + logger = logging.getLogger(__name__) try: from pyhive import presto from pyhive.exc import DatabaseError + enabled = True except ImportError: @@ -30,38 +32,31 @@ class Presto(BaseQueryRunner): - noop_query = 'SHOW TABLES' + noop_query = "SHOW TABLES" @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'host': { - 'type': 'string' - }, - 'protocol': { - 'type': 'string', - 'default': 'http' - }, - 'port': { - 'type': 'number' - }, - 'schema': { - 'type': 'string' - }, - 'catalog': { - 'type': 'string' - }, - 'username': { - 'type': 'string' - }, - 'password': { - 'type': 'string' - }, + "type": "object", + "properties": { + "host": {"type": "string"}, + "protocol": {"type": "string", "default": "http"}, + "port": {"type": "number"}, + "schema": {"type": "string"}, + "catalog": {"type": "string"}, + "username": {"type": "string"}, + "password": {"type": "string"}, }, - 'order': ['host', 'protocol', 'port', 'username', 'password', 'schema', 'catalog'], - 'required': ['host'] + "order": [ + "host", + "protocol", + "port", + "username", + "password", + "schema", + "catalog", + ], + "required": ["host"], } @classmethod @@ -87,45 +82,49 @@ def get_schema(self, get_stats=False): results = json_loads(results) - for row in results['rows']: - table_name = '{}.{}'.format(row['table_schema'], row['table_name']) + for row in results["rows"]: + table_name = "{}.{}".format(row["table_schema"], row["table_name"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) def run_query(self, query, user): connection = presto.connect( - host=self.configuration.get('host', ''), - port=self.configuration.get('port', 8080), - protocol=self.configuration.get('protocol', 'http'), - username=self.configuration.get('username', 'redash'), - password=(self.configuration.get('password') or None), - catalog=self.configuration.get('catalog', 'hive'), - schema=self.configuration.get('schema', 'default')) + host=self.configuration.get("host", ""), + port=self.configuration.get("port", 8080), + protocol=self.configuration.get("protocol", "http"), + username=self.configuration.get("username", "redash"), + password=(self.configuration.get("password") or None), + catalog=self.configuration.get("catalog", "hive"), + schema=self.configuration.get("schema", "default"), + ) cursor = connection.cursor() try: cursor.execute(query) - column_tuples = [(i[0], PRESTO_TYPES_MAPPING.get(i[1], None)) - for i in cursor.description] + column_tuples = [ + (i[0], PRESTO_TYPES_MAPPING.get(i[1], None)) for i in cursor.description + ] columns = self.fetch_columns(column_tuples) - rows = [dict(zip(([column['name'] for column in columns]), r)) - for i, r in enumerate(cursor.fetchall())] - data = {'columns': columns, 'rows': rows} + rows = [ + dict(zip(([column["name"] for column in columns]), r)) + for i, r in enumerate(cursor.fetchall()) + ] + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None except DatabaseError as db: json_data = None - default_message = 'Unspecified DatabaseError: {0}'.format( - db.message) + default_message = "Unspecified DatabaseError: {0}".format(db.message) if isinstance(db.message, dict): - message = db.message.get( - 'failureInfo', {'message', None}).get('message') + message = db.message.get("failureInfo", {"message", None}).get( + "message" + ) else: message = None error = default_message if message is None else message diff --git a/redash/query_runner/prometheus.py b/redash/query_runner/prometheus.py index 180fa7bc7d..26ffc55602 100644 --- a/redash/query_runner/prometheus.py +++ b/redash/query_runner/prometheus.py @@ -11,9 +11,9 @@ def get_instant_rows(metrics_data): rows = [] for metric in metrics_data: - row_data = metric['metric'] + row_data = metric["metric"] - timestamp, value = metric['value'] + timestamp, value = metric["value"] date_time = datetime.fromtimestamp(timestamp) row_data.update({"timestamp": date_time, "value": value}) @@ -25,8 +25,8 @@ def get_range_rows(metrics_data): rows = [] for metric in metrics_data: - ts_values = metric['values'] - metric_labels = metric['metric'] + ts_values = metric["values"] + metric_labels = metric["metric"] for values in ts_values: row_data = metric_labels.copy() @@ -34,7 +34,7 @@ def get_range_rows(metrics_data): timestamp, value = values date_time = datetime.fromtimestamp(timestamp) - row_data.update({'timestamp': date_time, 'value': value}) + row_data.update({"timestamp": date_time, "value": value}) rows.append(row_data) return rows @@ -43,7 +43,7 @@ def get_range_rows(metrics_data): def convert_query_range(payload): query_range = {} - for key in ['start', 'end']: + for key in ["start", "end"]: if key not in payload.keys(): continue value = payload[key][0] @@ -69,14 +69,9 @@ class Prometheus(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'title': 'Prometheus API URL' - } - }, - "required": ["url"] + "type": "object", + "properties": {"url": {"type": "string", "title": "Prometheus API URL"}}, + "required": ["url"], } def test_connection(self): @@ -85,14 +80,14 @@ def test_connection(self): def get_schema(self, get_stats=False): base_url = self.configuration["url"] - metrics_path = '/api/v1/label/__name__/values' + metrics_path = "/api/v1/label/__name__/values" response = requests.get(base_url + metrics_path) response.raise_for_status() - data = response.json()['data'] + data = response.json()["data"] schema = {} for name in data: - schema[name] = {'name': name, 'columns': []} + schema[name] = {"name": name, "columns": []} return list(schema.values()) def run_query(self, query, user): @@ -115,64 +110,57 @@ def run_query(self, query, user): base_url = self.configuration["url"] columns = [ - { - 'friendly_name': 'timestamp', - 'type': TYPE_DATETIME, - 'name': 'timestamp' - }, - { - 'friendly_name': 'value', - 'type': TYPE_STRING, - 'name': 'value' - }, + {"friendly_name": "timestamp", "type": TYPE_DATETIME, "name": "timestamp"}, + {"friendly_name": "value", "type": TYPE_STRING, "name": "value"}, ] try: error = None query = query.strip() # for backward compatibility - query = 'query={}'.format(query) if not query.startswith('query=') else query + query = ( + "query={}".format(query) if not query.startswith("query=") else query + ) payload = parse_qs(query) - query_type = 'query_range' if 'step' in payload.keys() else 'query' + query_type = "query_range" if "step" in payload.keys() else "query" # for the range of until now - if query_type == 'query_range' and ('end' not in payload.keys() or 'now' in payload['end']): + if query_type == "query_range" and ( + "end" not in payload.keys() or "now" in payload["end"] + ): date_now = datetime.now() - payload.update({'end': [date_now]}) + payload.update({"end": [date_now]}) convert_query_range(payload) - api_endpoint = base_url + '/api/v1/{}'.format(query_type) + api_endpoint = base_url + "/api/v1/{}".format(query_type) response = requests.get(api_endpoint, params=payload) response.raise_for_status() - metrics = response.json()['data']['result'] + metrics = response.json()["data"]["result"] if len(metrics) == 0: - return None, 'query result is empty.' + return None, "query result is empty." - metric_labels = metrics[0]['metric'].keys() + metric_labels = metrics[0]["metric"].keys() for label_name in metric_labels: - columns.append({ - 'friendly_name': label_name, - 'type': TYPE_STRING, - 'name': label_name - }) - - if query_type == 'query_range': + columns.append( + { + "friendly_name": label_name, + "type": TYPE_STRING, + "name": label_name, + } + ) + + if query_type == "query_range": rows = get_range_rows(metrics) else: rows = get_instant_rows(metrics) - json_data = json_dumps( - { - 'rows': rows, - 'columns': columns - } - ) + json_data = json_dumps({"rows": rows, "columns": columns}) except requests.RequestException as e: return None, str(e) diff --git a/redash/query_runner/python.py b/redash/query_runner/python.py index 7c8657c6d5..e54a91bb33 100644 --- a/redash/query_runner/python.py +++ b/redash/query_runner/python.py @@ -15,6 +15,7 @@ class CustomPrint(object): """CustomPrint redirect "print" calls to be sent as "log" on the result object.""" + def __init__(self): self.enabled = True self.lines = [] @@ -22,7 +23,9 @@ def __init__(self): def write(self, text): if self.enabled: if text and text.strip(): - log_line = "[{0}] {1}".format(datetime.datetime.utcnow().isoformat(), text) + log_line = "[{0}] {1}".format( + datetime.datetime.utcnow().isoformat(), text + ) self.lines.append(log_line) def enable(self): @@ -39,25 +42,43 @@ class Python(BaseQueryRunner): should_annotate_query = False safe_builtins = ( - 'sorted', 'reversed', 'map', 'any', 'all', - 'slice', 'filter', 'len', 'next', 'enumerate', - 'sum', 'abs', 'min', 'max', 'round', 'divmod', - 'str', 'int', 'float', 'complex', - 'tuple', 'set', 'list', 'dict', 'bool', + "sorted", + "reversed", + "map", + "any", + "all", + "slice", + "filter", + "len", + "next", + "enumerate", + "sum", + "abs", + "min", + "max", + "round", + "divmod", + "str", + "int", + "float", + "complex", + "tuple", + "set", + "list", + "dict", + "bool", ) @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'allowedImportModules': { - 'type': 'string', - 'title': 'Modules to import prior to running the script' + "type": "object", + "properties": { + "allowedImportModules": { + "type": "string", + "title": "Modules to import prior to running the script", }, - 'additionalModulesPaths': { - 'type': 'string' - } + "additionalModulesPaths": {"type": "string"}, }, } @@ -95,7 +116,9 @@ def custom_import(self, name, globals=None, locals=None, fromlist=(), level=0): return m - raise Exception("'{0}' is not configured as a supported import module".format(name)) + raise Exception( + "'{0}' is not configured as a supported import module".format(name) + ) @staticmethod def custom_write(obj): @@ -129,11 +152,9 @@ def add_result_column(result, column_name, friendly_name, column_type): if "columns" not in result: result["columns"] = [] - result["columns"].append({ - "name": column_name, - "friendly_name": friendly_name, - "type": column_type - }) + result["columns"].append( + {"name": column_name, "friendly_name": friendly_name, "type": column_type} + ) @staticmethod def add_result_row(result, values): @@ -221,7 +242,7 @@ def run_query(self, query, user): try: error = None - code = compile_restricted(query, '', 'exec') + code = compile_restricted(query, "", "exec") builtins = safe_builtins.copy() builtins["_write_"] = self.custom_write @@ -263,8 +284,8 @@ def run_query(self, query, user): exec(code, restricted_globals, self._script_locals) - result = self._script_locals['result'] - result['log'] = self._custom_print.lines + result = self._script_locals["result"] + result["log"] = self._custom_print.lines json_data = json_dumps(result) except KeyboardInterrupt: error = "Query cancelled by user." diff --git a/redash/query_runner/qubole.py b/redash/query_runner/qubole.py index c5ea676d6f..d35c12326e 100644 --- a/redash/query_runner/qubole.py +++ b/redash/query_runner/qubole.py @@ -1,4 +1,3 @@ - import time import requests import logging @@ -13,6 +12,7 @@ from qds_sdk.qubole import Qubole as qbol from qds_sdk.commands import Command, HiveCommand from qds_sdk.commands import SqlCommand, PrestoCommand + enabled = True except ImportError: enabled = False @@ -29,26 +29,23 @@ def configuration_schema(cls): "query_type": { "type": "string", "title": "Query Type (quantum / presto / hive)", - "default": "hive" + "default": "hive", }, "endpoint": { "type": "string", "title": "API Endpoint", - "default": "https://api.qubole.com" - }, - "token": { - "type": "string", - "title": "Auth Token" + "default": "https://api.qubole.com", }, + "token": {"type": "string", "title": "Auth Token"}, "cluster": { "type": "string", "title": "Cluster Label", - "default": "default" - } + "default": "default", + }, }, "order": ["query_type", "endpoint", "token", "cluster"], "required": ["endpoint", "token"], - "secret": ["token"] + "secret": ["token"], } @classmethod @@ -65,27 +62,40 @@ def enabled(cls): def test_connection(self): headers = self._get_header() - r = requests.head("%s/api/latest/users" % self.configuration.get('endpoint'), headers=headers) + r = requests.head( + "%s/api/latest/users" % self.configuration.get("endpoint"), headers=headers + ) r.status_code == 200 def run_query(self, query, user): - qbol.configure(api_token=self.configuration.get('token'), - api_url='%s/api' % self.configuration.get('endpoint')) + qbol.configure( + api_token=self.configuration.get("token"), + api_url="%s/api" % self.configuration.get("endpoint"), + ) try: - query_type = self.configuration.get('query_type', 'hive') + query_type = self.configuration.get("query_type", "hive") - if query_type == 'quantum': + if query_type == "quantum": cmd = SqlCommand.create(query=query) - elif query_type == 'hive': - cmd = HiveCommand.create(query=query, label=self.configuration.get('cluster')) - elif query_type == 'presto': - cmd = PrestoCommand.create(query=query, label=self.configuration.get('cluster')) + elif query_type == "hive": + cmd = HiveCommand.create( + query=query, label=self.configuration.get("cluster") + ) + elif query_type == "presto": + cmd = PrestoCommand.create( + query=query, label=self.configuration.get("cluster") + ) else: - raise Exception("Invalid Query Type:%s.\ - It must be : hive / presto / quantum." % self.configuration.get('query_type')) + raise Exception( + "Invalid Query Type:%s.\ + It must be : hive / presto / quantum." + % self.configuration.get("query_type") + ) - logging.info("Qubole command created with Id: %s and Status: %s", cmd.id, cmd.status) + logging.info( + "Qubole command created with Id: %s and Status: %s", cmd.id, cmd.status + ) while not Command.is_done(cmd.status): time.sleep(qbol.poll_interval) @@ -96,21 +106,32 @@ def run_query(self, query, user): columns = [] error = None - if cmd.status == 'done': + if cmd.status == "done": fp = StringIO() - cmd.get_results(fp=fp, inline=True, delim='\t', fetch=False, - qlog=None, arguments=['true']) + cmd.get_results( + fp=fp, + inline=True, + delim="\t", + fetch=False, + qlog=None, + arguments=["true"], + ) results = fp.getvalue() fp.close() - data = results.split('\r\n') - columns = self.fetch_columns([(i, TYPE_STRING) for i in data.pop(0).split('\t')]) - rows = [dict(zip((column['name'] for column in columns), row.split('\t'))) for row in data] + data = results.split("\r\n") + columns = self.fetch_columns( + [(i, TYPE_STRING) for i in data.pop(0).split("\t")] + ) + rows = [ + dict(zip((column["name"] for column in columns), row.split("\t"))) + for row in data + ] - json_data = json_dumps({'columns': columns, 'rows': rows}) + json_data = json_dumps({"columns": columns, "rows": rows}) except KeyboardInterrupt: - logging.info('Sending KILL signal to Qubole Command Id: %s', cmd.id) + logging.info("Sending KILL signal to Qubole Command Id: %s", cmd.id) cmd.cancel() error = "Query cancelled by user." json_data = None @@ -121,29 +142,37 @@ def get_schema(self, get_stats=False): schemas = {} try: headers = self._get_header() - content = requests.get("%s/api/latest/hive?describe=true&per_page=10000" % - self.configuration.get('endpoint'), headers=headers) + content = requests.get( + "%s/api/latest/hive?describe=true&per_page=10000" + % self.configuration.get("endpoint"), + headers=headers, + ) data = content.json() - for schema in data['schemas']: - tables = data['schemas'][schema] + for schema in data["schemas"]: + tables = data["schemas"][schema] for table in tables: table_name = list(table.keys())[0] - columns = [f['name'] for f in table[table_name]['columns']] + columns = [f["name"] for f in table[table_name]["columns"]] - if schema != 'default': - table_name = '{}.{}'.format(schema, table_name) + if schema != "default": + table_name = "{}.{}".format(schema, table_name) - schemas[table_name] = {'name': table_name, 'columns': columns} + schemas[table_name] = {"name": table_name, "columns": columns} except Exception as e: - logging.error("Failed to get schema information from Qubole. Error {}".format(str(e))) + logging.error( + "Failed to get schema information from Qubole. Error {}".format(str(e)) + ) return list(schemas.values()) def _get_header(self): - return {"Content-type": "application/json", "Accept": "application/json", - "X-AUTH-TOKEN": self.configuration.get('token')} + return { + "Content-type": "application/json", + "Accept": "application/json", + "X-AUTH-TOKEN": self.configuration.get("token"), + } register(Qubole) diff --git a/redash/query_runner/query_results.py b/redash/query_runner/query_results.py index 564a234470..20ff9f98dc 100644 --- a/redash/query_runner/query_results.py +++ b/redash/query_runner/query_results.py @@ -19,13 +19,12 @@ class CreateTableError(Exception): def extract_query_ids(query): - queries = re.findall(r'(?:join|from)\s+query_(\d+)', query, re.IGNORECASE) + queries = re.findall(r"(?:join|from)\s+query_(\d+)", query, re.IGNORECASE) return [int(q) for q in queries] def extract_cached_query_ids(query): - queries = re.findall(r'(?:join|from)\s+cached_query_(\d+)', query, - re.IGNORECASE) + queries = re.findall(r"(?:join|from)\s+cached_query_(\d+)", query, re.IGNORECASE) return [int(q) for q in queries] @@ -38,8 +37,7 @@ def _load_query(user, query_id): # TODO: this duplicates some of the logic we already have in the redash.handlers.query_results. # We should merge it so it's consistent. if not has_access(query.data_source, user, view_only): - raise PermissionError("You do not have access to query id {}.".format( - query.id)) + raise PermissionError("You do not have access to query id {}.".format(query.id)) return query @@ -50,35 +48,31 @@ def get_query_results(user, query_id, bring_from_cache): if query.latest_query_data_id is not None: results = query.latest_query_data.data else: - raise Exception("No cached result available for query {}.".format( - query.id)) + raise Exception("No cached result available for query {}.".format(query.id)) else: results, error = query.data_source.query_runner.run_query( - query.query_text, user) + query.query_text, user + ) if error: - raise Exception("Failed loading results for query id {}.".format( - query.id)) + raise Exception("Failed loading results for query id {}.".format(query.id)) return json_loads(results) -def create_tables_from_query_ids(user, - connection, - query_ids, - cached_query_ids=[]): +def create_tables_from_query_ids(user, connection, query_ids, cached_query_ids=[]): for query_id in set(cached_query_ids): results = get_query_results(user, query_id, True) - table_name = 'cached_query_{query_id}'.format(query_id=query_id) + table_name = "cached_query_{query_id}".format(query_id=query_id) create_table(connection, table_name, results) for query_id in set(query_ids): results = get_query_results(user, query_id, False) - table_name = 'query_{query_id}'.format(query_id=query_id) + table_name = "query_{query_id}".format(query_id=query_id) create_table(connection, table_name, results) def fix_column_name(name): - return '"{}"'.format(re.sub('[:."\s]', '_', name, flags=re.UNICODE)) + return '"{}"'.format(re.sub('[:."\s]', "_", name, flags=re.UNICODE)) def flatten(value): @@ -90,31 +84,34 @@ def flatten(value): def create_table(connection, table_name, query_results): try: - columns = [column['name'] for column in query_results['columns']] + columns = [column["name"] for column in query_results["columns"]] safe_columns = [fix_column_name(column) for column in columns] column_list = ", ".join(safe_columns) create_table = "CREATE TABLE {table_name} ({column_list})".format( - table_name=table_name, column_list=column_list) + table_name=table_name, column_list=column_list + ) logger.debug("CREATE TABLE query: %s", create_table) connection.execute(create_table) except sqlite3.OperationalError as exc: - raise CreateTableError("Error creating table {}: {}".format( - table_name, str(exc))) + raise CreateTableError( + "Error creating table {}: {}".format(table_name, str(exc)) + ) insert_template = "insert into {table_name} ({column_list}) values ({place_holders})".format( table_name=table_name, column_list=column_list, - place_holders=','.join(['?'] * len(columns))) + place_holders=",".join(["?"] * len(columns)), + ) - for row in query_results['rows']: + for row in query_results["rows"]: values = [flatten(row.get(column)) for column in columns] connection.execute(insert_template, values) class Results(BaseQueryRunner): should_annotate_query = False - noop_query = 'SELECT 1' + noop_query = "SELECT 1" @classmethod def configuration_schema(cls): @@ -125,12 +122,11 @@ def name(cls): return "Query Results" def run_query(self, query, user): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") query_ids = extract_query_ids(query) cached_query_ids = extract_cached_query_ids(query) - create_tables_from_query_ids(user, connection, query_ids, - cached_query_ids) + create_tables_from_query_ids(user, connection, query_ids, cached_query_ids) cursor = connection.cursor() @@ -138,28 +134,27 @@ def run_query(self, query, user): cursor.execute(query) if cursor.description is not None: - columns = self.fetch_columns([(i[0], None) - for i in cursor.description]) + columns = self.fetch_columns([(i[0], None) for i in cursor.description]) rows = [] - column_names = [c['name'] for c in columns] + column_names = [c["name"] for c in columns] for i, row in enumerate(cursor): for j, col in enumerate(row): guess = guess_type(col) - if columns[j]['type'] is None: - columns[j]['type'] = guess - elif columns[j]['type'] != guess: - columns[j]['type'] = TYPE_STRING + if columns[j]["type"] is None: + columns[j]["type"] = guess + elif columns[j]["type"] != guess: + columns[j]["type"] = TYPE_STRING rows.append(dict(zip(column_names, row))) - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) else: - error = 'Query completed but it returned no data.' + error = "Query completed but it returned no data." json_data = None except KeyboardInterrupt: connection.cancel() diff --git a/redash/query_runner/rockset.py b/redash/query_runner/rockset.py index 5d3a6a332e..567db6d7e0 100644 --- a/redash/query_runner/rockset.py +++ b/redash/query_runner/rockset.py @@ -22,29 +22,29 @@ def __init__(self, api_key, api_server): self.api_key = api_key self.api_server = api_server - def _request(self, endpoint, method='GET', body=None): - headers = {'Authorization': 'ApiKey {}'.format(self.api_key)} - url = '{}/v1/orgs/self/{}'.format(self.api_server, endpoint) + def _request(self, endpoint, method="GET", body=None): + headers = {"Authorization": "ApiKey {}".format(self.api_key)} + url = "{}/v1/orgs/self/{}".format(self.api_server, endpoint) - if method == 'GET': + if method == "GET": r = requests.get(url, headers=headers) return r.json() - elif method == 'POST': + elif method == "POST": r = requests.post(url, headers=headers, json=body) return r.json() else: - raise 'Unknown method: {}'.format(method) + raise "Unknown method: {}".format(method) def list(self): - response = self._request('ws/commons/collections') - return response['data'] + response = self._request("ws/commons/collections") + return response["data"] def query(self, sql): - return self._request('queries', 'POST', {'sql': {'query': sql}}) + return self._request("queries", "POST", {"sql": {"query": sql}}) class Rockset(BaseSQLQueryRunner): - noop_query = 'SELECT 1' + noop_query = "SELECT 1" @classmethod def configuration_schema(cls): @@ -54,16 +54,13 @@ def configuration_schema(cls): "api_server": { "type": "string", "title": "API Server", - "default": "https://api.rs2.usw2.rockset.com" - }, - "api_key": { - "title": "API Key", - "type": "string", + "default": "https://api.rs2.usw2.rockset.com", }, + "api_key": {"title": "API Key", "type": "string"}, }, "order": ["api_key", "api_server"], "required": ["api_server", "api_key"], - "secret": ["api_key"] + "secret": ["api_key"], } @classmethod @@ -72,33 +69,37 @@ def type(cls): def __init__(self, configuration): super(Rockset, self).__init__(configuration) - self.api = RocksetAPI(self.configuration.get('api_key'), self.configuration.get( - 'api_server', "https://api.rs2.usw2.rockset.com")) + self.api = RocksetAPI( + self.configuration.get("api_key"), + self.configuration.get("api_server", "https://api.rs2.usw2.rockset.com"), + ) def _get_tables(self, schema): for col in self.api.list(): - table_name = col['name'] + table_name = col["name"] describe = self.api.query('DESCRIBE "{}"'.format(table_name)) - columns = list(set([result['field'][0] for result in describe['results']])) - schema[table_name] = {'name': table_name, 'columns': columns} + columns = list(set([result["field"][0] for result in describe["results"]])) + schema[table_name] = {"name": table_name, "columns": columns} return list(schema.values()) def run_query(self, query, user): results = self.api.query(query) - if 'code' in results and results['code'] != 200: - return None, '{}: {}'.format(results['type'], results['message']) + if "code" in results and results["code"] != 200: + return None, "{}: {}".format(results["type"], results["message"]) - if 'results' not in results: - message = results.get('message', "Unknown response from Rockset.") + if "results" not in results: + message = results.get("message", "Unknown response from Rockset.") return None, message - rows = results['results'] + rows = results["results"] columns = [] if len(rows) > 0: columns = [] for k in rows[0]: - columns.append({'name': k, 'friendly_name': k, 'type': _get_type(rows[0][k])}) - data = json_dumps({'columns': columns, 'rows': rows}) + columns.append( + {"name": k, "friendly_name": k, "type": _get_type(rows[0][k])} + ) + data = json_dumps({"columns": columns, "rows": rows}) return data, None diff --git a/redash/query_runner/salesforce.py b/redash/query_runner/salesforce.py index 3a2af3ab6c..3091f09943 100644 --- a/redash/query_runner/salesforce.py +++ b/redash/query_runner/salesforce.py @@ -2,13 +2,22 @@ import logging from collections import OrderedDict from redash.query_runner import BaseQueryRunner, register -from redash.query_runner import TYPE_STRING, TYPE_DATE, TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN +from redash.query_runner import ( + TYPE_STRING, + TYPE_DATE, + TYPE_DATETIME, + TYPE_INTEGER, + TYPE_FLOAT, + TYPE_BOOLEAN, +) from redash.utils import json_dumps + logger = logging.getLogger(__name__) try: from simple_salesforce import Salesforce as SimpleSalesforce from simple_salesforce.api import SalesforceError, DEFAULT_API_VERSION + enabled = True except ImportError as e: enabled = False @@ -39,7 +48,7 @@ combobox=TYPE_STRING, calculated=TYPE_STRING, anyType=TYPE_STRING, - address=TYPE_STRING + address=TYPE_STRING, ) # Query Runner for Salesforce SOQL Queries @@ -59,27 +68,18 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "username": { - "type": "string" - }, - "password": { - "type": "string" - }, - "token": { - "type": "string", - "title": "Security Token" - }, - "sandbox": { - "type": "boolean" - }, + "username": {"type": "string"}, + "password": {"type": "string"}, + "token": {"type": "string", "title": "Security Token"}, + "sandbox": {"type": "boolean"}, "api_version": { "type": "string", "title": "Salesforce API Version", - "default": DEFAULT_API_VERSION - } + "default": DEFAULT_API_VERSION, + }, }, "required": ["username", "password", "token"], - "secret": ["password", "token"] + "secret": ["password", "token"], } def test_connection(self): @@ -89,23 +89,25 @@ def test_connection(self): pass def _get_sf(self): - sf = SimpleSalesforce(username=self.configuration['username'], - password=self.configuration['password'], - security_token=self.configuration['token'], - sandbox=self.configuration.get('sandbox', False), - version=self.configuration.get('api_version', DEFAULT_API_VERSION), - client_id='Redash') + sf = SimpleSalesforce( + username=self.configuration["username"], + password=self.configuration["password"], + security_token=self.configuration["token"], + sandbox=self.configuration.get("sandbox", False), + version=self.configuration.get("api_version", DEFAULT_API_VERSION), + client_id="Redash", + ) return sf def _clean_value(self, value): - if isinstance(value, OrderedDict) and 'records' in value: - value = value['records'] + if isinstance(value, OrderedDict) and "records" in value: + value = value["records"] for row in value: - row.pop('attributes', None) + row.pop("attributes", None) return value def _get_value(self, dct, dots): - for key in dots.split('.'): + for key in dots.split("."): if dct is not None and key in dct: dct = dct.get(key) else: @@ -113,20 +115,20 @@ def _get_value(self, dct, dots): return dct def _get_column_name(self, key, parents=[]): - return '.'.join(parents + [key]) + return ".".join(parents + [key]) def _build_columns(self, sf, child, parents=[]): - child_type = child['attributes']['type'] + child_type = child["attributes"]["type"] child_desc = sf.__getattr__(child_type).describe() - child_type_map = dict((f['name'], f['type'])for f in child_desc['fields']) + child_type_map = dict((f["name"], f["type"]) for f in child_desc["fields"]) columns = [] for key in child.keys(): - if key != 'attributes': - if isinstance(child[key], OrderedDict) and 'attributes' in child[key]: + if key != "attributes": + if isinstance(child[key], OrderedDict) and "attributes" in child[key]: columns.extend(self._build_columns(sf, child[key], parents + [key])) else: column_name = self._get_column_name(key, parents) - key_type = child_type_map.get(key, 'string') + key_type = child_type_map.get(key, "string") column_type = TYPES_MAP.get(key_type, TYPE_STRING) columns.append((column_name, column_type)) return columns @@ -134,7 +136,7 @@ def _build_columns(self, sf, child, parents=[]): def _build_rows(self, columns, records): rows = [] for record in records: - record.pop('attributes', None) + record.pop("attributes", None) row = dict() for column in columns: key = column[0] @@ -151,16 +153,16 @@ def run_query(self, query, user): rows = [] sf = self._get_sf() response = sf.query_all(query) - records = response['records'] - if response['totalSize'] > 0 and len(records) == 0: - columns = self.fetch_columns([('Count', TYPE_INTEGER)]) - rows = [{'Count': response['totalSize']}] + records = response["records"] + if response["totalSize"] > 0 and len(records) == 0: + columns = self.fetch_columns([("Count", TYPE_INTEGER)]) + rows = [{"Count": response["totalSize"]}] elif len(records) > 0: cols = self._build_columns(sf, records[0]) rows = self._build_rows(cols, records) columns = self.fetch_columns(cols) error = None - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) except SalesforceError as err: error = err.content @@ -174,12 +176,15 @@ def get_schema(self, get_stats=False): raise Exception("Failed describing objects.") schema = {} - for sobject in response['sobjects']: - table_name = sobject['name'] - if sobject['queryable'] is True and table_name not in schema: - desc = sf.__getattr__(sobject['name']).describe() - fields = desc['fields'] - schema[table_name] = {'name': table_name, 'columns': [f['name'] for f in fields]} + for sobject in response["sobjects"]: + table_name = sobject["name"] + if sobject["queryable"] is True and table_name not in schema: + desc = sf.__getattr__(sobject["name"]).describe() + fields = desc["fields"] + schema[table_name] = { + "name": table_name, + "columns": [f["name"] for f in fields], + } return list(schema.values()) diff --git a/redash/query_runner/script.py b/redash/query_runner/script.py index 6c529e9e39..4324961a79 100644 --- a/redash/query_runner/script.py +++ b/redash/query_runner/script.py @@ -38,18 +38,15 @@ def enabled(cls): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'title': 'Scripts path' + "type": "object", + "properties": { + "path": {"type": "string", "title": "Scripts path"}, + "shell": { + "type": "boolean", + "title": "Execute command through the shell", }, - 'shell': { - 'type': 'boolean', - 'title': 'Execute command through the shell' - } }, - 'required': ['path'] + "required": ["path"], } @classmethod @@ -65,7 +62,9 @@ def __init__(self, configuration): # Poor man's protection against running scripts from outside the scripts directory if self.configuration["path"].find("../") > -1: - raise ValueError("Scripts can only be run from the configured scripts directory") + raise ValueError( + "Scripts can only be run from the configured scripts directory" + ) def test_connection(self): pass @@ -73,7 +72,7 @@ def test_connection(self): def run_query(self, query, user): try: script = query_to_script_path(self.configuration["path"], query) - return run_script(script, self.configuration['shell']) + return run_script(script, self.configuration["shell"]) except IOError as e: return None, e.message except subprocess.CalledProcessError as e: diff --git a/redash/query_runner/snowflake.py b/redash/query_runner/snowflake.py index 148f58086a..bac5102e36 100644 --- a/redash/query_runner/snowflake.py +++ b/redash/query_runner/snowflake.py @@ -1,14 +1,20 @@ - - try: import snowflake.connector + enabled = True except ImportError: enabled = False from redash.query_runner import BaseQueryRunner, register -from redash.query_runner import TYPE_STRING, TYPE_DATE, TYPE_DATETIME, TYPE_INTEGER, TYPE_FLOAT, TYPE_BOOLEAN +from redash.query_runner import ( + TYPE_STRING, + TYPE_DATE, + TYPE_DATETIME, + TYPE_INTEGER, + TYPE_FLOAT, + TYPE_BOOLEAN, +) from redash.utils import json_dumps, json_loads TYPES_MAP = { @@ -19,7 +25,7 @@ 4: TYPE_DATETIME, 5: TYPE_STRING, 6: TYPE_DATETIME, - 13: TYPE_BOOLEAN + 13: TYPE_BOOLEAN, } @@ -31,29 +37,16 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "account": { - "type": "string" - }, - "user": { - "type": "string" - }, - "password": { - "type": "string" - }, - "warehouse": { - "type": "string" - }, - "database": { - "type": "string" - }, - "region": { - "type": "string", - "default": "us-west" - } + "account": {"type": "string"}, + "user": {"type": "string"}, + "password": {"type": "string"}, + "warehouse": {"type": "string"}, + "database": {"type": "string"}, + "region": {"type": "string", "default": "us-west"}, }, "order": ["account", "user", "password", "warehouse", "database", "region"], "required": ["user", "password", "account", "database", "warehouse"], - "secret": ["password"] + "secret": ["password"], } @classmethod @@ -68,28 +61,30 @@ def determine_type(cls, data_type, scale): return t def _get_connection(self): - region = self.configuration.get('region') + region = self.configuration.get("region") # for us-west we don't need to pass a region (and if we do, it fails to connect) - if region == 'us-west': + if region == "us-west": region = None connection = snowflake.connector.connect( - user=self.configuration['user'], - password=self.configuration['password'], - account=self.configuration['account'], - region=region + user=self.configuration["user"], + password=self.configuration["password"], + account=self.configuration["account"], + region=region, ) return connection def _parse_results(self, cursor): columns = self.fetch_columns( - [(i[0], self.determine_type(i[1], i[5])) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) - for row in cursor] + [(i[0], self.determine_type(i[1], i[5])) for i in cursor.description] + ) + rows = [ + dict(zip((column["name"] for column in columns), row)) for row in cursor + ] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} return data def run_query(self, query, user): @@ -97,9 +92,8 @@ def run_query(self, query, user): cursor = connection.cursor() try: - cursor.execute("USE WAREHOUSE {}".format( - self.configuration['warehouse'])) - cursor.execute("USE {}".format(self.configuration['database'])) + cursor.execute("USE WAREHOUSE {}".format(self.configuration["warehouse"])) + cursor.execute("USE {}".format(self.configuration["database"])) cursor.execute(query) @@ -117,7 +111,7 @@ def _run_query_without_warehouse(self, query): cursor = connection.cursor() try: - cursor.execute("USE {}".format(self.configuration['database'])) + cursor.execute("USE {}".format(self.configuration["database"])) cursor.execute(query) @@ -132,7 +126,9 @@ def _run_query_without_warehouse(self, query): def get_schema(self, get_stats=False): query = """ SHOW COLUMNS IN DATABASE {database} - """.format(database=self.configuration['database']) + """.format( + database=self.configuration["database"] + ) results, error = self._run_query_without_warehouse(query) @@ -140,14 +136,14 @@ def get_schema(self, get_stats=False): raise Exception("Failed getting schema.") schema = {} - for row in results['rows']: - if row['kind'] == 'COLUMN': - table_name = '{}.{}'.format(row['schema_name'], row['table_name']) + for row in results["rows"]: + if row["kind"] == "COLUMN": + table_name = "{}.{}".format(row["schema_name"], row["table_name"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) diff --git a/redash/query_runner/sqlite.py b/redash/query_runner/sqlite.py index 3881e569f6..f0ccf6b8e6 100644 --- a/redash/query_runner/sqlite.py +++ b/redash/query_runner/sqlite.py @@ -14,12 +14,7 @@ class Sqlite(BaseSQLQueryRunner): def configuration_schema(cls): return { "type": "object", - "properties": { - "dbpath": { - "type": "string", - "title": "Database Path" - } - }, + "properties": {"dbpath": {"type": "string", "title": "Database Path"}}, "required": ["dbpath"], } @@ -30,7 +25,7 @@ def type(cls): def __init__(self, configuration): super(Sqlite, self).__init__(configuration) - self._dbpath = self.configuration['dbpath'] + self._dbpath = self.configuration["dbpath"] def _get_tables(self, schema): query_table = "select tbl_name from sqlite_master where type='table'" @@ -43,16 +38,16 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - table_name = row['tbl_name'] - schema[table_name] = {'name': table_name, 'columns': []} + for row in results["rows"]: + table_name = row["tbl_name"] + schema[table_name] = {"name": table_name, "columns": []} results_table, error = self.run_query(query_columns % (table_name,), None) if error is not None: raise Exception("Failed getting schema.") results_table = json_loads(results_table) - for row_column in results_table['rows']: - schema[table_name]['columns'].append(row_column['name']) + for row_column in results_table["rows"]: + schema[table_name]["columns"].append(row_column["name"]) return list(schema.values()) @@ -66,13 +61,16 @@ def run_query(self, query, user): if cursor.description is not None: columns = self.fetch_columns([(i[0], None) for i in cursor.description]) - rows = [dict(zip((column['name'] for column in columns), row)) for row in cursor] + rows = [ + dict(zip((column["name"] for column in columns), row)) + for row in cursor + ] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} error = None json_data = json_dumps(data) else: - error = 'Query completed but it returned no data.' + error = "Query completed but it returned no data." json_data = None except KeyboardInterrupt: connection.cancel() diff --git a/redash/query_runner/treasuredata.py b/redash/query_runner/treasuredata.py index 88637ec957..ae3cd7776d 100644 --- a/redash/query_runner/treasuredata.py +++ b/redash/query_runner/treasuredata.py @@ -8,28 +8,29 @@ try: import tdclient from tdclient import errors + enabled = True except ImportError: enabled = False TD_TYPES_MAPPING = { - 'bigint': TYPE_INTEGER, - 'tinyint': TYPE_INTEGER, - 'smallint': TYPE_INTEGER, - 'int': TYPE_INTEGER, - 'integer': TYPE_INTEGER, - 'long': TYPE_INTEGER, - 'double': TYPE_FLOAT, - 'decimal': TYPE_FLOAT, - 'float': TYPE_FLOAT, - 'real': TYPE_FLOAT, - 'boolean': TYPE_BOOLEAN, - 'timestamp': TYPE_DATETIME, - 'date': TYPE_DATETIME, - 'char': TYPE_STRING, - 'string': TYPE_STRING, - 'varchar': TYPE_STRING, + "bigint": TYPE_INTEGER, + "tinyint": TYPE_INTEGER, + "smallint": TYPE_INTEGER, + "int": TYPE_INTEGER, + "integer": TYPE_INTEGER, + "long": TYPE_INTEGER, + "double": TYPE_FLOAT, + "decimal": TYPE_FLOAT, + "float": TYPE_FLOAT, + "real": TYPE_FLOAT, + "boolean": TYPE_BOOLEAN, + "timestamp": TYPE_DATETIME, + "date": TYPE_DATETIME, + "char": TYPE_STRING, + "string": TYPE_STRING, + "varchar": TYPE_STRING, } @@ -40,28 +41,19 @@ class TreasureData(BaseQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'endpoint': { - 'type': 'string' - }, - 'apikey': { - 'type': 'string' - }, - 'type': { - 'type': 'string' - }, - 'db': { - 'type': 'string', - 'title': 'Database Name' + "type": "object", + "properties": { + "endpoint": {"type": "string"}, + "apikey": {"type": "string"}, + "type": {"type": "string"}, + "db": {"type": "string", "title": "Database Name"}, + "get_schema": { + "type": "boolean", + "title": "Auto Schema Retrieval", + "default": False, }, - 'get_schema': { - 'type': 'boolean', - 'title': 'Auto Schema Retrieval', - 'default': False - } }, - 'required': ['apikey', 'db'] + "required": ["apikey", "db"], } @classmethod @@ -74,15 +66,17 @@ def type(cls): def get_schema(self, get_stats=False): schema = {} - if self.configuration.get('get_schema', False): + if self.configuration.get("get_schema", False): try: - with tdclient.Client(self.configuration.get('apikey')) as client: - for table in client.tables(self.configuration.get('db')): - table_name = '{}.{}'.format(self.configuration.get('db'), table.name) + with tdclient.Client(self.configuration.get("apikey")) as client: + for table in client.tables(self.configuration.get("db")): + table_name = "{}.{}".format( + self.configuration.get("db"), table.name + ) for table_schema in table.schema: schema[table_name] = { - 'name': table_name, - 'columns': [column[0] for column in table.schema], + "name": table_name, + "columns": [column[0] for column in table.schema], } except Exception as ex: raise Exception("Failed getting schema") @@ -90,27 +84,39 @@ def get_schema(self, get_stats=False): def run_query(self, query, user): connection = tdclient.connect( - endpoint=self.configuration.get('endpoint', 'https://api.treasuredata.com'), - apikey=self.configuration.get('apikey'), - type=self.configuration.get('type', 'hive').lower(), - db=self.configuration.get('db')) + endpoint=self.configuration.get("endpoint", "https://api.treasuredata.com"), + apikey=self.configuration.get("apikey"), + type=self.configuration.get("type", "hive").lower(), + db=self.configuration.get("db"), + ) cursor = connection.cursor() try: cursor.execute(query) - columns_tuples = [(i[0], TD_TYPES_MAPPING.get(i[1], None)) for i in cursor.show_job()['hive_result_schema']] + columns_tuples = [ + (i[0], TD_TYPES_MAPPING.get(i[1], None)) + for i in cursor.show_job()["hive_result_schema"] + ] columns = self.fetch_columns(columns_tuples) if cursor.rowcount == 0: rows = [] else: - rows = [dict(zip(([column['name'] for column in columns]), r)) for r in cursor.fetchall()] - data = {'columns': columns, 'rows': rows} + rows = [ + dict(zip(([column["name"] for column in columns]), r)) + for r in cursor.fetchall() + ] + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None except errors.InternalError as e: json_data = None - error = "%s: %s" % (e.message, cursor.show_job().get('debug', {}).get('stderr', 'No stderr message in the response')) + error = "%s: %s" % ( + e.message, + cursor.show_job() + .get("debug", {}) + .get("stderr", "No stderr message in the response"), + ) return json_data, error diff --git a/redash/query_runner/uptycs.py b/redash/query_runner/uptycs.py index 3f5a37a209..15c9b30bc0 100644 --- a/redash/query_runner/uptycs.py +++ b/redash/query_runner/uptycs.py @@ -18,37 +18,29 @@ def configuration_schema(cls): return { "type": "object", "properties": { - "url": { - "type": "string" - }, - "customer_id": { - "type": "string" - }, - "key": { - "type": "string" - }, + "url": {"type": "string"}, + "customer_id": {"type": "string"}, + "key": {"type": "string"}, "verify_ssl": { "type": "boolean", "default": True, "title": "Verify SSL Certificates", }, - "secret": { - "type": "string", - }, + "secret": {"type": "string"}, }, - "order": ['url', 'customer_id', 'key', 'secret'], + "order": ["url", "customer_id", "key", "secret"], "required": ["url", "customer_id", "key", "secret"], - "secret": ["secret", "key"] + "secret": ["secret", "key"], } def generate_header(self, key, secret): header = {} utcnow = datetime.datetime.utcnow() date = utcnow.strftime("%a, %d %b %Y %H:%M:%S GMT") - auth_var = jwt.encode({'iss': key}, secret, algorithm='HS256') + auth_var = jwt.encode({"iss": key}, secret, algorithm="HS256") authorization = "Bearer %s" % (auth_var) - header['date'] = date - header['Authorization'] = authorization + header["date"] = date + header["Authorization"] = authorization return header def transformed_to_redash_json(self, data): @@ -56,42 +48,44 @@ def transformed_to_redash_json(self, data): rows = [] # convert all type to JSON string # In future we correct data type mapping later - if 'columns' in data: - for json_each in data['columns']: - name = json_each['name'] - new_json = {"name": name, - "type": "string", - "friendly_name": name} + if "columns" in data: + for json_each in data["columns"]: + name = json_each["name"] + new_json = {"name": name, "type": "string", "friendly_name": name} transformed_columns.append(new_json) # Transfored items into rows. - if 'items' in data: - rows = data['items'] + if "items" in data: + rows = data["items"] - redash_json_data = {"columns": transformed_columns, - "rows": rows} + redash_json_data = {"columns": transformed_columns, "rows": rows} return redash_json_data def api_call(self, sql): # JWT encoded header - header = self.generate_header(self.configuration.get('key'), - self.configuration.get('secret')) + header = self.generate_header( + self.configuration.get("key"), self.configuration.get("secret") + ) # URL form using API key file based on GLOBAL - url = ("%s/public/api/customers/%s/query" % - (self.configuration.get('url'), - self.configuration.get('customer_id'))) + url = "%s/public/api/customers/%s/query" % ( + self.configuration.get("url"), + self.configuration.get("customer_id"), + ) # post data base sql post_data_json = {"query": sql} - response = requests.post(url, headers=header, json=post_data_json, - verify=self.configuration.get('verify_ssl', - True)) + response = requests.post( + url, + headers=header, + json=post_data_json, + verify=self.configuration.get("verify_ssl", True), + ) if response.status_code == 200: response_output = json_loads(response.content) else: - error = 'status_code ' + str(response.status_code) + '\n' + error = "status_code " + str(response.status_code) + "\n" error = error + "failed to connect" json_data = {} return json_data, error @@ -99,9 +93,9 @@ def api_call(self, sql): json_data = self.transformed_to_redash_json(response_output) error = None # if we got error from Uptycs include error information - if 'error' in response_output: - error = response_output['error']['message']['brief'] - error = error + '\n' + response_output['error']['message']['detail'] + if "error" in response_output: + error = response_output["error"]["message"]["brief"] + error = error + "\n" + response_output["error"]["message"]["detail"] return json_data, error def run_query(self, query, user): @@ -111,21 +105,23 @@ def run_query(self, query, user): return json_data, error def get_schema(self, get_stats=False): - header = self.generate_header(self.configuration.get('key'), - self.configuration.get('secret')) - url = ("%s/public/api/customers/%s/schema/global" % - (self.configuration.get('url'), - self.configuration.get('customer_id'))) - response = requests.get(url, headers=header, - verify=self.configuration.get('verify_ssl', - True)) + header = self.generate_header( + self.configuration.get("key"), self.configuration.get("secret") + ) + url = "%s/public/api/customers/%s/schema/global" % ( + self.configuration.get("url"), + self.configuration.get("customer_id"), + ) + response = requests.get( + url, headers=header, verify=self.configuration.get("verify_ssl", True) + ) redash_json = [] schema = json_loads(response.content) - for each_def in schema['tables']: - table_name = each_def['name'] + for each_def in schema["tables"]: + table_name = each_def["name"] columns = [] - for col in each_def['columns']: - columns.append(col['name']) + for col in each_def["columns"]: + columns.append(col["name"]) table_json = {"name": table_name, "columns": columns} redash_json.append(table_json) diff --git a/redash/query_runner/vertica.py b/redash/query_runner/vertica.py index 5c5e80e7d6..fd075d248f 100644 --- a/redash/query_runner/vertica.py +++ b/redash/query_runner/vertica.py @@ -23,7 +23,7 @@ 114: TYPE_DATETIME, 115: TYPE_STRING, 116: TYPE_STRING, - 117: TYPE_STRING + 117: TYPE_STRING, } @@ -33,37 +33,27 @@ class Vertica(BaseSQLQueryRunner): @classmethod def configuration_schema(cls): return { - 'type': 'object', - 'properties': { - 'host': { - 'type': 'string' - }, - 'user': { - 'type': 'string' - }, - 'password': { - 'type': 'string', - 'title': 'Password' - }, - 'database': { - 'type': 'string', - 'title': 'Database name' - }, - "port": { - "type": "number" - }, - "read_timeout": { - "type": "number", - "title": "Read Timeout" - }, - "connection_timeout": { - "type": "number", - "title": "Connection Timeout" - }, + "type": "object", + "properties": { + "host": {"type": "string"}, + "user": {"type": "string"}, + "password": {"type": "string", "title": "Password"}, + "database": {"type": "string", "title": "Database name"}, + "port": {"type": "number"}, + "read_timeout": {"type": "number", "title": "Read Timeout"}, + "connection_timeout": {"type": "number", "title": "Connection Timeout"}, }, - 'required': ['database'], - 'order': ['host', 'port', 'user', 'password', 'database', 'read_timeout', 'connection_timeout'], - 'secret': ['password'] + "required": ["database"], + "order": [ + "host", + "port", + "user", + "password", + "database", + "read_timeout", + "connection_timeout", + ], + "secret": ["password"], } @classmethod @@ -89,13 +79,13 @@ def _get_tables(self, schema): results = json_loads(results) - for row in results['rows']: - table_name = '{}.{}'.format(row['table_schema'], row['table_name']) + for row in results["rows"]: + table_name = "{}.{}".format(row["table_schema"], row["table_name"]) if table_name not in schema: - schema[table_name] = {'name': table_name, 'columns': []} + schema[table_name] = {"name": table_name, "columns": []} - schema[table_name]['columns'].append(row['column_name']) + schema[table_name]["columns"].append(row["column_name"]) return list(schema.values()) @@ -110,16 +100,18 @@ def run_query(self, query, user): connection = None try: conn_info = { - 'host': self.configuration.get('host', ''), - 'port': self.configuration.get('port', 5433), - 'user': self.configuration.get('user', ''), - 'password': self.configuration.get('password', ''), - 'database': self.configuration.get('database', ''), - 'read_timeout': self.configuration.get('read_timeout', 600) + "host": self.configuration.get("host", ""), + "port": self.configuration.get("port", 5433), + "user": self.configuration.get("user", ""), + "password": self.configuration.get("password", ""), + "database": self.configuration.get("database", ""), + "read_timeout": self.configuration.get("read_timeout", 600), } - if self.configuration.get('connection_timeout'): - conn_info['connection_timeout'] = self.configuration.get('connection_timeout') + if self.configuration.get("connection_timeout"): + conn_info["connection_timeout"] = self.configuration.get( + "connection_timeout" + ) connection = vertica_python.connect(**conn_info) cursor = connection.cursor() @@ -127,13 +119,17 @@ def run_query(self, query, user): cursor.execute(query) if cursor.description is not None: - columns_data = [(i[0], types_map.get(i[1], None)) for i in cursor.description] + columns_data = [ + (i[0], types_map.get(i[1], None)) for i in cursor.description + ] columns = self.fetch_columns(columns_data) - rows = [dict(zip(([c['name'] for c in columns]), r)) - for r in cursor.fetchall()] + rows = [ + dict(zip(([c["name"] for c in columns]), r)) + for r in cursor.fetchall() + ] - data = {'columns': columns, 'rows': rows} + data = {"columns": columns, "rows": rows} json_data = json_dumps(data) error = None else: diff --git a/redash/query_runner/yandex_metrica.py b/redash/query_runner/yandex_metrica.py index 3736eb1290..948f42f700 100644 --- a/redash/query_runner/yandex_metrica.py +++ b/redash/query_runner/yandex_metrica.py @@ -10,55 +10,66 @@ logger = logging.getLogger(__name__) COLUMN_TYPES = { - 'date': ( - 'firstVisitDate', 'firstVisitStartOfYear', 'firstVisitStartOfQuarter', - 'firstVisitStartOfMonth', 'firstVisitStartOfWeek', + "date": ( + "firstVisitDate", + "firstVisitStartOfYear", + "firstVisitStartOfQuarter", + "firstVisitStartOfMonth", + "firstVisitStartOfWeek", ), - 'datetime': ( - 'firstVisitStartOfHour', 'firstVisitStartOfDekaminute', 'firstVisitStartOfMinute', - 'firstVisitDateTime', 'firstVisitHour', 'firstVisitHourMinute' - + "datetime": ( + "firstVisitStartOfHour", + "firstVisitStartOfDekaminute", + "firstVisitStartOfMinute", + "firstVisitDateTime", + "firstVisitHour", + "firstVisitHourMinute", ), - 'int': ( - 'pageViewsInterval', 'pageViews', 'firstVisitYear', 'firstVisitMonth', - 'firstVisitDayOfMonth', 'firstVisitDayOfWeek', 'firstVisitMinute', - 'firstVisitDekaminute', + "int": ( + "pageViewsInterval", + "pageViews", + "firstVisitYear", + "firstVisitMonth", + "firstVisitDayOfMonth", + "firstVisitDayOfWeek", + "firstVisitMinute", + "firstVisitDekaminute", ), } for type_, elements in COLUMN_TYPES.items(): for el in elements: - if 'first' in el: - el = el.replace('first', 'last') - COLUMN_TYPES[type_] += (el, ) + if "first" in el: + el = el.replace("first", "last") + COLUMN_TYPES[type_] += (el,) def parse_ym_response(response): columns = [] - dimensions_len = len(response['query']['dimensions']) + dimensions_len = len(response["query"]["dimensions"]) - for h in response['query']['dimensions'] + response['query']['metrics']: - friendly_name = h.split(':')[-1] - if friendly_name in COLUMN_TYPES['date']: + for h in response["query"]["dimensions"] + response["query"]["metrics"]: + friendly_name = h.split(":")[-1] + if friendly_name in COLUMN_TYPES["date"]: data_type = TYPE_DATE - elif friendly_name in COLUMN_TYPES['datetime']: + elif friendly_name in COLUMN_TYPES["datetime"]: data_type = TYPE_DATETIME else: data_type = TYPE_STRING - columns.append({'name': h, 'friendly_name': friendly_name, 'type': data_type}) + columns.append({"name": h, "friendly_name": friendly_name, "type": data_type}) rows = [] - for num, row in enumerate(response['data']): + for num, row in enumerate(response["data"]): res = {} - for i, d in enumerate(row['dimensions']): - res[columns[i]['name']] = d['name'] - for i, d in enumerate(row['metrics']): - res[columns[dimensions_len + i]['name']] = d + for i, d in enumerate(row["dimensions"]): + res[columns[i]["name"]] = d["name"] + for i, d in enumerate(row["metrics"]): + res[columns[dimensions_len + i]["name"]] = d if num == 0 and isinstance(d, float): - columns[dimensions_len + i]['type'] = TYPE_FLOAT + columns[dimensions_len + i]["type"] = TYPE_FLOAT rows.append(res) - return {'columns': columns, 'rows': rows} + return {"columns": columns, "rows": rows} class YandexMetrica(BaseSQLQueryRunner): @@ -77,46 +88,41 @@ def name(cls): def configuration_schema(cls): return { "type": "object", - "properties": { - "token": { - "type": "string", - "title": "OAuth Token" - } - }, + "properties": {"token": {"type": "string", "title": "OAuth Token"}}, "required": ["token"], } def __init__(self, configuration): super(YandexMetrica, self).__init__(configuration) - self.syntax = 'yaml' - self.host = 'https://api-metrica.yandex.com' - self.list_path = 'counters' + self.syntax = "yaml" + self.host = "https://api-metrica.yandex.com" + self.list_path = "counters" def _get_tables(self, schema): - counters = self._send_query('management/v1/{0}'.format(self.list_path)) + counters = self._send_query("management/v1/{0}".format(self.list_path)) for row in counters[self.list_path]: - owner = row.get('owner_login') - counter = '{0} | {1}'.format( - row.get('name', 'Unknown').encode('utf-8'), row.get('id', 'Unknown') + owner = row.get("owner_login") + counter = "{0} | {1}".format( + row.get("name", "Unknown").encode("utf-8"), row.get("id", "Unknown") ) if owner not in schema: - schema[owner] = {'name': owner, 'columns': []} + schema[owner] = {"name": owner, "columns": []} - schema[owner]['columns'].append(counter) + schema[owner]["columns"].append(counter) return list(schema.values()) def test_connection(self): - self._send_query('management/v1/{0}'.format(self.list_path)) + self._send_query("management/v1/{0}".format(self.list_path)) - def _send_query(self, path='stat/v1/data', **kwargs): - token = kwargs.pop('oauth_token', self.configuration['token']) + def _send_query(self, path="stat/v1/data", **kwargs): + token = kwargs.pop("oauth_token", self.configuration["token"]) r = requests.get( - '{0}/{1}'.format(self.host, path), - headers={'Authorization': 'OAuth {}'.format(token)}, - params=kwargs + "{0}/{1}".format(self.host, path), + headers={"Authorization": "OAuth {}".format(token)}, + params=kwargs, ) if r.status_code != 200: raise Exception(r.text) @@ -137,10 +143,10 @@ def run_query(self, query, user): return data, error if isinstance(params, dict): - if 'url' in params: - params = parse_qs(urlparse(params['url']).query, keep_blank_values=True) + if "url" in params: + params = parse_qs(urlparse(params["url"]).query, keep_blank_values=True) else: - error = 'The query format must be JSON or YAML' + error = "The query format must be JSON or YAML" return data, error try: @@ -164,8 +170,8 @@ def name(cls): def __init__(self, configuration): super(YandexAppMetrica, self).__init__(configuration) - self.host = 'https://api.appmetrica.yandex.com' - self.list_path = 'applications' + self.host = "https://api.appmetrica.yandex.com" + self.list_path = "applications" register(YandexMetrica) diff --git a/redash/schedule.py b/redash/schedule.py index 8a577e4c39..2577c8db09 100644 --- a/redash/schedule.py +++ b/redash/schedule.py @@ -10,31 +10,38 @@ from rq_scheduler import Scheduler from redash import settings, rq_redis_connection -from redash.tasks import (sync_user_details, refresh_queries, - empty_schedules, refresh_schemas, - cleanup_query_results, purge_failed_jobs, - version_check, send_aggregated_errors) +from redash.tasks import ( + sync_user_details, + refresh_queries, + empty_schedules, + refresh_schemas, + cleanup_query_results, + purge_failed_jobs, + version_check, + send_aggregated_errors, +) logger = logging.getLogger(__name__) -rq_scheduler = Scheduler(connection=rq_redis_connection, - queue_name="periodic", - interval=5) +rq_scheduler = Scheduler( + connection=rq_redis_connection, queue_name="periodic", interval=5 +) + def job_id(kwargs): metadata = kwargs.copy() - metadata['func'] = metadata['func'].__name__ + metadata["func"] = metadata["func"].__name__ return hashlib.sha1(json.dumps(metadata, sort_keys=True).encode()).hexdigest() def prep(kwargs): - interval = kwargs['interval'] + interval = kwargs["interval"] if isinstance(interval, timedelta): interval = int(interval.total_seconds()) - kwargs['interval'] = interval - kwargs['result_ttl'] = kwargs.get('result_ttl', interval * 2) + kwargs["interval"] = interval + kwargs["result_ttl"] = kwargs.get("result_ttl", interval * 2) return kwargs @@ -47,10 +54,21 @@ def periodic_job_definitions(): jobs = [ {"func": refresh_queries, "interval": 30, "result_ttl": 600}, {"func": empty_schedules, "interval": timedelta(minutes=60)}, - {"func": refresh_schemas, "interval": timedelta(minutes=settings.SCHEMAS_REFRESH_SCHEDULE)}, - {"func": sync_user_details, "timeout": 60, "ttl": 45, "interval": timedelta(minutes=1)}, + { + "func": refresh_schemas, + "interval": timedelta(minutes=settings.SCHEMAS_REFRESH_SCHEDULE), + }, + { + "func": sync_user_details, + "timeout": 60, + "ttl": 45, + "interval": timedelta(minutes=1), + }, {"func": purge_failed_jobs, "interval": timedelta(days=1)}, - {"func": send_aggregated_errors, "interval": timedelta(minutes=settings.SEND_FAILURE_EMAIL_INTERVAL)} + { + "func": send_aggregated_errors, + "interval": timedelta(minutes=settings.SEND_FAILURE_EMAIL_INTERVAL), + }, ] if settings.VERSION_CHECK: @@ -69,10 +87,14 @@ def schedule_periodic_jobs(jobs): job_definitions = [prep(job) for job in jobs] jobs_to_clean_up = Job.fetch_many( - set([job.id for job in rq_scheduler.get_jobs()]) - set([job_id(job) for job in job_definitions]), - rq_redis_connection) + set([job.id for job in rq_scheduler.get_jobs()]) + - set([job_id(job) for job in job_definitions]), + rq_redis_connection, + ) - jobs_to_schedule = [job for job in job_definitions if job_id(job) not in rq_scheduler] + jobs_to_schedule = [ + job for job in job_definitions if job_id(job) not in rq_scheduler + ] for job in jobs_to_clean_up: logger.info("Removing %s (%s) from schedule.", job.id, job.func_name) @@ -80,5 +102,10 @@ def schedule_periodic_jobs(jobs): job.delete() for job in jobs_to_schedule: - logger.info("Scheduling %s (%s) with interval %s.", job_id(job), job['func'].__name__, job.get('interval')) + logger.info( + "Scheduling %s (%s) with interval %s.", + job_id(job), + job["func"].__name__, + job.get("interval"), + ) schedule(job) diff --git a/redash/security.py b/redash/security.py index dc753008c2..dd14441385 100644 --- a/redash/security.py +++ b/redash/security.py @@ -8,16 +8,14 @@ def csp_allows_embeding(fn): - @functools.wraps(fn) def decorated(*args, **kwargs): return fn(*args, **kwargs) embedable_csp = talisman.content_security_policy + "frame-ancestors *;" - return talisman( - content_security_policy=embedable_csp, - frame_options=None, - )(decorated) + return talisman(content_security_policy=embedable_csp, frame_options=None)( + decorated + ) def init_app(app): diff --git a/redash/serializers/__init__.py b/redash/serializers/__init__.py index 0dc8b6a9f9..755f8fa01e 100644 --- a/redash/serializers/__init__.py +++ b/redash/serializers/__init__.py @@ -12,51 +12,56 @@ from redash.utils import json_loads from redash.models.parameterized_query import ParameterizedQuery -from .query_result import serialize_query_result, serialize_query_result_to_csv, serialize_query_result_to_xlsx +from .query_result import ( + serialize_query_result, + serialize_query_result_to_csv, + serialize_query_result_to_xlsx, +) def public_widget(widget): res = { - 'id': widget.id, - 'width': widget.width, - 'options': json_loads(widget.options), - 'text': widget.text, - 'updated_at': widget.updated_at, - 'created_at': widget.created_at + "id": widget.id, + "width": widget.width, + "options": json_loads(widget.options), + "text": widget.text, + "updated_at": widget.updated_at, + "created_at": widget.created_at, } v = widget.visualization if v and v.id: - res['visualization'] = { - 'type': v.type, - 'name': v.name, - 'description': v.description, - 'options': json_loads(v.options), - 'updated_at': v.updated_at, - 'created_at': v.created_at, - 'query': { - 'id': v.query_rel.id, - 'name': v.query_rel.name, - 'description': v.query_rel.description, - 'options': v.query_rel.options - } + res["visualization"] = { + "type": v.type, + "name": v.name, + "description": v.description, + "options": json_loads(v.options), + "updated_at": v.updated_at, + "created_at": v.created_at, + "query": { + "id": v.query_rel.id, + "name": v.query_rel.name, + "description": v.query_rel.description, + "options": v.query_rel.options, + }, } return res def public_dashboard(dashboard): - dashboard_dict = project(serialize_dashboard(dashboard, with_favorite_state=False), ( - 'name', 'layout', 'dashboard_filters_enabled', 'updated_at', - 'created_at' - )) - - widget_list = (models.Widget.query - .filter(models.Widget.dashboard_id == dashboard.id) - .outerjoin(models.Visualization) - .outerjoin(models.Query)) - - dashboard_dict['widgets'] = [public_widget(w) for w in widget_list] + dashboard_dict = project( + serialize_dashboard(dashboard, with_favorite_state=False), + ("name", "layout", "dashboard_filters_enabled", "updated_at", "created_at"), + ) + + widget_list = ( + models.Widget.query.filter(models.Widget.dashboard_id == dashboard.id) + .outerjoin(models.Visualization) + .outerjoin(models.Query) + ) + + dashboard_dict["widgets"] = [public_widget(w) for w in widget_list] return dashboard_dict @@ -72,116 +77,137 @@ def __init__(self, object_or_list, **kwargs): def serialize(self): if isinstance(self.object_or_list, models.Query): result = serialize_query(self.object_or_list, **self.options) - if self.options.get('with_favorite_state', True) and not current_user.is_api_user(): - result['is_favorite'] = models.Favorite.is_favorite(current_user.id, self.object_or_list) + if ( + self.options.get("with_favorite_state", True) + and not current_user.is_api_user() + ): + result["is_favorite"] = models.Favorite.is_favorite( + current_user.id, self.object_or_list + ) else: - result = [serialize_query(query, **self.options) for query in self.object_or_list] - if self.options.get('with_favorite_state', True): - favorite_ids = models.Favorite.are_favorites(current_user.id, self.object_or_list) + result = [ + serialize_query(query, **self.options) for query in self.object_or_list + ] + if self.options.get("with_favorite_state", True): + favorite_ids = models.Favorite.are_favorites( + current_user.id, self.object_or_list + ) for query in result: - query['is_favorite'] = query['id'] in favorite_ids + query["is_favorite"] = query["id"] in favorite_ids return result -def serialize_query(query, with_stats=False, with_visualizations=False, with_user=True, with_last_modified_by=True): +def serialize_query( + query, + with_stats=False, + with_visualizations=False, + with_user=True, + with_last_modified_by=True, +): d = { - 'id': query.id, - 'latest_query_data_id': query.latest_query_data_id, - 'name': query.name, - 'description': query.description, - 'query': query.query_text, - 'query_hash': query.query_hash, - 'schedule': query.schedule, - 'api_key': query.api_key, - 'is_archived': query.is_archived, - 'is_draft': query.is_draft, - 'updated_at': query.updated_at, - 'created_at': query.created_at, - 'data_source_id': query.data_source_id, - 'options': query.options, - 'version': query.version, - 'tags': query.tags or [], - 'is_safe': query.parameterized.is_safe, + "id": query.id, + "latest_query_data_id": query.latest_query_data_id, + "name": query.name, + "description": query.description, + "query": query.query_text, + "query_hash": query.query_hash, + "schedule": query.schedule, + "api_key": query.api_key, + "is_archived": query.is_archived, + "is_draft": query.is_draft, + "updated_at": query.updated_at, + "created_at": query.created_at, + "data_source_id": query.data_source_id, + "options": query.options, + "version": query.version, + "tags": query.tags or [], + "is_safe": query.parameterized.is_safe, } if with_user: - d['user'] = query.user.to_dict() + d["user"] = query.user.to_dict() else: - d['user_id'] = query.user_id + d["user_id"] = query.user_id if with_last_modified_by: - d['last_modified_by'] = query.last_modified_by.to_dict() if query.last_modified_by is not None else None + d["last_modified_by"] = ( + query.last_modified_by.to_dict() + if query.last_modified_by is not None + else None + ) else: - d['last_modified_by_id'] = query.last_modified_by_id + d["last_modified_by_id"] = query.last_modified_by_id if with_stats: if query.latest_query_data is not None: - d['retrieved_at'] = query.retrieved_at - d['runtime'] = query.runtime + d["retrieved_at"] = query.retrieved_at + d["runtime"] = query.runtime else: - d['retrieved_at'] = None - d['runtime'] = None + d["retrieved_at"] = None + d["runtime"] = None if with_visualizations: - d['visualizations'] = [serialize_visualization(vis, with_query=False) - for vis in query.visualizations] + d["visualizations"] = [ + serialize_visualization(vis, with_query=False) + for vis in query.visualizations + ] return d def serialize_visualization(object, with_query=True): d = { - 'id': object.id, - 'type': object.type, - 'name': object.name, - 'description': object.description, - 'options': json_loads(object.options), - 'updated_at': object.updated_at, - 'created_at': object.created_at + "id": object.id, + "type": object.type, + "name": object.name, + "description": object.description, + "options": json_loads(object.options), + "updated_at": object.updated_at, + "created_at": object.created_at, } if with_query: - d['query'] = serialize_query(object.query_rel) + d["query"] = serialize_query(object.query_rel) return d def serialize_widget(object): d = { - 'id': object.id, - 'width': object.width, - 'options': json_loads(object.options), - 'dashboard_id': object.dashboard_id, - 'text': object.text, - 'updated_at': object.updated_at, - 'created_at': object.created_at + "id": object.id, + "width": object.width, + "options": json_loads(object.options), + "dashboard_id": object.dashboard_id, + "text": object.text, + "updated_at": object.updated_at, + "created_at": object.created_at, } if object.visualization and object.visualization.id: - d['visualization'] = serialize_visualization(object.visualization) + d["visualization"] = serialize_visualization(object.visualization) return d def serialize_alert(alert, full=True): d = { - 'id': alert.id, - 'name': alert.name, - 'options': alert.options, - 'state': alert.state, - 'last_triggered_at': alert.last_triggered_at, - 'updated_at': alert.updated_at, - 'created_at': alert.created_at, - 'rearm': alert.rearm + "id": alert.id, + "name": alert.name, + "options": alert.options, + "state": alert.state, + "last_triggered_at": alert.last_triggered_at, + "updated_at": alert.updated_at, + "created_at": alert.created_at, + "rearm": alert.rearm, } if full: - d['query'] = serialize_query(alert.query_rel) - d['user'] = alert.user.to_dict() + d["query"] = serialize_query(alert.query_rel) + d["user"] = alert.user.to_dict() else: - d['query_id'] = alert.query_id - d['user_id'] = alert.user_id + d["query_id"] = alert.query_id + d["user_id"] = alert.user_id return d @@ -198,33 +224,42 @@ def serialize_dashboard(obj, with_widgets=False, user=None, with_favorite_state= elif user and has_access(w.visualization.query_rel, user, view_only): widgets.append(serialize_widget(w)) else: - widget = project(serialize_widget(w), - ('id', 'width', 'dashboard_id', 'options', 'created_at', 'updated_at')) - widget['restricted'] = True + widget = project( + serialize_widget(w), + ( + "id", + "width", + "dashboard_id", + "options", + "created_at", + "updated_at", + ), + ) + widget["restricted"] = True widgets.append(widget) else: widgets = None d = { - 'id': obj.id, - 'slug': obj.slug, - 'name': obj.name, - 'user_id': obj.user_id, + "id": obj.id, + "slug": obj.slug, + "name": obj.name, + "user_id": obj.user_id, # TODO: we should properly load the users - 'user': obj.user.to_dict(), - 'layout': layout, - 'dashboard_filters_enabled': obj.dashboard_filters_enabled, - 'widgets': widgets, - 'is_archived': obj.is_archived, - 'is_draft': obj.is_draft, - 'tags': obj.tags or [], + "user": obj.user.to_dict(), + "layout": layout, + "dashboard_filters_enabled": obj.dashboard_filters_enabled, + "widgets": widgets, + "is_archived": obj.is_archived, + "is_draft": obj.is_draft, + "tags": obj.tags or [], # TODO: bulk load favorites - 'updated_at': obj.updated_at, - 'created_at': obj.created_at, - 'version': obj.version + "updated_at": obj.updated_at, + "created_at": obj.created_at, + "version": obj.version, } if with_favorite_state: - d['is_favorite'] = models.Favorite.is_favorite(current_user.id, obj) + d["is_favorite"] = models.Favorite.is_favorite(current_user.id, obj) return d diff --git a/redash/serializers/query_result.py b/redash/serializers/query_result.py index 5d43734406..92fff23871 100644 --- a/redash/serializers/query_result.py +++ b/redash/serializers/query_result.py @@ -4,12 +4,20 @@ from funcy import rpartial, project from dateutil.parser import isoparse as parse_date from redash.utils import json_loads, UnicodeWriter -from redash.query_runner import (TYPE_BOOLEAN, TYPE_DATE, TYPE_DATETIME) +from redash.query_runner import TYPE_BOOLEAN, TYPE_DATE, TYPE_DATETIME from redash.authentication.org_resolving import current_org def _convert_format(fmt): - return fmt.replace('DD', '%d').replace('MM', '%m').replace('YYYY', '%Y').replace('YY', '%y').replace('HH', '%H').replace('mm', '%M').replace('ss', '%s') + return ( + fmt.replace("DD", "%d") + .replace("MM", "%m") + .replace("YYYY", "%Y") + .replace("YY", "%y") + .replace("HH", "%H") + .replace("mm", "%M") + .replace("ss", "%s") + ) def _convert_bool(value): @@ -35,8 +43,13 @@ def _convert_datetime(value, fmt): def _get_column_lists(columns): - date_format = _convert_format(current_org.get_setting('date_format')) - datetime_format = _convert_format('{} {}'.format(current_org.get_setting('date_format'), current_org.get_setting('time_format'))) + date_format = _convert_format(current_org.get_setting("date_format")) + datetime_format = _convert_format( + "{} {}".format( + current_org.get_setting("date_format"), + current_org.get_setting("time_format"), + ) + ) special_types = { TYPE_BOOLEAN: _convert_bool, @@ -48,18 +61,18 @@ def _get_column_lists(columns): special_columns = dict() for col in columns: - fieldnames.append(col['name']) + fieldnames.append(col["name"]) for col_type in special_types.keys(): - if col['type'] == col_type: - special_columns[col['name']] = special_types[col_type] + if col["type"] == col_type: + special_columns[col["name"]] = special_types[col_type] return fieldnames, special_columns def serialize_query_result(query_result, is_api_user): if is_api_user: - publicly_needed_keys = ['data', 'retrieved_at'] + publicly_needed_keys = ["data", "retrieved_at"] return project(query_result.to_dict(), publicly_needed_keys) else: return query_result.to_dict() @@ -70,12 +83,12 @@ def serialize_query_result_to_csv(query_result): query_data = query_result.data - fieldnames, special_columns = _get_column_lists(query_data['columns'] or []) + fieldnames, special_columns = _get_column_lists(query_data["columns"] or []) writer = csv.DictWriter(s, extrasaction="ignore", fieldnames=fieldnames) writer.writeheader() - for row in query_data['rows']: + for row in query_data["rows"]: for col_name, converter in special_columns.items(): if col_name in row: row[col_name] = converter(row[col_name]) @@ -89,15 +102,15 @@ def serialize_query_result_to_xlsx(query_result): output = io.BytesIO() query_data = query_result.data - book = xlsxwriter.Workbook(output, {'constant_memory': True}) + book = xlsxwriter.Workbook(output, {"constant_memory": True}) sheet = book.add_worksheet("result") column_names = [] - for c, col in enumerate(query_data['columns']): - sheet.write(0, c, col['name']) - column_names.append(col['name']) + for c, col in enumerate(query_data["columns"]): + sheet.write(0, c, col["name"]) + column_names.append(col["name"]) - for r, row in enumerate(query_data['rows']): + for r, row in enumerate(query_data["rows"]): for c, name in enumerate(column_names): v = row.get(name) if isinstance(v, (dict, list)): diff --git a/redash/settings/__init__.py b/redash/settings/__init__.py index 025fe4dcb5..d317a2d3f1 100644 --- a/redash/settings/__init__.py +++ b/redash/settings/__init__.py @@ -4,25 +4,38 @@ from funcy import distinct, remove from flask_talisman import talisman -from .helpers import fix_assets_path, array_from_string, parse_boolean, int_or_none, set_from_string, add_decode_responses_to_redis_url +from .helpers import ( + fix_assets_path, + array_from_string, + parse_boolean, + int_or_none, + set_from_string, + add_decode_responses_to_redis_url, +) from .organization import DATE_FORMAT, TIME_FORMAT # noqa # _REDIS_URL is the unchanged REDIS_URL we get from env vars, to be used later with Celery -_REDIS_URL = os.environ.get('REDASH_REDIS_URL', os.environ.get('REDIS_URL', "redis://localhost:6379/0")) +_REDIS_URL = os.environ.get( + "REDASH_REDIS_URL", os.environ.get("REDIS_URL", "redis://localhost:6379/0") +) # This is the one to use for Redash' own connection: REDIS_URL = add_decode_responses_to_redis_url(_REDIS_URL) -PROXIES_COUNT = int(os.environ.get('REDASH_PROXIES_COUNT', "1")) +PROXIES_COUNT = int(os.environ.get("REDASH_PROXIES_COUNT", "1")) -STATSD_HOST = os.environ.get('REDASH_STATSD_HOST', "127.0.0.1") -STATSD_PORT = int(os.environ.get('REDASH_STATSD_PORT', "8125")) -STATSD_PREFIX = os.environ.get('REDASH_STATSD_PREFIX', "redash") -STATSD_USE_TAGS = parse_boolean(os.environ.get('REDASH_STATSD_USE_TAGS', "false")) +STATSD_HOST = os.environ.get("REDASH_STATSD_HOST", "127.0.0.1") +STATSD_PORT = int(os.environ.get("REDASH_STATSD_PORT", "8125")) +STATSD_PREFIX = os.environ.get("REDASH_STATSD_PREFIX", "redash") +STATSD_USE_TAGS = parse_boolean(os.environ.get("REDASH_STATSD_USE_TAGS", "false")) # Connection settings for Redash's own database (where we store the queries, results, etc) -SQLALCHEMY_DATABASE_URI = os.environ.get("REDASH_DATABASE_URL", os.environ.get('DATABASE_URL', "postgresql:///postgres")) +SQLALCHEMY_DATABASE_URI = os.environ.get( + "REDASH_DATABASE_URL", os.environ.get("DATABASE_URL", "postgresql:///postgres") +) SQLALCHEMY_MAX_OVERFLOW = int_or_none(os.environ.get("SQLALCHEMY_MAX_OVERFLOW")) SQLALCHEMY_POOL_SIZE = int_or_none(os.environ.get("SQLALCHEMY_POOL_SIZE")) -SQLALCHEMY_DISABLE_POOL = parse_boolean(os.environ.get("SQLALCHEMY_DISABLE_POOL", "false")) +SQLALCHEMY_DISABLE_POOL = parse_boolean( + os.environ.get("SQLALCHEMY_DISABLE_POOL", "false") +) SQLALCHEMY_TRACK_MODIFICATIONS = False SQLALCHEMY_ECHO = False @@ -32,82 +45,108 @@ CELERY_BROKER = os.environ.get("REDASH_CELERY_BROKER", _REDIS_URL) CELERY_RESULT_BACKEND = os.environ.get( "REDASH_CELERY_RESULT_BACKEND", - os.environ.get("REDASH_CELERY_BACKEND", CELERY_BROKER)) -CELERY_RESULT_EXPIRES = int(os.environ.get( - "REDASH_CELERY_RESULT_EXPIRES", - os.environ.get("REDASH_CELERY_TASK_RESULT_EXPIRES", 3600 * 4))) -CELERY_INIT_TIMEOUT = int(os.environ.get( - "REDASH_CELERY_INIT_TIMEOUT", 10)) -CELERY_BROKER_USE_SSL = CELERY_BROKER.startswith('rediss') -CELERY_SSL_CONFIG = { - 'ssl_cert_reqs': int(os.environ.get("REDASH_CELERY_BROKER_SSL_CERT_REQS", ssl.CERT_OPTIONAL)), - 'ssl_ca_certs': os.environ.get("REDASH_CELERY_BROKER_SSL_CA_CERTS"), - 'ssl_certfile': os.environ.get("REDASH_CELERY_BROKER_SSL_CERTFILE"), - 'ssl_keyfile': os.environ.get("REDASH_CELERY_BROKER_SSL_KEYFILE"), -} if CELERY_BROKER_USE_SSL else None - -CELERY_WORKER_PREFETCH_MULTIPLIER = int(os.environ.get("REDASH_CELERY_WORKER_PREFETCH_MULTIPLIER", 1)) -CELERY_ACCEPT_CONTENT = os.environ.get("REDASH_CELERY_ACCEPT_CONTENT", "json").split(",") + os.environ.get("REDASH_CELERY_BACKEND", CELERY_BROKER), +) +CELERY_RESULT_EXPIRES = int( + os.environ.get( + "REDASH_CELERY_RESULT_EXPIRES", + os.environ.get("REDASH_CELERY_TASK_RESULT_EXPIRES", 3600 * 4), + ) +) +CELERY_INIT_TIMEOUT = int(os.environ.get("REDASH_CELERY_INIT_TIMEOUT", 10)) +CELERY_BROKER_USE_SSL = CELERY_BROKER.startswith("rediss") +CELERY_SSL_CONFIG = ( + { + "ssl_cert_reqs": int( + os.environ.get("REDASH_CELERY_BROKER_SSL_CERT_REQS", ssl.CERT_OPTIONAL) + ), + "ssl_ca_certs": os.environ.get("REDASH_CELERY_BROKER_SSL_CA_CERTS"), + "ssl_certfile": os.environ.get("REDASH_CELERY_BROKER_SSL_CERTFILE"), + "ssl_keyfile": os.environ.get("REDASH_CELERY_BROKER_SSL_KEYFILE"), + } + if CELERY_BROKER_USE_SSL + else None +) + +CELERY_WORKER_PREFETCH_MULTIPLIER = int( + os.environ.get("REDASH_CELERY_WORKER_PREFETCH_MULTIPLIER", 1) +) +CELERY_ACCEPT_CONTENT = os.environ.get("REDASH_CELERY_ACCEPT_CONTENT", "json").split( + "," +) CELERY_TASK_SERIALIZER = os.environ.get("REDASH_CELERY_TASK_SERIALIZER", "json") CELERY_RESULT_SERIALIZER = os.environ.get("REDASH_CELERY_RESULT_SERIALIZER", "json") # The following enables periodic job (every 5 minutes) of removing unused query results. -QUERY_RESULTS_CLEANUP_ENABLED = parse_boolean(os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_ENABLED", "true")) -QUERY_RESULTS_CLEANUP_COUNT = int(os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_COUNT", "100")) -QUERY_RESULTS_CLEANUP_MAX_AGE = int(os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_MAX_AGE", "7")) +QUERY_RESULTS_CLEANUP_ENABLED = parse_boolean( + os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_ENABLED", "true") +) +QUERY_RESULTS_CLEANUP_COUNT = int( + os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_COUNT", "100") +) +QUERY_RESULTS_CLEANUP_MAX_AGE = int( + os.environ.get("REDASH_QUERY_RESULTS_CLEANUP_MAX_AGE", "7") +) SCHEMAS_REFRESH_SCHEDULE = int(os.environ.get("REDASH_SCHEMAS_REFRESH_SCHEDULE", 30)) AUTH_TYPE = os.environ.get("REDASH_AUTH_TYPE", "api_key") -INVITATION_TOKEN_MAX_AGE = int(os.environ.get("REDASH_INVITATION_TOKEN_MAX_AGE", 60 * 60 * 24 * 7)) +INVITATION_TOKEN_MAX_AGE = int( + os.environ.get("REDASH_INVITATION_TOKEN_MAX_AGE", 60 * 60 * 24 * 7) +) # The secret key to use in the Flask app for various cryptographic features SECRET_KEY = os.environ.get("REDASH_COOKIE_SECRET", "c292a0a3aa32397cdb050e233733900f") # The secret key to use when encrypting data source options -DATASOURCE_SECRET_KEY = os.environ.get('REDASH_SECRET_KEY', SECRET_KEY) +DATASOURCE_SECRET_KEY = os.environ.get("REDASH_SECRET_KEY", SECRET_KEY) # Whether and how to redirect non-HTTP requests to HTTPS. Disabled by default. ENFORCE_HTTPS = parse_boolean(os.environ.get("REDASH_ENFORCE_HTTPS", "false")) ENFORCE_HTTPS_PERMANENT = parse_boolean( - os.environ.get("REDASH_ENFORCE_HTTPS_PERMANENT", "false")) + os.environ.get("REDASH_ENFORCE_HTTPS_PERMANENT", "false") +) # Whether file downloads are enforced or not. -ENFORCE_FILE_SAVE = parse_boolean( - os.environ.get("REDASH_ENFORCE_FILE_SAVE", "true")) +ENFORCE_FILE_SAVE = parse_boolean(os.environ.get("REDASH_ENFORCE_FILE_SAVE", "true")) # Whether to use secure cookies by default. COOKIES_SECURE = parse_boolean( - os.environ.get("REDASH_COOKIES_SECURE", str(ENFORCE_HTTPS))) + os.environ.get("REDASH_COOKIES_SECURE", str(ENFORCE_HTTPS)) +) # Whether the session cookie is set to secure. SESSION_COOKIE_SECURE = parse_boolean( - os.environ.get("REDASH_SESSION_COOKIE_SECURE") or str(COOKIES_SECURE)) + os.environ.get("REDASH_SESSION_COOKIE_SECURE") or str(COOKIES_SECURE) +) # Whether the session cookie is set HttpOnly. SESSION_COOKIE_HTTPONLY = parse_boolean( - os.environ.get("REDASH_SESSION_COOKIE_HTTPONLY", "true")) + os.environ.get("REDASH_SESSION_COOKIE_HTTPONLY", "true") +) # Whether the session cookie is set to secure. REMEMBER_COOKIE_SECURE = parse_boolean( - os.environ.get("REDASH_REMEMBER_COOKIE_SECURE") or str(COOKIES_SECURE)) + os.environ.get("REDASH_REMEMBER_COOKIE_SECURE") or str(COOKIES_SECURE) +) # Whether the remember cookie is set HttpOnly. REMEMBER_COOKIE_HTTPONLY = parse_boolean( - os.environ.get("REDASH_REMEMBER_COOKIE_HTTPONLY", "true")) + os.environ.get("REDASH_REMEMBER_COOKIE_HTTPONLY", "true") +) # Doesn't set X-Frame-Options by default since it's highly dependent # on the specific deployment. # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Frame-Options # for more information. FRAME_OPTIONS = os.environ.get("REDASH_FRAME_OPTIONS", "deny") -FRAME_OPTIONS_ALLOW_FROM = os.environ.get( - "REDASH_FRAME_OPTIONS_ALLOW_FROM", "") +FRAME_OPTIONS_ALLOW_FROM = os.environ.get("REDASH_FRAME_OPTIONS_ALLOW_FROM", "") # Whether and how to send Strict-Transport-Security response headers. # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Strict-Transport-Security # for more information. HSTS_ENABLED = parse_boolean( - os.environ.get("REDASH_HSTS_ENABLED") or str(ENFORCE_HTTPS)) + os.environ.get("REDASH_HSTS_ENABLED") or str(ENFORCE_HTTPS) +) HSTS_PRELOAD = parse_boolean(os.environ.get("REDASH_HSTS_PRELOAD", "false")) -HSTS_MAX_AGE = int( - os.environ.get("REDASH_HSTS_MAX_AGE", talisman.ONE_YEAR_IN_SECS)) +HSTS_MAX_AGE = int(os.environ.get("REDASH_HSTS_MAX_AGE", talisman.ONE_YEAR_IN_SECS)) HSTS_INCLUDE_SUBDOMAINS = parse_boolean( - os.environ.get("REDASH_HSTS_INCLUDE_SUBDOMAINS", "false")) + os.environ.get("REDASH_HSTS_INCLUDE_SUBDOMAINS", "false") +) # Whether and how to send Content-Security-Policy response headers. # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy @@ -118,21 +157,25 @@ # for more information. E.g.: CONTENT_SECURITY_POLICY = os.environ.get( "REDASH_CONTENT_SECURITY_POLICY", - "default-src 'self'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-eval'; font-src 'self' data:; img-src 'self' http: https: data:; object-src 'none'; frame-ancestors 'none'; frame-src redash.io;" + "default-src 'self'; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-eval'; font-src 'self' data:; img-src 'self' http: https: data:; object-src 'none'; frame-ancestors 'none'; frame-src redash.io;", ) CONTENT_SECURITY_POLICY_REPORT_URI = os.environ.get( - "REDASH_CONTENT_SECURITY_POLICY_REPORT_URI", "") + "REDASH_CONTENT_SECURITY_POLICY_REPORT_URI", "" +) CONTENT_SECURITY_POLICY_REPORT_ONLY = parse_boolean( - os.environ.get("REDASH_CONTENT_SECURITY_POLICY_REPORT_ONLY", "false")) + os.environ.get("REDASH_CONTENT_SECURITY_POLICY_REPORT_ONLY", "false") +) CONTENT_SECURITY_POLICY_NONCE_IN = array_from_string( - os.environ.get("REDASH_CONTENT_SECURITY_POLICY_NONCE_IN", "")) + os.environ.get("REDASH_CONTENT_SECURITY_POLICY_NONCE_IN", "") +) # Whether and how to send Referrer-Policy response headers. Defaults to # 'strict-origin-when-cross-origin'. # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Referrer-Policy # for more information. REFERRER_POLICY = os.environ.get( - "REDASH_REFERRER_POLICY", "strict-origin-when-cross-origin") + "REDASH_REFERRER_POLICY", "strict-origin-when-cross-origin" +) # Whether and how to send Feature-Policy response headers. Defaults to # an empty value. # See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Feature-Policy @@ -168,191 +211,291 @@ # If you also set the organization setting auth_password_login_enabled to false, # then your authentication will be seamless. Otherwise a link will be presented # on the login page to trigger remote user auth. -REMOTE_USER_LOGIN_ENABLED = parse_boolean(os.environ.get("REDASH_REMOTE_USER_LOGIN_ENABLED", "false")) -REMOTE_USER_HEADER = os.environ.get("REDASH_REMOTE_USER_HEADER", "X-Forwarded-Remote-User") +REMOTE_USER_LOGIN_ENABLED = parse_boolean( + os.environ.get("REDASH_REMOTE_USER_LOGIN_ENABLED", "false") +) +REMOTE_USER_HEADER = os.environ.get( + "REDASH_REMOTE_USER_HEADER", "X-Forwarded-Remote-User" +) # If the organization setting auth_password_login_enabled is not false, then users will still be # able to login through Redash instead of the LDAP server -LDAP_LOGIN_ENABLED = parse_boolean(os.environ.get('REDASH_LDAP_LOGIN_ENABLED', 'false')) +LDAP_LOGIN_ENABLED = parse_boolean(os.environ.get("REDASH_LDAP_LOGIN_ENABLED", "false")) # Bind LDAP using SSL. Default is False -LDAP_SSL = parse_boolean(os.environ.get('REDASH_LDAP_USE_SSL', 'false')) +LDAP_SSL = parse_boolean(os.environ.get("REDASH_LDAP_USE_SSL", "false")) # Choose authentication method(SIMPLE, ANONYMOUS or NTLM). Default is SIMPLE -LDAP_AUTH_METHOD = os.environ.get('REDASH_LDAP_AUTH_METHOD', 'SIMPLE') +LDAP_AUTH_METHOD = os.environ.get("REDASH_LDAP_AUTH_METHOD", "SIMPLE") # The LDAP directory address (ex. ldap://10.0.10.1:389) -LDAP_HOST_URL = os.environ.get('REDASH_LDAP_URL', None) +LDAP_HOST_URL = os.environ.get("REDASH_LDAP_URL", None) # The DN & password used to connect to LDAP to determine the identity of the user being authenticated. # For AD this should be "org\\user". -LDAP_BIND_DN = os.environ.get('REDASH_LDAP_BIND_DN', None) -LDAP_BIND_DN_PASSWORD = os.environ.get('REDASH_LDAP_BIND_DN_PASSWORD', '') +LDAP_BIND_DN = os.environ.get("REDASH_LDAP_BIND_DN", None) +LDAP_BIND_DN_PASSWORD = os.environ.get("REDASH_LDAP_BIND_DN_PASSWORD", "") # AD/LDAP email and display name keys -LDAP_DISPLAY_NAME_KEY = os.environ.get('REDASH_LDAP_DISPLAY_NAME_KEY', 'displayName') -LDAP_EMAIL_KEY = os.environ.get('REDASH_LDAP_EMAIL_KEY', "mail") +LDAP_DISPLAY_NAME_KEY = os.environ.get("REDASH_LDAP_DISPLAY_NAME_KEY", "displayName") +LDAP_EMAIL_KEY = os.environ.get("REDASH_LDAP_EMAIL_KEY", "mail") # Prompt that should be shown above username/email field. -LDAP_CUSTOM_USERNAME_PROMPT = os.environ.get('REDASH_LDAP_CUSTOM_USERNAME_PROMPT', 'LDAP/AD/SSO username:') +LDAP_CUSTOM_USERNAME_PROMPT = os.environ.get( + "REDASH_LDAP_CUSTOM_USERNAME_PROMPT", "LDAP/AD/SSO username:" +) # LDAP Search DN TEMPLATE (for AD this should be "(sAMAccountName=%(username)s)"") -LDAP_SEARCH_TEMPLATE = os.environ.get('REDASH_LDAP_SEARCH_TEMPLATE', '(cn=%(username)s)') +LDAP_SEARCH_TEMPLATE = os.environ.get( + "REDASH_LDAP_SEARCH_TEMPLATE", "(cn=%(username)s)" +) # The schema to bind to (ex. cn=users,dc=ORG,dc=local) -LDAP_SEARCH_DN = os.environ.get('REDASH_LDAP_SEARCH_DN', os.environ.get('REDASH_SEARCH_DN')) +LDAP_SEARCH_DN = os.environ.get( + "REDASH_LDAP_SEARCH_DN", os.environ.get("REDASH_SEARCH_DN") +) -STATIC_ASSETS_PATH = fix_assets_path(os.environ.get("REDASH_STATIC_ASSETS_PATH", "../client/dist/")) +STATIC_ASSETS_PATH = fix_assets_path( + os.environ.get("REDASH_STATIC_ASSETS_PATH", "../client/dist/") +) JOB_EXPIRY_TIME = int(os.environ.get("REDASH_JOB_EXPIRY_TIME", 3600 * 12)) -JOB_DEFAULT_FAILURE_TTL = int(os.environ.get("REDASH_JOB_DEFAULT_FAILURE_TTL", 7 * 24 * 60 * 60)) +JOB_DEFAULT_FAILURE_TTL = int( + os.environ.get("REDASH_JOB_DEFAULT_FAILURE_TTL", 7 * 24 * 60 * 60) +) LOG_LEVEL = os.environ.get("REDASH_LOG_LEVEL", "INFO") -LOG_STDOUT = parse_boolean(os.environ.get('REDASH_LOG_STDOUT', 'false')) -LOG_PREFIX = os.environ.get('REDASH_LOG_PREFIX', '') -LOG_FORMAT = os.environ.get('REDASH_LOG_FORMAT', LOG_PREFIX + '[%(asctime)s][PID:%(process)d][%(levelname)s][%(name)s] %(message)s') +LOG_STDOUT = parse_boolean(os.environ.get("REDASH_LOG_STDOUT", "false")) +LOG_PREFIX = os.environ.get("REDASH_LOG_PREFIX", "") +LOG_FORMAT = os.environ.get( + "REDASH_LOG_FORMAT", + LOG_PREFIX + "[%(asctime)s][PID:%(process)d][%(levelname)s][%(name)s] %(message)s", +) CELERYD_WORKER_LOG_FORMAT = os.environ.get( "REDASH_CELERYD_WORKER_LOG_FORMAT", - os.environ.get('REDASH_CELERYD_LOG_FORMAT', - LOG_PREFIX + '[%(asctime)s][PID:%(process)d][%(levelname)s][%(processName)s] %(message)s')) + os.environ.get( + "REDASH_CELERYD_LOG_FORMAT", + LOG_PREFIX + + "[%(asctime)s][PID:%(process)d][%(levelname)s][%(processName)s] %(message)s", + ), +) CELERYD_WORKER_TASK_LOG_FORMAT = os.environ.get( "REDASH_CELERYD_WORKER_TASK_LOG_FORMAT", - os.environ.get('REDASH_CELERYD_TASK_LOG_FORMAT', - (LOG_PREFIX + '[%(asctime)s][PID:%(process)d][%(levelname)s][%(processName)s] ' - 'task_name=%(task_name)s ' - 'task_id=%(task_id)s %(message)s'))) -RQ_WORKER_JOB_LOG_FORMAT = os.environ.get("REDASH_RQ_WORKER_JOB_LOG_FORMAT", - (LOG_PREFIX + '[%(asctime)s][PID:%(process)d][%(levelname)s][%(name)s] ' - 'job.func_name=%(job_func_name)s ' - 'job.id=%(job_id)s %(message)s')) + os.environ.get( + "REDASH_CELERYD_TASK_LOG_FORMAT", + ( + LOG_PREFIX + + "[%(asctime)s][PID:%(process)d][%(levelname)s][%(processName)s] " + "task_name=%(task_name)s " + "task_id=%(task_id)s %(message)s" + ), + ), +) +RQ_WORKER_JOB_LOG_FORMAT = os.environ.get( + "REDASH_RQ_WORKER_JOB_LOG_FORMAT", + ( + LOG_PREFIX + "[%(asctime)s][PID:%(process)d][%(levelname)s][%(name)s] " + "job.func_name=%(job_func_name)s " + "job.id=%(job_id)s %(message)s" + ), +) # Mail settings: -MAIL_SERVER = os.environ.get('REDASH_MAIL_SERVER', 'localhost') -MAIL_PORT = int(os.environ.get('REDASH_MAIL_PORT', 25)) -MAIL_USE_TLS = parse_boolean(os.environ.get('REDASH_MAIL_USE_TLS', 'false')) -MAIL_USE_SSL = parse_boolean(os.environ.get('REDASH_MAIL_USE_SSL', 'false')) -MAIL_USERNAME = os.environ.get('REDASH_MAIL_USERNAME', None) -MAIL_PASSWORD = os.environ.get('REDASH_MAIL_PASSWORD', None) -MAIL_DEFAULT_SENDER = os.environ.get('REDASH_MAIL_DEFAULT_SENDER', None) -MAIL_MAX_EMAILS = os.environ.get('REDASH_MAIL_MAX_EMAILS', None) -MAIL_ASCII_ATTACHMENTS = parse_boolean(os.environ.get('REDASH_MAIL_ASCII_ATTACHMENTS', 'false')) +MAIL_SERVER = os.environ.get("REDASH_MAIL_SERVER", "localhost") +MAIL_PORT = int(os.environ.get("REDASH_MAIL_PORT", 25)) +MAIL_USE_TLS = parse_boolean(os.environ.get("REDASH_MAIL_USE_TLS", "false")) +MAIL_USE_SSL = parse_boolean(os.environ.get("REDASH_MAIL_USE_SSL", "false")) +MAIL_USERNAME = os.environ.get("REDASH_MAIL_USERNAME", None) +MAIL_PASSWORD = os.environ.get("REDASH_MAIL_PASSWORD", None) +MAIL_DEFAULT_SENDER = os.environ.get("REDASH_MAIL_DEFAULT_SENDER", None) +MAIL_MAX_EMAILS = os.environ.get("REDASH_MAIL_MAX_EMAILS", None) +MAIL_ASCII_ATTACHMENTS = parse_boolean( + os.environ.get("REDASH_MAIL_ASCII_ATTACHMENTS", "false") +) def email_server_is_configured(): return MAIL_DEFAULT_SENDER is not None -HOST = os.environ.get('REDASH_HOST', '') +HOST = os.environ.get("REDASH_HOST", "") -SEND_FAILURE_EMAIL_INTERVAL = int(os.environ.get('REDASH_SEND_FAILURE_EMAIL_INTERVAL', 60)) -MAX_FAILURE_REPORTS_PER_QUERY = int(os.environ.get('REDASH_MAX_FAILURE_REPORTS_PER_QUERY', 100)) +SEND_FAILURE_EMAIL_INTERVAL = int( + os.environ.get("REDASH_SEND_FAILURE_EMAIL_INTERVAL", 60) +) +MAX_FAILURE_REPORTS_PER_QUERY = int( + os.environ.get("REDASH_MAX_FAILURE_REPORTS_PER_QUERY", 100) +) -ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE = os.environ.get('REDASH_ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE', "({state}) {alert_name}") +ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE = os.environ.get( + "REDASH_ALERTS_DEFAULT_MAIL_SUBJECT_TEMPLATE", "({state}) {alert_name}" +) # How many requests are allowed per IP to the login page before # being throttled? # See https://flask-limiter.readthedocs.io/en/stable/#rate-limit-string-notation -RATELIMIT_ENABLED = parse_boolean(os.environ.get('REDASH_RATELIMIT_ENABLED', 'true')) -THROTTLE_LOGIN_PATTERN = os.environ.get('REDASH_THROTTLE_LOGIN_PATTERN', '50/hour') +RATELIMIT_ENABLED = parse_boolean(os.environ.get("REDASH_RATELIMIT_ENABLED", "true")) +THROTTLE_LOGIN_PATTERN = os.environ.get("REDASH_THROTTLE_LOGIN_PATTERN", "50/hour") LIMITER_STORAGE = os.environ.get("REDASH_LIMITER_STORAGE", REDIS_URL) # CORS settings for the Query Result API (and possbily future external APIs). # In most cases all you need to do is set REDASH_CORS_ACCESS_CONTROL_ALLOW_ORIGIN # to the calling domain (or domains in a comma separated list). -ACCESS_CONTROL_ALLOW_ORIGIN = set_from_string(os.environ.get("REDASH_CORS_ACCESS_CONTROL_ALLOW_ORIGIN", "")) -ACCESS_CONTROL_ALLOW_CREDENTIALS = parse_boolean(os.environ.get("REDASH_CORS_ACCESS_CONTROL_ALLOW_CREDENTIALS", "false")) -ACCESS_CONTROL_REQUEST_METHOD = os.environ.get("REDASH_CORS_ACCESS_CONTROL_REQUEST_METHOD", "GET, POST, PUT") -ACCESS_CONTROL_ALLOW_HEADERS = os.environ.get("REDASH_CORS_ACCESS_CONTROL_ALLOW_HEADERS", "Content-Type") +ACCESS_CONTROL_ALLOW_ORIGIN = set_from_string( + os.environ.get("REDASH_CORS_ACCESS_CONTROL_ALLOW_ORIGIN", "") +) +ACCESS_CONTROL_ALLOW_CREDENTIALS = parse_boolean( + os.environ.get("REDASH_CORS_ACCESS_CONTROL_ALLOW_CREDENTIALS", "false") +) +ACCESS_CONTROL_REQUEST_METHOD = os.environ.get( + "REDASH_CORS_ACCESS_CONTROL_REQUEST_METHOD", "GET, POST, PUT" +) +ACCESS_CONTROL_ALLOW_HEADERS = os.environ.get( + "REDASH_CORS_ACCESS_CONTROL_ALLOW_HEADERS", "Content-Type" +) # Query Runners default_query_runners = [ - 'redash.query_runner.athena', - 'redash.query_runner.big_query', - 'redash.query_runner.google_spreadsheets', - 'redash.query_runner.graphite', - 'redash.query_runner.mongodb', - 'redash.query_runner.couchbase', - 'redash.query_runner.mysql', - 'redash.query_runner.pg', - 'redash.query_runner.url', - 'redash.query_runner.influx_db', - 'redash.query_runner.elasticsearch', - 'redash.query_runner.amazon_elasticsearch', - 'redash.query_runner.presto', - 'redash.query_runner.databricks', - 'redash.query_runner.hive_ds', - 'redash.query_runner.impala_ds', - 'redash.query_runner.vertica', - 'redash.query_runner.clickhouse', - 'redash.query_runner.yandex_metrica', - 'redash.query_runner.rockset', - 'redash.query_runner.treasuredata', - 'redash.query_runner.sqlite', - 'redash.query_runner.dynamodb_sql', - 'redash.query_runner.mssql', - 'redash.query_runner.memsql_ds', - 'redash.query_runner.mapd', - 'redash.query_runner.jql', - 'redash.query_runner.google_analytics', - 'redash.query_runner.axibase_tsd', - 'redash.query_runner.salesforce', - 'redash.query_runner.query_results', - 'redash.query_runner.prometheus', - 'redash.query_runner.qubole', - 'redash.query_runner.db2', - 'redash.query_runner.druid', - 'redash.query_runner.kylin', - 'redash.query_runner.drill', - 'redash.query_runner.uptycs', - 'redash.query_runner.snowflake', - 'redash.query_runner.phoenix', - 'redash.query_runner.json_ds', - 'redash.query_runner.cass', - 'redash.query_runner.dgraph', - 'redash.query_runner.azure_kusto', - 'redash.query_runner.exasol', - 'redash.query_runner.cloudwatch', - 'redash.query_runner.cloudwatch_insights', + "redash.query_runner.athena", + "redash.query_runner.big_query", + "redash.query_runner.google_spreadsheets", + "redash.query_runner.graphite", + "redash.query_runner.mongodb", + "redash.query_runner.couchbase", + "redash.query_runner.mysql", + "redash.query_runner.pg", + "redash.query_runner.url", + "redash.query_runner.influx_db", + "redash.query_runner.elasticsearch", + "redash.query_runner.amazon_elasticsearch", + "redash.query_runner.presto", + "redash.query_runner.databricks", + "redash.query_runner.hive_ds", + "redash.query_runner.impala_ds", + "redash.query_runner.vertica", + "redash.query_runner.clickhouse", + "redash.query_runner.yandex_metrica", + "redash.query_runner.rockset", + "redash.query_runner.treasuredata", + "redash.query_runner.sqlite", + "redash.query_runner.dynamodb_sql", + "redash.query_runner.mssql", + "redash.query_runner.memsql_ds", + "redash.query_runner.mapd", + "redash.query_runner.jql", + "redash.query_runner.google_analytics", + "redash.query_runner.axibase_tsd", + "redash.query_runner.salesforce", + "redash.query_runner.query_results", + "redash.query_runner.prometheus", + "redash.query_runner.qubole", + "redash.query_runner.db2", + "redash.query_runner.druid", + "redash.query_runner.kylin", + "redash.query_runner.drill", + "redash.query_runner.uptycs", + "redash.query_runner.snowflake", + "redash.query_runner.phoenix", + "redash.query_runner.json_ds", + "redash.query_runner.cass", + "redash.query_runner.dgraph", + "redash.query_runner.azure_kusto", + "redash.query_runner.exasol", + "redash.query_runner.cloudwatch", + "redash.query_runner.cloudwatch_insights", ] -enabled_query_runners = array_from_string(os.environ.get("REDASH_ENABLED_QUERY_RUNNERS", ",".join(default_query_runners))) -additional_query_runners = array_from_string(os.environ.get("REDASH_ADDITIONAL_QUERY_RUNNERS", "")) -disabled_query_runners = array_from_string(os.environ.get("REDASH_DISABLED_QUERY_RUNNERS", "")) +enabled_query_runners = array_from_string( + os.environ.get("REDASH_ENABLED_QUERY_RUNNERS", ",".join(default_query_runners)) +) +additional_query_runners = array_from_string( + os.environ.get("REDASH_ADDITIONAL_QUERY_RUNNERS", "") +) +disabled_query_runners = array_from_string( + os.environ.get("REDASH_DISABLED_QUERY_RUNNERS", "") +) -QUERY_RUNNERS = remove(set(disabled_query_runners), distinct(enabled_query_runners + additional_query_runners)) +QUERY_RUNNERS = remove( + set(disabled_query_runners), + distinct(enabled_query_runners + additional_query_runners), +) -dynamic_settings = importlib.import_module(os.environ.get('REDASH_DYNAMIC_SETTINGS_MODULE', 'redash.settings.dynamic_settings')) +dynamic_settings = importlib.import_module( + os.environ.get("REDASH_DYNAMIC_SETTINGS_MODULE", "redash.settings.dynamic_settings") +) # Destinations default_destinations = [ - 'redash.destinations.email', - 'redash.destinations.slack', - 'redash.destinations.webhook', - 'redash.destinations.hipchat', - 'redash.destinations.mattermost', - 'redash.destinations.chatwork', - 'redash.destinations.pagerduty', - 'redash.destinations.hangoutschat' + "redash.destinations.email", + "redash.destinations.slack", + "redash.destinations.webhook", + "redash.destinations.hipchat", + "redash.destinations.mattermost", + "redash.destinations.chatwork", + "redash.destinations.pagerduty", + "redash.destinations.hangoutschat", ] -enabled_destinations = array_from_string(os.environ.get("REDASH_ENABLED_DESTINATIONS", ",".join(default_destinations))) -additional_destinations = array_from_string(os.environ.get("REDASH_ADDITIONAL_DESTINATIONS", "")) +enabled_destinations = array_from_string( + os.environ.get("REDASH_ENABLED_DESTINATIONS", ",".join(default_destinations)) +) +additional_destinations = array_from_string( + os.environ.get("REDASH_ADDITIONAL_DESTINATIONS", "") +) DESTINATIONS = distinct(enabled_destinations + additional_destinations) -EVENT_REPORTING_WEBHOOKS = array_from_string(os.environ.get("REDASH_EVENT_REPORTING_WEBHOOKS", "")) +EVENT_REPORTING_WEBHOOKS = array_from_string( + os.environ.get("REDASH_EVENT_REPORTING_WEBHOOKS", "") +) # Support for Sentry (https://getsentry.com/). Just set your Sentry DSN to enable it: SENTRY_DSN = os.environ.get("REDASH_SENTRY_DSN", "") # Client side toggles: -ALLOW_SCRIPTS_IN_USER_INPUT = parse_boolean(os.environ.get("REDASH_ALLOW_SCRIPTS_IN_USER_INPUT", "false")) -DASHBOARD_REFRESH_INTERVALS = list(map(int, array_from_string(os.environ.get("REDASH_DASHBOARD_REFRESH_INTERVALS", "60,300,600,1800,3600,43200,86400")))) -QUERY_REFRESH_INTERVALS = list(map(int, array_from_string(os.environ.get("REDASH_QUERY_REFRESH_INTERVALS", "60, 300, 600, 900, 1800, 3600, 7200, 10800, 14400, 18000, 21600, 25200, 28800, 32400, 36000, 39600, 43200, 86400, 604800, 1209600, 2592000")))) -PAGE_SIZE = int(os.environ.get('REDASH_PAGE_SIZE', 20)) -PAGE_SIZE_OPTIONS = list(map(int, array_from_string(os.environ.get("REDASH_PAGE_SIZE_OPTIONS", "5,10,20,50,100")))) -TABLE_CELL_MAX_JSON_SIZE = int(os.environ.get('REDASH_TABLE_CELL_MAX_JSON_SIZE', 50000)) +ALLOW_SCRIPTS_IN_USER_INPUT = parse_boolean( + os.environ.get("REDASH_ALLOW_SCRIPTS_IN_USER_INPUT", "false") +) +DASHBOARD_REFRESH_INTERVALS = list( + map( + int, + array_from_string( + os.environ.get( + "REDASH_DASHBOARD_REFRESH_INTERVALS", "60,300,600,1800,3600,43200,86400" + ) + ), + ) +) +QUERY_REFRESH_INTERVALS = list( + map( + int, + array_from_string( + os.environ.get( + "REDASH_QUERY_REFRESH_INTERVALS", + "60, 300, 600, 900, 1800, 3600, 7200, 10800, 14400, 18000, 21600, 25200, 28800, 32400, 36000, 39600, 43200, 86400, 604800, 1209600, 2592000", + ) + ), + ) +) +PAGE_SIZE = int(os.environ.get("REDASH_PAGE_SIZE", 20)) +PAGE_SIZE_OPTIONS = list( + map( + int, + array_from_string(os.environ.get("REDASH_PAGE_SIZE_OPTIONS", "5,10,20,50,100")), + ) +) +TABLE_CELL_MAX_JSON_SIZE = int(os.environ.get("REDASH_TABLE_CELL_MAX_JSON_SIZE", 50000)) # Features: VERSION_CHECK = parse_boolean(os.environ.get("REDASH_VERSION_CHECK", "true")) -FEATURE_DISABLE_REFRESH_QUERIES = parse_boolean(os.environ.get("REDASH_FEATURE_DISABLE_REFRESH_QUERIES", "false")) -FEATURE_SHOW_QUERY_RESULTS_COUNT = parse_boolean(os.environ.get("REDASH_FEATURE_SHOW_QUERY_RESULTS_COUNT", "true")) -FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS = parse_boolean(os.environ.get("REDASH_FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS", "false")) -FEATURE_AUTO_PUBLISH_NAMED_QUERIES = parse_boolean(os.environ.get("REDASH_FEATURE_AUTO_PUBLISH_NAMED_QUERIES", "true")) -FEATURE_EXTENDED_ALERT_OPTIONS = parse_boolean(os.environ.get("REDASH_FEATURE_EXTENDED_ALERT_OPTIONS", "false")) +FEATURE_DISABLE_REFRESH_QUERIES = parse_boolean( + os.environ.get("REDASH_FEATURE_DISABLE_REFRESH_QUERIES", "false") +) +FEATURE_SHOW_QUERY_RESULTS_COUNT = parse_boolean( + os.environ.get("REDASH_FEATURE_SHOW_QUERY_RESULTS_COUNT", "true") +) +FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS = parse_boolean( + os.environ.get("REDASH_FEATURE_ALLOW_CUSTOM_JS_VISUALIZATIONS", "false") +) +FEATURE_AUTO_PUBLISH_NAMED_QUERIES = parse_boolean( + os.environ.get("REDASH_FEATURE_AUTO_PUBLISH_NAMED_QUERIES", "true") +) +FEATURE_EXTENDED_ALERT_OPTIONS = parse_boolean( + os.environ.get("REDASH_FEATURE_EXTENDED_ALERT_OPTIONS", "false") +) # BigQuery BIGQUERY_HTTP_TIMEOUT = int(os.environ.get("REDASH_BIGQUERY_HTTP_TIMEOUT", "600")) @@ -360,18 +503,24 @@ def email_server_is_configured(): # Allow Parameters in Embeds # WARNING: Deprecated! # See https://discuss.redash.io/t/support-for-parameters-in-embedded-visualizations/3337 for more details. -ALLOW_PARAMETERS_IN_EMBEDS = parse_boolean(os.environ.get("REDASH_ALLOW_PARAMETERS_IN_EMBEDS", "false")) +ALLOW_PARAMETERS_IN_EMBEDS = parse_boolean( + os.environ.get("REDASH_ALLOW_PARAMETERS_IN_EMBEDS", "false") +) # Enhance schema fetching -SCHEMA_RUN_TABLE_SIZE_CALCULATIONS = parse_boolean(os.environ.get("REDASH_SCHEMA_RUN_TABLE_SIZE_CALCULATIONS", "false")) +SCHEMA_RUN_TABLE_SIZE_CALCULATIONS = parse_boolean( + os.environ.get("REDASH_SCHEMA_RUN_TABLE_SIZE_CALCULATIONS", "false") +) # kylin -KYLIN_OFFSET = int(os.environ.get('REDASH_KYLIN_OFFSET', 0)) -KYLIN_LIMIT = int(os.environ.get('REDASH_KYLIN_LIMIT', 50000)) -KYLIN_ACCEPT_PARTIAL = parse_boolean(os.environ.get("REDASH_KYLIN_ACCEPT_PARTIAL", "false")) +KYLIN_OFFSET = int(os.environ.get("REDASH_KYLIN_OFFSET", 0)) +KYLIN_LIMIT = int(os.environ.get("REDASH_KYLIN_LIMIT", 50000)) +KYLIN_ACCEPT_PARTIAL = parse_boolean( + os.environ.get("REDASH_KYLIN_ACCEPT_PARTIAL", "false") +) # sqlparse SQLPARSE_FORMAT_OPTIONS = { - 'reindent': parse_boolean(os.environ.get('SQLPARSE_FORMAT_REINDENT', 'true')), - 'keyword_case': os.environ.get('SQLPARSE_FORMAT_KEYWORD_CASE', 'upper'), + "reindent": parse_boolean(os.environ.get("SQLPARSE_FORMAT_REINDENT", "true")), + "keyword_case": os.environ.get("SQLPARSE_FORMAT_KEYWORD_CASE", "upper"), } diff --git a/redash/settings/dynamic_settings.py b/redash/settings/dynamic_settings.py index 3577d06772..c58de011d8 100644 --- a/redash/settings/dynamic_settings.py +++ b/redash/settings/dynamic_settings.py @@ -4,8 +4,12 @@ # Replace this method with your own implementation in case you want to limit the time limit on certain queries or users. def query_time_limit(is_scheduled, user_id, org_id): - scheduled_time_limit = int_or_none(os.environ.get('REDASH_SCHEDULED_QUERY_TIME_LIMIT', None)) - adhoc_time_limit = int_or_none(os.environ.get('REDASH_ADHOC_QUERY_TIME_LIMIT', None)) + scheduled_time_limit = int_or_none( + os.environ.get("REDASH_SCHEDULED_QUERY_TIME_LIMIT", None) + ) + adhoc_time_limit = int_or_none( + os.environ.get("REDASH_ADHOC_QUERY_TIME_LIMIT", None) + ) return scheduled_time_limit if is_scheduled else adhoc_time_limit @@ -27,4 +31,3 @@ def periodic_jobs(): # This provides the ability to override the way we store QueryResult's data column. # Reference implementation: redash.models.DBPersistence QueryResultPersistence = None - diff --git a/redash/settings/helpers.py b/redash/settings/helpers.py index 03fd83c49f..c69f326f98 100644 --- a/redash/settings/helpers.py +++ b/redash/settings/helpers.py @@ -8,7 +8,7 @@ def fix_assets_path(path): def array_from_string(s): - array = s.split(',') + array = s.split(",") if "" in array: array.remove("") @@ -22,12 +22,12 @@ def set_from_string(s): def parse_boolean(s): """Takes a string and returns the equivalent as a boolean value.""" s = s.strip().lower() - if s in ('yes', 'true', 'on', '1'): + if s in ("yes", "true", "on", "1"): return True - elif s in ('no', 'false', 'off', '0', 'none'): + elif s in ("no", "false", "off", "0", "none"): return False else: - raise ValueError('Invalid boolean value %r' % s) + raise ValueError("Invalid boolean value %r" % s) def int_or_none(value): @@ -41,10 +41,19 @@ def add_decode_responses_to_redis_url(url): """Make sure that the Redis URL includes the `decode_responses` option.""" parsed = urlparse(url) - query = 'decode_responses=True' - if parsed.query and 'decode_responses' not in parsed.query: + query = "decode_responses=True" + if parsed.query and "decode_responses" not in parsed.query: query = "{}&{}".format(parsed.query, query) - elif 'decode_responses' in parsed.query: + elif "decode_responses" in parsed.query: query = parsed.query - return urlunparse([parsed.scheme, parsed.netloc, parsed.path, parsed.params, query, parsed.fragment]) + return urlunparse( + [ + parsed.scheme, + parsed.netloc, + parsed.path, + parsed.params, + query, + parsed.fragment, + ] + ) diff --git a/redash/settings/organization.py b/redash/settings/organization.py index 4aa9b4b1f2..1559886326 100644 --- a/redash/settings/organization.py +++ b/redash/settings/organization.py @@ -1,15 +1,18 @@ - import os from .helpers import parse_boolean if os.environ.get("REDASH_SAML_LOCAL_METADATA_PATH") is not None: print("DEPRECATION NOTICE:\n") - print("SAML_LOCAL_METADATA_PATH is no longer supported. Only URL metadata is supported now, please update") + print( + "SAML_LOCAL_METADATA_PATH is no longer supported. Only URL metadata is supported now, please update" + ) print("your configuration and reload.") raise SystemExit(1) -PASSWORD_LOGIN_ENABLED = parse_boolean(os.environ.get("REDASH_PASSWORD_LOGIN_ENABLED", "true")) +PASSWORD_LOGIN_ENABLED = parse_boolean( + os.environ.get("REDASH_PASSWORD_LOGIN_ENABLED", "true") +) SAML_METADATA_URL = os.environ.get("REDASH_SAML_METADATA_URL", "") SAML_ENTITY_ID = os.environ.get("REDASH_SAML_ENTITY_ID", "") @@ -20,19 +23,26 @@ TIME_FORMAT = os.environ.get("REDASH_TIME_FORMAT", "HH:mm") INTEGER_FORMAT = os.environ.get("REDASH_INTEGER_FORMAT", "0,0") FLOAT_FORMAT = os.environ.get("REDASH_FLOAT_FORMAT", "0,0.00") -MULTI_BYTE_SEARCH_ENABLED = parse_boolean(os.environ.get("MULTI_BYTE_SEARCH_ENABLED", "false")) +MULTI_BYTE_SEARCH_ENABLED = parse_boolean( + os.environ.get("MULTI_BYTE_SEARCH_ENABLED", "false") +) JWT_LOGIN_ENABLED = parse_boolean(os.environ.get("REDASH_JWT_LOGIN_ENABLED", "false")) JWT_AUTH_ISSUER = os.environ.get("REDASH_JWT_AUTH_ISSUER", "") JWT_AUTH_PUBLIC_CERTS_URL = os.environ.get("REDASH_JWT_AUTH_PUBLIC_CERTS_URL", "") JWT_AUTH_AUDIENCE = os.environ.get("REDASH_JWT_AUTH_AUDIENCE", "") -JWT_AUTH_ALGORITHMS = os.environ.get("REDASH_JWT_AUTH_ALGORITHMS", "HS256,RS256,ES256").split(',') +JWT_AUTH_ALGORITHMS = os.environ.get( + "REDASH_JWT_AUTH_ALGORITHMS", "HS256,RS256,ES256" +).split(",") JWT_AUTH_COOKIE_NAME = os.environ.get("REDASH_JWT_AUTH_COOKIE_NAME", "") JWT_AUTH_HEADER_NAME = os.environ.get("REDASH_JWT_AUTH_HEADER_NAME", "") -FEATURE_SHOW_PERMISSIONS_CONTROL = parse_boolean(os.environ.get("REDASH_FEATURE_SHOW_PERMISSIONS_CONTROL", "false")) +FEATURE_SHOW_PERMISSIONS_CONTROL = parse_boolean( + os.environ.get("REDASH_FEATURE_SHOW_PERMISSIONS_CONTROL", "false") +) SEND_EMAIL_ON_FAILED_SCHEDULED_QUERIES = parse_boolean( - os.environ.get('REDASH_SEND_EMAIL_ON_FAILED_SCHEDULED_QUERIES', 'false')) + os.environ.get("REDASH_SEND_EMAIL_ON_FAILED_SCHEDULED_QUERIES", "false") +) settings = { "beacon_consent": None, @@ -54,5 +64,5 @@ "auth_jwt_auth_cookie_name": JWT_AUTH_COOKIE_NAME, "auth_jwt_auth_header_name": JWT_AUTH_HEADER_NAME, "feature_show_permissions_control": FEATURE_SHOW_PERMISSIONS_CONTROL, - "send_email_on_failed_scheduled_queries": SEND_EMAIL_ON_FAILED_SCHEDULED_QUERIES + "send_email_on_failed_scheduled_queries": SEND_EMAIL_ON_FAILED_SCHEDULED_QUERIES, } diff --git a/redash/tasks/__init__.py b/redash/tasks/__init__.py index 6f0d590514..e485c3ebf1 100644 --- a/redash/tasks/__init__.py +++ b/redash/tasks/__init__.py @@ -1,5 +1,18 @@ -from .general import record_event, version_check, send_mail, sync_user_details, purge_failed_jobs -from .queries import (QueryTask, enqueue_query, execute_query, refresh_queries, - refresh_schemas, cleanup_query_results, empty_schedules) +from .general import ( + record_event, + version_check, + send_mail, + sync_user_details, + purge_failed_jobs, +) +from .queries import ( + QueryTask, + enqueue_query, + execute_query, + refresh_queries, + refresh_schemas, + cleanup_query_results, + empty_schedules, +) from .alerts import check_alerts_for_query from .failure_report import send_aggregated_errors diff --git a/redash/tasks/alerts.py b/redash/tasks/alerts.py index ea2d26fb48..f2dcbb02ed 100644 --- a/redash/tasks/alerts.py +++ b/redash/tasks/alerts.py @@ -11,7 +11,9 @@ def notify_subscriptions(alert, new_state): host = utils.base_url(alert.query_rel.org) for subscription in alert.subscriptions: try: - subscription.notify(alert, alert.query_rel, subscription.user, new_state, current_app, host) + subscription.notify( + alert, alert.query_rel, subscription.user, new_state, current_app, host + ) except Exception as e: logger.exception("Error with processing destination") @@ -19,12 +21,17 @@ def notify_subscriptions(alert, new_state): def should_notify(alert, new_state): passed_rearm_threshold = False if alert.rearm and alert.last_triggered_at: - passed_rearm_threshold = alert.last_triggered_at + datetime.timedelta(seconds=alert.rearm) < utils.utcnow() + passed_rearm_threshold = ( + alert.last_triggered_at + datetime.timedelta(seconds=alert.rearm) + < utils.utcnow() + ) - return new_state != alert.state or (alert.state == models.Alert.TRIGGERED_STATE and passed_rearm_threshold) + return new_state != alert.state or ( + alert.state == models.Alert.TRIGGERED_STATE and passed_rearm_threshold + ) -@job('default', timeout=300) +@job("default", timeout=300) def check_alerts_for_query(query_id): logger.debug("Checking query %d for alerts", query_id) @@ -42,8 +49,13 @@ def check_alerts_for_query(query_id): alert.last_triggered_at = utils.utcnow() models.db.session.commit() - if old_state == models.Alert.UNKNOWN_STATE and new_state == models.Alert.OK_STATE: - logger.debug("Skipping notification (previous state was unknown and now it's ok).") + if ( + old_state == models.Alert.UNKNOWN_STATE + and new_state == models.Alert.OK_STATE + ): + logger.debug( + "Skipping notification (previous state was unknown and now it's ok)." + ) continue if alert.muted: diff --git a/redash/tasks/failure_report.py b/redash/tasks/failure_report.py index 398803ac3b..3e543372d2 100644 --- a/redash/tasks/failure_report.py +++ b/redash/tasks/failure_report.py @@ -8,22 +8,22 @@ def key(user_id): - return 'aggregated_failures:{}'.format(user_id) + return "aggregated_failures:{}".format(user_id) def comment_for(failure): - schedule_failures = failure.get('schedule_failures') + schedule_failures = failure.get("schedule_failures") if schedule_failures > settings.MAX_FAILURE_REPORTS_PER_QUERY * 0.75: return """NOTICE: This query has failed a total of {schedule_failures} times. Reporting may stop when the query exceeds {max_failure_reports} overall failures.""".format( schedule_failures=schedule_failures, - max_failure_reports=settings.MAX_FAILURE_REPORTS_PER_QUERY + max_failure_reports=settings.MAX_FAILURE_REPORTS_PER_QUERY, ) def send_aggregated_errors(): for k in redis_connection.scan_iter(key("*")): - user_id = re.search(r'\d+', k).group() + user_id = re.search(r"\d+", k).group() send_failure_report(user_id) @@ -33,23 +33,31 @@ def send_failure_report(user_id): if errors: errors.reverse() - occurrences = Counter((e.get('id'), e.get('message')) for e in errors) - unique_errors = {(e.get('id'), e.get('message')): e for e in errors} + occurrences = Counter((e.get("id"), e.get("message")) for e in errors) + unique_errors = {(e.get("id"), e.get("message")): e for e in errors} context = { - 'failures': [{ - 'id': v.get('id'), - 'name': v.get('name'), - 'failed_at': v.get('failed_at'), - 'failure_reason': v.get('message'), - 'failure_count': occurrences[k], - 'comment': comment_for(v) - } for k, v in unique_errors.items()], - 'base_url': base_url(user.org) + "failures": [ + { + "id": v.get("id"), + "name": v.get("name"), + "failed_at": v.get("failed_at"), + "failure_reason": v.get("message"), + "failure_count": occurrences[k], + "comment": comment_for(v), + } + for k, v in unique_errors.items() + ], + "base_url": base_url(user.org), } - subject = "Redash failed to execute {} of your scheduled queries".format(len(unique_errors.keys())) - html, text = [render_template('emails/failures.{}'.format(f), context) for f in ['html', 'txt']] + subject = "Redash failed to execute {} of your scheduled queries".format( + len(unique_errors.keys()) + ) + html, text = [ + render_template("emails/failures.{}".format(f), context) + for f in ["html", "txt"] + ] send_mail.delay([user.email], subject, html, text) @@ -57,17 +65,26 @@ def send_failure_report(user_id): def notify_of_failure(message, query): - subscribed = query.org.get_setting('send_email_on_failed_scheduled_queries') - exceeded_threshold = query.schedule_failures >= settings.MAX_FAILURE_REPORTS_PER_QUERY + subscribed = query.org.get_setting("send_email_on_failed_scheduled_queries") + exceeded_threshold = ( + query.schedule_failures >= settings.MAX_FAILURE_REPORTS_PER_QUERY + ) if subscribed and not query.user.is_disabled and not exceeded_threshold: - redis_connection.lpush(key(query.user.id), json_dumps({ - 'id': query.id, - 'name': query.name, - 'message': message, - 'schedule_failures': query.schedule_failures, - 'failed_at': datetime.datetime.utcnow().strftime("%B %d, %Y %I:%M%p UTC") - })) + redis_connection.lpush( + key(query.user.id), + json_dumps( + { + "id": query.id, + "name": query.name, + "message": message, + "schedule_failures": query.schedule_failures, + "failed_at": datetime.datetime.utcnow().strftime( + "%B %d, %Y %I:%M%p UTC" + ), + } + ), + ) def track_failure(query, error): diff --git a/redash/tasks/general.py b/redash/tasks/general.py index f4927ba11b..49a5e2d0b1 100644 --- a/redash/tasks/general.py +++ b/redash/tasks/general.py @@ -13,7 +13,7 @@ logger = get_job_logger(__name__) -@job('default') +@job("default") def record_event(raw_event): event = models.Event.record(raw_event) models.db.session.commit() @@ -23,7 +23,7 @@ def record_event(raw_event): try: data = { "schema": "iglu:io.redash.webhooks/event/jsonschema/1-0-0", - "data": event.to_dict() + "data": event.to_dict(), } response = requests.post(hook, json=data) if response.status_code != 200: @@ -36,30 +36,31 @@ def version_check(): run_version_check() -@job('default') +@job("default") def subscribe(form): - logger.info("Subscribing to: [security notifications=%s], [newsletter=%s]", form['security_notifications'], form['newsletter']) + logger.info( + "Subscribing to: [security notifications=%s], [newsletter=%s]", + form["security_notifications"], + form["newsletter"], + ) data = { - 'admin_name': form['name'], - 'admin_email': form['email'], - 'org_name': form['org_name'], - 'security_notifications': form['security_notifications'], - 'newsletter': form['newsletter'] + "admin_name": form["name"], + "admin_email": form["email"], + "org_name": form["org_name"], + "security_notifications": form["security_notifications"], + "newsletter": form["newsletter"], } - requests.post('https://beacon.redash.io/subscribe', json=data) + requests.post("https://beacon.redash.io/subscribe", json=data) -@job('emails') +@job("emails") def send_mail(to, subject, html, text): try: - message = Message(recipients=to, - subject=subject, - html=html, - body=text) + message = Message(recipients=to, subject=subject, html=html, body=text) mail.send(message) except Exception: - logger.exception('Failed sending message: %s', message.subject) + logger.exception("Failed sending message: %s", message.subject) def sync_user_details(): @@ -71,10 +72,20 @@ def purge_failed_jobs(): for queue in Queue.all(): failed_job_ids = FailedJobRegistry(queue=queue).get_job_ids() failed_jobs = Job.fetch_many(failed_job_ids, rq_redis_connection) - stale_jobs = [job for job in failed_jobs if job and (datetime.utcnow() - job.ended_at).seconds > settings.JOB_DEFAULT_FAILURE_TTL] + stale_jobs = [ + job + for job in failed_jobs + if job + and (datetime.utcnow() - job.ended_at).seconds + > settings.JOB_DEFAULT_FAILURE_TTL + ] for job in stale_jobs: job.delete() if stale_jobs: - logger.info('Purged %d old failed jobs from the %s queue.', len(stale_jobs), queue.name) + logger.info( + "Purged %d old failed jobs from the %s queue.", + len(stale_jobs), + queue.name, + ) diff --git a/redash/tasks/queries/__init__.py b/redash/tasks/queries/__init__.py index 120091952a..2eb3fe8529 100644 --- a/redash/tasks/queries/__init__.py +++ b/redash/tasks/queries/__init__.py @@ -1,2 +1,7 @@ -from .maintenance import refresh_queries, refresh_schemas, cleanup_query_results, empty_schedules +from .maintenance import ( + refresh_queries, + refresh_schemas, + cleanup_query_results, + empty_schedules, +) from .execution import QueryTask, execute_query, enqueue_query diff --git a/redash/tasks/queries/execution.py b/redash/tasks/queries/execution.py index eead672319..463a5b7822 100644 --- a/redash/tasks/queries/execution.py +++ b/redash/tasks/queries/execution.py @@ -28,13 +28,7 @@ def _unlock(query_hash, data_source_id): class QueryTask(object): # TODO: this is mapping to the old Job class statuses. Need to update the client side and remove this - STATUSES = { - 'PENDING': 1, - 'STARTED': 2, - 'SUCCESS': 3, - 'FAILURE': 4, - 'REVOKED': 4 - } + STATUSES = {"PENDING": 1, "STARTED": 2, "SUCCESS": 3, "FAILURE": 4, "REVOKED": 4} def __init__(self, job_id=None, async_result=None): if async_result: @@ -48,9 +42,9 @@ def id(self): def to_dict(self): task_info = self._async_result._get_task_meta() - result, task_status = task_info['result'], task_info['status'] - if task_status == 'STARTED': - updated_at = result.get('start_time', 0) + result, task_status = task_info["result"], task_info["status"] + if task_status == "STARTED": + updated_at = result.get("start_time", 0) else: updated_at = 0 @@ -62,27 +56,27 @@ def to_dict(self): elif isinstance(result, Exception): error = str(result) status = 4 - elif task_status == 'REVOKED': - error = 'Query execution cancelled.' + elif task_status == "REVOKED": + error = "Query execution cancelled." else: - error = '' + error = "" - if task_status == 'SUCCESS' and not error: + if task_status == "SUCCESS" and not error: query_result_id = result else: query_result_id = None return { - 'id': self._async_result.id, - 'updated_at': updated_at, - 'status': status, - 'error': error, - 'query_result_id': query_result_id, + "id": self._async_result.id, + "updated_at": updated_at, + "status": status, + "error": error, + "query_result_id": query_result_id, } @property def is_cancelled(self): - return self._async_result.status == 'REVOKED' + return self._async_result.status == "REVOKED" @property def celery_status(self): @@ -92,10 +86,12 @@ def ready(self): return self._async_result.ready() def cancel(self): - return self._async_result.revoke(terminate=True, signal='SIGINT') + return self._async_result.revoke(terminate=True, signal="SIGINT") -def enqueue_query(query, data_source, user_id, is_api_key=False, scheduled_query=None, metadata={}): +def enqueue_query( + query, data_source, user_id, is_api_key=False, scheduled_query=None, metadata={} +): query_hash = gen_query_hash(query) logging.info("Inserting job for %s with metadata=%s", query_hash, metadata) try_count = 0 @@ -114,7 +110,11 @@ def enqueue_query(query, data_source, user_id, is_api_key=False, scheduled_query job = QueryTask(job_id=job_id) if job.ready(): - logging.info("[%s] job found is ready (%s), removing lock", query_hash, job.celery_status) + logging.info( + "[%s] job found is ready (%s), removing lock", + query_hash, + job.celery_status, + ) redis_connection.delete(_job_lock_id(query_hash, data_source.id)) job = None @@ -128,26 +128,43 @@ def enqueue_query(query, data_source, user_id, is_api_key=False, scheduled_query queue_name = data_source.queue_name scheduled_query_id = None - args = (query, data_source.id, metadata, user_id, scheduled_query_id, is_api_key) - argsrepr = json_dumps({ - 'org_id': data_source.org_id, - 'data_source_id': data_source.id, - 'enqueue_time': time.time(), - 'scheduled': scheduled_query_id is not None, - 'query_id': metadata.get('Query ID'), - 'user_id': user_id - }) - - time_limit = settings.dynamic_settings.query_time_limit(scheduled_query, user_id, data_source.org_id) - - result = execute_query.apply_async(args=args, - argsrepr=argsrepr, - queue=queue_name, - soft_time_limit=time_limit) + args = ( + query, + data_source.id, + metadata, + user_id, + scheduled_query_id, + is_api_key, + ) + argsrepr = json_dumps( + { + "org_id": data_source.org_id, + "data_source_id": data_source.id, + "enqueue_time": time.time(), + "scheduled": scheduled_query_id is not None, + "query_id": metadata.get("Query ID"), + "user_id": user_id, + } + ) + + time_limit = settings.dynamic_settings.query_time_limit( + scheduled_query, user_id, data_source.org_id + ) + + result = execute_query.apply_async( + args=args, + argsrepr=argsrepr, + queue=queue_name, + soft_time_limit=time_limit, + ) job = QueryTask(async_result=result) logging.info("[%s] Created new job: %s", query_hash, job.id) - pipe.set(_job_lock_id(query_hash, data_source.id), job.id, settings.JOB_EXPIRY_TIME) + pipe.set( + _job_lock_id(query_hash, data_source.id), + job.id, + settings.JOB_EXPIRY_TIME, + ) pipe.execute() break @@ -187,14 +204,22 @@ def _resolve_user(user_id, is_api_key, query_id): # We could have created this as a celery.Task derived class, and act as the task itself. But this might result in weird # issues as the task class created once per process, so decided to have a plain object instead. class QueryExecutor(object): - def __init__(self, task, query, data_source_id, user_id, is_api_key, metadata, - scheduled_query): + def __init__( + self, + task, + query, + data_source_id, + user_id, + is_api_key, + metadata, + scheduled_query, + ): self.task = task self.query = query self.data_source_id = data_source_id self.metadata = metadata self.data_source = self._load_data_source() - self.user = _resolve_user(user_id, is_api_key, metadata.get('Query ID')) + self.user = _resolve_user(user_id, is_api_key, metadata.get("Query ID")) # Close DB connection to prevent holding a connection for a long time while the query is executing. models.db.session.close() @@ -209,7 +234,7 @@ def run(self): started_at = time.time() logger.debug("Executing query:\n%s", self.query) - self._log_progress('executing_query') + self._log_progress("executing_query") query_runner = self.data_source.query_runner annotated_query = self._annotate_query(query_runner) @@ -223,49 +248,62 @@ def run(self): error = text_type(e) data = None - logging.warning('Unexpected error while running query:', exc_info=1) + logging.warning("Unexpected error while running query:", exc_info=1) run_time = time.time() - started_at - logger.info("task=execute_query query_hash=%s data_length=%s error=[%s]", - self.query_hash, data and len(data), error) + logger.info( + "task=execute_query query_hash=%s data_length=%s error=[%s]", + self.query_hash, + data and len(data), + error, + ) _unlock(self.query_hash, self.data_source.id) if error is not None and data is None: result = QueryExecutionError(error) if self.scheduled_query is not None: - self.scheduled_query = models.db.session.merge(self.scheduled_query, load=False) + self.scheduled_query = models.db.session.merge( + self.scheduled_query, load=False + ) track_failure(self.scheduled_query, error) raise result else: - if (self.scheduled_query and self.scheduled_query.schedule_failures > 0): - self.scheduled_query = models.db.session.merge(self.scheduled_query, load=False) + if self.scheduled_query and self.scheduled_query.schedule_failures > 0: + self.scheduled_query = models.db.session.merge( + self.scheduled_query, load=False + ) self.scheduled_query.schedule_failures = 0 models.db.session.add(self.scheduled_query) query_result = models.QueryResult.store_result( - self.data_source.org_id, self.data_source, - self.query_hash, self.query, data, - run_time, utcnow()) + self.data_source.org_id, + self.data_source, + self.query_hash, + self.query, + data, + run_time, + utcnow(), + ) updated_query_ids = models.Query.update_latest_result(query_result) models.db.session.commit() # make sure that alert sees the latest query result - self._log_progress('checking_alerts') + self._log_progress("checking_alerts") for query_id in updated_query_ids: check_alerts_for_query.delay(query_id) - self._log_progress('finished') + self._log_progress("finished") result = query_result.id models.db.session.commit() return result def _annotate_query(self, query_runner): - self.metadata['Task ID'] = self.task.request.id - self.metadata['Query Hash'] = self.query_hash - self.metadata['Queue'] = self.task.request.delivery_info['routing_key'] - self.metadata['Scheduled'] = self.scheduled_query is not None + self.metadata["Task ID"] = self.task.request.id + self.metadata["Query Hash"] = self.query_hash + self.metadata["Queue"] = self.task.request.delivery_info["routing_key"] + self.metadata["Scheduled"] = self.scheduled_query is not None return query_runner.annotate_query(self.query, self.metadata) @@ -273,11 +311,15 @@ def _log_progress(self, state): logger.info( "task=execute_query state=%s query_hash=%s type=%s ds_id=%d " "task_id=%s queue=%s query_id=%s username=%s", - state, self.query_hash, self.data_source.type, self.data_source.id, + state, + self.query_hash, + self.data_source.type, + self.data_source.id, self.task.request.id, - self.task.request.delivery_info['routing_key'], - self.metadata.get('Query ID', 'unknown'), - self.metadata.get('Username', 'unknown')) + self.task.request.delivery_info["routing_key"], + self.metadata.get("Query ID", "unknown"), + self.metadata.get("Username", "unknown"), + ) def _load_data_source(self): logger.info("task=execute_query state=load_ds ds_id=%d", self.data_source_id) @@ -287,12 +329,20 @@ def _load_data_source(self): # user_id is added last as a keyword argument for backward compatability -- to support executing previously submitted # jobs before the upgrade to this version. @celery.task(name="redash.tasks.execute_query", bind=True, track_started=True) -def execute_query(self, query, data_source_id, metadata, user_id=None, - scheduled_query_id=None, is_api_key=False): +def execute_query( + self, + query, + data_source_id, + metadata, + user_id=None, + scheduled_query_id=None, + is_api_key=False, +): if scheduled_query_id is not None: scheduled_query = models.Query.query.get(scheduled_query_id) else: scheduled_query = None - return QueryExecutor(self, query, data_source_id, user_id, is_api_key, metadata, - scheduled_query).run() + return QueryExecutor( + self, query, data_source_id, user_id, is_api_key, metadata, scheduled_query + ).run() diff --git a/redash/tasks/queries/maintenance.py b/redash/tasks/queries/maintenance.py index 40d3155ed7..cf72068503 100644 --- a/redash/tasks/queries/maintenance.py +++ b/redash/tasks/queries/maintenance.py @@ -3,8 +3,10 @@ from celery.exceptions import SoftTimeLimitExceeded from redash import models, redis_connection, settings, statsd_client -from redash.models.parameterized_query import (InvalidParameterError, - QueryDetachedFromDataSourceError) +from redash.models.parameterized_query import ( + InvalidParameterError, + QueryDetachedFromDataSourceError, +) from redash.tasks.failure_report import track_failure from redash.utils import json_dumps from redash.worker import job, get_job_logger @@ -31,55 +33,79 @@ def refresh_queries(): outdated_queries_count = 0 query_ids = [] - with statsd_client.timer('manager.outdated_queries_lookup'): + with statsd_client.timer("manager.outdated_queries_lookup"): for query in models.Query.outdated_queries(): if settings.FEATURE_DISABLE_REFRESH_QUERIES: logging.info("Disabled refresh queries.") elif query.org.is_disabled: - logging.debug("Skipping refresh of %s because org is disabled.", query.id) + logging.debug( + "Skipping refresh of %s because org is disabled.", query.id + ) elif query.data_source is None: - logging.debug("Skipping refresh of %s because the datasource is none.", query.id) + logging.debug( + "Skipping refresh of %s because the datasource is none.", query.id + ) elif query.data_source.paused: - logging.debug("Skipping refresh of %s because datasource - %s is paused (%s).", - query.id, query.data_source.name, query.data_source.pause_reason) + logging.debug( + "Skipping refresh of %s because datasource - %s is paused (%s).", + query.id, + query.data_source.name, + query.data_source.pause_reason, + ) else: query_text = query.query_text - parameters = {p['name']: p.get('value') for p in query.parameters} + parameters = {p["name"]: p.get("value") for p in query.parameters} if any(parameters): try: query_text = query.parameterized.apply(parameters).query except InvalidParameterError as e: - error = u"Skipping refresh of {} because of invalid parameters: {}".format(query.id, str(e)) + error = u"Skipping refresh of {} because of invalid parameters: {}".format( + query.id, str(e) + ) track_failure(query, error) continue except QueryDetachedFromDataSourceError as e: - error = ("Skipping refresh of {} because a related dropdown " - "query ({}) is unattached to any datasource.").format(query.id, e.query_id) + error = ( + "Skipping refresh of {} because a related dropdown " + "query ({}) is unattached to any datasource." + ).format(query.id, e.query_id) track_failure(query, error) continue - enqueue_query(query_text, query.data_source, query.user_id, - scheduled_query=query, - metadata={'Query ID': query.id, 'Username': 'Scheduled'}) + enqueue_query( + query_text, + query.data_source, + query.user_id, + scheduled_query=query, + metadata={"Query ID": query.id, "Username": "Scheduled"}, + ) query_ids.append(query.id) outdated_queries_count += 1 - statsd_client.gauge('manager.outdated_queries', outdated_queries_count) + statsd_client.gauge("manager.outdated_queries", outdated_queries_count) - logger.info("Done refreshing queries. Found %d outdated queries: %s" % (outdated_queries_count, query_ids)) + logger.info( + "Done refreshing queries. Found %d outdated queries: %s" + % (outdated_queries_count, query_ids) + ) - status = redis_connection.hgetall('redash:status') + status = redis_connection.hgetall("redash:status") now = time.time() - redis_connection.hmset('redash:status', { - 'outdated_queries_count': outdated_queries_count, - 'last_refresh_at': now, - 'query_ids': json_dumps(query_ids) - }) + redis_connection.hmset( + "redash:status", + { + "outdated_queries_count": outdated_queries_count, + "last_refresh_at": now, + "query_ids": json_dumps(query_ids), + }, + ) - statsd_client.gauge('manager.seconds_since_refresh', now - float(status.get('last_refresh_at', now))) + statsd_client.gauge( + "manager.seconds_since_refresh", now - float(status.get("last_refresh_at", now)) + ) def cleanup_query_results(): @@ -91,52 +117,88 @@ def cleanup_query_results(): the database in case of many such results. """ - logging.info("Running query results clean up (removing maximum of %d unused results, that are %d days old or more)", - settings.QUERY_RESULTS_CLEANUP_COUNT, settings.QUERY_RESULTS_CLEANUP_MAX_AGE) + logging.info( + "Running query results clean up (removing maximum of %d unused results, that are %d days old or more)", + settings.QUERY_RESULTS_CLEANUP_COUNT, + settings.QUERY_RESULTS_CLEANUP_MAX_AGE, + ) - unused_query_results = models.QueryResult.unused(settings.QUERY_RESULTS_CLEANUP_MAX_AGE) + unused_query_results = models.QueryResult.unused( + settings.QUERY_RESULTS_CLEANUP_MAX_AGE + ) deleted_count = models.QueryResult.query.filter( - models.QueryResult.id.in_(unused_query_results.limit(settings.QUERY_RESULTS_CLEANUP_COUNT).subquery()) + models.QueryResult.id.in_( + unused_query_results.limit(settings.QUERY_RESULTS_CLEANUP_COUNT).subquery() + ) ).delete(synchronize_session=False) models.db.session.commit() logger.info("Deleted %d unused query results.", deleted_count) -@job('schemas') +@job("schemas") def refresh_schema(data_source_id): ds = models.DataSource.get_by_id(data_source_id) logger.info(u"task=refresh_schema state=start ds_id=%s", ds.id) start_time = time.time() try: ds.get_schema(refresh=True) - logger.info(u"task=refresh_schema state=finished ds_id=%s runtime=%.2f", ds.id, time.time() - start_time) - statsd_client.incr('refresh_schema.success') + logger.info( + u"task=refresh_schema state=finished ds_id=%s runtime=%.2f", + ds.id, + time.time() - start_time, + ) + statsd_client.incr("refresh_schema.success") except SoftTimeLimitExceeded: - logger.info(u"task=refresh_schema state=timeout ds_id=%s runtime=%.2f", ds.id, time.time() - start_time) - statsd_client.incr('refresh_schema.timeout') + logger.info( + u"task=refresh_schema state=timeout ds_id=%s runtime=%.2f", + ds.id, + time.time() - start_time, + ) + statsd_client.incr("refresh_schema.timeout") except Exception: - logger.warning(u"Failed refreshing schema for the data source: %s", ds.name, exc_info=1) - statsd_client.incr('refresh_schema.error') - logger.info(u"task=refresh_schema state=failed ds_id=%s runtime=%.2f", ds.id, time.time() - start_time) + logger.warning( + u"Failed refreshing schema for the data source: %s", ds.name, exc_info=1 + ) + statsd_client.incr("refresh_schema.error") + logger.info( + u"task=refresh_schema state=failed ds_id=%s runtime=%.2f", + ds.id, + time.time() - start_time, + ) def refresh_schemas(): """ Refreshes the data sources schemas. """ - blacklist = [int(ds_id) for ds_id in redis_connection.smembers('data_sources:schema:blacklist') if ds_id] + blacklist = [ + int(ds_id) + for ds_id in redis_connection.smembers("data_sources:schema:blacklist") + if ds_id + ] global_start_time = time.time() logger.info(u"task=refresh_schemas state=start") for ds in models.DataSource.query: if ds.paused: - logger.info(u"task=refresh_schema state=skip ds_id=%s reason=paused(%s)", ds.id, ds.pause_reason) + logger.info( + u"task=refresh_schema state=skip ds_id=%s reason=paused(%s)", + ds.id, + ds.pause_reason, + ) elif ds.id in blacklist: - logger.info(u"task=refresh_schema state=skip ds_id=%s reason=blacklist", ds.id) + logger.info( + u"task=refresh_schema state=skip ds_id=%s reason=blacklist", ds.id + ) elif ds.org.is_disabled: - logger.info(u"task=refresh_schema state=skip ds_id=%s reason=org_disabled", ds.id) + logger.info( + u"task=refresh_schema state=skip ds_id=%s reason=org_disabled", ds.id + ) else: refresh_schema.delay(ds.id) - logger.info(u"task=refresh_schemas state=finish total_runtime=%.2f", time.time() - global_start_time) + logger.info( + u"task=refresh_schemas state=finish total_runtime=%.2f", + time.time() - global_start_time, + ) diff --git a/redash/utils/__init__.py b/redash/utils/__init__.py index 3302a41587..94f570f625 100644 --- a/redash/utils/__init__.py +++ b/redash/utils/__init__.py @@ -24,8 +24,8 @@ COMMENTS_REGEX = re.compile("/\*.*?\*/") -WRITER_ENCODING = os.environ.get('REDASH_CSV_WRITER_ENCODING', 'utf-8') -WRITER_ERRORS = os.environ.get('REDASH_CSV_WRITER_ERRORS', 'strict') +WRITER_ENCODING = os.environ.get("REDASH_CSV_WRITER_ENCODING", "utf-8") +WRITER_ERRORS = os.environ.get("REDASH_CSV_WRITER_ERRORS", "strict") def utcnow(): @@ -47,7 +47,7 @@ def dt_from_timestamp(timestamp, tz_aware=True): def slugify(s): - return re.sub('[^a-z0-9_\-]+', '-', s.lower()) + return re.sub("[^a-z0-9_\-]+", "-", s.lower()) def gen_query_hash(sql): @@ -60,16 +60,14 @@ def gen_query_hash(sql): """ sql = COMMENTS_REGEX.sub("", sql) sql = "".join(sql.split()).lower() - return hashlib.md5(sql.encode('utf-8')).hexdigest() + return hashlib.md5(sql.encode("utf-8")).hexdigest() def generate_token(length): - chars = ('abcdefghijklmnopqrstuvwxyz' - 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' - '0123456789') + chars = "abcdefghijklmnopqrstuvwxyz" "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "0123456789" rand = random.SystemRandom() - return ''.join(rand.choice(chars) for x in range(length)) + return "".join(rand.choice(chars) for x in range(length)) class JSONEncoder(simplejson.JSONEncoder): @@ -88,8 +86,8 @@ def default(self, o): result = o.isoformat() if o.microsecond: result = result[:23] + result[26:] - if result.endswith('+00:00'): - result = result[:-6] + 'Z' + if result.endswith("+00:00"): + result = result[:-6] + "Z" elif isinstance(o, datetime.date): result = o.isoformat() elif isinstance(o, datetime.time): @@ -116,8 +114,8 @@ def json_loads(data, *args, **kwargs): def json_dumps(data, *args, **kwargs): """A custom JSON dumping function which passes all parameters to the simplejson.dumps function.""" - kwargs.setdefault('cls', JSONEncoder) - kwargs.setdefault('encoding', None) + kwargs.setdefault("cls", JSONEncoder) + kwargs.setdefault("encoding", None) return simplejson.dumps(data, *args, **kwargs) @@ -127,11 +125,11 @@ def mustache_render(template, context=None, **kwargs): def build_url(request, host, path): - parts = request.host.split(':') + parts = request.host.split(":") if len(parts) > 1: port = parts[1] - if (port, request.scheme) not in (('80', 'http'), ('443', 'https')): - host = '{}:{}'.format(host, port) + if (port, request.scheme) not in (("80", "http"), ("443", "https")): + host = "{}:{}".format(host, port) return "{}://{}{}".format(request.scheme, host, path) @@ -176,7 +174,7 @@ def collect_parameters_from_request(args): parameters = {} for k, v in args.items(): - if k.startswith('p_'): + if k.startswith("p_"): parameters[k[2:]] = v return parameters @@ -201,7 +199,7 @@ def to_filename(s): def deprecated(): def wrapper(K): - setattr(K, 'deprecated', True) + setattr(K, "deprecated", True) return K return wrapper diff --git a/redash/utils/configuration.py b/redash/utils/configuration.py index c75d27a6c1..fd66dc1a53 100644 --- a/redash/utils/configuration.py +++ b/redash/utils/configuration.py @@ -5,7 +5,7 @@ from redash.utils import json_dumps, json_loads -SECRET_PLACEHOLDER = '--------' +SECRET_PLACEHOLDER = "--------" class ConfigurationContainer(Mutable): @@ -27,10 +27,10 @@ def __init__(self, config, schema=None): def set_schema(self, schema): configuration_schema = copy.deepcopy(schema) if isinstance(configuration_schema, dict): - for prop in configuration_schema.get('properties', {}).values(): - if 'extendedEnum' in prop: - prop['enum'] = map(lambda v: v['value'], prop['extendedEnum']) - del prop['extendedEnum'] + for prop in configuration_schema.get("properties", {}).values(): + if "extendedEnum" in prop: + prop["enum"] = map(lambda v: v["value"], prop["extendedEnum"]) + del prop["extendedEnum"] self._schema = configuration_schema @property @@ -58,12 +58,12 @@ def iteritems(self): return self._config.items() def to_dict(self, mask_secrets=False): - if mask_secrets is False or 'secret' not in self.schema: + if mask_secrets is False or "secret" not in self.schema: return self._config config = self._config.copy() for key in config: - if key in self.schema['secret']: + if key in self.schema["secret"]: config[key] = SECRET_PLACEHOLDER return config @@ -73,7 +73,7 @@ def update(self, new_config): config = {} for k, v in new_config.items(): - if k in self.schema.get('secret', []) and v == SECRET_PLACEHOLDER: + if k in self.schema.get("secret", []) and v == SECRET_PLACEHOLDER: config[k] = self[k] else: config[k] = v diff --git a/redash/utils/sentry.py b/redash/utils/sentry.py index ea49be4ada..bab88af661 100644 --- a/redash/utils/sentry.py +++ b/redash/utils/sentry.py @@ -7,12 +7,12 @@ from redash import settings, __version__ -NON_REPORTED_EXCEPTIONS = ['QueryExecutionError'] +NON_REPORTED_EXCEPTIONS = ["QueryExecutionError"] def before_send(event, hint): - if 'exc_info' in hint: - exc_type, exc_value, tb = hint['exc_info'] + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] if any([(e in str(type(exc_value))) for e in NON_REPORTED_EXCEPTIONS]): return None @@ -26,6 +26,11 @@ def init(): release=__version__, before_send=before_send, send_default_pii=True, - integrations=[FlaskIntegration(), CeleryIntegration(), SqlalchemyIntegration(), - RedisIntegration(), RqIntegration()] + integrations=[ + FlaskIntegration(), + CeleryIntegration(), + SqlalchemyIntegration(), + RedisIntegration(), + RqIntegration(), + ], ) diff --git a/redash/version_check.py b/redash/version_check.py index 0870460b8b..d4f83f6ac8 100644 --- a/redash/version_check.py +++ b/redash/version_check.py @@ -48,12 +48,20 @@ def usage_data(): data_sources_query = "SELECT type, count(0) FROM data_sources GROUP by 1" visualizations_query = "SELECT type, count(0) FROM visualizations GROUP by 1" - destinations_query = "SELECT type, count(0) FROM notification_destinations GROUP by 1" + destinations_query = ( + "SELECT type, count(0) FROM notification_destinations GROUP by 1" + ) data = {name: value for (name, value) in db.session.execute(counts_query)} - data['data_sources'] = {name: value for (name, value) in db.session.execute(data_sources_query)} - data['visualization_types'] = {name: value for (name, value) in db.session.execute(visualizations_query)} - data['destination_types'] = {name: value for (name, value) in db.session.execute(destinations_query)} + data["data_sources"] = { + name: value for (name, value) in db.session.execute(data_sources_query) + } + data["visualization_types"] = { + name: value for (name, value) in db.session.execute(visualizations_query) + } + data["destination_types"] = { + name: value for (name, value) in db.session.execute(destinations_query) + } return data @@ -62,23 +70,26 @@ def run_version_check(): logging.info("Performing version check.") logging.info("Current version: %s", current_version) - data = { - 'current_version': current_version - } + data = {"current_version": current_version} - if Organization.query.first().get_setting('beacon_consent'): - data['usage'] = usage_data() + if Organization.query.first().get_setting("beacon_consent"): + data["usage"] = usage_data() try: - response = requests.post('https://version.redash.io/api/report?channel=stable', - json=data, timeout=3.0) - latest_version = response.json()['release']['version'] + response = requests.post( + "https://version.redash.io/api/report?channel=stable", + json=data, + timeout=3.0, + ) + latest_version = response.json()["release"]["version"] _compare_and_update(latest_version) except requests.RequestException: logging.exception("Failed checking for new version.") except (ValueError, KeyError): - logging.exception("Failed checking for new version (probably bad/non-JSON response).") + logging.exception( + "Failed checking for new version (probably bad/non-JSON response)." + ) def reset_new_version_status(): diff --git a/redash/worker.py b/redash/worker.py index 90667d9133..655ea016e1 100644 --- a/redash/worker.py +++ b/redash/worker.py @@ -1,4 +1,3 @@ - from datetime import timedelta from functools import partial @@ -12,7 +11,13 @@ from rq import get_current_job from rq.decorators import job as rq_job -from redash import create_app, extensions, settings, redis_connection, rq_redis_connection +from redash import ( + create_app, + extensions, + settings, + redis_connection, + rq_redis_connection, +) from redash.metrics import celery as celery_metrics # noqa logger = get_logger(__name__) @@ -24,14 +29,14 @@ class CurrentJobFilter(logging.Filter): def filter(self, record): current_job = get_current_job() - record.job_id = current_job.id if current_job else '' - record.job_func_name = current_job.func_name if current_job else '' + record.job_id = current_job.id if current_job else "" + record.job_func_name = current_job.func_name if current_job else "" return True def get_job_logger(name): - logger = logging.getLogger('rq.job.' + name) + logger = logging.getLogger("rq.job." + name) handler = logging.StreamHandler() handler.formatter = logging.Formatter(settings.RQ_WORKER_JOB_LOG_FORMAT) @@ -43,22 +48,26 @@ def get_job_logger(name): return logger -celery = Celery('redash', - broker=settings.CELERY_BROKER, - broker_use_ssl=settings.CELERY_SSL_CONFIG, - redis_backend_use_ssl=settings.CELERY_SSL_CONFIG, - include='redash.tasks') - - -celery.conf.update(result_backend=settings.CELERY_RESULT_BACKEND, - timezone='UTC', - result_expires=settings.CELERY_RESULT_EXPIRES, - worker_log_format=settings.CELERYD_WORKER_LOG_FORMAT, - worker_task_log_format=settings.CELERYD_WORKER_TASK_LOG_FORMAT, - worker_prefetch_multiplier=settings.CELERY_WORKER_PREFETCH_MULTIPLIER, - accept_content=settings.CELERY_ACCEPT_CONTENT, - task_serializer=settings.CELERY_TASK_SERIALIZER, - result_serializer=settings.CELERY_RESULT_SERIALIZER) +celery = Celery( + "redash", + broker=settings.CELERY_BROKER, + broker_use_ssl=settings.CELERY_SSL_CONFIG, + redis_backend_use_ssl=settings.CELERY_SSL_CONFIG, + include="redash.tasks", +) + + +celery.conf.update( + result_backend=settings.CELERY_RESULT_BACKEND, + timezone="UTC", + result_expires=settings.CELERY_RESULT_EXPIRES, + worker_log_format=settings.CELERYD_WORKER_LOG_FORMAT, + worker_task_log_format=settings.CELERYD_WORKER_TASK_LOG_FORMAT, + worker_prefetch_multiplier=settings.CELERY_WORKER_PREFETCH_MULTIPLIER, + accept_content=settings.CELERY_ACCEPT_CONTENT, + task_serializer=settings.CELERY_TASK_SERIALIZER, + result_serializer=settings.CELERY_RESULT_SERIALIZER, +) # Create a new Task base class, that pushes a new Flask app context to allow DB connections if needed. TaskBase = celery.Task diff --git a/tests/__init__.py b/tests/__init__.py index 566052c2af..7b02c5bd22 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -4,17 +4,21 @@ from unittest import TestCase from contextlib import contextmanager -os.environ['REDASH_REDIS_URL'] = os.environ.get('REDASH_REDIS_URL', "redis://localhost:6379/0").replace("/0", "/5") +os.environ["REDASH_REDIS_URL"] = os.environ.get( + "REDASH_REDIS_URL", "redis://localhost:6379/0" +).replace("/0", "/5") # Use different url for Celery to avoid DB being cleaned up: -os.environ['REDASH_CELERY_BROKER'] = os.environ.get('REDASH_REDIS_URL', "redis://localhost:6379/0").replace("/5", "/6") +os.environ["REDASH_CELERY_BROKER"] = os.environ.get( + "REDASH_REDIS_URL", "redis://localhost:6379/0" +).replace("/5", "/6") # Dummy values for oauth login -os.environ['REDASH_GOOGLE_CLIENT_ID'] = "dummy" -os.environ['REDASH_GOOGLE_CLIENT_SECRET'] = "dummy" -os.environ['REDASH_MULTI_ORG'] = "true" +os.environ["REDASH_GOOGLE_CLIENT_ID"] = "dummy" +os.environ["REDASH_GOOGLE_CLIENT_SECRET"] = "dummy" +os.environ["REDASH_MULTI_ORG"] = "true" # Make sure rate limit is enabled -os.environ['REDASH_RATELIMIT_ENABLED'] = "true" +os.environ["REDASH_RATELIMIT_ENABLED"] = "true" from redash import limiter, redis_connection from redash.app import create_app @@ -29,7 +33,7 @@ def authenticate_request(c, user): with c.session_transaction() as sess: - sess['user_id'] = user.get_id() + sess["user_id"] = user.get_id() @contextmanager @@ -46,7 +50,7 @@ class BaseTestCase(TestCase): def setUp(self): self.app = create_app() self.db = db - self.app.config['TESTING'] = True + self.app.config["TESTING"] = True limiter.enabled = False self.app_ctx = self.app.app_context() self.app_ctx.push() @@ -62,8 +66,16 @@ def tearDown(self): self.app_ctx.pop() redis_connection.flushdb() - def make_request(self, method, path, org=None, user=None, data=None, - is_json=True, follow_redirects=False): + def make_request( + self, + method, + path, + org=None, + user=None, + data=None, + is_json=True, + follow_redirects=False, + ): if user is None: user = self.factory.user @@ -83,7 +95,7 @@ def make_request(self, method, path, org=None, user=None, data=None, data = json_dumps(data) if is_json: - content_type = 'application/json' + content_type = "application/json" else: content_type = None @@ -110,8 +122,9 @@ def post_request(self, path, data=None, org=None, headers=None): def assertResponseEqual(self, expected, actual): for k, v in expected.items(): - if isinstance(v, datetime.datetime) or isinstance(actual[k], - datetime.datetime): + if isinstance(v, datetime.datetime) or isinstance( + actual[k], datetime.datetime + ): continue if isinstance(v, list): @@ -121,4 +134,8 @@ def assertResponseEqual(self, expected, actual): self.assertResponseEqual(v, actual[k]) continue - self.assertEqual(v, actual[k], "{} not equal (expected: {}, actual: {}).".format(k, v, actual[k])) + self.assertEqual( + v, + actual[k], + "{} not equal (expected: {}, actual: {}).".format(k, v, actual[k]), + ) diff --git a/tests/extensions/redash-dummy/setup.py b/tests/extensions/redash-dummy/setup.py index 7ecef3e50f..e1e79a96cb 100644 --- a/tests/extensions/redash-dummy/setup.py +++ b/tests/extensions/redash-dummy/setup.py @@ -15,9 +15,7 @@ "not_importable_extension = missing_extension_module:extension", "assertive_extension = redash_dummy:assertive_extension", ], - "redash.periodic_tasks": [ - "dummy_periodic_task = redash_dummy:periodic_task" - ], + "redash.periodic_tasks": ["dummy_periodic_task = redash_dummy:periodic_task"], }, py_modules=["redash_dummy"], ) diff --git a/tests/factories.py b/tests/factories.py index 2915cc7d4a..e46ca9c0ad 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -40,107 +40,135 @@ def __call__(self): return self.string.format(self.sequence) -user_factory = ModelFactory(redash.models.User, - name='John Doe', email=Sequence('test{}@example.com'), - password_hash=pwd_context.encrypt('test1234'), - group_ids=[2], - org_id=1) - -org_factory = ModelFactory(redash.models.Organization, - name=Sequence("Org {}"), - slug=Sequence("org{}.example.com"), - settings={}) - -data_source_factory = ModelFactory(redash.models.DataSource, - name=Sequence('Test {}'), - type='pg', - # If we don't use lambda here it will reuse the same options between tests: - options=lambda: ConfigurationContainer.from_json('{"dbname": "test"}'), - org_id=1) - -dashboard_factory = ModelFactory(redash.models.Dashboard, - name='test', - user=user_factory.create, - layout='[]', - is_draft=False, - org=1) - -api_key_factory = ModelFactory(redash.models.ApiKey, - object=dashboard_factory.create) - -query_factory = ModelFactory(redash.models.Query, - name='Query', - description='', - query_text='SELECT 1', - user=user_factory.create, - is_archived=False, - is_draft=False, - schedule=None, - data_source=data_source_factory.create, - org_id=1) - -query_with_params_factory = ModelFactory(redash.models.Query, - name='New Query with Params', - description='', - query_text='SELECT {{param1}}', - user=user_factory.create, - is_archived=False, - is_draft=False, - schedule={}, - data_source=data_source_factory.create, - org_id=1) - -access_permission_factory = ModelFactory(redash.models.AccessPermission, - object_id=query_factory.create, - object_type=redash.models.Query.__name__, - access_type=ACCESS_TYPE_MODIFY, - grantor=user_factory.create, - grantee=user_factory.create) - -alert_factory = ModelFactory(redash.models.Alert, - name=Sequence('Alert {}'), - query_rel=query_factory.create, - user=user_factory.create, - options={}) - -query_result_factory = ModelFactory(redash.models.QueryResult, - data='{"columns":{}, "rows":[]}', - runtime=1, - retrieved_at=utcnow, - query_text="SELECT 1", - query_hash=gen_query_hash('SELECT 1'), - data_source=data_source_factory.create, - org_id=1) - -visualization_factory = ModelFactory(redash.models.Visualization, - type='CHART', - query_rel=query_factory.create, - name='Chart', - description='', - options='{}') - -widget_factory = ModelFactory(redash.models.Widget, - width=1, - options='{}', - dashboard=dashboard_factory.create, - visualization=visualization_factory.create) - -destination_factory = ModelFactory(redash.models.NotificationDestination, - org_id=1, - user=user_factory.create, - name=Sequence('Destination {}'), - type='slack', - options=ConfigurationContainer.from_json('{"url": "https://www.slack.com"}')) - -alert_subscription_factory = ModelFactory(redash.models.AlertSubscription, - user=user_factory.create, - destination=destination_factory.create, - alert=alert_factory.create) - -query_snippet_factory = ModelFactory(redash.models.QuerySnippet, - trigger=Sequence('trigger {}'), - description='description', - snippet='snippet') +user_factory = ModelFactory( + redash.models.User, + name="John Doe", + email=Sequence("test{}@example.com"), + password_hash=pwd_context.encrypt("test1234"), + group_ids=[2], + org_id=1, +) + +org_factory = ModelFactory( + redash.models.Organization, + name=Sequence("Org {}"), + slug=Sequence("org{}.example.com"), + settings={}, +) + +data_source_factory = ModelFactory( + redash.models.DataSource, + name=Sequence("Test {}"), + type="pg", + # If we don't use lambda here it will reuse the same options between tests: + options=lambda: ConfigurationContainer.from_json('{"dbname": "test"}'), + org_id=1, +) + +dashboard_factory = ModelFactory( + redash.models.Dashboard, + name="test", + user=user_factory.create, + layout="[]", + is_draft=False, + org=1, +) + +api_key_factory = ModelFactory(redash.models.ApiKey, object=dashboard_factory.create) + +query_factory = ModelFactory( + redash.models.Query, + name="Query", + description="", + query_text="SELECT 1", + user=user_factory.create, + is_archived=False, + is_draft=False, + schedule=None, + data_source=data_source_factory.create, + org_id=1, +) + +query_with_params_factory = ModelFactory( + redash.models.Query, + name="New Query with Params", + description="", + query_text="SELECT {{param1}}", + user=user_factory.create, + is_archived=False, + is_draft=False, + schedule={}, + data_source=data_source_factory.create, + org_id=1, +) + +access_permission_factory = ModelFactory( + redash.models.AccessPermission, + object_id=query_factory.create, + object_type=redash.models.Query.__name__, + access_type=ACCESS_TYPE_MODIFY, + grantor=user_factory.create, + grantee=user_factory.create, +) + +alert_factory = ModelFactory( + redash.models.Alert, + name=Sequence("Alert {}"), + query_rel=query_factory.create, + user=user_factory.create, + options={}, +) + +query_result_factory = ModelFactory( + redash.models.QueryResult, + data='{"columns":{}, "rows":[]}', + runtime=1, + retrieved_at=utcnow, + query_text="SELECT 1", + query_hash=gen_query_hash("SELECT 1"), + data_source=data_source_factory.create, + org_id=1, +) + +visualization_factory = ModelFactory( + redash.models.Visualization, + type="CHART", + query_rel=query_factory.create, + name="Chart", + description="", + options="{}", +) + +widget_factory = ModelFactory( + redash.models.Widget, + width=1, + options="{}", + dashboard=dashboard_factory.create, + visualization=visualization_factory.create, +) + +destination_factory = ModelFactory( + redash.models.NotificationDestination, + org_id=1, + user=user_factory.create, + name=Sequence("Destination {}"), + type="slack", + options=ConfigurationContainer.from_json('{"url": "https://www.slack.com"}'), +) + +alert_subscription_factory = ModelFactory( + redash.models.AlertSubscription, + user=user_factory.create, + destination=destination_factory.create, + alert=alert_factory.create, +) + +query_snippet_factory = ModelFactory( + redash.models.QuerySnippet, + trigger=Sequence("trigger {}"), + description="description", + snippet="snippet", +) class Factory(object): @@ -162,49 +190,54 @@ def user(self): def data_source(self): if self._data_source is None: self._data_source = data_source_factory.create(org=self.org) - db.session.add(redash.models.DataSourceGroup( - group=self.default_group, - data_source=self._data_source)) + db.session.add( + redash.models.DataSourceGroup( + group=self.default_group, data_source=self._data_source + ) + ) return self._data_source def create_org(self, **kwargs): org = org_factory.create(**kwargs) - self.create_group(org=org, type=redash.models.Group.BUILTIN_GROUP, name="default") - self.create_group(org=org, type=redash.models.Group.BUILTIN_GROUP, name="admin", - permissions=["admin"]) + self.create_group( + org=org, type=redash.models.Group.BUILTIN_GROUP, name="default" + ) + self.create_group( + org=org, + type=redash.models.Group.BUILTIN_GROUP, + name="admin", + permissions=["admin"], + ) return org def create_user(self, **kwargs): - args = { - 'org': self.org, - 'group_ids': [self.default_group.id] - } + args = {"org": self.org, "group_ids": [self.default_group.id]} - if 'org' in kwargs: - args['group_ids'] = [kwargs['org'].default_group.id] + if "org" in kwargs: + args["group_ids"] = [kwargs["org"].default_group.id] args.update(kwargs) return user_factory.create(**args) def create_admin(self, **kwargs): args = { - 'org': self.org, - 'group_ids': [self.admin_group.id, self.default_group.id] + "org": self.org, + "group_ids": [self.admin_group.id, self.default_group.id], } - if 'org' in kwargs: - args['group_ids'] = [kwargs['org'].default_group.id, kwargs['org'].admin_group.id] + if "org" in kwargs: + args["group_ids"] = [ + kwargs["org"].default_group.id, + kwargs["org"].admin_group.id, + ] args.update(kwargs) return user_factory.create(**args) def create_group(self, **kwargs): - args = { - 'name': 'Group', - 'org': self.org - } + args = {"name": "Group", "org": self.org} args.update(kwargs) @@ -212,131 +245,98 @@ def create_group(self, **kwargs): return g def create_alert(self, **kwargs): - args = { - 'user': self.user, - 'query_rel': self.create_query() - } + args = {"user": self.user, "query_rel": self.create_query()} args.update(**kwargs) return alert_factory.create(**args) def create_alert_subscription(self, **kwargs): - args = { - 'user': self.user, - 'alert': self.create_alert() - } + args = {"user": self.user, "alert": self.create_alert()} args.update(**kwargs) return alert_subscription_factory.create(**args) def create_data_source(self, **kwargs): group = None - if 'group' in kwargs: - group = kwargs.pop('group') - args = { - 'org': self.org - } + if "group" in kwargs: + group = kwargs.pop("group") + args = {"org": self.org} args.update(kwargs) - if group and 'org' not in kwargs: - args['org'] = group.org + if group and "org" not in kwargs: + args["org"] = group.org - view_only = args.pop('view_only', False) + view_only = args.pop("view_only", False) data_source = data_source_factory.create(**args) if group: - db.session.add(redash.models.DataSourceGroup( - group=group, - data_source=data_source, - view_only=view_only)) + db.session.add( + redash.models.DataSourceGroup( + group=group, data_source=data_source, view_only=view_only + ) + ) return data_source def create_dashboard(self, **kwargs): - args = { - 'user': self.user, - 'org': self.org - } + args = {"user": self.user, "org": self.org} args.update(kwargs) return dashboard_factory.create(**args) def create_query(self, **kwargs): - args = { - 'user': self.user, - 'data_source': self.data_source, - 'org': self.org - } + args = {"user": self.user, "data_source": self.data_source, "org": self.org} args.update(kwargs) return query_factory.create(**args) def create_query_with_params(self, **kwargs): - args = { - 'user': self.user, - 'data_source': self.data_source, - 'org': self.org - } + args = {"user": self.user, "data_source": self.data_source, "org": self.org} args.update(kwargs) return query_with_params_factory.create(**args) def create_access_permission(self, **kwargs): - args = { - 'grantor': self.user - } + args = {"grantor": self.user} args.update(kwargs) return access_permission_factory.create(**args) def create_query_result(self, **kwargs): - args = { - 'data_source': self.data_source, - } + args = {"data_source": self.data_source} args.update(kwargs) - if 'data_source' in args and 'org' not in args: - args['org'] = args['data_source'].org + if "data_source" in args and "org" not in args: + args["org"] = args["data_source"].org return query_result_factory.create(**args) def create_visualization(self, **kwargs): - args = { - 'query_rel': self.create_query() - } + args = {"query_rel": self.create_query()} args.update(kwargs) return visualization_factory.create(**args) def create_visualization_with_params(self, **kwargs): - args = { - 'query_rel': self.create_query_with_params() - } + args = {"query_rel": self.create_query_with_params()} args.update(kwargs) return visualization_factory.create(**args) def create_widget(self, **kwargs): args = { - 'dashboard': self.create_dashboard(), - 'visualization': self.create_visualization() + "dashboard": self.create_dashboard(), + "visualization": self.create_visualization(), } args.update(kwargs) return widget_factory.create(**args) def create_api_key(self, **kwargs): - args = { - 'org': self.org - } + args = {"org": self.org} args.update(kwargs) return api_key_factory.create(**args) def create_destination(self, **kwargs): - args = { - 'org': self.org - } + args = {"org": self.org} args.update(kwargs) return destination_factory.create(**args) def create_query_snippet(self, **kwargs): - args = { - 'user': self.user, - 'org': self.org - } + args = {"user": self.user, "org": self.org} args.update(kwargs) return query_snippet_factory.create(**args) diff --git a/tests/handlers/test_alerts.py b/tests/handlers/test_alerts.py index 638595cb0d..e1449090dd 100644 --- a/tests/handlers/test_alerts.py +++ b/tests/handlers/test_alerts.py @@ -7,7 +7,7 @@ class TestAlertResourceGet(BaseTestCase): def test_returns_200_if_allowed(self): alert = self.factory.create_alert() - rv = self.make_request('get', "/api/alerts/{}".format(alert.id)) + rv = self.make_request("get", "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 200) def test_returns_403_if_not_allowed(self): @@ -15,7 +15,7 @@ def test_returns_403_if_not_allowed(self): query = self.factory.create_query(data_source=data_source) alert = self.factory.create_alert(query_rel=query) db.session.commit() - rv = self.make_request('get', "/api/alerts/{}".format(alert.id)) + rv = self.make_request("get", "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 403) def test_returns_404_if_admin_from_another_org(self): @@ -24,14 +24,21 @@ def test_returns_404_if_admin_from_another_org(self): alert = self.factory.create_alert() - rv = self.make_request('get', "/api/alerts/{}".format(alert.id), org=second_org, user=second_org_admin) + rv = self.make_request( + "get", + "/api/alerts/{}".format(alert.id), + org=second_org, + user=second_org_admin, + ) self.assertEqual(rv.status_code, 404) class TestAlertResourcePost(BaseTestCase): def test_updates_alert(self): alert = self.factory.create_alert() - rv = self.make_request('post', '/api/alerts/{}'.format(alert.id), data={"name": "Testing"}) + rv = self.make_request( + "post", "/api/alerts/{}".format(alert.id), data={"name": "Testing"} + ) class TestAlertResourceDelete(BaseTestCase): @@ -39,7 +46,7 @@ def test_removes_alert_and_subscriptions(self): subscription = self.factory.create_alert_subscription() alert = subscription.alert db.session.commit() - rv = self.make_request('delete', "/api/alerts/{}".format(alert.id)) + rv = self.make_request("delete", "/api/alerts/{}".format(alert.id)) self.assertEqual(rv.status_code, 200) self.assertEqual(Alert.query.get(subscription.alert.id), None) @@ -49,10 +56,14 @@ def test_returns_403_if_not_allowed(self): alert = self.factory.create_alert() user = self.factory.create_user() - rv = self.make_request('delete', "/api/alerts/{}".format(alert.id), user=user) + rv = self.make_request("delete", "/api/alerts/{}".format(alert.id), user=user) self.assertEqual(rv.status_code, 403) - rv = self.make_request('delete', "/api/alerts/{}".format(alert.id), user=self.factory.create_admin()) + rv = self.make_request( + "delete", + "/api/alerts/{}".format(alert.id), + user=self.factory.create_admin(), + ) self.assertEqual(rv.status_code, 200) def test_returns_404_for_unauthorized_users(self): @@ -60,29 +71,35 @@ def test_returns_404_for_unauthorized_users(self): second_org = self.factory.create_org() second_org_admin = self.factory.create_admin(org=second_org) - rv = self.make_request('delete', "/api/alerts/{}".format(alert.id), user=second_org_admin) + rv = self.make_request( + "delete", "/api/alerts/{}".format(alert.id), user=second_org_admin + ) self.assertEqual(rv.status_code, 404) class TestAlertListGet(BaseTestCase): def test_returns_all_alerts(self): alert = self.factory.create_alert() - rv = self.make_request('get', "/api/alerts") + rv = self.make_request("get", "/api/alerts") self.assertEqual(rv.status_code, 200) - alert_ids = [a['id'] for a in rv.json] + alert_ids = [a["id"] for a in rv.json] self.assertIn(alert.id, alert_ids) def test_returns_alerts_only_from_users_groups(self): alert = self.factory.create_alert() - query = self.factory.create_query(data_source=self.factory.create_data_source(group=self.factory.create_group())) + query = self.factory.create_query( + data_source=self.factory.create_data_source( + group=self.factory.create_group() + ) + ) alert2 = self.factory.create_alert(query_rel=query) - rv = self.make_request('get', "/api/alerts") + rv = self.make_request("get", "/api/alerts") self.assertEqual(rv.status_code, 200) - alert_ids = [a['id'] for a in rv.json] + alert_ids = [a["id"] for a in rv.json] self.assertIn(alert.id, alert_ids) self.assertNotIn(alert2.id, alert_ids) @@ -92,19 +109,35 @@ def test_returns_200_if_has_access_to_query(self): query = self.factory.create_query() destination = self.factory.create_destination() db.session.commit() - rv = self.make_request('post', "/api/alerts", data=dict(name='Alert', query_id=query.id, - destination_id=destination.id, options={}, - rearm=100)) + rv = self.make_request( + "post", + "/api/alerts", + data=dict( + name="Alert", + query_id=query.id, + destination_id=destination.id, + options={}, + rearm=100, + ), + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['rearm'], 100) + self.assertEqual(rv.json["rearm"], 100) def test_fails_if_doesnt_have_access_to_query(self): data_source = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=data_source) destination = self.factory.create_destination() db.session.commit() - rv = self.make_request('post', "/api/alerts", data=dict(name='Alert', query_id=query.id, - destination_id=destination.id, options={})) + rv = self.make_request( + "post", + "/api/alerts", + data=dict( + name="Alert", + query_id=query.id, + destination_id=destination.id, + options={}, + ), + ) self.assertEqual(rv.status_code, 403) @@ -113,7 +146,11 @@ def test_subscribers_user_to_alert(self): alert = self.factory.create_alert() destination = self.factory.create_destination() - rv = self.make_request('post', "/api/alerts/{}/subscriptions".format(alert.id), data=dict(destination_id=destination.id)) + rv = self.make_request( + "post", + "/api/alerts/{}/subscriptions".format(alert.id), + data=dict(destination_id=destination.id), + ) self.assertEqual(rv.status_code, 200) self.assertIn(self.factory.user, alert.subscribers()) @@ -123,7 +160,11 @@ def test_doesnt_subscribers_user_to_alert_without_access(self): alert = self.factory.create_alert(query_rel=query) destination = self.factory.create_destination() - rv = self.make_request('post', "/api/alerts/{}/subscriptions".format(alert.id), data=dict(destination_id=destination.id)) + rv = self.make_request( + "post", + "/api/alerts/{}/subscriptions".format(alert.id), + data=dict(destination_id=destination.id), + ) self.assertEqual(rv.status_code, 403) self.assertNotIn(self.factory.user, alert.subscribers()) @@ -132,7 +173,7 @@ class TestAlertSubscriptionListResourceGet(BaseTestCase): def test_returns_subscribers(self): alert = self.factory.create_alert() - rv = self.make_request('get', "/api/alerts/{}/subscriptions".format(alert.id)) + rv = self.make_request("get", "/api/alerts/{}/subscriptions".format(alert.id)) self.assertEqual(rv.status_code, 200) def test_doesnt_return_subscribers_when_not_allowed(self): @@ -140,7 +181,7 @@ def test_doesnt_return_subscribers_when_not_allowed(self): query = self.factory.create_query(data_source=data_source) alert = self.factory.create_alert(query_rel=query) - rv = self.make_request('get', "/api/alerts/{}/subscriptions".format(alert.id)) + rv = self.make_request("get", "/api/alerts/{}/subscriptions".format(alert.id)) self.assertEqual(rv.status_code, 403) @@ -149,22 +190,20 @@ def test_only_subscriber_or_admin_can_unsubscribe(self): subscription = self.factory.create_alert_subscription() alert = subscription.alert user = subscription.user - path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, - subscription.id) + path = "/api/alerts/{}/subscriptions/{}".format(alert.id, subscription.id) other_user = self.factory.create_user() - response = self.make_request('delete', path, user=other_user) + response = self.make_request("delete", path, user=other_user) self.assertEqual(response.status_code, 403) - response = self.make_request('delete', path, user=user) + response = self.make_request("delete", path, user=user) self.assertEqual(response.status_code, 200) subscription_two = AlertSubscription(alert=alert, user=other_user) admin_user = self.factory.create_admin() db.session.add_all([subscription_two, admin_user]) db.session.commit() - path = '/api/alerts/{}/subscriptions/{}'.format(alert.id, - subscription_two.id) - response = self.make_request('delete', path, user=admin_user) + path = "/api/alerts/{}/subscriptions/{}".format(alert.id, subscription_two.id) + response = self.make_request("delete", path, user=admin_user) self.assertEqual(response.status_code, 200) diff --git a/tests/handlers/test_authentication.py b/tests/handlers/test_authentication.py index e868821170..ed216be220 100644 --- a/tests/handlers/test_authentication.py +++ b/tests/handlers/test_authentication.py @@ -12,69 +12,97 @@ class TestResetPassword(BaseTestCase): def test_shows_reset_password_form(self): user = self.factory.create_user(is_invitation_pending=False) token = invite_token(user) - response = self.get_request('/reset/{}'.format(token), org=self.factory.org) + response = self.get_request("/reset/{}".format(token), org=self.factory.org) self.assertEqual(response.status_code, 200) class TestInvite(BaseTestCase): def test_expired_invite_token(self): - with mock.patch('time.time') as patched_time: + with mock.patch("time.time") as patched_time: patched_time.return_value = time.time() - (7 * 24 * 3600) - 10 token = invite_token(self.factory.user) - response = self.get_request('/invite/{}'.format(token), org=self.factory.org) + response = self.get_request("/invite/{}".format(token), org=self.factory.org) self.assertEqual(response.status_code, 400) def test_invalid_invite_token(self): - response = self.get_request('/invite/badtoken', org=self.factory.org) + response = self.get_request("/invite/badtoken", org=self.factory.org) self.assertEqual(response.status_code, 400) def test_valid_token(self): user = self.factory.create_user(is_invitation_pending=True) token = invite_token(user) - response = self.get_request('/invite/{}'.format(token), org=self.factory.org) + response = self.get_request("/invite/{}".format(token), org=self.factory.org) self.assertEqual(response.status_code, 200) def test_already_active_user(self): token = invite_token(self.factory.user) - self.post_request('/invite/{}'.format(token), data={'password': 'test1234'}, org=self.factory.org) - response = self.get_request('/invite/{}'.format(token), org=self.factory.org) + self.post_request( + "/invite/{}".format(token), + data={"password": "test1234"}, + org=self.factory.org, + ) + response = self.get_request("/invite/{}".format(token), org=self.factory.org) self.assertEqual(response.status_code, 400) class TestInvitePost(BaseTestCase): def test_empty_password(self): token = invite_token(self.factory.user) - response = self.post_request('/invite/{}'.format(token), data={'password': ''}, org=self.factory.org) + response = self.post_request( + "/invite/{}".format(token), data={"password": ""}, org=self.factory.org + ) self.assertEqual(response.status_code, 400) def test_invalid_password(self): token = invite_token(self.factory.user) - response = self.post_request('/invite/{}'.format(token), data={'password': '1234'}, org=self.factory.org) + response = self.post_request( + "/invite/{}".format(token), data={"password": "1234"}, org=self.factory.org + ) self.assertEqual(response.status_code, 400) def test_bad_token(self): - response = self.post_request('/invite/{}'.format('jdsnfkjdsnfkj'), data={'password': '1234'}, org=self.factory.org) + response = self.post_request( + "/invite/{}".format("jdsnfkjdsnfkj"), + data={"password": "1234"}, + org=self.factory.org, + ) self.assertEqual(response.status_code, 400) def test_user_invited_before_invitation_pending_check(self): user = self.factory.create_user(details={}) token = invite_token(user) - response = self.post_request('/invite/{}'.format(token), data={'password': 'test1234'}, org=self.factory.org) + response = self.post_request( + "/invite/{}".format(token), + data={"password": "test1234"}, + org=self.factory.org, + ) self.assertEqual(response.status_code, 302) def test_already_active_user(self): token = invite_token(self.factory.user) - self.post_request('/invite/{}'.format(token), data={'password': 'test1234'}, org=self.factory.org) - response = self.post_request('/invite/{}'.format(token), data={'password': 'test1234'}, org=self.factory.org) + self.post_request( + "/invite/{}".format(token), + data={"password": "test1234"}, + org=self.factory.org, + ) + response = self.post_request( + "/invite/{}".format(token), + data={"password": "test1234"}, + org=self.factory.org, + ) self.assertEqual(response.status_code, 400) def test_valid_password(self): user = self.factory.create_user(is_invitation_pending=True) token = invite_token(user) - password = 'test1234' - response = self.post_request('/invite/{}'.format(token), data={'password': password}, org=self.factory.org) + password = "test1234" + response = self.post_request( + "/invite/{}".format(token), + data={"password": password}, + org=self.factory.org, + ) self.assertEqual(response.status_code, 302) user = User.query.get(user.id) self.assertTrue(user.verify_password(password)) @@ -85,15 +113,17 @@ class TestLogin(BaseTestCase): def test_throttle_login(self): limiter.enabled = True # Extract the limit from settings (ex: '50/day') - limit = settings.THROTTLE_LOGIN_PATTERN.split('/')[0] + limit = settings.THROTTLE_LOGIN_PATTERN.split("/")[0] for _ in range(0, int(limit)): - self.get_request('/login', org=self.factory.org) + self.get_request("/login", org=self.factory.org) - response = self.get_request('/login', org=self.factory.org) + response = self.get_request("/login", org=self.factory.org) self.assertEqual(response.status_code, 429) class TestSession(BaseTestCase): # really simple test just to trigger this route def test_get(self): - self.make_request('get', '/default/api/session', user=self.factory.user, org=False) + self.make_request( + "get", "/default/api/session", user=self.factory.user, org=False + ) diff --git a/tests/handlers/test_dashboards.py b/tests/handlers/test_dashboards.py index 4207cb2bde..b60fe627a9 100644 --- a/tests/handlers/test_dashboards.py +++ b/tests/handlers/test_dashboards.py @@ -8,12 +8,12 @@ class TestDashboardListResource(BaseTestCase): def test_create_new_dashboard(self): - dashboard_name = 'Test Dashboard' - rv = self.make_request('post', '/api/dashboards', data={'name': dashboard_name}) + dashboard_name = "Test Dashboard" + rv = self.make_request("post", "/api/dashboards", data={"name": dashboard_name}) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], 'Test Dashboard') - self.assertEqual(rv.json['user_id'], self.factory.user.id) - self.assertEqual(rv.json['layout'], []) + self.assertEqual(rv.json["name"], "Test Dashboard") + self.assertEqual(rv.json["user_id"], self.factory.user.id) + self.assertEqual(rv.json["layout"], []) class TestDashboardListGetResource(BaseTestCase): @@ -22,34 +22,38 @@ def test_returns_dashboards(self): d2 = self.factory.create_dashboard() d3 = self.factory.create_dashboard() - rv = self.make_request('get', '/api/dashboards') + rv = self.make_request("get", "/api/dashboards") - assert len(rv.json['results']) == 3 - assert set([result['id'] for result in rv.json['results']]) == set([d1.id, d2.id, d3.id]) + assert len(rv.json["results"]) == 3 + assert set([result["id"] for result in rv.json["results"]]) == set( + [d1.id, d2.id, d3.id] + ) def test_filters_with_tags(self): - d1 = self.factory.create_dashboard(tags=['test']) + d1 = self.factory.create_dashboard(tags=["test"]) d2 = self.factory.create_dashboard() d3 = self.factory.create_dashboard() - rv = self.make_request('get', '/api/dashboards?tags=test') - assert len(rv.json['results']) == 1 - assert set([result['id'] for result in rv.json['results']]) == set([d1.id]) + rv = self.make_request("get", "/api/dashboards?tags=test") + assert len(rv.json["results"]) == 1 + assert set([result["id"] for result in rv.json["results"]]) == set([d1.id]) def test_search_term(self): d1 = self.factory.create_dashboard(name="Sales") d2 = self.factory.create_dashboard(name="Q1 sales") d3 = self.factory.create_dashboard(name="Ops") - rv = self.make_request('get', '/api/dashboards?q=sales') - assert len(rv.json['results']) == 2 - assert set([result['id'] for result in rv.json['results']]) == set([d1.id, d2.id]) + rv = self.make_request("get", "/api/dashboards?q=sales") + assert len(rv.json["results"]) == 2 + assert set([result["id"] for result in rv.json["results"]]) == set( + [d1.id, d2.id] + ) class TestDashboardResourceGet(BaseTestCase): def test_get_dashboard(self): d1 = self.factory.create_dashboard() - rv = self.make_request('get', '/api/dashboards/{0}'.format(d1.slug)) + rv = self.make_request("get", "/api/dashboards/{0}".format(d1.slug)) self.assertEqual(rv.status_code, 200) expected = serialize_dashboard(d1, with_widgets=True, with_favorite_state=False) @@ -60,50 +64,63 @@ def test_get_dashboard(self): def test_get_dashboard_filters_unauthorized_widgets(self): dashboard = self.factory.create_dashboard() - restricted_ds = self.factory.create_data_source(group=self.factory.create_group()) + restricted_ds = self.factory.create_data_source( + group=self.factory.create_group() + ) query = self.factory.create_query(data_source=restricted_ds) vis = self.factory.create_visualization(query_rel=query) - restricted_widget = self.factory.create_widget(visualization=vis, dashboard=dashboard) + restricted_widget = self.factory.create_widget( + visualization=vis, dashboard=dashboard + ) widget = self.factory.create_widget(dashboard=dashboard) - dashboard.layout = '[[{}, {}]]'.format(widget.id, restricted_widget.id) + dashboard.layout = "[[{}, {}]]".format(widget.id, restricted_widget.id) db.session.commit() - rv = self.make_request('get', '/api/dashboards/{0}'.format(dashboard.slug)) + rv = self.make_request("get", "/api/dashboards/{0}".format(dashboard.slug)) self.assertEqual(rv.status_code, 200) - self.assertTrue(rv.json['widgets'][0]['restricted']) - self.assertNotIn('restricted', rv.json['widgets'][1]) + self.assertTrue(rv.json["widgets"][0]["restricted"]) + self.assertNotIn("restricted", rv.json["widgets"][1]) def test_get_non_existing_dashboard(self): - rv = self.make_request('get', '/api/dashboards/not_existing') + rv = self.make_request("get", "/api/dashboards/not_existing") self.assertEqual(rv.status_code, 404) class TestDashboardResourcePost(BaseTestCase): def test_update_dashboard(self): d = self.factory.create_dashboard() - new_name = 'New Name' - rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), - data={'name': new_name, 'layout': '[]'}) + new_name = "New Name" + rv = self.make_request( + "post", + "/api/dashboards/{0}".format(d.id), + data={"name": new_name, "layout": "[]"}, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], new_name) + self.assertEqual(rv.json["name"], new_name) def test_raises_error_in_case_of_conflict(self): d = self.factory.create_dashboard() - d.name = 'Updated' + d.name = "Updated" db.session.commit() - new_name = 'New Name' - rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), - data={'name': new_name, 'layout': '[]', 'version': d.version - 1}) + new_name = "New Name" + rv = self.make_request( + "post", + "/api/dashboards/{0}".format(d.id), + data={"name": new_name, "layout": "[]", "version": d.version - 1}, + ) self.assertEqual(rv.status_code, 409) def test_overrides_existing_if_no_version_specified(self): d = self.factory.create_dashboard() - d.name = 'Updated' + d.name = "Updated" - new_name = 'New Name' - rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), - data={'name': new_name, 'layout': '[]'}) + new_name = "New Name" + rv = self.make_request( + "post", + "/api/dashboards/{0}".format(d.id), + data={"name": new_name, "layout": "[]"}, + ) self.assertEqual(rv.status_code, 200) @@ -111,25 +128,35 @@ def test_works_for_non_owner_with_permission(self): d = self.factory.create_dashboard() user = self.factory.create_user() - new_name = 'New Name' - rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), - data={'name': new_name, 'layout': '[]', 'version': d.version}, user=user) + new_name = "New Name" + rv = self.make_request( + "post", + "/api/dashboards/{0}".format(d.id), + data={"name": new_name, "layout": "[]", "version": d.version}, + user=user, + ) self.assertEqual(rv.status_code, 403) - AccessPermission.grant(obj=d, access_type=ACCESS_TYPE_MODIFY, grantee=user, grantor=d.user) + AccessPermission.grant( + obj=d, access_type=ACCESS_TYPE_MODIFY, grantee=user, grantor=d.user + ) - rv = self.make_request('post', '/api/dashboards/{0}'.format(d.id), - data={'name': new_name, 'layout': '[]', 'version': d.version}, user=user) + rv = self.make_request( + "post", + "/api/dashboards/{0}".format(d.id), + data={"name": new_name, "layout": "[]", "version": d.version}, + user=user, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], new_name) + self.assertEqual(rv.json["name"], new_name) class TestDashboardResourceDelete(BaseTestCase): def test_delete_dashboard(self): d = self.factory.create_dashboard() - rv = self.make_request('delete', '/api/dashboards/{0}'.format(d.slug)) + rv = self.make_request("delete", "/api/dashboards/{0}".format(d.slug)) self.assertEqual(rv.status_code, 200) d = Dashboard.get_by_slug_and_org(d.slug, d.org) @@ -140,20 +167,24 @@ class TestDashboardShareResourcePost(BaseTestCase): def test_creates_api_key(self): dashboard = self.factory.create_dashboard() - res = self.make_request('post', '/api/dashboards/{}/share'.format(dashboard.id)) + res = self.make_request("post", "/api/dashboards/{}/share".format(dashboard.id)) self.assertEqual(res.status_code, 200) - self.assertEqual(res.json['api_key'], ApiKey.get_by_object(dashboard).api_key) + self.assertEqual(res.json["api_key"], ApiKey.get_by_object(dashboard).api_key) def test_requires_admin_or_owner(self): dashboard = self.factory.create_dashboard() user = self.factory.create_user() - res = self.make_request('post', '/api/dashboards/{}/share'.format(dashboard.id), user=user) + res = self.make_request( + "post", "/api/dashboards/{}/share".format(dashboard.id), user=user + ) self.assertEqual(res.status_code, 403) user.group_ids.append(self.factory.org.admin_group.id) - res = self.make_request('post', '/api/dashboards/{}/share'.format(dashboard.id), user=user) + res = self.make_request( + "post", "/api/dashboards/{}/share".format(dashboard.id), user=user + ) self.assertEqual(res.status_code, 200) @@ -162,24 +193,32 @@ def test_disables_api_key(self): dashboard = self.factory.create_dashboard() ApiKey.create_for_object(dashboard, self.factory.user) - res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id)) + res = self.make_request( + "delete", "/api/dashboards/{}/share".format(dashboard.id) + ) self.assertEqual(res.status_code, 200) self.assertIsNone(ApiKey.get_by_object(dashboard)) def test_ignores_when_no_api_key_exists(self): dashboard = self.factory.create_dashboard() - res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id)) + res = self.make_request( + "delete", "/api/dashboards/{}/share".format(dashboard.id) + ) self.assertEqual(res.status_code, 200) def test_requires_admin_or_owner(self): dashboard = self.factory.create_dashboard() user = self.factory.create_user() - res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id), user=user) + res = self.make_request( + "delete", "/api/dashboards/{}/share".format(dashboard.id), user=user + ) self.assertEqual(res.status_code, 403) user.group_ids.append(self.factory.org.admin_group.id) - res = self.make_request('delete', '/api/dashboards/{}/share'.format(dashboard.id), user=user) + res = self.make_request( + "delete", "/api/dashboards/{}/share".format(dashboard.id), user=user + ) self.assertEqual(res.status_code, 200) diff --git a/tests/handlers/test_data_sources.py b/tests/handlers/test_data_sources.py index 4aa0b4fc61..537ae6af9b 100644 --- a/tests/handlers/test_data_sources.py +++ b/tests/handlers/test_data_sources.py @@ -9,11 +9,19 @@ class TestDataSourceGetSchema(BaseTestCase): def test_fails_if_user_doesnt_belong_to_org(self): other_user = self.factory.create_user(org=self.factory.create_org()) - response = self.make_request("get", "/api/data_sources/{}/schema".format(self.factory.data_source.id), user=other_user) + response = self.make_request( + "get", + "/api/data_sources/{}/schema".format(self.factory.data_source.id), + user=other_user, + ) self.assertEqual(response.status_code, 404) other_admin = self.factory.create_admin(org=self.factory.create_org()) - response = self.make_request("get", "/api/data_sources/{}/schema".format(self.factory.data_source.id), user=other_admin) + response = self.make_request( + "get", + "/api/data_sources/{}/schema".format(self.factory.data_source.id), + user=other_admin, + ) self.assertEqual(response.status_code, 404) @@ -31,26 +39,26 @@ def test_returns_data_sources_ordered_by_id(self): self.factory.create_data_source(group=self.factory.org.default_group) self.factory.create_data_source(group=self.factory.org.default_group) response = self.make_request("get", "/api/data_sources", user=self.factory.user) - ids = [datasource['id'] for datasource in response.json] + ids = [datasource["id"] for datasource in response.json] self.assertTrue(all(left <= right for left, right in pairwise(ids))) class DataSourceTypesTest(BaseTestCase): def test_returns_data_for_admin(self): admin = self.factory.create_admin() - rv = self.make_request('get', "/api/data_sources/types", user=admin) + rv = self.make_request("get", "/api/data_sources/types", user=admin) self.assertEqual(rv.status_code, 200) def test_does_not_show_deprecated_types(self): admin = self.factory.create_admin() - with patch.object(PostgreSQL, 'deprecated', return_value=True): - rv = self.make_request('get', "/api/data_sources/types", user=admin) + with patch.object(PostgreSQL, "deprecated", return_value=True): + rv = self.make_request("get", "/api/data_sources/types", user=admin) - types = [datasource_type['type'] for datasource_type in rv.json] - self.assertNotIn('pg', types) + types = [datasource_type["type"] for datasource_type in rv.json] + self.assertNotIn("pg", types) def test_returns_403_for_non_admin(self): - rv = self.make_request('get', "/api/data_sources/types") + rv = self.make_request("get", "/api/data_sources/types") self.assertEqual(rv.status_code, 403) @@ -61,18 +69,25 @@ def setUp(self): def test_returns_400_when_configuration_invalid(self): admin = self.factory.create_admin() - rv = self.make_request('post', self.path, - data={'name': 'DS 1', 'type': 'pg', 'options': {}}, user=admin) + rv = self.make_request( + "post", + self.path, + data={"name": "DS 1", "type": "pg", "options": {}}, + user=admin, + ) self.assertEqual(rv.status_code, 400) def test_updates_data_source(self): admin = self.factory.create_admin() - new_name = 'New Name' + new_name = "New Name" new_options = {"dbname": "newdb"} - rv = self.make_request('post', self.path, - data={'name': new_name, 'type': 'pg', 'options': new_options}, - user=admin) + rv = self.make_request( + "post", + self.path, + data={"name": new_name, "type": "pg", "options": new_options}, + user=admin, + ) self.assertEqual(rv.status_code, 200) data_source = DataSource.query.get(self.factory.data_source.id) @@ -86,7 +101,9 @@ def test_deletes_the_data_source(self): data_source = self.factory.create_data_source() admin = self.factory.create_admin() - rv = self.make_request('delete', '/api/data_sources/{}'.format(data_source.id), user=admin) + rv = self.make_request( + "delete", "/api/data_sources/{}".format(data_source.id), user=admin + ) self.assertEqual(204, rv.status_code) self.assertIsNone(DataSource.query.get(data_source.id)) @@ -95,49 +112,80 @@ def test_deletes_the_data_source(self): class TestDataSourceListResourcePost(BaseTestCase): def test_returns_400_when_missing_fields(self): admin = self.factory.create_admin() - rv = self.make_request('post', "/api/data_sources", user=admin) + rv = self.make_request("post", "/api/data_sources", user=admin) self.assertEqual(rv.status_code, 400) - rv = self.make_request('post', "/api/data_sources", data={'name': 'DS 1'}, user=admin) + rv = self.make_request( + "post", "/api/data_sources", data={"name": "DS 1"}, user=admin + ) self.assertEqual(rv.status_code, 400) def test_returns_400_when_configuration_invalid(self): admin = self.factory.create_admin() - rv = self.make_request('post', '/api/data_sources', - data={'name': 'DS 1', 'type': 'pg', 'options': {}}, user=admin) + rv = self.make_request( + "post", + "/api/data_sources", + data={"name": "DS 1", "type": "pg", "options": {}}, + user=admin, + ) self.assertEqual(rv.status_code, 400) def test_creates_data_source(self): admin = self.factory.create_admin() - rv = self.make_request('post', '/api/data_sources', - data={'name': 'DS 1', 'type': 'pg', 'options': {"dbname": "redash"}}, user=admin) + rv = self.make_request( + "post", + "/api/data_sources", + data={"name": "DS 1", "type": "pg", "options": {"dbname": "redash"}}, + user=admin, + ) self.assertEqual(rv.status_code, 200) - self.assertIsNotNone(DataSource.query.get(rv.json['id'])) + self.assertIsNotNone(DataSource.query.get(rv.json["id"])) class TestDataSourcePausePost(BaseTestCase): def test_pauses_data_source(self): admin = self.factory.create_admin() - rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin) + rv = self.make_request( + "post", + "/api/data_sources/{}/pause".format(self.factory.data_source.id), + user=admin, + ) self.assertEqual(rv.status_code, 200) self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, True) def test_pause_sets_reason(self): admin = self.factory.create_admin() - rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin, data={'reason': 'testing'}) + rv = self.make_request( + "post", + "/api/data_sources/{}/pause".format(self.factory.data_source.id), + user=admin, + data={"reason": "testing"}, + ) self.assertEqual(rv.status_code, 200) self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, True) - self.assertEqual(DataSource.query.get(self.factory.data_source.id).pause_reason, 'testing') - - rv = self.make_request('post', '/api/data_sources/{}/pause?reason=test'.format(self.factory.data_source.id), user=admin) - self.assertEqual(DataSource.query.get(self.factory.data_source.id).pause_reason, 'test') + self.assertEqual( + DataSource.query.get(self.factory.data_source.id).pause_reason, "testing" + ) + + rv = self.make_request( + "post", + "/api/data_sources/{}/pause?reason=test".format( + self.factory.data_source.id + ), + user=admin, + ) + self.assertEqual( + DataSource.query.get(self.factory.data_source.id).pause_reason, "test" + ) def test_requires_admin(self): - rv = self.make_request('post', '/api/data_sources/{}/pause'.format(self.factory.data_source.id)) + rv = self.make_request( + "post", "/api/data_sources/{}/pause".format(self.factory.data_source.id) + ) self.assertEqual(rv.status_code, 403) @@ -145,10 +193,18 @@ class TestDataSourcePauseDelete(BaseTestCase): def test_resumes_data_source(self): admin = self.factory.create_admin() self.factory.data_source.pause() - rv = self.make_request('delete', '/api/data_sources/{}/pause'.format(self.factory.data_source.id), user=admin) + rv = self.make_request( + "delete", + "/api/data_sources/{}/pause".format(self.factory.data_source.id), + user=admin, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(DataSource.query.get(self.factory.data_source.id).paused, False) + self.assertEqual( + DataSource.query.get(self.factory.data_source.id).paused, False + ) def test_requires_admin(self): - rv = self.make_request('delete', '/api/data_sources/{}/pause'.format(self.factory.data_source.id)) + rv = self.make_request( + "delete", "/api/data_sources/{}/pause".format(self.factory.data_source.id) + ) self.assertEqual(rv.status_code, 403) diff --git a/tests/handlers/test_destinations.py b/tests/handlers/test_destinations.py index 6736e936b7..b1779177dc 100644 --- a/tests/handlers/test_destinations.py +++ b/tests/handlers/test_destinations.py @@ -10,7 +10,7 @@ def test_get_returns_all_destinations(self): d1 = self.factory.create_destination() d2 = self.factory.create_destination() - rv = self.make_request('get', '/api/destinations', user=self.factory.user) + rv = self.make_request("get", "/api/destinations", user=self.factory.user) self.assertEqual(len(rv.json), 2) def test_get_returns_only_destinations_of_current_org(self): @@ -18,75 +18,92 @@ def test_get_returns_only_destinations_of_current_org(self): d2 = self.factory.create_destination() d3 = self.factory.create_destination(org=self.factory.create_org()) - rv = self.make_request('get', '/api/destinations', user=self.factory.user) + rv = self.make_request("get", "/api/destinations", user=self.factory.user) self.assertEqual(len(rv.json), 2) def test_post_creates_new_destination(self): data = { - 'options': {'addresses': 'test@example.com'}, - 'name': 'Test', - 'type': 'email' + "options": {"addresses": "test@example.com"}, + "name": "Test", + "type": "email", } - rv = self.make_request('post', '/api/destinations', user=self.factory.create_admin(), data=data) + rv = self.make_request( + "post", "/api/destinations", user=self.factory.create_admin(), data=data + ) self.assertEqual(rv.status_code, 200) pass def test_post_requires_admin(self): data = { - 'options': {'addresses': 'test@example.com'}, - 'name': 'Test', - 'type': 'email' + "options": {"addresses": "test@example.com"}, + "name": "Test", + "type": "email", } - rv = self.make_request('post', '/api/destinations', user=self.factory.user, data=data) + rv = self.make_request( + "post", "/api/destinations", user=self.factory.user, data=data + ) self.assertEqual(rv.status_code, 403) def test_returns_400_when_name_already_exists(self): d1 = self.factory.create_destination() data = { - 'options': {'addresses': 'test@example.com'}, - 'name': d1.name, - 'type': 'email' + "options": {"addresses": "test@example.com"}, + "name": d1.name, + "type": "email", } - rv = self.make_request('post', '/api/destinations', user=self.factory.create_admin(), data=data) + rv = self.make_request( + "post", "/api/destinations", user=self.factory.create_admin(), data=data + ) self.assertEqual(rv.status_code, 400) class TestDestinationResource(BaseTestCase): def test_get(self): d = self.factory.create_destination() - rv = self.make_request('get', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin()) + rv = self.make_request( + "get", "/api/destinations/{}".format(d.id), user=self.factory.create_admin() + ) self.assertEqual(rv.status_code, 200) def test_delete(self): d = self.factory.create_destination() - rv = self.make_request('delete', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin()) + rv = self.make_request( + "delete", + "/api/destinations/{}".format(d.id), + user=self.factory.create_admin(), + ) self.assertEqual(rv.status_code, 204) self.assertIsNone(NotificationDestination.query.get(d.id)) def test_post(self): d = self.factory.create_destination() data = { - 'name': 'updated', - 'type': d.type, - 'options': {"url": "https://www.slack.com/updated"} + "name": "updated", + "type": d.type, + "options": {"url": "https://www.slack.com/updated"}, } with self.app.app_context(): - rv = self.make_request('post', '/api/destinations/{}'.format(d.id), user=self.factory.create_admin(), data=data) + rv = self.make_request( + "post", + "/api/destinations/{}".format(d.id), + user=self.factory.create_admin(), + data=data, + ) self.assertEqual(rv.status_code, 200) d = NotificationDestination.query.get(d.id) - self.assertEqual(d.name, data['name']) - self.assertEqual(d.options['url'], data['options']['url']) + self.assertEqual(d.name, data["name"]) + self.assertEqual(d.options["url"], data["options"]["url"]) class DestinationTypesTest(BaseTestCase): def test_does_not_show_deprecated_types(self): admin = self.factory.create_admin() - with patch.object(Slack, 'deprecated', return_value=True): - rv = self.make_request('get', "/api/destinations/types", user=admin) + with patch.object(Slack, "deprecated", return_value=True): + rv = self.make_request("get", "/api/destinations/types", user=admin) - types = [destination_type['type'] for destination_type in rv.json] - self.assertNotIn('slack', types) + types = [destination_type["type"] for destination_type in rv.json] + self.assertNotIn("slack", types) diff --git a/tests/handlers/test_embed.py b/tests/handlers/test_embed.py index f4c90c3276..c7d624afab 100644 --- a/tests/handlers/test_embed.py +++ b/tests/handlers/test_embed.py @@ -5,10 +5,10 @@ class TestUnembedables(BaseTestCase): def test_not_embedable(self): query = self.factory.create_query() - res = self.make_request('get', '/api/queries/{0}'.format(query.id)) + res = self.make_request("get", "/api/queries/{0}".format(query.id)) self.assertEqual(res.status_code, 200) - self.assertIn("frame-ancestors 'none'", res.headers['Content-Security-Policy']) - self.assertEqual(res.headers['X-Frame-Options'], 'deny') + self.assertIn("frame-ancestors 'none'", res.headers["Content-Security-Policy"]) + self.assertEqual(res.headers["X-Frame-Options"], "deny") class TestEmbedVisualization(BaseTestCase): @@ -17,9 +17,13 @@ def test_sucesss(self): vis.query_rel.latest_query_data = self.factory.create_query_result() db.session.add(vis.query_rel) - res = self.make_request("get", "/embed/query/{}/visualization/{}".format(vis.query_rel.id, vis.id), is_json=False) + res = self.make_request( + "get", + "/embed/query/{}/visualization/{}".format(vis.query_rel.id, vis.id), + is_json=False, + ) self.assertEqual(res.status_code, 200) - self.assertIn('frame-ancestors *', res.headers['Content-Security-Policy']) + self.assertIn("frame-ancestors *", res.headers["Content-Security-Policy"]) self.assertNotIn("X-Frame-Options", res.headers) @@ -29,26 +33,40 @@ def test_success(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard) - res = self.make_request('get', '/public/dashboards/{}'.format(api_key.api_key), user=False, is_json=False) + res = self.make_request( + "get", + "/public/dashboards/{}".format(api_key.api_key), + user=False, + is_json=False, + ) self.assertEqual(res.status_code, 200) - self.assertIn('frame-ancestors *', res.headers['Content-Security-Policy']) + self.assertIn("frame-ancestors *", res.headers["Content-Security-Policy"]) self.assertNotIn("X-Frame-Options", res.headers) def test_works_for_logged_in_user(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard) - res = self.make_request('get', '/public/dashboards/{}'.format(api_key.api_key), is_json=False) + res = self.make_request( + "get", "/public/dashboards/{}".format(api_key.api_key), is_json=False + ) self.assertEqual(res.status_code, 200) def test_bad_token(self): - res = self.make_request('get', '/public/dashboards/bad-token', user=False, is_json=False) + res = self.make_request( + "get", "/public/dashboards/bad-token", user=False, is_json=False + ) self.assertEqual(res.status_code, 302) def test_inactive_token(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard, active=False) - res = self.make_request('get', '/public/dashboards/{}'.format(api_key.api_key), user=False, is_json=False) + res = self.make_request( + "get", + "/public/dashboards/{}".format(api_key.api_key), + user=False, + is_json=False, + ) self.assertEqual(res.status_code, 302) # Not relevant for now, as tokens in api_keys table are only created for dashboards. Once this changes, we should @@ -62,26 +80,40 @@ def test_success(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard) - res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), user=False, is_json=False) + res = self.make_request( + "get", + "/api/dashboards/public/{}".format(api_key.api_key), + user=False, + is_json=False, + ) self.assertEqual(res.status_code, 200) - self.assertIn('frame-ancestors *', res.headers['Content-Security-Policy']) + self.assertIn("frame-ancestors *", res.headers["Content-Security-Policy"]) self.assertNotIn("X-Frame-Options", res.headers) def test_works_for_logged_in_user(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard) - res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), is_json=False) + res = self.make_request( + "get", "/api/dashboards/public/{}".format(api_key.api_key), is_json=False + ) self.assertEqual(res.status_code, 200) def test_bad_token(self): - res = self.make_request('get', '/api/dashboards/public/bad-token', user=False, is_json=False) + res = self.make_request( + "get", "/api/dashboards/public/bad-token", user=False, is_json=False + ) self.assertEqual(res.status_code, 404) def test_inactive_token(self): dashboard = self.factory.create_dashboard() api_key = self.factory.create_api_key(object=dashboard, active=False) - res = self.make_request('get', '/api/dashboards/public/{}'.format(api_key.api_key), user=False, is_json=False) + res = self.make_request( + "get", + "/api/dashboards/public/{}".format(api_key.api_key), + user=False, + is_json=False, + ) self.assertEqual(res.status_code, 404) # Not relevant for now, as tokens in api_keys table are only created for dashboards. Once this changes, we should diff --git a/tests/handlers/test_favorites.py b/tests/handlers/test_favorites.py index 4253edb7bc..7d9576c7ea 100644 --- a/tests/handlers/test_favorites.py +++ b/tests/handlers/test_favorites.py @@ -5,32 +5,32 @@ class TestQueryFavoriteResource(BaseTestCase): def test_favorite(self): query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{}/favorite'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/favorite".format(query.id)) self.assertEqual(rv.status_code, 200) - rv = self.make_request('get', '/api/queries/{}'.format(query.id)) - self.assertEqual(rv.json['is_favorite'], True) + rv = self.make_request("get", "/api/queries/{}".format(query.id)) + self.assertEqual(rv.json["is_favorite"], True) def test_duplicate_favorite(self): query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{}/favorite'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/favorite".format(query.id)) self.assertEqual(rv.status_code, 200) - rv = self.make_request('post', '/api/queries/{}/favorite'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/favorite".format(query.id)) self.assertEqual(rv.status_code, 200) def test_unfavorite(self): query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{}/favorite'.format(query.id)) - rv = self.make_request('delete', '/api/queries/{}/favorite'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/favorite".format(query.id)) + rv = self.make_request("delete", "/api/queries/{}/favorite".format(query.id)) self.assertEqual(rv.status_code, 200) - rv = self.make_request('get', '/api/queries/{}'.format(query.id)) - self.assertEqual(rv.json['is_favorite'], False) + rv = self.make_request("get", "/api/queries/{}".format(query.id)) + self.assertEqual(rv.json["is_favorite"], False) class TestQueryFavoriteListResource(BaseTestCase): def test_get_favorites(self): - rv = self.make_request('get', '/api/queries/favorites') + rv = self.make_request("get", "/api/queries/favorites") self.assertEqual(rv.status_code, 200) diff --git a/tests/handlers/test_groups.py b/tests/handlers/test_groups.py index 9535d0d3df..5bdebba6d2 100644 --- a/tests/handlers/test_groups.py +++ b/tests/handlers/test_groups.py @@ -9,7 +9,11 @@ def test_returns_only_groups_for_current_org(self): group = self.factory.create_group(org=self.factory.create_org()) data_source = self.factory.create_data_source(group=group) db.session.flush() - response = self.make_request('get', '/api/groups/{}/data_sources'.format(group.id), user=self.factory.create_admin()) + response = self.make_request( + "get", + "/api/groups/{}/data_sources".format(group.id), + user=self.factory.create_admin(), + ) self.assertEqual(response.status_code, 404) def test_list(self): @@ -17,80 +21,105 @@ def test_list(self): ds = self.factory.create_data_source(group=group) db.session.flush() response = self.make_request( - 'get', '/api/groups/{}/data_sources'.format(group.id), - user=self.factory.create_admin()) + "get", + "/api/groups/{}/data_sources".format(group.id), + user=self.factory.create_admin(), + ) self.assertEqual(response.status_code, 200) self.assertEqual(len(response.json), 1) - self.assertEqual(response.json[0]['id'], ds.id) + self.assertEqual(response.json[0]["id"], ds.id) class TestGroupResourceList(BaseTestCase): - def test_list_admin(self): self.factory.create_group(org=self.factory.create_org()) - response = self.make_request('get', '/api/groups', - user=self.factory.create_admin()) - g_keys = ['type', 'id', 'name', 'permissions'] + response = self.make_request( + "get", "/api/groups", user=self.factory.create_admin() + ) + g_keys = ["type", "id", "name", "permissions"] def filtergroups(gs): return [project(g, g_keys) for g in gs] - self.assertEqual(filtergroups(response.json), - filtergroups(g.to_dict() for g in [ - self.factory.admin_group, - self.factory.default_group])) + + self.assertEqual( + filtergroups(response.json), + filtergroups( + g.to_dict() + for g in [self.factory.admin_group, self.factory.default_group] + ), + ) def test_list(self): - group1 = self.factory.create_group(org=self.factory.create_org(), - permissions=['view_dashboard']) + group1 = self.factory.create_group( + org=self.factory.create_org(), permissions=["view_dashboard"] + ) db.session.flush() - u = self.factory.create_user(group_ids=[self.factory.default_group.id, - group1.id]) + u = self.factory.create_user( + group_ids=[self.factory.default_group.id, group1.id] + ) db.session.flush() - response = self.make_request('get', '/api/groups', - user=u) - g_keys = ['type', 'id', 'name', 'permissions'] + response = self.make_request("get", "/api/groups", user=u) + g_keys = ["type", "id", "name", "permissions"] def filtergroups(gs): return [project(g, g_keys) for g in gs] - self.assertEqual(filtergroups(response.json), - filtergroups(g.to_dict() for g in [ - self.factory.default_group, - group1])) + + self.assertEqual( + filtergroups(response.json), + filtergroups(g.to_dict() for g in [self.factory.default_group, group1]), + ) class TestGroupResourcePost(BaseTestCase): def test_doesnt_change_builtin_groups(self): current_name = self.factory.default_group.name - response = self.make_request('post', '/api/groups/{}'.format(self.factory.default_group.id), - user=self.factory.create_admin(), - data={'name': 'Another Name'}) + response = self.make_request( + "post", + "/api/groups/{}".format(self.factory.default_group.id), + user=self.factory.create_admin(), + data={"name": "Another Name"}, + ) self.assertEqual(response.status_code, 400) - self.assertEqual(current_name, Group.query.get(self.factory.default_group.id).name) + self.assertEqual( + current_name, Group.query.get(self.factory.default_group.id).name + ) class TestGroupResourceDelete(BaseTestCase): def test_allowed_only_to_admin(self): group = self.factory.create_group() - response = self.make_request('delete', '/api/groups/{}'.format(group.id)) + response = self.make_request("delete", "/api/groups/{}".format(group.id)) self.assertEqual(response.status_code, 403) - response = self.make_request('delete', '/api/groups/{}'.format(group.id), user=self.factory.create_admin()) + response = self.make_request( + "delete", + "/api/groups/{}".format(group.id), + user=self.factory.create_admin(), + ) self.assertEqual(response.status_code, 200) self.assertIsNone(Group.query.get(group.id)) def test_cant_delete_builtin_group(self): for group in [self.factory.default_group, self.factory.admin_group]: - response = self.make_request('delete', '/api/groups/{}'.format(group.id), user=self.factory.create_admin()) + response = self.make_request( + "delete", + "/api/groups/{}".format(group.id), + user=self.factory.create_admin(), + ) self.assertEqual(response.status_code, 400) def test_can_delete_group_with_data_sources(self): group = self.factory.create_group() data_source = self.factory.create_data_source(group=group) - response = self.make_request('delete', '/api/groups/{}'.format(group.id), user=self.factory.create_admin()) + response = self.make_request( + "delete", + "/api/groups/{}".format(group.id), + user=self.factory.create_admin(), + ) self.assertEqual(response.status_code, 200) @@ -99,11 +128,13 @@ def test_can_delete_group_with_data_sources(self): class TestGroupResourceGet(BaseTestCase): def test_returns_group(self): - rv = self.make_request('get', '/api/groups/{}'.format(self.factory.default_group.id)) + rv = self.make_request( + "get", "/api/groups/{}".format(self.factory.default_group.id) + ) self.assertEqual(rv.status_code, 200) def test_doesnt_return_if_user_not_member_or_admin(self): - rv = self.make_request('get', '/api/groups/{}'.format(self.factory.admin_group.id)) + rv = self.make_request( + "get", "/api/groups/{}".format(self.factory.admin_group.id) + ) self.assertEqual(rv.status_code, 403) - - diff --git a/tests/handlers/test_paginate.py b/tests/handlers/test_paginate.py index b251af9ba0..509277560b 100644 --- a/tests/handlers/test_paginate.py +++ b/tests/handlers/test_paginate.py @@ -4,11 +4,14 @@ from unittest import TestCase from mock import MagicMock + class DummyResults(object): items = [i for i in range(25)] + dummy_results = DummyResults() + class TestPaginate(TestCase): def setUp(self): self.query_set = MagicMock() @@ -17,16 +20,23 @@ def setUp(self): def test_returns_paginated_results(self): page = paginate(self.query_set, 1, 25, lambda x: x) - self.assertEqual(page['page'], 1) - self.assertEqual(page['page_size'], 25) - self.assertEqual(page['count'], 102) - self.assertEqual(page['results'], dummy_results.items) + self.assertEqual(page["page"], 1) + self.assertEqual(page["page_size"], 25) + self.assertEqual(page["count"], 102) + self.assertEqual(page["results"], dummy_results.items) def test_raises_error_for_bad_page(self): - self.assertRaises(BadRequest, lambda: paginate(self.query_set, -1, 25, lambda x: x)) - self.assertRaises(BadRequest, lambda: paginate(self.query_set, 6, 25, lambda x: x)) + self.assertRaises( + BadRequest, lambda: paginate(self.query_set, -1, 25, lambda x: x) + ) + self.assertRaises( + BadRequest, lambda: paginate(self.query_set, 6, 25, lambda x: x) + ) def test_raises_error_for_bad_page_size(self): - self.assertRaises(BadRequest, lambda: paginate(self.query_set, 1, 251, lambda x: x)) - self.assertRaises(BadRequest, lambda: paginate(self.query_set, 1, -1, lambda x: x)) - + self.assertRaises( + BadRequest, lambda: paginate(self.query_set, 1, 251, lambda x: x) + ) + self.assertRaises( + BadRequest, lambda: paginate(self.query_set, 1, -1, lambda x: x) + ) diff --git a/tests/handlers/test_permissions.py b/tests/handlers/test_permissions.py index 6910e167c9..509b3687e5 100644 --- a/tests/handlers/test_permissions.py +++ b/tests/handlers/test_permissions.py @@ -8,7 +8,7 @@ class TestObjectPermissionsListGet(BaseTestCase): def test_returns_empty_list_when_no_permissions(self): query = self.factory.create_query() user = self.factory.user - rv = self.make_request('get', '/api/queries/{}/acl'.format(query.id), user=user) + rv = self.make_request("get", "/api/queries/{}/acl".format(query.id), user=user) self.assertEqual(rv.status_code, 200) self.assertEqual({}, rv.json) @@ -17,19 +17,23 @@ def test_returns_permissions(self): query = self.factory.create_query() user = self.factory.user - AccessPermission.grant(obj=query, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, grantee=self.factory.user) + AccessPermission.grant( + obj=query, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) - rv = self.make_request('get', '/api/queries/{}/acl'.format(query.id), user=user) + rv = self.make_request("get", "/api/queries/{}/acl".format(query.id), user=user) self.assertEqual(rv.status_code, 200) - self.assertIn('modify', rv.json) - self.assertEqual(user.id, rv.json['modify'][0]['id']) + self.assertIn("modify", rv.json) + self.assertEqual(user.id, rv.json["modify"][0]["id"]) def test_returns_404_for_outside_of_organization_users(self): query = self.factory.create_query() user = self.factory.create_user(org=self.factory.create_org()) - rv = self.make_request('get', '/api/queries/{}/acl'.format(query.id), user=user) + rv = self.make_request("get", "/api/queries/{}/acl".format(query.id), user=user) self.assertEqual(rv.status_code, 404) @@ -39,12 +43,11 @@ def test_creates_permission_if_the_user_is_an_owner(self): query = self.factory.create_query() other_user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) + rv = self.make_request( + "post", "/api/queries/{}/acl".format(query.id), user=query.user, data=data + ) self.assertEqual(200, rv.status_code) self.assertTrue(AccessPermission.exists(query, ACCESS_TYPE_MODIFY, other_user)) @@ -53,48 +56,44 @@ def test_returns_403_if_the_user_isnt_owner(self): query = self.factory.create_query() other_user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=other_user, data=data) + rv = self.make_request( + "post", "/api/queries/{}/acl".format(query.id), user=other_user, data=data + ) self.assertEqual(403, rv.status_code) def test_returns_400_if_the_grantee_isnt_from_organization(self): query = self.factory.create_query() other_user = self.factory.create_user(org=self.factory.create_org()) - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) + rv = self.make_request( + "post", "/api/queries/{}/acl".format(query.id), user=query.user, data=data + ) self.assertEqual(400, rv.status_code) def test_returns_404_if_the_user_from_different_org(self): query = self.factory.create_query() other_user = self.factory.create_user(org=self.factory.create_org()) - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=other_user, data=data) + rv = self.make_request( + "post", "/api/queries/{}/acl".format(query.id), user=other_user, data=data + ) self.assertEqual(404, rv.status_code) def test_accepts_only_correct_access_types(self): query = self.factory.create_query() other_user = self.factory.create_user() - data = { - 'access_type': 'random string', - 'user_id': other_user.id - } + data = {"access_type": "random string", "user_id": other_user.id} - rv = self.make_request('post', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) + rv = self.make_request( + "post", "/api/queries/{}/acl".format(query.id), user=query.user, data=data + ) self.assertEqual(400, rv.status_code) @@ -105,14 +104,18 @@ def test_removes_permission(self): user = self.factory.user other_user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - AccessPermission.grant(obj=query, access_type=ACCESS_TYPE_MODIFY, grantor=self.factory.user, grantee=other_user) + AccessPermission.grant( + obj=query, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=other_user, + ) - rv = self.make_request('delete', '/api/queries/{}/acl'.format(query.id), user=user, data=data) + rv = self.make_request( + "delete", "/api/queries/{}/acl".format(query.id), user=user, data=data + ) self.assertEqual(rv.status_code, 200) @@ -122,15 +125,21 @@ def test_removes_permission_created_by_another_user(self): query = self.factory.create_query() other_user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': other_user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": other_user.id} - AccessPermission.grant(obj=query, access_type=ACCESS_TYPE_MODIFY, grantor=self.factory.user, grantee=other_user) + AccessPermission.grant( + obj=query, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=other_user, + ) - rv = self.make_request('delete', '/api/queries/{}/acl'.format(query.id), user=self.factory.create_admin(), - data=data) + rv = self.make_request( + "delete", + "/api/queries/{}/acl".format(query.id), + user=self.factory.create_admin(), + data=data, + ) self.assertEqual(rv.status_code, 200) @@ -139,11 +148,10 @@ def test_removes_permission_created_by_another_user(self): def test_returns_404_for_outside_of_organization_users(self): query = self.factory.create_query() user = self.factory.create_user(org=self.factory.create_org()) - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': user.id - } - rv = self.make_request('delete', '/api/queries/{}/acl'.format(query.id), user=user, data=data) + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": user.id} + rv = self.make_request( + "delete", "/api/queries/{}/acl".format(query.id), user=user, data=data + ) self.assertEqual(rv.status_code, 404) @@ -151,11 +159,10 @@ def test_returns_403_for_non_owner(self): query = self.factory.create_query() user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': user.id - } - rv = self.make_request('delete', '/api/queries/{}/acl'.format(query.id), user=user, data=data) + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": user.id} + rv = self.make_request( + "delete", "/api/queries/{}/acl".format(query.id), user=user, data=data + ) self.assertEqual(rv.status_code, 403) @@ -163,12 +170,11 @@ def test_returns_200_even_if_there_is_no_permission(self): query = self.factory.create_query() user = self.factory.create_user() - data = { - 'access_type': ACCESS_TYPE_MODIFY, - 'user_id': user.id - } + data = {"access_type": ACCESS_TYPE_MODIFY, "user_id": user.id} - rv = self.make_request('delete', '/api/queries/{}/acl'.format(query.id), user=query.user, data=data) + rv = self.make_request( + "delete", "/api/queries/{}/acl".format(query.id), user=query.user, data=data + ) self.assertEqual(rv.status_code, 200) @@ -178,26 +184,43 @@ def test_returns_true_for_existing_permission(self): query = self.factory.create_query() other_user = self.factory.create_user() - AccessPermission.grant(obj=query, access_type=ACCESS_TYPE_MODIFY, grantor=self.factory.user, grantee=other_user) + AccessPermission.grant( + obj=query, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=other_user, + ) - rv = self.make_request('get', '/api/queries/{}/acl/{}'.format(query.id, ACCESS_TYPE_MODIFY), user=other_user) + rv = self.make_request( + "get", + "/api/queries/{}/acl/{}".format(query.id, ACCESS_TYPE_MODIFY), + user=other_user, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(True, rv.json['response']) + self.assertEqual(True, rv.json["response"]) def test_returns_false_for_existing_permission(self): query = self.factory.create_query() other_user = self.factory.create_user() - rv = self.make_request('get', '/api/queries/{}/acl/{}'.format(query.id, ACCESS_TYPE_MODIFY), user=other_user) + rv = self.make_request( + "get", + "/api/queries/{}/acl/{}".format(query.id, ACCESS_TYPE_MODIFY), + user=other_user, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(False, rv.json['response']) + self.assertEqual(False, rv.json["response"]) def test_returns_404_for_outside_of_org_users(self): query = self.factory.create_query() other_user = self.factory.create_user(org=self.factory.create_org()) - rv = self.make_request('get', '/api/queries/{}/acl/{}'.format(query.id, ACCESS_TYPE_MODIFY), user=other_user) + rv = self.make_request( + "get", + "/api/queries/{}/acl/{}".format(query.id, ACCESS_TYPE_MODIFY), + user=other_user, + ) self.assertEqual(rv.status_code, 404) diff --git a/tests/handlers/test_queries.py b/tests/handlers/test_queries.py index 89f39dd933..0091573085 100644 --- a/tests/handlers/test_queries.py +++ b/tests/handlers/test_queries.py @@ -10,30 +10,32 @@ class TestQueryResourceGet(BaseTestCase): def test_get_query(self): query = self.factory.create_query() - rv = self.make_request('get', '/api/queries/{0}'.format(query.id)) + rv = self.make_request("get", "/api/queries/{0}".format(query.id)) self.assertEqual(rv.status_code, 200) expected = serialize_query(query, with_visualizations=True) - expected['can_edit'] = True - expected['is_favorite'] = False + expected["can_edit"] = True + expected["is_favorite"] = False self.assertResponseEqual(expected, rv.json) def test_get_all_queries(self): [self.factory.create_query() for _ in range(10)] - rv = self.make_request('get', '/api/queries') + rv = self.make_request("get", "/api/queries") self.assertEqual(rv.status_code, 200) - self.assertEqual(len(rv.json['results']), 10) + self.assertEqual(len(rv.json["results"]), 10) def test_query_without_data_source_should_be_available_only_by_admin(self): query = self.factory.create_query() query.data_source = None db.session.add(query) - rv = self.make_request('get', '/api/queries/{}'.format(query.id)) + rv = self.make_request("get", "/api/queries/{}".format(query.id)) self.assertEqual(rv.status_code, 403) - rv = self.make_request('get', '/api/queries/{}'.format(query.id), user=self.factory.create_admin()) + rv = self.make_request( + "get", "/api/queries/{}".format(query.id), user=self.factory.create_admin() + ) self.assertEqual(rv.status_code, 200) def test_query_only_accessible_to_users_from_its_organization(self): @@ -44,40 +46,41 @@ def test_query_only_accessible_to_users_from_its_organization(self): query.data_source = None db.session.add(query) - rv = self.make_request('get', '/api/queries/{}'.format(query.id), user=second_org_admin) + rv = self.make_request( + "get", "/api/queries/{}".format(query.id), user=second_org_admin + ) self.assertEqual(rv.status_code, 404) - rv = self.make_request('get', '/api/queries/{}'.format(query.id), user=self.factory.create_admin()) + rv = self.make_request( + "get", "/api/queries/{}".format(query.id), user=self.factory.create_admin() + ) self.assertEqual(rv.status_code, 200) def test_query_search(self): - names = [ - 'Harder', - 'Better', - 'Faster', - 'Stronger', - ] + names = ["Harder", "Better", "Faster", "Stronger"] for name in names: self.factory.create_query(name=name) - rv = self.make_request('get', '/api/queries?q=better') + rv = self.make_request("get", "/api/queries?q=better") self.assertEqual(rv.status_code, 200) - self.assertEqual(len(rv.json['results']), 1) + self.assertEqual(len(rv.json["results"]), 1) - rv = self.make_request('get', '/api/queries?q=better OR faster') + rv = self.make_request("get", "/api/queries?q=better OR faster") self.assertEqual(rv.status_code, 200) - self.assertEqual(len(rv.json['results']), 2) + self.assertEqual(len(rv.json["results"]), 2) # test the old search API and that it redirects to the new one - rv = self.make_request('get', '/api/queries/search?q=stronger') + rv = self.make_request("get", "/api/queries/search?q=stronger") self.assertEqual(rv.status_code, 301) - self.assertIn('/api/queries?q=stronger', rv.headers['Location']) + self.assertIn("/api/queries?q=stronger", rv.headers["Location"]) - rv = self.make_request('get', '/api/queries/search?q=stronger', follow_redirects=True) + rv = self.make_request( + "get", "/api/queries/search?q=stronger", follow_redirects=True + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(len(rv.json['results']), 1) + self.assertEqual(len(rv.json["results"]), 1) class TestQueryResourcePost(BaseTestCase): @@ -89,26 +92,33 @@ def test_update_query(self): new_qr = self.factory.create_query_result() data = { - 'name': 'Testing', - 'query': 'select 2', - 'latest_query_data_id': new_qr.id, - 'data_source_id': new_ds.id + "name": "Testing", + "query": "select 2", + "latest_query_data_id": new_qr.id, + "data_source_id": new_ds.id, } - rv = self.make_request('post', '/api/queries/{0}'.format(query.id), data=data, user=admin) + rv = self.make_request( + "post", "/api/queries/{0}".format(query.id), data=data, user=admin + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], data['name']) - self.assertEqual(rv.json['last_modified_by']['id'], admin.id) - self.assertEqual(rv.json['query'], data['query']) - self.assertEqual(rv.json['data_source_id'], data['data_source_id']) - self.assertEqual(rv.json['latest_query_data_id'], data['latest_query_data_id']) + self.assertEqual(rv.json["name"], data["name"]) + self.assertEqual(rv.json["last_modified_by"]["id"], admin.id) + self.assertEqual(rv.json["query"], data["query"]) + self.assertEqual(rv.json["data_source_id"], data["data_source_id"]) + self.assertEqual(rv.json["latest_query_data_id"], data["latest_query_data_id"]) def test_raises_error_in_case_of_conflict(self): q = self.factory.create_query() q.name = "Another Name" db.session.add(q) - rv = self.make_request('post', '/api/queries/{0}'.format(q.id), data={'name': 'Testing', 'version': q.version - 1}, user=self.factory.user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(q.id), + data={"name": "Testing", "version": q.version - 1}, + user=self.factory.user, + ) self.assertEqual(rv.status_code, 409) def test_allows_association_with_authorized_dropdown_queries(self): @@ -121,52 +131,58 @@ def test_allows_association_with_authorized_dropdown_queries(self): db.session.add(my_query) options = { - 'parameters': [{ - 'name': 'foo', - 'type': 'query', - 'queryId': other_query.id - }, { - 'name': 'bar', - 'type': 'query', - 'queryId': other_query.id - }] + "parameters": [ + {"name": "foo", "type": "query", "queryId": other_query.id}, + {"name": "bar", "type": "query", "queryId": other_query.id}, + ] } - rv = self.make_request('post', '/api/queries/{0}'.format(my_query.id), data={'options': options}, user=self.factory.user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(my_query.id), + data={"options": options}, + user=self.factory.user, + ) self.assertEqual(rv.status_code, 200) def test_prevents_association_with_unauthorized_dropdown_queries(self): - other_data_source = self.factory.create_data_source(group=self.factory.create_group()) + other_data_source = self.factory.create_data_source( + group=self.factory.create_group() + ) other_query = self.factory.create_query(data_source=other_data_source) db.session.add(other_query) - my_data_source = self.factory.create_data_source(group=self.factory.create_group()) + my_data_source = self.factory.create_data_source( + group=self.factory.create_group() + ) my_query = self.factory.create_query(data_source=my_data_source) db.session.add(my_query) - options = { - 'parameters': [{ - 'type': 'query', - 'queryId': other_query.id - }] - } + options = {"parameters": [{"type": "query", "queryId": other_query.id}]} - rv = self.make_request('post', '/api/queries/{0}'.format(my_query.id), data={'options': options}, user=self.factory.user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(my_query.id), + data={"options": options}, + user=self.factory.user, + ) self.assertEqual(rv.status_code, 403) def test_prevents_association_with_non_existing_dropdown_queries(self): - my_data_source = self.factory.create_data_source(group=self.factory.create_group()) + my_data_source = self.factory.create_data_source( + group=self.factory.create_group() + ) my_query = self.factory.create_query(data_source=my_data_source) db.session.add(my_query) - options = { - 'parameters': [{ - 'type': 'query', - 'queryId': 100000 - }] - } + options = {"parameters": [{"type": "query", "queryId": 100000}]} - rv = self.make_request('post', '/api/queries/{0}'.format(my_query.id), data={'options': options}, user=self.factory.user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(my_query.id), + data={"options": options}, + user=self.factory.user, + ) self.assertEqual(rv.status_code, 400) def test_overrides_existing_if_no_version_specified(self): @@ -174,22 +190,39 @@ def test_overrides_existing_if_no_version_specified(self): q.name = "Another Name" db.session.add(q) - rv = self.make_request('post', '/api/queries/{0}'.format(q.id), data={'name': 'Testing'}, user=self.factory.user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(q.id), + data={"name": "Testing"}, + user=self.factory.user, + ) self.assertEqual(rv.status_code, 200) def test_works_for_non_owner_with_permission(self): query = self.factory.create_query() user = self.factory.create_user() - rv = self.make_request('post', '/api/queries/{0}'.format(query.id), data={'name': 'Testing'}, user=user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(query.id), + data={"name": "Testing"}, + user=user, + ) self.assertEqual(rv.status_code, 403) - models.AccessPermission.grant(obj=query, access_type=ACCESS_TYPE_MODIFY, grantee=user, grantor=query.user) + models.AccessPermission.grant( + obj=query, access_type=ACCESS_TYPE_MODIFY, grantee=user, grantor=query.user + ) - rv = self.make_request('post', '/api/queries/{0}'.format(query.id), data={'name': 'Testing'}, user=user) + rv = self.make_request( + "post", + "/api/queries/{0}".format(query.id), + data={"name": "Testing"}, + user=user, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], 'Testing') - self.assertEqual(rv.json['last_modified_by']['id'], user.id) + self.assertEqual(rv.json["name"], "Testing") + self.assertEqual(rv.json["last_modified_by"]["id"], user.id) class TestQueryListResourceGet(BaseTestCase): @@ -198,48 +231,52 @@ def test_returns_queries(self): q2 = self.factory.create_query() q3 = self.factory.create_query() - rv = self.make_request('get', '/api/queries') + rv = self.make_request("get", "/api/queries") - assert len(rv.json['results']) == 3 - assert set([result['id'] for result in rv.json['results']]) == set([q1.id, q2.id, q3.id]) + assert len(rv.json["results"]) == 3 + assert set([result["id"] for result in rv.json["results"]]) == set( + [q1.id, q2.id, q3.id] + ) def test_filters_with_tags(self): - q1 = self.factory.create_query(tags=['test']) + q1 = self.factory.create_query(tags=["test"]) self.factory.create_query() self.factory.create_query() - rv = self.make_request('get', '/api/queries?tags=test') - assert len(rv.json['results']) == 1 - assert set([result['id'] for result in rv.json['results']]) == set([q1.id]) + rv = self.make_request("get", "/api/queries?tags=test") + assert len(rv.json["results"]) == 1 + assert set([result["id"] for result in rv.json["results"]]) == set([q1.id]) def test_search_term(self): q1 = self.factory.create_query(name="Sales") q2 = self.factory.create_query(name="Q1 sales") self.factory.create_query(name="Ops") - rv = self.make_request('get', '/api/queries?q=sales') - assert len(rv.json['results']) == 2 - assert set([result['id'] for result in rv.json['results']]) == set([q1.id, q2.id]) + rv = self.make_request("get", "/api/queries?q=sales") + assert len(rv.json["results"]) == 2 + assert set([result["id"] for result in rv.json["results"]]) == set( + [q1.id, q2.id] + ) class TestQueryListResourcePost(BaseTestCase): def test_create_query(self): query_data = { - 'name': 'Testing', - 'query': 'SELECT 1', - 'schedule': {"interval": "3600"}, - 'data_source_id': self.factory.data_source.id + "name": "Testing", + "query": "SELECT 1", + "schedule": {"interval": "3600"}, + "data_source_id": self.factory.data_source.id, } - rv = self.make_request('post', '/api/queries', data=query_data) + rv = self.make_request("post", "/api/queries", data=query_data) self.assertEqual(rv.status_code, 200) self.assertDictContainsSubset(query_data, rv.json) - self.assertEqual(rv.json['user']['id'], self.factory.user.id) - self.assertIsNotNone(rv.json['api_key']) - self.assertIsNotNone(rv.json['query_hash']) + self.assertEqual(rv.json["user"]["id"], self.factory.user.id) + self.assertIsNotNone(rv.json["api_key"]) + self.assertIsNotNone(rv.json["query_hash"]) - query = models.Query.query.get(rv.json['id']) + query = models.Query.query.get(rv.json["id"]) self.assertEqual(len(list(query.visualizations)), 1) self.assertTrue(query.is_draft) @@ -250,64 +287,53 @@ def test_allows_association_with_authorized_dropdown_queries(self): db.session.add(other_query) query_data = { - 'name': 'Testing', - 'query': 'SELECT 1', - 'schedule': {"interval": "3600"}, - 'data_source_id': self.factory.data_source.id, - 'options': { - 'parameters': [{ - 'name': 'foo', - 'type': 'query', - 'queryId': other_query.id - }, { - 'name': 'bar', - 'type': 'query', - 'queryId': other_query.id - }] - } + "name": "Testing", + "query": "SELECT 1", + "schedule": {"interval": "3600"}, + "data_source_id": self.factory.data_source.id, + "options": { + "parameters": [ + {"name": "foo", "type": "query", "queryId": other_query.id}, + {"name": "bar", "type": "query", "queryId": other_query.id}, + ] + }, } - rv = self.make_request('post', '/api/queries', data=query_data) + rv = self.make_request("post", "/api/queries", data=query_data) self.assertEqual(rv.status_code, 200) def test_prevents_association_with_unauthorized_dropdown_queries(self): - other_data_source = self.factory.create_data_source(group=self.factory.create_group()) + other_data_source = self.factory.create_data_source( + group=self.factory.create_group() + ) other_query = self.factory.create_query(data_source=other_data_source) db.session.add(other_query) - my_data_source = self.factory.create_data_source(group=self.factory.create_group()) + my_data_source = self.factory.create_data_source( + group=self.factory.create_group() + ) query_data = { - 'name': 'Testing', - 'query': 'SELECT 1', - 'schedule': {"interval": "3600"}, - 'data_source_id': my_data_source.id, - 'options': { - 'parameters': [{ - 'type': 'query', - 'queryId': other_query.id - }] - } + "name": "Testing", + "query": "SELECT 1", + "schedule": {"interval": "3600"}, + "data_source_id": my_data_source.id, + "options": {"parameters": [{"type": "query", "queryId": other_query.id}]}, } - rv = self.make_request('post', '/api/queries', data=query_data) + rv = self.make_request("post", "/api/queries", data=query_data) self.assertEqual(rv.status_code, 403) def test_prevents_association_with_non_existing_dropdown_queries(self): query_data = { - 'name': 'Testing', - 'query': 'SELECT 1', - 'schedule': {"interval": "3600"}, - 'data_source_id': self.factory.data_source.id, - 'options': { - 'parameters': [{ - 'type': 'query', - 'queryId': 100000 - }] - } + "name": "Testing", + "query": "SELECT 1", + "schedule": {"interval": "3600"}, + "data_source_id": self.factory.data_source.id, + "options": {"parameters": [{"type": "query", "queryId": 100000}]}, } - rv = self.make_request('post', '/api/queries', data=query_data) + rv = self.make_request("post", "/api/queries", data=query_data) self.assertEqual(rv.status_code, 400) @@ -317,19 +343,23 @@ def test_returns_queries(self): q2 = self.factory.create_query(is_archived=True) self.factory.create_query() - rv = self.make_request('get', '/api/queries/archive') + rv = self.make_request("get", "/api/queries/archive") - assert len(rv.json['results']) == 2 - assert set([result['id'] for result in rv.json['results']]) == set([q1.id, q2.id]) + assert len(rv.json["results"]) == 2 + assert set([result["id"] for result in rv.json["results"]]) == set( + [q1.id, q2.id] + ) def test_search_term(self): q1 = self.factory.create_query(name="Sales", is_archived=True) q2 = self.factory.create_query(name="Q1 sales", is_archived=True) self.factory.create_query(name="Q2 sales") - rv = self.make_request('get', '/api/queries/archive?q=sales') - assert len(rv.json['results']) == 2 - assert set([result['id'] for result in rv.json['results']]) == set([q1.id, q2.id]) + rv = self.make_request("get", "/api/queries/archive?q=sales") + assert len(rv.json["results"]) == 2 + assert set([result["id"] for result in rv.json["results"]]) == set( + [q1.id, q2.id] + ) class QueryRefreshTest(BaseTestCase): @@ -337,24 +367,24 @@ def setUp(self): super(QueryRefreshTest, self).setUp() self.query = self.factory.create_query() - self.path = '/api/queries/{}/refresh'.format(self.query.id) + self.path = "/api/queries/{}/refresh".format(self.query.id) def test_refresh_regular_query(self): - response = self.make_request('post', self.path) + response = self.make_request("post", self.path) self.assertEqual(200, response.status_code) def test_refresh_of_query_with_parameters(self): self.query.query_text = "SELECT {{param}}" db.session.add(self.query) - response = self.make_request('post', "{}?p_param=1".format(self.path)) + response = self.make_request("post", "{}?p_param=1".format(self.path)) self.assertEqual(200, response.status_code) def test_refresh_of_query_with_parameters_without_parameters(self): self.query.query_text = "SELECT {{param}}" db.session.add(self.query) - response = self.make_request('post', "{}".format(self.path)) + response = self.make_request("post", "{}".format(self.path)) self.assertEqual(400, response.status_code) def test_refresh_query_you_dont_have_access_to(self): @@ -362,14 +392,20 @@ def test_refresh_query_you_dont_have_access_to(self): db.session.add(group) db.session.commit() user = self.factory.create_user(group_ids=[group.id]) - response = self.make_request('post', self.path, user=user) + response = self.make_request("post", self.path, user=user) self.assertEqual(403, response.status_code) def test_refresh_forbiden_with_query_api_key(self): - response = self.make_request('post', '{}?api_key={}'.format(self.path, self.query.api_key), user=False) + response = self.make_request( + "post", "{}?api_key={}".format(self.path, self.query.api_key), user=False + ) self.assertEqual(403, response.status_code) - response = self.make_request('post', '{}?api_key={}'.format(self.path, self.factory.user.api_key), user=False) + response = self.make_request( + "post", + "{}?api_key={}".format(self.path, self.factory.user.api_key), + user=False, + ) self.assertEqual(200, response.status_code) @@ -380,7 +416,11 @@ def test_non_admin_cannot_regenerate_api_key_of_other_user(self): other_user = self.factory.create_user() orig_api_key = query.api_key - rv = self.make_request('post', "/api/queries/{}/regenerate_api_key".format(query.id), user=other_user) + rv = self.make_request( + "post", + "/api/queries/{}/regenerate_api_key".format(query.id), + user=other_user, + ) self.assertEqual(rv.status_code, 403) reloaded_query = models.Query.query.get(query.id) @@ -392,7 +432,11 @@ def test_admin_can_regenerate_api_key_of_other_user(self): admin_user = self.factory.create_admin() orig_api_key = query.api_key - rv = self.make_request('post', "/api/queries/{}/regenerate_api_key".format(query.id), user=admin_user) + rv = self.make_request( + "post", + "/api/queries/{}/regenerate_api_key".format(query.id), + user=admin_user, + ) self.assertEqual(rv.status_code, 200) reloaded_query = models.Query.query.get(query.id) @@ -404,7 +448,11 @@ def test_admin_can_regenerate_api_key_of_myself(self): query = self.factory.create_query(user=query_creator) orig_api_key = query.api_key - rv = self.make_request('post', "/api/queries/{}/regenerate_api_key".format(query.id), user=admin_user) + rv = self.make_request( + "post", + "/api/queries/{}/regenerate_api_key".format(query.id), + user=admin_user, + ) self.assertEqual(rv.status_code, 200) updated_query = models.Query.query.get(query.id) @@ -415,7 +463,9 @@ def test_user_can_regenerate_api_key_of_myself(self): query = self.factory.create_query(user=user) orig_api_key = query.api_key - rv = self.make_request('post', "/api/queries/{}/regenerate_api_key".format(query.id), user=user) + rv = self.make_request( + "post", "/api/queries/{}/regenerate_api_key".format(query.id), user=user + ) self.assertEqual(rv.status_code, 200) updated_query = models.Query.query.get(query.id) @@ -424,18 +474,22 @@ def test_user_can_regenerate_api_key_of_myself(self): class TestQueryForkResourcePost(BaseTestCase): def test_forks_a_query(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=False) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=False + ) query = self.factory.create_query(data_source=ds) - rv = self.make_request('post', '/api/queries/{}/fork'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/fork".format(query.id)) self.assertEqual(rv.status_code, 200) def test_must_have_full_access_to_data_source(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=True) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=True + ) query = self.factory.create_query(data_source=ds) - rv = self.make_request('post', '/api/queries/{}/fork'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/fork".format(query.id)) self.assertEqual(rv.status_code, 403) @@ -443,7 +497,7 @@ def test_must_have_full_access_to_data_source(self): class TestFormatSQLQueryAPI(BaseTestCase): def test_format_sql_query(self): admin = self.factory.create_admin() - query = 'select a,b,c FROM foobar Where x=1 and y=2;' + query = "select a,b,c FROM foobar Where x=1 and y=2;" expected = """SELECT a, b, c @@ -451,7 +505,8 @@ def test_format_sql_query(self): WHERE x=1 AND y=2;""" - rv = self.make_request('post', '/api/queries/format', user=admin, data={'query': query}) - - self.assertEqual(rv.json['query'], expected) + rv = self.make_request( + "post", "/api/queries/format", user=admin, data={"query": query} + ) + self.assertEqual(rv.json["query"], expected) diff --git a/tests/handlers/test_query_results.py b/tests/handlers/test_query_results.py index a5e50a0a09..1016e33066 100644 --- a/tests/handlers/test_query_results.py +++ b/tests/handlers/test_query_results.py @@ -10,20 +10,22 @@ def test_uses_cache_headers_for_specific_result(self): query_result = self.factory.create_query_result() query = self.factory.create_query(latest_query_data=query_result) - rv = self.make_request('get', '/api/queries/{}/results/{}.json'.format(query.id, query_result.id)) - self.assertIn('Cache-Control', rv.headers) + rv = self.make_request( + "get", "/api/queries/{}/results/{}.json".format(query.id, query_result.id) + ) + self.assertIn("Cache-Control", rv.headers) def test_doesnt_use_cache_headers_for_non_specific_result(self): query_result = self.factory.create_query_result() query = self.factory.create_query(latest_query_data=query_result) - rv = self.make_request('get', '/api/queries/{}/results.json'.format(query.id)) - self.assertNotIn('Cache-Control', rv.headers) + rv = self.make_request("get", "/api/queries/{}/results.json".format(query.id)) + self.assertNotIn("Cache-Control", rv.headers) def test_returns_404_if_no_cached_result_found(self): query = self.factory.create_query(latest_query_data=None) - rv = self.make_request('get', '/api/queries/{}/results.json'.format(query.id)) + rv = self.make_request("get", "/api/queries/{}/results.json".format(query.id)) self.assertEqual(404, rv.status_code) @@ -32,24 +34,34 @@ def test_get_existing_result(self): query_result = self.factory.create_query_result() query = self.factory.create_query() - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': query.query_text}) + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": query.query_text, + }, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(query_result.id, rv.json['query_result']['id']) + self.assertEqual(query_result.id, rv.json["query_result"]["id"]) def test_execute_new_query(self): query_result = self.factory.create_query_result() query = self.factory.create_query() - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': query.query_text, - 'max_age': 0}) + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": query.query_text, + "max_age": 0, + }, + ) self.assertEqual(rv.status_code, 200) - self.assertNotIn('query_result', rv.json) - self.assertIn('job', rv.json) + self.assertNotIn("query_result", rv.json) + self.assertIn("job", rv.json) def test_execute_query_without_access(self): group = self.factory.create_group() @@ -57,62 +69,87 @@ def test_execute_query_without_access(self): user = self.factory.create_user(group_ids=[group.id]) query = self.factory.create_query() - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': query.query_text, - 'max_age': 0}, - user=user) + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": query.query_text, + "max_age": 0, + }, + user=user, + ) self.assertEqual(rv.status_code, 403) - self.assertIn('job', rv.json) + self.assertIn("job", rv.json) def test_execute_query_with_params(self): query = "SELECT {{param}}" - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': query, - 'max_age': 0}) + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": query, + "max_age": 0, + }, + ) self.assertEqual(rv.status_code, 400) - self.assertIn('job', rv.json) - - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': query, - 'parameters': {'param': 1}, - 'max_age': 0}) + self.assertIn("job", rv.json) + + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": query, + "parameters": {"param": 1}, + "max_age": 0, + }, + ) self.assertEqual(rv.status_code, 200) - self.assertIn('job', rv.json) - - rv = self.make_request('post', '/api/query_results?p_param=1', - data={'data_source_id': self.factory.data_source.id, - 'query': query, - 'max_age': 0}) + self.assertIn("job", rv.json) + + rv = self.make_request( + "post", + "/api/query_results?p_param=1", + data={ + "data_source_id": self.factory.data_source.id, + "query": query, + "max_age": 0, + }, + ) self.assertEqual(rv.status_code, 200) - self.assertIn('job', rv.json) + self.assertIn("job", rv.json) def test_execute_on_paused_data_source(self): self.factory.data_source.pause() - rv = self.make_request('post', '/api/query_results', - data={'data_source_id': self.factory.data_source.id, - 'query': 'SELECT 1', - 'max_age': 0}) + rv = self.make_request( + "post", + "/api/query_results", + data={ + "data_source_id": self.factory.data_source.id, + "query": "SELECT 1", + "max_age": 0, + }, + ) self.assertEqual(rv.status_code, 400) - self.assertNotIn('query_result', rv.json) - self.assertIn('job', rv.json) + self.assertNotIn("query_result", rv.json) + self.assertIn("job", rv.json) def test_execute_without_data_source(self): - rv = self.make_request('post', '/api/query_results', - data={'query': 'SELECT 1', - 'max_age': 0}) + rv = self.make_request( + "post", "/api/query_results", data={"query": "SELECT 1", "max_age": 0} + ) self.assertEqual(rv.status_code, 401) - self.assertDictEqual(rv.json, error_messages['select_data_source'][0]) + self.assertDictEqual(rv.json, error_messages["select_data_source"][0]) class TestQueryResultAPI(BaseTestCase): @@ -120,111 +157,173 @@ def test_has_no_access_to_data_source(self): ds = self.factory.create_data_source(group=self.factory.create_group()) query_result = self.factory.create_query_result(data_source=ds) - rv = self.make_request('get', '/api/query_results/{}'.format(query_result.id)) + rv = self.make_request("get", "/api/query_results/{}".format(query_result.id)) self.assertEqual(rv.status_code, 403) def test_has_view_only_access_to_data_source(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=True) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=True + ) query_result = self.factory.create_query_result(data_source=ds) - rv = self.make_request('get', '/api/query_results/{}'.format(query_result.id)) + rv = self.make_request("get", "/api/query_results/{}".format(query_result.id)) self.assertEqual(rv.status_code, 200) def test_has_full_access_to_data_source(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=False) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=False + ) query_result = self.factory.create_query_result(data_source=ds) - rv = self.make_request('get', '/api/query_results/{}'.format(query_result.id)) + rv = self.make_request("get", "/api/query_results/{}".format(query_result.id)) self.assertEqual(rv.status_code, 200) def test_execute_new_query(self): query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{}/results'.format(query.id), data={'parameters': {}}) + rv = self.make_request( + "post", "/api/queries/{}/results".format(query.id), data={"parameters": {}} + ) self.assertEqual(rv.status_code, 200) - self.assertIn('job', rv.json) - + self.assertIn("job", rv.json) + def test_execute_but_has_no_access_to_data_source(self): ds = self.factory.create_data_source(group=self.factory.create_group()) query = self.factory.create_query(data_source=ds) - rv = self.make_request('post', '/api/queries/{}/results'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/results".format(query.id)) self.assertEqual(rv.status_code, 403) - self.assertDictEqual(rv.json, error_messages['no_permission'][0]) + self.assertDictEqual(rv.json, error_messages["no_permission"][0]) def test_execute_with_no_parameter_values(self): query = self.factory.create_query() - rv = self.make_request('post', '/api/queries/{}/results'.format(query.id)) + rv = self.make_request("post", "/api/queries/{}/results".format(query.id)) self.assertEqual(rv.status_code, 200) - self.assertIn('job', rv.json) + self.assertIn("job", rv.json) def test_prevents_execution_of_unsafe_queries_on_view_only_data_sources(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=True) - query = self.factory.create_query(data_source=ds, options={"parameters": [{"name": "foo", "type": "text"}]}) - - rv = self.make_request('post', '/api/queries/{}/results'.format(query.id), data={"parameters": {}}) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=True + ) + query = self.factory.create_query( + data_source=ds, options={"parameters": [{"name": "foo", "type": "text"}]} + ) + + rv = self.make_request( + "post", "/api/queries/{}/results".format(query.id), data={"parameters": {}} + ) self.assertEqual(rv.status_code, 403) - self.assertDictEqual(rv.json, error_messages['unsafe_on_view_only'][0]) + self.assertDictEqual(rv.json, error_messages["unsafe_on_view_only"][0]) def test_allows_execution_of_safe_queries_on_view_only_data_sources(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=True) - query = self.factory.create_query(data_source=ds, options={"parameters": [{"name": "foo", "type": "number"}]}) - - rv = self.make_request('post', '/api/queries/{}/results'.format(query.id), data={"parameters": {}}) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=True + ) + query = self.factory.create_query( + data_source=ds, options={"parameters": [{"name": "foo", "type": "number"}]} + ) + + rv = self.make_request( + "post", "/api/queries/{}/results".format(query.id), data={"parameters": {}} + ) self.assertEqual(rv.status_code, 200) def test_prevents_execution_of_unsafe_queries_using_api_key(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=True) - query = self.factory.create_query(data_source=ds, options={"parameters": [{"name": "foo", "type": "text"}]}) - - data = {'parameters': {'foo': 'bar'}} - rv = self.make_request('post', '/api/queries/{}/results?api_key={}'.format(query.id, query.api_key), data=data) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=True + ) + query = self.factory.create_query( + data_source=ds, options={"parameters": [{"name": "foo", "type": "text"}]} + ) + + data = {"parameters": {"foo": "bar"}} + rv = self.make_request( + "post", + "/api/queries/{}/results?api_key={}".format(query.id, query.api_key), + data=data, + ) self.assertEqual(rv.status_code, 403) - self.assertDictEqual(rv.json, error_messages['unsafe_when_shared'][0]) + self.assertDictEqual(rv.json, error_messages["unsafe_when_shared"][0]) def test_access_with_query_api_key(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=False) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=False + ) query = self.factory.create_query() - query_result = self.factory.create_query_result(data_source=ds, query_text=query.query_text) - - rv = self.make_request('get', '/api/queries/{}/results/{}.json?api_key={}'.format(query.id, query_result.id, query.api_key), user=False) + query_result = self.factory.create_query_result( + data_source=ds, query_text=query.query_text + ) + + rv = self.make_request( + "get", + "/api/queries/{}/results/{}.json?api_key={}".format( + query.id, query_result.id, query.api_key + ), + user=False, + ) self.assertEqual(rv.status_code, 200) def test_access_with_query_api_key_without_query_result_id(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=False) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=False + ) query = self.factory.create_query() - query_result = self.factory.create_query_result(data_source=ds, query_text=query.query_text, query_hash=query.query_hash) + query_result = self.factory.create_query_result( + data_source=ds, query_text=query.query_text, query_hash=query.query_hash + ) query.latest_query_data = query_result - rv = self.make_request('get', '/api/queries/{}/results.json?api_key={}'.format(query.id, query.api_key), user=False) + rv = self.make_request( + "get", + "/api/queries/{}/results.json?api_key={}".format(query.id, query.api_key), + user=False, + ) self.assertEqual(rv.status_code, 200) def test_query_api_key_and_different_query_result(self): - ds = self.factory.create_data_source(group=self.factory.org.default_group, view_only=False) + ds = self.factory.create_data_source( + group=self.factory.org.default_group, view_only=False + ) query = self.factory.create_query(query_text="SELECT 8") - query_result2 = self.factory.create_query_result(data_source=ds, query_hash='something-different') - - rv = self.make_request('get', '/api/queries/{}/results/{}.json?api_key={}'.format(query.id, query_result2.id, query.api_key), user=False) + query_result2 = self.factory.create_query_result( + data_source=ds, query_hash="something-different" + ) + + rv = self.make_request( + "get", + "/api/queries/{}/results/{}.json?api_key={}".format( + query.id, query_result2.id, query.api_key + ), + user=False, + ) self.assertEqual(rv.status_code, 404) def test_signed_in_user_and_different_query_result(self): - ds2 = self.factory.create_data_source(group=self.factory.org.admin_group, view_only=False) + ds2 = self.factory.create_data_source( + group=self.factory.org.admin_group, view_only=False + ) query = self.factory.create_query(query_text="SELECT 8") - query_result2 = self.factory.create_query_result(data_source=ds2, query_hash='something-different') + query_result2 = self.factory.create_query_result( + data_source=ds2, query_hash="something-different" + ) - rv = self.make_request('get', '/api/queries/{}/results/{}.json'.format(query.id, query_result2.id)) + rv = self.make_request( + "get", "/api/queries/{}/results/{}.json".format(query.id, query_result2.id) + ) self.assertEqual(rv.status_code, 403) class TestQueryResultDropdownResource(BaseTestCase): def test_checks_for_access_to_the_query(self): - ds2 = self.factory.create_data_source(group=self.factory.org.admin_group, view_only=False) + ds2 = self.factory.create_data_source( + group=self.factory.org.admin_group, view_only=False + ) query = self.factory.create_query(data_source=ds2) - rv = self.make_request('get', '/api/queries/{}/dropdown'.format(query.id)) + rv = self.make_request("get", "/api/queries/{}/dropdown".format(query.id)) self.assertEqual(rv.status_code, 403) @@ -232,13 +331,20 @@ def test_checks_for_access_to_the_query(self): class TestQueryDropdownsResource(BaseTestCase): def test_prevents_access_if_unassociated_and_doesnt_have_access(self): query = self.factory.create_query() - ds2 = self.factory.create_data_source(group=self.factory.org.admin_group, view_only=False) + ds2 = self.factory.create_data_source( + group=self.factory.org.admin_group, view_only=False + ) unrelated_dropdown_query = self.factory.create_query(data_source=ds2) # unrelated_dropdown_query has not been associated with query # user does not have direct access to unrelated_dropdown_query - rv = self.make_request('get', '/api/queries/{}/dropdowns/{}'.format(query.id, unrelated_dropdown_query.id)) + rv = self.make_request( + "get", + "/api/queries/{}/dropdowns/{}".format( + query.id, unrelated_dropdown_query.id + ), + ) self.assertEqual(rv.status_code, 403) @@ -246,59 +352,56 @@ def test_allows_access_if_unassociated_but_user_has_access(self): query = self.factory.create_query() query_result = self.factory.create_query_result() - data = { - 'rows': [], - 'columns': [{'name': 'whatever'}] - } + data = {"rows": [], "columns": [{"name": "whatever"}]} query_result = self.factory.create_query_result(data=json_dumps(data)) - unrelated_dropdown_query = self.factory.create_query(latest_query_data=query_result) + unrelated_dropdown_query = self.factory.create_query( + latest_query_data=query_result + ) # unrelated_dropdown_query has not been associated with query # user has direct access to unrelated_dropdown_query - rv = self.make_request('get', '/api/queries/{}/dropdowns/{}'.format(query.id, unrelated_dropdown_query.id)) + rv = self.make_request( + "get", + "/api/queries/{}/dropdowns/{}".format( + query.id, unrelated_dropdown_query.id + ), + ) self.assertEqual(rv.status_code, 200) def test_allows_access_if_associated_and_has_access_to_parent(self): query_result = self.factory.create_query_result() - data = { - 'rows': [], - 'columns': [{'name': 'whatever'}] - } + data = {"rows": [], "columns": [{"name": "whatever"}]} query_result = self.factory.create_query_result(data=json_dumps(data)) dropdown_query = self.factory.create_query(latest_query_data=query_result) - options = { - 'parameters': [{ - 'type': 'query', - 'queryId': dropdown_query.id - }] - } + options = {"parameters": [{"type": "query", "queryId": dropdown_query.id}]} query = self.factory.create_query(options=options) # dropdown_query has been associated with query # user has access to query - rv = self.make_request('get', '/api/queries/{}/dropdowns/{}'.format(query.id, dropdown_query.id)) + rv = self.make_request( + "get", "/api/queries/{}/dropdowns/{}".format(query.id, dropdown_query.id) + ) self.assertEqual(rv.status_code, 200) def test_prevents_access_if_associated_and_doesnt_have_access_to_parent(self): - ds2 = self.factory.create_data_source(group=self.factory.org.admin_group, view_only=False) + ds2 = self.factory.create_data_source( + group=self.factory.org.admin_group, view_only=False + ) dropdown_query = self.factory.create_query(data_source=ds2) - options = { - 'parameters': [{ - 'type': 'query', - 'queryId': dropdown_query.id - }] - } + options = {"parameters": [{"type": "query", "queryId": dropdown_query.id}]} query = self.factory.create_query(data_source=ds2, options=options) # dropdown_query has been associated with query # user doesnt have access to either query - rv = self.make_request('get', '/api/queries/{}/dropdowns/{}'.format(query.id, dropdown_query.id)) + rv = self.make_request( + "get", "/api/queries/{}/dropdowns/{}".format(query.id, dropdown_query.id) + ) self.assertEqual(rv.status_code, 403) @@ -308,23 +411,24 @@ def test_renders_excel_file(self): query = self.factory.create_query() query_result = self.factory.create_query_result() - rv = self.make_request('get', '/api/queries/{}/results/{}.xlsx'.format(query.id, query_result.id), is_json=False) + rv = self.make_request( + "get", + "/api/queries/{}/results/{}.xlsx".format(query.id, query_result.id), + is_json=False, + ) self.assertEqual(rv.status_code, 200) def test_renders_excel_file_when_rows_have_missing_columns(self): query = self.factory.create_query() data = { - 'rows': [ - {'test': 1}, - {'test': 2, 'test2': 3}, - ], - 'columns': [ - {'name': 'test'}, - {'name': 'test2'}, - ], + "rows": [{"test": 1}, {"test": 2, "test2": 3}], + "columns": [{"name": "test"}, {"name": "test2"}], } query_result = self.factory.create_query_result(data=json_dumps(data)) - rv = self.make_request('get', '/api/queries/{}/results/{}.xlsx'.format(query.id, query_result.id), is_json=False) + rv = self.make_request( + "get", + "/api/queries/{}/results/{}.xlsx".format(query.id, query_result.id), + is_json=False, + ) self.assertEqual(rv.status_code, 200) - diff --git a/tests/handlers/test_query_snippets.py b/tests/handlers/test_query_snippets.py index 0a6109c539..99ef3a6b7b 100644 --- a/tests/handlers/test_query_snippets.py +++ b/tests/handlers/test_query_snippets.py @@ -6,28 +6,30 @@ class TestQuerySnippetResource(BaseTestCase): def test_get_snippet(self): snippet = self.factory.create_query_snippet() - rv = self.make_request('get', '/api/query_snippets/{}'.format(snippet.id)) + rv = self.make_request("get", "/api/query_snippets/{}".format(snippet.id)) - for field in ('snippet', 'description', 'trigger'): + for field in ("snippet", "description", "trigger"): self.assertEqual(rv.json[field], getattr(snippet, field)) def test_update_snippet(self): snippet = self.factory.create_query_snippet() data = { - 'snippet': 'updated', - 'trigger': 'updated trigger', - 'description': 'updated description' + "snippet": "updated", + "trigger": "updated trigger", + "description": "updated description", } - rv = self.make_request('post', '/api/query_snippets/{}'.format(snippet.id), data=data) + rv = self.make_request( + "post", "/api/query_snippets/{}".format(snippet.id), data=data + ) - for field in ('snippet', 'description', 'trigger'): + for field in ("snippet", "description", "trigger"): self.assertEqual(rv.json[field], data[field]) def test_delete_snippet(self): snippet = self.factory.create_query_snippet() - rv = self.make_request('delete', '/api/query_snippets/{}'.format(snippet.id)) + rv = self.make_request("delete", "/api/query_snippets/{}".format(snippet.id)) self.assertIsNone(QuerySnippet.query.get(snippet.id)) @@ -35,23 +37,24 @@ def test_delete_snippet(self): class TestQuerySnippetListResource(BaseTestCase): def test_create_snippet(self): data = { - 'snippet': 'updated', - 'trigger': 'updated trigger', - 'description': 'updated description' + "snippet": "updated", + "trigger": "updated trigger", + "description": "updated description", } - rv = self.make_request('post', '/api/query_snippets', data=data) + rv = self.make_request("post", "/api/query_snippets", data=data) self.assertEqual(rv.status_code, 200) def test_list_all_snippets(self): snippet1 = self.factory.create_query_snippet() snippet2 = self.factory.create_query_snippet() - snippet_diff_org = self.factory.create_query_snippet(org=self.factory.create_org()) + snippet_diff_org = self.factory.create_query_snippet( + org=self.factory.create_org() + ) - rv = self.make_request('get', '/api/query_snippets') - ids = [s['id'] for s in rv.json] + rv = self.make_request("get", "/api/query_snippets") + ids = [s["id"] for s in rv.json] self.assertIn(snippet1.id, ids) self.assertIn(snippet2.id, ids) self.assertNotIn(snippet_diff_org.id, ids) - diff --git a/tests/handlers/test_settings.py b/tests/handlers/test_settings.py index e8c9606e04..6c9e33b9a9 100644 --- a/tests/handlers/test_settings.py +++ b/tests/handlers/test_settings.py @@ -5,26 +5,45 @@ class TestOrganizationSettings(BaseTestCase): def test_post(self): admin = self.factory.create_admin() - rv = self.make_request('post', '/api/settings/organization', data={'auth_password_login_enabled': False}, user=admin) - self.assertEqual(rv.json['settings']['auth_password_login_enabled'], False) - self.assertEqual(self.factory.org.settings['settings']['auth_password_login_enabled'], False) + rv = self.make_request( + "post", + "/api/settings/organization", + data={"auth_password_login_enabled": False}, + user=admin, + ) + self.assertEqual(rv.json["settings"]["auth_password_login_enabled"], False) + self.assertEqual( + self.factory.org.settings["settings"]["auth_password_login_enabled"], False + ) - rv = self.make_request('post', '/api/settings/organization', data={'auth_password_login_enabled': True}, user=admin) + rv = self.make_request( + "post", + "/api/settings/organization", + data={"auth_password_login_enabled": True}, + user=admin, + ) updated_org = Organization.get_by_slug(self.factory.org.slug) - self.assertEqual(rv.json['settings']['auth_password_login_enabled'], True) - self.assertEqual(updated_org.settings['settings']['auth_password_login_enabled'], True) + self.assertEqual(rv.json["settings"]["auth_password_login_enabled"], True) + self.assertEqual( + updated_org.settings["settings"]["auth_password_login_enabled"], True + ) def test_updates_google_apps_domains(self): admin = self.factory.create_admin() - domains = ['example.com'] - rv = self.make_request('post', '/api/settings/organization', data={'auth_google_apps_domains': domains}, user=admin) + domains = ["example.com"] + rv = self.make_request( + "post", + "/api/settings/organization", + data={"auth_google_apps_domains": domains}, + user=admin, + ) updated_org = Organization.get_by_slug(self.factory.org.slug) self.assertEqual(updated_org.google_apps_domains, domains) def test_get_returns_google_appas_domains(self): admin = self.factory.create_admin() - domains = ['example.com'] + domains = ["example.com"] admin.org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = domains - rv = self.make_request('get', '/api/settings/organization', user=admin) - self.assertEqual(rv.json['settings']['auth_google_apps_domains'], domains) + rv = self.make_request("get", "/api/settings/organization", user=admin) + self.assertEqual(rv.json["settings"]["auth_google_apps_domains"], domains) diff --git a/tests/handlers/test_users.py b/tests/handlers/test_users.py index 0657e7eab1..3147ae8c98 100644 --- a/tests/handlers/test_users.py +++ b/tests/handlers/test_users.py @@ -2,93 +2,99 @@ from tests import BaseTestCase from mock import patch + class TestUserListResourcePost(BaseTestCase): def test_returns_403_for_non_admin(self): - rv = self.make_request('post', "/api/users") + rv = self.make_request("post", "/api/users") self.assertEqual(rv.status_code, 403) def test_returns_400_when_missing_fields(self): admin = self.factory.create_admin() - rv = self.make_request('post', "/api/users", user=admin) + rv = self.make_request("post", "/api/users", user=admin) self.assertEqual(rv.status_code, 400) - rv = self.make_request('post', '/api/users', data={'name': 'User'}, user=admin) + rv = self.make_request("post", "/api/users", data={"name": "User"}, user=admin) self.assertEqual(rv.status_code, 400) - rv = self.make_request('post', '/api/users', data={'name': 'User', 'email': 'bademailaddress'}, user=admin) + rv = self.make_request( + "post", + "/api/users", + data={"name": "User", "email": "bademailaddress"}, + user=admin, + ) self.assertEqual(rv.status_code, 400) def test_returns_400_when_using_temporary_email(self): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': 'user@mailinator.com', 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": "user@mailinator.com", "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 400) - test_user['email'] = 'arik@qq.com' - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user["email"] = "arik@qq.com" + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 400) def test_creates_user(self): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': 'user@example.com', 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": "user@example.com", "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], test_user['name']) - self.assertEqual(rv.json['email'], test_user['email']) + self.assertEqual(rv.json["name"], test_user["name"]) + self.assertEqual(rv.json["email"], test_user["email"]) - @patch('redash.settings.email_server_is_configured', return_value=False) + @patch("redash.settings.email_server_is_configured", return_value=False) def test_shows_invite_link_when_email_is_not_configured(self, _): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': 'user@example.com'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": "user@example.com"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 200) - self.assertTrue('invite_link' in rv.json) + self.assertTrue("invite_link" in rv.json) - @patch('redash.settings.email_server_is_configured', return_value=True) + @patch("redash.settings.email_server_is_configured", return_value=True) def test_does_not_show_invite_link_when_email_is_configured(self, _): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': 'user@example.com'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": "user@example.com"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 200) - self.assertFalse('invite_link' in rv.json) + self.assertFalse("invite_link" in rv.json) def test_creates_user_case_insensitive_email(self): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': 'User@Example.com', 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": "User@Example.com", "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], test_user['name']) - self.assertEqual(rv.json['email'], 'user@example.com') + self.assertEqual(rv.json["name"], test_user["name"]) + self.assertEqual(rv.json["email"], "user@example.com") def test_returns_400_when_email_taken(self): admin = self.factory.create_admin() - test_user = {'name': 'User', 'email': admin.email, 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user = {"name": "User", "email": admin.email, "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 400) def test_returns_400_when_email_taken_case_insensitive(self): admin = self.factory.create_admin() - test_user1 = {'name': 'User', 'email': 'user@example.com', 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user1, user=admin) + test_user1 = {"name": "User", "email": "user@example.com", "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user1, user=admin) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['email'], 'user@example.com') + self.assertEqual(rv.json["email"], "user@example.com") - test_user2 = {'name': 'User', 'email': 'user@Example.com', 'password': 'test'} - rv = self.make_request('post', '/api/users', data=test_user2, user=admin) + test_user2 = {"name": "User", "email": "user@Example.com", "password": "test"} + rv = self.make_request("post", "/api/users", data=test_user2, user=admin) self.assertEqual(rv.status_code, 400) @@ -101,18 +107,30 @@ class PlainObject(object): result = PlainObject() now = models.db.func.now() - result.enabled_active1 = self.factory.create_user(disabled_at=None, is_invitation_pending=None).id - result.enabled_active2 = self.factory.create_user(disabled_at=None, is_invitation_pending=False).id - result.enabled_pending = self.factory.create_user(disabled_at=None, is_invitation_pending=True).id - result.disabled_active1 = self.factory.create_user(disabled_at=now, is_invitation_pending=None).id - result.disabled_active2 = self.factory.create_user(disabled_at=now, is_invitation_pending=False).id - result.disabled_pending = self.factory.create_user(disabled_at=now, is_invitation_pending=True).id + result.enabled_active1 = self.factory.create_user( + disabled_at=None, is_invitation_pending=None + ).id + result.enabled_active2 = self.factory.create_user( + disabled_at=None, is_invitation_pending=False + ).id + result.enabled_pending = self.factory.create_user( + disabled_at=None, is_invitation_pending=True + ).id + result.disabled_active1 = self.factory.create_user( + disabled_at=now, is_invitation_pending=None + ).id + result.disabled_active2 = self.factory.create_user( + disabled_at=now, is_invitation_pending=False + ).id + result.disabled_pending = self.factory.create_user( + disabled_at=now, is_invitation_pending=True + ).id return result def make_request_and_return_ids(self, *args, **kwargs): rv = self.make_request(*args, **kwargs) - return [user['id'] for user in rv.json['results']] + return [user["id"] for user in rv.json["results"]] def assertUsersListMatches(self, actual_ids, expected_ids, unexpected_ids): actual_ids = set(actual_ids) @@ -127,87 +145,113 @@ def test_returns_users_for_given_org_only(self): org = self.factory.create_org() user3 = self.factory.create_user(org=org) - user_ids = self.make_request_and_return_ids('get', '/api/users') + user_ids = self.make_request_and_return_ids("get", "/api/users") self.assertUsersListMatches(user_ids, [user1.id, user2.id], [user3.id]) def test_gets_all_enabled(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users') + user_ids = self.make_request_and_return_ids("get", "/api/users") self.assertUsersListMatches( user_ids, [users.enabled_active1, users.enabled_active2, users.enabled_pending], - [users.disabled_active1, users.disabled_active2, users.disabled_pending] + [users.disabled_active1, users.disabled_active2, users.disabled_pending], ) def test_gets_all_disabled(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users?disabled=true') + user_ids = self.make_request_and_return_ids("get", "/api/users?disabled=true") self.assertUsersListMatches( user_ids, [users.disabled_active1, users.disabled_active2, users.disabled_pending], - [users.enabled_active1, users.enabled_active2, users.enabled_pending] + [users.enabled_active1, users.enabled_active2, users.enabled_pending], ) def test_gets_all_enabled_and_active(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users?pending=false') + user_ids = self.make_request_and_return_ids("get", "/api/users?pending=false") self.assertUsersListMatches( user_ids, [users.enabled_active1, users.enabled_active2], - [users.enabled_pending, users.disabled_active1, users.disabled_active2, users.disabled_pending] + [ + users.enabled_pending, + users.disabled_active1, + users.disabled_active2, + users.disabled_pending, + ], ) def test_gets_all_enabled_and_pending(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users?pending=true') + user_ids = self.make_request_and_return_ids("get", "/api/users?pending=true") self.assertUsersListMatches( user_ids, [users.enabled_pending], - [users.enabled_active1, users.enabled_active2, users.disabled_active1, users.disabled_active2, users.disabled_pending] + [ + users.enabled_active1, + users.enabled_active2, + users.disabled_active1, + users.disabled_active2, + users.disabled_pending, + ], ) def test_gets_all_disabled_and_active(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users?disabled=true&pending=false') + user_ids = self.make_request_and_return_ids( + "get", "/api/users?disabled=true&pending=false" + ) self.assertUsersListMatches( user_ids, [users.disabled_active1, users.disabled_active2], - [users.disabled_pending, users.enabled_active1, users.enabled_active2, users.enabled_pending] + [ + users.disabled_pending, + users.enabled_active1, + users.enabled_active2, + users.enabled_pending, + ], ) def test_gets_all_disabled_and_pending(self): users = self.create_filters_fixtures() - user_ids = self.make_request_and_return_ids('get', '/api/users?disabled=true&pending=true') + user_ids = self.make_request_and_return_ids( + "get", "/api/users?disabled=true&pending=true" + ) self.assertUsersListMatches( user_ids, [users.disabled_pending], - [users.disabled_active1, users.disabled_active2, users.enabled_active1, users.enabled_active2, users.enabled_pending] + [ + users.disabled_active1, + users.disabled_active2, + users.enabled_active1, + users.enabled_active2, + users.enabled_pending, + ], ) class TestUserResourceGet(BaseTestCase): def test_returns_api_key_for_your_own_user(self): - rv = self.make_request('get', "/api/users/{}".format(self.factory.user.id)) - self.assertIn('api_key', rv.json) + rv = self.make_request("get", "/api/users/{}".format(self.factory.user.id)) + self.assertIn("api_key", rv.json) def test_returns_api_key_for_other_user_when_admin(self): other_user = self.factory.user admin = self.factory.create_admin() - rv = self.make_request('get', "/api/users/{}".format(other_user.id), user=admin) - self.assertIn('api_key', rv.json) + rv = self.make_request("get", "/api/users/{}".format(other_user.id), user=admin) + self.assertIn("api_key", rv.json) def test_doesnt_return_api_key_for_other_user(self): other_user = self.factory.create_user() - rv = self.make_request('get', "/api/users/{}".format(other_user.id)) - self.assertNotIn('api_key', rv.json) + rv = self.make_request("get", "/api/users/{}".format(other_user.id)) + self.assertNotIn("api_key", rv.json) def test_doesnt_return_user_from_different_org(self): org = self.factory.create_org() other_user = self.factory.create_user(org=org) - rv = self.make_request('get', "/api/users/{}".format(other_user.id)) + rv = self.make_request("get", "/api/users/{}".format(other_user.id)) self.assertEqual(rv.status_code, 404) @@ -215,39 +259,64 @@ class TestUserResourcePost(BaseTestCase): def test_returns_403_for_non_admin_changing_not_his_own(self): other_user = self.factory.create_user() - rv = self.make_request('post', "/api/users/{}".format(other_user.id), data={"name": "New Name"}) + rv = self.make_request( + "post", "/api/users/{}".format(other_user.id), data={"name": "New Name"} + ) self.assertEqual(rv.status_code, 403) def test_returns_200_for_non_admin_changing_his_own(self): - rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"name": "New Name"}) + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"name": "New Name"}, + ) self.assertEqual(rv.status_code, 200) - @patch('redash.settings.email_server_is_configured', return_value=True) + @patch("redash.settings.email_server_is_configured", return_value=True) def test_marks_email_as_not_verified_when_changed(self, _): user = self.factory.user user.is_email_verified = True - rv = self.make_request('post', "/api/users/{}".format(user.id), data={"email": "donald@trump.biz"}) + rv = self.make_request( + "post", "/api/users/{}".format(user.id), data={"email": "donald@trump.biz"} + ) self.assertFalse(user.is_email_verified) - @patch('redash.settings.email_server_is_configured', return_value=False) - def test_doesnt_mark_email_as_not_verified_when_changed_and_email_server_is_not_configured(self, _): + @patch("redash.settings.email_server_is_configured", return_value=False) + def test_doesnt_mark_email_as_not_verified_when_changed_and_email_server_is_not_configured( + self, _ + ): user = self.factory.user user.is_email_verified = True - rv = self.make_request('post', "/api/users/{}".format(user.id), data={"email": "donald@trump.biz"}) + rv = self.make_request( + "post", "/api/users/{}".format(user.id), data={"email": "donald@trump.biz"} + ) self.assertTrue(user.is_email_verified) def test_returns_200_for_admin_changing_other_user(self): admin = self.factory.create_admin() - rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"name": "New Name"}, user=admin) + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"name": "New Name"}, + user=admin, + ) self.assertEqual(rv.status_code, 200) def test_fails_password_change_without_old_password(self): - rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"password": "new password"}) + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"password": "new password"}, + ) self.assertEqual(rv.status_code, 403) def test_fails_password_change_with_incorrect_old_password(self): - rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"password": "new password", "old_password": "wrong"}) + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"password": "new password", "old_password": "wrong"}, + ) self.assertEqual(rv.status_code, 403) def test_changes_password(self): @@ -257,7 +326,11 @@ def test_changes_password(self): self.factory.user.hash_password(old_password) models.db.session.add(self.factory.user) - rv = self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"password": new_password, "old_password": old_password}) + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"password": new_password, "old_password": old_password}, + ) self.assertEqual(rv.status_code, 200) user = models.User.query.get(self.factory.user.id) @@ -266,43 +339,56 @@ def test_changes_password(self): def test_returns_400_when_using_temporary_email(self): admin = self.factory.create_admin() - test_user = {'email': 'user@mailinator.com'} - rv = self.make_request('post', '/api/users/{}'.format(self.factory.user.id), data=test_user, user=admin) + test_user = {"email": "user@mailinator.com"} + rv = self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data=test_user, + user=admin, + ) self.assertEqual(rv.status_code, 400) - test_user['email'] = 'arik@qq.com' - rv = self.make_request('post', '/api/users', data=test_user, user=admin) + test_user["email"] = "arik@qq.com" + rv = self.make_request("post", "/api/users", data=test_user, user=admin) self.assertEqual(rv.status_code, 400) def test_changing_email_ends_any_other_sessions_of_current_user(self): with self.client as c: # visit profile page - self.make_request('get', "/api/users/{}".format(self.factory.user.id)) + self.make_request("get", "/api/users/{}".format(self.factory.user.id)) with c.session_transaction() as sess: - previous = sess['user_id'] + previous = sess["user_id"] # change e-mail address - this will result in a new `user_id` value inside the session - self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"email": "john@doe.com"}) + self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"email": "john@doe.com"}, + ) # force the old `user_id`, simulating that the user is logged in from another browser with c.session_transaction() as sess: - sess['user_id'] = previous + sess["user_id"] = previous rv = self.get_request("/api/users/{}".format(self.factory.user.id)) self.assertEqual(rv.status_code, 404) def test_changing_email_does_not_end_current_session(self): - self.make_request('get', "/api/users/{}".format(self.factory.user.id)) + self.make_request("get", "/api/users/{}".format(self.factory.user.id)) with self.client as c: with c.session_transaction() as sess: - previous = sess['user_id'] + previous = sess["user_id"] - self.make_request('post', "/api/users/{}".format(self.factory.user.id), data={"email": "john@doe.com"}) + self.make_request( + "post", + "/api/users/{}".format(self.factory.user.id), + data={"email": "john@doe.com"}, + ) with self.client as c: with c.session_transaction() as sess: - current = sess['user_id'] + current = sess["user_id"] # make sure the session's `user_id` has changed to reflect the new identity, thus not logging the user out self.assertNotEqual(previous, current) @@ -311,16 +397,23 @@ def test_admin_can_change_user_groups(self): admin_user = self.factory.create_admin() other_user = self.factory.create_user(group_ids=[1]) - rv = self.make_request('post', "/api/users/{}".format(other_user.id), data={"group_ids": [1, 2]}, user=admin_user) + rv = self.make_request( + "post", + "/api/users/{}".format(other_user.id), + data={"group_ids": [1, 2]}, + user=admin_user, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(models.User.query.get(other_user.id).group_ids, [1,2]) + self.assertEqual(models.User.query.get(other_user.id).group_ids, [1, 2]) def test_admin_can_delete_user(self): admin_user = self.factory.create_admin() other_user = self.factory.create_user(is_invitation_pending=True) - rv = self.make_request('delete', "/api/users/{}".format(other_user.id), user=admin_user) + rv = self.make_request( + "delete", "/api/users/{}".format(other_user.id), user=admin_user + ) self.assertEqual(rv.status_code, 200) self.assertEqual(models.User.query.get(other_user.id), None) @@ -331,7 +424,9 @@ def test_non_admin_cannot_disable_user(self): other_user = self.factory.create_user() self.assertFalse(other_user.is_disabled) - rv = self.make_request('post', "/api/users/{}/disable".format(other_user.id), user=other_user) + rv = self.make_request( + "post", "/api/users/{}/disable".format(other_user.id), user=other_user + ) self.assertEqual(rv.status_code, 403) # user should stay enabled @@ -343,7 +438,9 @@ def test_admin_can_disable_user(self): other_user = self.factory.create_user() self.assertFalse(other_user.is_disabled) - rv = self.make_request('post', "/api/users/{}/disable".format(other_user.id), user=admin_user) + rv = self.make_request( + "post", "/api/users/{}/disable".format(other_user.id), user=admin_user + ) self.assertEqual(rv.status_code, 200) # user should become disabled @@ -355,7 +452,9 @@ def test_admin_can_disable_another_admin(self): admin_user2 = self.factory.create_admin() self.assertFalse(admin_user2.is_disabled) - rv = self.make_request('post', "/api/users/{}/disable".format(admin_user2.id), user=admin_user1) + rv = self.make_request( + "post", "/api/users/{}/disable".format(admin_user2.id), user=admin_user1 + ) self.assertEqual(rv.status_code, 200) # user should become disabled @@ -366,7 +465,9 @@ def test_admin_cannot_disable_self(self): admin_user = self.factory.create_admin() self.assertFalse(admin_user.is_disabled) - rv = self.make_request('post', "/api/users/{}/disable".format(admin_user.id), user=admin_user) + rv = self.make_request( + "post", "/api/users/{}/disable".format(admin_user.id), user=admin_user + ) self.assertEqual(rv.status_code, 403) # user should stay enabled @@ -375,10 +476,12 @@ def test_admin_cannot_disable_self(self): def test_admin_can_enable_user(self): admin_user = self.factory.create_admin() - other_user = self.factory.create_user(disabled_at='2018-03-08 00:00') + other_user = self.factory.create_user(disabled_at="2018-03-08 00:00") self.assertTrue(other_user.is_disabled) - rv = self.make_request('delete', "/api/users/{}/disable".format(other_user.id), user=admin_user) + rv = self.make_request( + "delete", "/api/users/{}/disable".format(other_user.id), user=admin_user + ) self.assertEqual(rv.status_code, 200) # user should become enabled @@ -387,10 +490,12 @@ def test_admin_can_enable_user(self): def test_admin_can_enable_another_admin(self): admin_user1 = self.factory.create_admin() - admin_user2 = self.factory.create_admin(disabled_at='2018-03-08 00:00') + admin_user2 = self.factory.create_admin(disabled_at="2018-03-08 00:00") self.assertTrue(admin_user2.is_disabled) - rv = self.make_request('delete', "/api/users/{}/disable".format(admin_user2.id), user=admin_user1) + rv = self.make_request( + "delete", "/api/users/{}/disable".format(admin_user2.id), user=admin_user1 + ) self.assertEqual(rv.status_code, 200) # user should become enabled @@ -398,26 +503,30 @@ def test_admin_can_enable_another_admin(self): self.assertFalse(admin_user2.is_disabled) def test_disabled_user_cannot_login(self): - user = self.factory.create_user(disabled_at='2018-03-08 00:00') - user.hash_password('password') + user = self.factory.create_user(disabled_at="2018-03-08 00:00") + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.post_request('/login', data={'email': user.email, 'password': 'password'}, org=self.factory.org) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.post_request( + "/login", + data={"email": user.email, "password": "password"}, + org=self.factory.org, + ) # login handler should not be called login_user_mock.assert_not_called() # check if error is raised self.assertEqual(rv.status_code, 200) - self.assertIn('Wrong email or password', rv.data.decode()) + self.assertIn("Wrong email or password", rv.data.decode()) def test_disabled_user_should_not_access_api(self): # Note: some API does not require user, so check the one which requires # 1. create user; the user should have access to API user = self.factory.create_user() - rv = self.make_request('get', '/api/dashboards', user=user) + rv = self.make_request("get", "/api/dashboards", user=user) self.assertEqual(rv.status_code, 200) # 2. disable user; now API access should be forbidden @@ -425,7 +534,7 @@ def test_disabled_user_should_not_access_api(self): self.db.session.add(user) self.db.session.commit() - rv = self.make_request('get', '/api/dashboards', user=user) + rv = self.make_request("get", "/api/dashboards", user=user) self.assertNotEqual(rv.status_code, 200) def test_disabled_user_should_not_receive_restore_password_email(self): @@ -433,9 +542,13 @@ def test_disabled_user_should_not_receive_restore_password_email(self): # user should receive email user = self.factory.create_user() - with patch('redash.handlers.users.send_password_reset_email') as send_password_reset_email_mock: - send_password_reset_email_mock.return_value = 'reset_token' - rv = self.make_request('post', '/api/users/{}/reset_password'.format(user.id), user=admin_user) + with patch( + "redash.handlers.users.send_password_reset_email" + ) as send_password_reset_email_mock: + send_password_reset_email_mock.return_value = "reset_token" + rv = self.make_request( + "post", "/api/users/{}/reset_password".format(user.id), user=admin_user + ) self.assertEqual(rv.status_code, 200) send_password_reset_email_mock.assert_called_with(user) @@ -444,9 +557,13 @@ def test_disabled_user_should_not_receive_restore_password_email(self): self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.users.send_password_reset_email') as send_password_reset_email_mock: - send_password_reset_email_mock.return_value = 'reset_token' - rv = self.make_request('post', '/api/users/{}/reset_password'.format(user.id), user=admin_user) + with patch( + "redash.handlers.users.send_password_reset_email" + ) as send_password_reset_email_mock: + send_password_reset_email_mock.return_value = "reset_token" + rv = self.make_request( + "post", "/api/users/{}/reset_password".format(user.id), user=admin_user + ) self.assertEqual(rv.status_code, 404) send_password_reset_email_mock.assert_not_called() @@ -457,7 +574,11 @@ def test_non_admin_cannot_regenerate_other_user_api_key(self): other_user = self.factory.create_user() orig_api_key = other_user.api_key - rv = self.make_request('post', "/api/users/{}/regenerate_api_key".format(other_user.id), user=admin_user) + rv = self.make_request( + "post", + "/api/users/{}/regenerate_api_key".format(other_user.id), + user=admin_user, + ) self.assertEqual(rv.status_code, 200) other_user = models.User.query.get(other_user.id) @@ -468,7 +589,9 @@ def test_admin_can_regenerate_other_user_api_key(self): user2 = self.factory.create_user() orig_user2_api_key = user2.api_key - rv = self.make_request('post', "/api/users/{}/regenerate_api_key".format(user2.id), user=user1) + rv = self.make_request( + "post", "/api/users/{}/regenerate_api_key".format(user2.id), user=user1 + ) self.assertEqual(rv.status_code, 403) user = models.User.query.get(user2.id) @@ -478,7 +601,11 @@ def test_admin_can_regenerate_api_key_myself(self): admin_user = self.factory.create_admin() orig_api_key = admin_user.api_key - rv = self.make_request('post', "/api/users/{}/regenerate_api_key".format(admin_user.id), user=admin_user) + rv = self.make_request( + "post", + "/api/users/{}/regenerate_api_key".format(admin_user.id), + user=admin_user, + ) self.assertEqual(rv.status_code, 200) user = models.User.query.get(admin_user.id) @@ -488,7 +615,9 @@ def test_user_can_regenerate_api_key_myself(self): user = self.factory.create_user() orig_api_key = user.api_key - rv = self.make_request('post', "/api/users/{}/regenerate_api_key".format(user.id), user=user) + rv = self.make_request( + "post", "/api/users/{}/regenerate_api_key".format(user.id), user=user + ) self.assertEqual(rv.status_code, 200) user = models.User.query.get(user.id) diff --git a/tests/handlers/test_visualizations.py b/tests/handlers/test_visualizations.py index e79058147e..42a5b61668 100644 --- a/tests/handlers/test_visualizations.py +++ b/tests/handlers/test_visualizations.py @@ -8,23 +8,25 @@ def test_create_visualization(self): query = self.factory.create_query() models.db.session.commit() data = { - 'query_id': query.id, - 'name': 'Chart', - 'description': '', - 'options': {}, - 'type': 'CHART' + "query_id": query.id, + "name": "Chart", + "description": "", + "options": {}, + "type": "CHART", } - rv = self.make_request('post', '/api/visualizations', data=data) + rv = self.make_request("post", "/api/visualizations", data=data) self.assertEqual(rv.status_code, 200) - data.pop('query_id') + data.pop("query_id") self.assertDictContainsSubset(data, rv.json) def test_delete_visualization(self): visualization = self.factory.create_visualization() models.db.session.commit() - rv = self.make_request('delete', '/api/visualizations/{}'.format(visualization.id)) + rv = self.make_request( + "delete", "/api/visualizations/{}".format(visualization.id) + ) self.assertEqual(rv.status_code, 200) self.assertEqual(models.Visualization.query.count(), 0) @@ -32,10 +34,14 @@ def test_delete_visualization(self): def test_update_visualization(self): visualization = self.factory.create_visualization() models.db.session.commit() - rv = self.make_request('post', '/api/visualizations/{0}'.format(visualization.id), data={'name': 'After Update'}) + rv = self.make_request( + "post", + "/api/visualizations/{0}".format(visualization.id), + data={"name": "After Update"}, + ) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['name'], 'After Update') + self.assertEqual(rv.json["name"], "After Update") def test_only_owner_collaborator_or_admin_can_create_visualization(self): query = self.factory.create_query() @@ -47,31 +53,41 @@ def test_only_owner_collaborator_or_admin_can_create_visualization(self): models.db.session.refresh(other_user) models.db.session.refresh(admin_from_diff_org) data = { - 'query_id': query.id, - 'name': 'Chart', - 'description': '', - 'options': {}, - 'type': 'CHART' + "query_id": query.id, + "name": "Chart", + "description": "", + "options": {}, + "type": "CHART", } - rv = self.make_request('post', '/api/visualizations', data=data, user=admin) + rv = self.make_request("post", "/api/visualizations", data=data, user=admin) self.assertEqual(rv.status_code, 200) - rv = self.make_request('post', '/api/visualizations', data=data, user=other_user) + rv = self.make_request( + "post", "/api/visualizations", data=data, user=other_user + ) self.assertEqual(rv.status_code, 403) - self.make_request('post', '/api/queries/{}/acl'.format(query.id), data={'access_type': 'modify', 'user_id': other_user.id}) - rv = self.make_request('post', '/api/visualizations', data=data, user=other_user) + self.make_request( + "post", + "/api/queries/{}/acl".format(query.id), + data={"access_type": "modify", "user_id": other_user.id}, + ) + rv = self.make_request( + "post", "/api/visualizations", data=data, user=other_user + ) self.assertEqual(rv.status_code, 200) - rv = self.make_request('post', '/api/visualizations', data=data, user=admin_from_diff_org) + rv = self.make_request( + "post", "/api/visualizations", data=data, user=admin_from_diff_org + ) self.assertEqual(rv.status_code, 404) def test_only_owner_collaborator_or_admin_can_edit_visualization(self): vis = self.factory.create_visualization() models.db.session.flush() - path = '/api/visualizations/{}'.format(vis.id) - data = {'name': 'After Update'} + path = "/api/visualizations/{}".format(vis.id) + data = {"name": "After Update"} other_user = self.factory.create_user() admin = self.factory.create_admin() @@ -81,23 +97,27 @@ def test_only_owner_collaborator_or_admin_can_edit_visualization(self): models.db.session.refresh(other_user) models.db.session.refresh(admin_from_diff_org) - rv = self.make_request('post', path, user=admin, data=data) + rv = self.make_request("post", path, user=admin, data=data) self.assertEqual(rv.status_code, 200) - rv = self.make_request('post', path, user=other_user, data=data) + rv = self.make_request("post", path, user=other_user, data=data) self.assertEqual(rv.status_code, 403) - self.make_request('post', '/api/queries/{}/acl'.format(vis.query_id), data={'access_type': 'modify', 'user_id': other_user.id}) - rv = self.make_request('post', path, user=other_user, data=data) + self.make_request( + "post", + "/api/queries/{}/acl".format(vis.query_id), + data={"access_type": "modify", "user_id": other_user.id}, + ) + rv = self.make_request("post", path, user=other_user, data=data) self.assertEqual(rv.status_code, 200) - rv = self.make_request('post', path, user=admin_from_diff_org, data=data) + rv = self.make_request("post", path, user=admin_from_diff_org, data=data) self.assertEqual(rv.status_code, 404) def test_only_owner_collaborator_or_admin_can_delete_visualization(self): vis = self.factory.create_visualization() models.db.session.flush() - path = '/api/visualizations/{}'.format(vis.id) + path = "/api/visualizations/{}".format(vis.id) other_user = self.factory.create_user() admin = self.factory.create_admin() @@ -107,32 +127,38 @@ def test_only_owner_collaborator_or_admin_can_delete_visualization(self): models.db.session.refresh(admin) models.db.session.refresh(other_user) models.db.session.refresh(admin_from_diff_org) - rv = self.make_request('delete', path, user=admin) + rv = self.make_request("delete", path, user=admin) self.assertEqual(rv.status_code, 200) vis = self.factory.create_visualization() models.db.session.commit() - path = '/api/visualizations/{}'.format(vis.id) + path = "/api/visualizations/{}".format(vis.id) - rv = self.make_request('delete', path, user=other_user) + rv = self.make_request("delete", path, user=other_user) self.assertEqual(rv.status_code, 403) - self.make_request('post', '/api/queries/{}/acl'.format(vis.query_id), data={'access_type': 'modify', 'user_id': other_user.id}) + self.make_request( + "post", + "/api/queries/{}/acl".format(vis.query_id), + data={"access_type": "modify", "user_id": other_user.id}, + ) - rv = self.make_request('delete', path, user=other_user) + rv = self.make_request("delete", path, user=other_user) self.assertEqual(rv.status_code, 200) vis = self.factory.create_visualization() models.db.session.commit() - path = '/api/visualizations/{}'.format(vis.id) + path = "/api/visualizations/{}".format(vis.id) - rv = self.make_request('delete', path, user=admin_from_diff_org) + rv = self.make_request("delete", path, user=admin_from_diff_org) self.assertEqual(rv.status_code, 404) def test_deleting_a_visualization_deletes_dashboard_widgets(self): vis = self.factory.create_visualization() widget = self.factory.create_widget(visualization=vis) - rv = self.make_request('delete', '/api/visualizations/{}'.format(vis.id)) + rv = self.make_request("delete", "/api/visualizations/{}".format(vis.id)) - self.assertIsNone(models.Widget.query.filter(models.Widget.id == widget.id).first()) + self.assertIsNone( + models.Widget.query.filter(models.Widget.id == widget.id).first() + ) diff --git a/tests/handlers/test_widgets.py b/tests/handlers/test_widgets.py index c8e8f6f299..85cb522695 100644 --- a/tests/handlers/test_widgets.py +++ b/tests/handlers/test_widgets.py @@ -5,13 +5,13 @@ class WidgetAPITest(BaseTestCase): def create_widget(self, dashboard, visualization, width=1): data = { - 'visualization_id': visualization.id, - 'dashboard_id': dashboard.id, - 'options': {}, - 'width': width + "visualization_id": visualization.id, + "dashboard_id": dashboard.id, + "options": {}, + "width": width, } - rv = self.make_request('post', '/api/widgets', data=data) + rv = self.make_request("post", "/api/widgets", data=data) return rv @@ -31,36 +31,38 @@ def test_wont_create_widget_for_visualization_you_dont_have_access_to(self): models.db.session.add(vis.query_rel) data = { - 'visualization_id': vis.id, - 'dashboard_id': dashboard.id, - 'options': {}, - 'width': 1 + "visualization_id": vis.id, + "dashboard_id": dashboard.id, + "options": {}, + "width": 1, } - rv = self.make_request('post', '/api/widgets', data=data) + rv = self.make_request("post", "/api/widgets", data=data) self.assertEqual(rv.status_code, 403) def test_create_text_widget(self): dashboard = self.factory.create_dashboard() data = { - 'visualization_id': None, - 'text': 'Sample text.', - 'dashboard_id': dashboard.id, - 'options': {}, - 'width': 2 + "visualization_id": None, + "text": "Sample text.", + "dashboard_id": dashboard.id, + "options": {}, + "width": 2, } - rv = self.make_request('post', '/api/widgets', data=data) + rv = self.make_request("post", "/api/widgets", data=data) self.assertEqual(rv.status_code, 200) - self.assertEqual(rv.json['text'], 'Sample text.') + self.assertEqual(rv.json["text"], "Sample text.") def test_delete_widget(self): widget = self.factory.create_widget() - rv = self.make_request('delete', '/api/widgets/{0}'.format(widget.id)) + rv = self.make_request("delete", "/api/widgets/{0}".format(widget.id)) self.assertEqual(rv.status_code, 200) - dashboard = models.Dashboard.get_by_slug_and_org(widget.dashboard.slug, widget.dashboard.org) + dashboard = models.Dashboard.get_by_slug_and_org( + widget.dashboard.slug, widget.dashboard.org + ) self.assertEqual(dashboard.widgets.count(), 0) diff --git a/tests/models/test_alerts.py b/tests/models/test_alerts.py index 5de14e1c59..f93e11fe32 100644 --- a/tests/models/test_alerts.py +++ b/tests/models/test_alerts.py @@ -41,14 +41,18 @@ def test_return_each_alert_only_once(self): def get_results(value): - return json_dumps({'rows': [{'foo': value}], 'columns': [{'name': 'foo', 'type': 'STRING'}]}) + return json_dumps( + {"rows": [{"foo": value}], "columns": [{"name": "foo", "type": "STRING"}]} + ) class TestAlertEvaluate(BaseTestCase): - def create_alert(self, results, column='foo', value="1"): + def create_alert(self, results, column="foo", value="1"): result = self.factory.create_query_result(data=results) query = self.factory.create_query(latest_query_data_id=result.id) - alert = self.factory.create_alert(query_rel=query, options={'op': 'equals', 'column': column, 'value': value}) + alert = self.factory.create_alert( + query_rel=query, options={"op": "equals", "column": column, "value": value} + ) return alert def test_evaluate_triggers_alert_when_equal(self): @@ -60,29 +64,41 @@ def test_evaluate_number_value_and_string_threshold(self): self.assertEqual(alert.evaluate(), Alert.UNKNOWN_STATE) def test_evaluate_return_unknown_when_missing_column(self): - alert = self.create_alert(get_results(1), column='bar') + alert = self.create_alert(get_results(1), column="bar") self.assertEqual(alert.evaluate(), Alert.UNKNOWN_STATE) def test_evaluate_return_unknown_when_empty_results(self): - results = json_dumps({'rows': [], 'columns': [{'name': 'foo', 'type': 'STRING'}]}) + results = json_dumps( + {"rows": [], "columns": [{"name": "foo", "type": "STRING"}]} + ) alert = self.create_alert(results) self.assertEqual(alert.evaluate(), Alert.UNKNOWN_STATE) class TestNextState(TestCase): def test_numeric_value(self): - self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get('=='), 1, "1")) - self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get('=='), 1, "1.0")) - + self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get("=="), 1, "1")) + self.assertEqual( + Alert.TRIGGERED_STATE, next_state(OPERATORS.get("=="), 1, "1.0") + ) + def test_numeric_value_and_plain_string(self): - self.assertEqual(Alert.UNKNOWN_STATE, next_state(OPERATORS.get('=='), 1, "string")) + self.assertEqual( + Alert.UNKNOWN_STATE, next_state(OPERATORS.get("=="), 1, "string") + ) def test_non_numeric_value(self): - self.assertEqual(Alert.OK_STATE, next_state(OPERATORS.get('=='), "1", "1.0")) + self.assertEqual(Alert.OK_STATE, next_state(OPERATORS.get("=="), "1", "1.0")) def test_string_value(self): - self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get('=='), "string", "string")) - + self.assertEqual( + Alert.TRIGGERED_STATE, next_state(OPERATORS.get("=="), "string", "string") + ) + def test_boolean_value(self): - self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get('=='), False, 'false')) - self.assertEqual(Alert.TRIGGERED_STATE, next_state(OPERATORS.get('!='), False, 'true')) \ No newline at end of file + self.assertEqual( + Alert.TRIGGERED_STATE, next_state(OPERATORS.get("=="), False, "false") + ) + self.assertEqual( + Alert.TRIGGERED_STATE, next_state(OPERATORS.get("!="), False, "true") + ) diff --git a/tests/models/test_api_keys.py b/tests/models/test_api_keys.py index 1a9ad21f94..f01d91931e 100644 --- a/tests/models/test_api_keys.py +++ b/tests/models/test_api_keys.py @@ -14,4 +14,3 @@ def test_returns_only_active_key(self): api_key = self.factory.create_api_key(object=dashboard) self.assertEqual(api_key, ApiKey.get_by_object(dashboard)) - diff --git a/tests/models/test_changes.py b/tests/models/test_changes.py index 124e17a30d..e9847b0a0a 100644 --- a/tests/models/test_changes.py +++ b/tests/models/test_changes.py @@ -4,12 +4,14 @@ def create_object(factory): - obj = Query(name='Query', - description='', - query_text='SELECT 1', - user=factory.user, - data_source=factory.data_source, - org=factory.org) + obj = Query( + name="Query", + description="", + query_text="SELECT 1", + user=factory.user, + data_source=factory.data_source, + org=factory.org, + ) return obj @@ -19,17 +21,19 @@ def test_returns_initial_state(self): obj = create_object(self.factory) for change in Change.query.filter(Change.object == obj): - self.assertIsNone(change.change['previous']) + self.assertIsNone(change.change["previous"]) class TestLogChange(BaseTestCase): def obj(self): - obj = Query(name='Query', - description='', - query_text='SELECT 1', - user=self.factory.user, - data_source=self.factory.data_source, - org=self.factory.org) + obj = Query( + name="Query", + description="", + query_text="SELECT 1", + user=self.factory.user, + data_source=self.factory.data_source, + org=self.factory.org, + ) return obj @@ -54,8 +58,8 @@ def test_skips_unnecessary_fields(self): def test_properly_log_modification(self): obj = create_object(self.factory) obj.record_changes(changed_by=self.factory.user) - obj.name = 'Query 2' - obj.description = 'description' + obj.name = "Query 2" + obj.description = "description" db.session.flush() obj.record_changes(changed_by=self.factory.user) @@ -65,13 +69,18 @@ def test_properly_log_modification(self): # TODO: https://github.com/getredash/redash/issues/1550 # self.assertEqual(change.object_version, 2) self.assertEqual(change.object_version, obj.version) - self.assertIn('name', change.change) - self.assertIn('description', change.change) + self.assertIn("name", change.change) + self.assertIn("description", change.change) def test_logs_create_method(self): - q = Query(name='Query', description='', query_text='', - user=self.factory.user, data_source=self.factory.data_source, - org=self.factory.org) + q = Query( + name="Query", + description="", + query_text="", + user=self.factory.user, + data_source=self.factory.data_source, + org=self.factory.org, + ) change = Change.last_change(q) self.assertIsNotNone(change) diff --git a/tests/models/test_dashboards.py b/tests/models/test_dashboards.py index f168aaab3a..70e6048b78 100644 --- a/tests/models/test_dashboards.py +++ b/tests/models/test_dashboards.py @@ -15,16 +15,16 @@ def create_tagged_dashboard(self, tags): widget1 = self.factory.create_widget(visualization=vis1, dashboard=dashboard) widget2 = self.factory.create_widget(visualization=vis2, dashboard=dashboard) widget3 = self.factory.create_widget(visualization=vis3, dashboard=dashboard) - dashboard.layout = '[[{}, {}, {}]]'.format(widget1.id, widget2.id, widget3.id) + dashboard.layout = "[[{}, {}, {}]]".format(widget1.id, widget2.id, widget3.id) db.session.commit() return dashboard def test_all_tags(self): - self.create_tagged_dashboard(tags=['tag1']) - self.create_tagged_dashboard(tags=['tag1', 'tag2']) - self.create_tagged_dashboard(tags=['tag1', 'tag2', 'tag3']) + self.create_tagged_dashboard(tags=["tag1"]) + self.create_tagged_dashboard(tags=["tag1", "tag2"]) + self.create_tagged_dashboard(tags=["tag1", "tag2", "tag3"]) self.assertEqual( list(Dashboard.all_tags(self.factory.org, self.factory.user)), - [('tag1', 3), ('tag2', 2), ('tag3', 1)] + [("tag1", 3), ("tag2", 2), ("tag3", 1)], ) diff --git a/tests/models/test_data_sources.py b/tests/models/test_data_sources.py index 37d4af663b..54d09ce621 100644 --- a/tests/models/test_data_sources.py +++ b/tests/models/test_data_sources.py @@ -8,9 +8,11 @@ class DataSourceTest(BaseTestCase): def test_get_schema(self): - return_value = [{'name': 'table', 'columns': []}] + return_value = [{"name": "table", "columns": []}] - with mock.patch('redash.query_runner.pg.PostgreSQL.get_schema') as patched_get_schema: + with mock.patch( + "redash.query_runner.pg.PostgreSQL.get_schema" + ) as patched_get_schema: patched_get_schema.return_value = return_value schema = self.factory.data_source.get_schema() @@ -18,8 +20,10 @@ def test_get_schema(self): self.assertEqual(return_value, schema) def test_get_schema_uses_cache(self): - return_value = [{'name': 'table', 'columns': []}] - with mock.patch('redash.query_runner.pg.PostgreSQL.get_schema') as patched_get_schema: + return_value = [{"name": "table", "columns": []}] + with mock.patch( + "redash.query_runner.pg.PostgreSQL.get_schema" + ) as patched_get_schema: patched_get_schema.return_value = return_value self.factory.data_source.get_schema() @@ -29,12 +33,14 @@ def test_get_schema_uses_cache(self): self.assertEqual(patched_get_schema.call_count, 1) def test_get_schema_skips_cache_with_refresh_true(self): - return_value = [{'name': 'table', 'columns': []}] - with mock.patch('redash.query_runner.pg.PostgreSQL.get_schema') as patched_get_schema: + return_value = [{"name": "table", "columns": []}] + with mock.patch( + "redash.query_runner.pg.PostgreSQL.get_schema" + ) as patched_get_schema: patched_get_schema.return_value = return_value self.factory.data_source.get_schema() - new_return_value = [{'name': 'new_table', 'columns': []}] + new_return_value = [{"name": "new_table", "columns": []}] patched_get_schema.return_value = new_return_value schema = self.factory.data_source.get_schema(refresh=True) @@ -44,7 +50,12 @@ def test_get_schema_skips_cache_with_refresh_true(self): class TestDataSourceCreate(BaseTestCase): def test_adds_data_source_to_default_group(self): - data_source = DataSource.create_with_group(org=self.factory.org, name='test', options=ConfigurationContainer.from_json('{"dbname": "test"}'), type='pg') + data_source = DataSource.create_with_group( + org=self.factory.org, + name="test", + options=ConfigurationContainer.from_json('{"dbname": "test"}'), + type="pg", + ) self.assertIn(self.factory.org.default_group.id, data_source.groups) @@ -92,13 +103,18 @@ def test_sets_queries_data_source_to_null(self): def test_deletes_child_models(self): data_source = self.factory.create_data_source() self.factory.create_query_result(data_source=data_source) - self.factory.create_query(data_source=data_source, latest_query_data=self.factory.create_query_result(data_source=data_source)) + self.factory.create_query( + data_source=data_source, + latest_query_data=self.factory.create_query_result(data_source=data_source), + ) data_source.delete() self.assertIsNone(DataSource.query.get(data_source.id)) - self.assertEqual(0, QueryResult.query.filter(QueryResult.data_source == data_source).count()) + self.assertEqual( + 0, QueryResult.query.filter(QueryResult.data_source == data_source).count() + ) - @patch('redash.redis_connection.delete') + @patch("redash.redis_connection.delete") def test_deletes_schema(self, mock_redis): data_source = self.factory.create_data_source() data_source.delete() diff --git a/tests/models/test_parameterized_query.py b/tests/models/test_parameterized_query.py index dc2ac0372d..3d290b4323 100644 --- a/tests/models/test_parameterized_query.py +++ b/tests/models/test_parameterized_query.py @@ -3,7 +3,12 @@ from collections import namedtuple import pytest -from redash.models.parameterized_query import ParameterizedQuery, InvalidParameterError, QueryDetachedFromDataSourceError, dropdown_values +from redash.models.parameterized_query import ( + ParameterizedQuery, + InvalidParameterError, + QueryDetachedFromDataSourceError, + dropdown_values, +) class TestParameterizedQuery(TestCase): @@ -13,36 +18,30 @@ def test_returns_empty_list_for_regular_query(self): def test_finds_all_params_when_missing(self): query = ParameterizedQuery("SELECT {{param}} FROM {{table}}") - self.assertEqual(set(['param', 'table']), query.missing_params) + self.assertEqual(set(["param", "table"]), query.missing_params) def test_finds_all_params(self): - query = ParameterizedQuery("SELECT {{param}} FROM {{table}}").apply({ - 'param': 'value', - 'table': 'value' - }) + query = ParameterizedQuery("SELECT {{param}} FROM {{table}}").apply( + {"param": "value", "table": "value"} + ) self.assertEqual(set([]), query.missing_params) def test_deduplicates_params(self): - query = ParameterizedQuery("SELECT {{param}}, {{param}} FROM {{table}}").apply({ - 'param': 'value', - 'table': 'value' - }) + query = ParameterizedQuery("SELECT {{param}}, {{param}} FROM {{table}}").apply( + {"param": "value", "table": "value"} + ) self.assertEqual(set([]), query.missing_params) def test_handles_nested_params(self): - query = ParameterizedQuery("SELECT {{param}}, {{param}} FROM {{table}} -- {{#test}} {{nested_param}} {{/test}}").apply({ - 'param': 'value', - 'table': 'value' - }) - self.assertEqual(set(['test', 'nested_param']), query.missing_params) + query = ParameterizedQuery( + "SELECT {{param}}, {{param}} FROM {{table}} -- {{#test}} {{nested_param}} {{/test}}" + ).apply({"param": "value", "table": "value"}) + self.assertEqual(set(["test", "nested_param"]), query.missing_params) def test_handles_objects(self): - query = ParameterizedQuery("SELECT * FROM USERS WHERE created_at between '{{ created_at.start }}' and '{{ created_at.end }}'").apply({ - 'created_at': { - 'start': 1, - 'end': 2 - } - }) + query = ParameterizedQuery( + "SELECT * FROM USERS WHERE created_at between '{{ created_at.start }}' and '{{ created_at.end }}'" + ).apply({"created_at": {"start": 1, "end": 2}}) self.assertEqual(set([]), query.missing_params) def test_raises_on_parameters_not_in_schema(self): @@ -127,12 +126,14 @@ def test_raises_on_unlisted_enum_value_parameters(self): query.apply({"bar": "shlomo"}) def test_raises_on_unlisted_enum_list_value_parameters(self): - schema = [{ - "name": "bar", - "type": "enum", - "enumOptions": ["baz", "qux"], - "multiValuesOptions": {"separator": ",", "prefix": "", "suffix": ""} - }] + schema = [ + { + "name": "bar", + "type": "enum", + "enumOptions": ["baz", "qux"], + "multiValuesOptions": {"separator": ",", "prefix": "", "suffix": ""}, + } + ] query = ParameterizedQuery("foo", schema) with pytest.raises(InvalidParameterError): @@ -147,19 +148,24 @@ def test_validates_enum_parameters(self): self.assertEqual("foo baz", query.text) def test_validates_enum_list_value_parameters(self): - schema = [{ - "name": "bar", - "type": "enum", - "enumOptions": ["baz", "qux"], - "multiValuesOptions": {"separator": ",", "prefix": "'", "suffix": "'"} - }] + schema = [ + { + "name": "bar", + "type": "enum", + "enumOptions": ["baz", "qux"], + "multiValuesOptions": {"separator": ",", "prefix": "'", "suffix": "'"}, + } + ] query = ParameterizedQuery("foo {{bar}}", schema) query.apply({"bar": ["qux", "baz"]}) self.assertEqual("foo 'qux','baz'", query.text) - @patch('redash.models.parameterized_query.dropdown_values', return_value=[{"value": "1"}]) + @patch( + "redash.models.parameterized_query.dropdown_values", + return_value=[{"value": "1"}], + ) def test_validation_accepts_integer_values_for_dropdowns(self, _): schema = [{"name": "bar", "type": "query", "queryId": 1}] query = ParameterizedQuery("foo {{bar}}", schema) @@ -168,7 +174,7 @@ def test_validation_accepts_integer_values_for_dropdowns(self, _): self.assertEqual("foo 1", query.text) - @patch('redash.models.parameterized_query.dropdown_values') + @patch("redash.models.parameterized_query.dropdown_values") def test_raises_on_invalid_query_parameters(self, _): schema = [{"name": "bar", "type": "query", "queryId": 1}] query = ParameterizedQuery("foo", schema) @@ -176,7 +182,10 @@ def test_raises_on_invalid_query_parameters(self, _): with pytest.raises(InvalidParameterError): query.apply({"bar": 7}) - @patch('redash.models.parameterized_query.dropdown_values', return_value=[{"value": "baz"}]) + @patch( + "redash.models.parameterized_query.dropdown_values", + return_value=[{"value": "baz"}], + ) def test_raises_on_unlisted_query_value_parameters(self, _): schema = [{"name": "bar", "type": "query", "queryId": 1}] query = ParameterizedQuery("foo", schema) @@ -184,7 +193,10 @@ def test_raises_on_unlisted_query_value_parameters(self, _): with pytest.raises(InvalidParameterError): query.apply({"bar": "shlomo"}) - @patch('redash.models.parameterized_query.dropdown_values', return_value=[{"value": "baz"}]) + @patch( + "redash.models.parameterized_query.dropdown_values", + return_value=[{"value": "baz"}], + ) def test_validates_query_parameters(self, _): schema = [{"name": "bar", "type": "query", "queryId": 1}] query = ParameterizedQuery("foo {{bar}}", schema) @@ -204,7 +216,9 @@ def test_validates_date_range_parameters(self): schema = [{"name": "bar", "type": "date-range"}] query = ParameterizedQuery("foo {{bar.start}} {{bar.end}}", schema) - query.apply({"bar": {"start": "2000-01-01 12:00:00", "end": "2000-12-31 12:00:00"}}) + query.apply( + {"bar": {"start": "2000-01-01 12:00:00", "end": "2000-12-31 12:00:00"}} + ) self.assertEqual("foo 2000-01-01 12:00:00 2000-12-31 12:00:00", query.text) @@ -233,28 +247,43 @@ def test_is_safe_if_not_expecting_any_parameters(self): self.assertTrue(query.is_safe) - @patch('redash.models.parameterized_query._load_result', return_value={ - "columns": [{"name": "id"}, {"name": "Name"}, {"name": "Value"}], - "rows": [{"id": 5, "Name": "John", "Value": "John Doe"}]}) + @patch( + "redash.models.parameterized_query._load_result", + return_value={ + "columns": [{"name": "id"}, {"name": "Name"}, {"name": "Value"}], + "rows": [{"id": 5, "Name": "John", "Value": "John Doe"}], + }, + ) def test_dropdown_values_prefers_name_and_value_columns(self, _): values = dropdown_values(1, None) self.assertEqual(values, [{"name": "John", "value": "John Doe"}]) - @patch('redash.models.parameterized_query._load_result', return_value={ - "columns": [{"name": "id"}, {"name": "fish"}, {"name": "poultry"}], - "rows": [{"fish": "Clown", "id": 5, "poultry": "Hen"}]}) + @patch( + "redash.models.parameterized_query._load_result", + return_value={ + "columns": [{"name": "id"}, {"name": "fish"}, {"name": "poultry"}], + "rows": [{"fish": "Clown", "id": 5, "poultry": "Hen"}], + }, + ) def test_dropdown_values_compromises_for_first_column(self, _): values = dropdown_values(1, None) self.assertEqual(values, [{"name": 5, "value": "5"}]) - @patch('redash.models.parameterized_query._load_result', return_value={ - "columns": [{"name": "ID"}, {"name": "fish"}, {"name": "poultry"}], - "rows": [{"fish": "Clown", "ID": 5, "poultry": "Hen"}]}) + @patch( + "redash.models.parameterized_query._load_result", + return_value={ + "columns": [{"name": "ID"}, {"name": "fish"}, {"name": "poultry"}], + "rows": [{"fish": "Clown", "ID": 5, "poultry": "Hen"}], + }, + ) def test_dropdown_supports_upper_cased_columns(self, _): values = dropdown_values(1, None) self.assertEqual(values, [{"name": 5, "value": "5"}]) - @patch('redash.models.Query.get_by_id_and_org', return_value=namedtuple('Query', 'data_source')(None)) + @patch( + "redash.models.Query.get_by_id_and_org", + return_value=namedtuple("Query", "data_source")(None), + ) def test_dropdown_values_raises_when_query_is_detached_from_data_source(self, _): with pytest.raises(QueryDetachedFromDataSourceError): dropdown_values(1, None) diff --git a/tests/models/test_permissions.py b/tests/models/test_permissions.py index 7406a8f9ec..9a0ef15c1c 100644 --- a/tests/models/test_permissions.py +++ b/tests/models/test_permissions.py @@ -6,9 +6,12 @@ class TestAccessPermissionGrant(BaseTestCase): def test_creates_correct_object(self): q = self.factory.create_query() - permission = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=self.factory.user) + permission = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) self.assertEqual(permission.object, q) self.assertEqual(permission.grantor, self.factory.user) @@ -17,13 +20,19 @@ def test_creates_correct_object(self): def test_returns_existing_object_if_exists(self): q = self.factory.create_query() - permission1 = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=self.factory.user) - - permission2 = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=self.factory.user) + permission1 = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) + + permission2 = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) self.assertEqual(permission1.id, permission2.id) @@ -31,44 +40,66 @@ def test_returns_existing_object_if_exists(self): class TestAccessPermissionRevoke(BaseTestCase): def test_deletes_nothing_when_no_permission_exists(self): q = self.factory.create_query() - self.assertEqual(0, AccessPermission.revoke(q, self.factory.user, ACCESS_TYPE_MODIFY)) + self.assertEqual( + 0, AccessPermission.revoke(q, self.factory.user, ACCESS_TYPE_MODIFY) + ) def test_deletes_permission(self): q = self.factory.create_query() - permission = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=self.factory.user) - self.assertEqual(1, AccessPermission.revoke(q, self.factory.user, ACCESS_TYPE_MODIFY)) + permission = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) + self.assertEqual( + 1, AccessPermission.revoke(q, self.factory.user, ACCESS_TYPE_MODIFY) + ) def test_deletes_permission_for_only_given_grantee_on_given_grant_type(self): q = self.factory.create_query() - first_user = self.factory.create_user() + first_user = self.factory.create_user() second_user = self.factory.create_user() - AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=first_user) - - AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=second_user) - - AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_VIEW, - grantor=self.factory.user, - grantee=second_user) + AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=first_user, + ) + + AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=second_user, + ) + + AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_VIEW, + grantor=self.factory.user, + grantee=second_user, + ) self.assertEqual(1, AccessPermission.revoke(q, second_user, ACCESS_TYPE_VIEW)) def test_deletes_all_permissions_if_no_type_given(self): q = self.factory.create_query() - permission = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_MODIFY, - grantor=self.factory.user, - grantee=self.factory.user) - - permission = AccessPermission.grant(obj=q, access_type=ACCESS_TYPE_VIEW, - grantor=self.factory.user, - grantee=self.factory.user) + permission = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_MODIFY, + grantor=self.factory.user, + grantee=self.factory.user, + ) + + permission = AccessPermission.grant( + obj=q, + access_type=ACCESS_TYPE_VIEW, + grantor=self.factory.user, + grantee=self.factory.user, + ) self.assertEqual(2, AccessPermission.revoke(q, self.factory.user)) diff --git a/tests/models/test_queries.py b/tests/models/test_queries.py index 1aaeb64dbc..dbbd817c98 100644 --- a/tests/models/test_queries.py +++ b/tests/models/test_queries.py @@ -20,13 +20,13 @@ def create_tagged_query(self, tags): return query def test_all_tags(self): - self.create_tagged_query(tags=['tag1']) - self.create_tagged_query(tags=['tag1', 'tag2']) - self.create_tagged_query(tags=['tag1', 'tag2', 'tag3']) + self.create_tagged_query(tags=["tag1"]) + self.create_tagged_query(tags=["tag1", "tag2"]) + self.create_tagged_query(tags=["tag1", "tag2", "tag3"]) self.assertEqual( list(Query.all_tags(self.factory.user)), - [('tag1', 3), ('tag2', 2), ('tag3', 1)] + [("tag1", 3), ("tag2", 2), ("tag3", 1)], ) def test_search_finds_in_name(self): @@ -55,7 +55,9 @@ def test_search_finds_in_multi_byte_name_and_description(self): q2 = self.factory.create_query(description="日本語の説明文テスト") q3 = self.factory.create_query(description="Testing search") - queries = Query.search("テスト", [self.factory.default_group.id], multi_byte_search=True) + queries = Query.search( + "テスト", [self.factory.default_group.id], multi_byte_search=True + ) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -75,7 +77,7 @@ def test_search_by_id_returns_query(self): def test_search_by_number(self): q = self.factory.create_query(description="Testing search 12345") db.session.flush() - queries = Query.search('12345', [self.factory.default_group.id]) + queries = Query.search("12345", [self.factory.default_group.id]) self.assertIn(q, queries) @@ -94,7 +96,9 @@ def test_search_respects_groups(self): self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = list(Query.search("Testing", [other_group.id, self.factory.default_group.id])) + queries = list( + Query.search("Testing", [other_group.id, self.factory.default_group.id]) + ) self.assertIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) @@ -112,7 +116,12 @@ def test_returns_each_query_only_once(self): q1 = self.factory.create_query(description="Testing search", data_source=ds) db.session.flush() - queries = list(Query.search("Testing", [self.factory.default_group.id, other_group.id, second_group.id])) + queries = list( + Query.search( + "Testing", + [self.factory.default_group.id, other_group.id, second_group.id], + ) + ) self.assertEqual(1, len(queries)) @@ -121,20 +130,22 @@ def test_save_updates_updated_at_field(self): one_day_ago = utcnow().date() - datetime.timedelta(days=1) q = self.factory.create_query(created_at=one_day_ago, updated_at=one_day_ago) db.session.flush() - q.name = 'x' + q.name = "x" db.session.flush() self.assertNotEqual(q.updated_at, one_day_ago) def test_search_is_case_insensitive(self): q = self.factory.create_query(name="Testing search") - self.assertIn(q, Query.search('testing', [self.factory.default_group.id])) + self.assertIn(q, Query.search("testing", [self.factory.default_group.id])) def test_search_query_parser_or(self): q1 = self.factory.create_query(name="Testing") q2 = self.factory.create_query(name="search") - queries = list(Query.search('testing or search', [self.factory.default_group.id])) + queries = list( + Query.search("testing or search", [self.factory.default_group.id]) + ) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -142,7 +153,7 @@ def test_search_query_parser_negation(self): q1 = self.factory.create_query(name="Testing") q2 = self.factory.create_query(name="search") - queries = list(Query.search('testing -search', [self.factory.default_group.id])) + queries = list(Query.search("testing -search", [self.factory.default_group.id])) self.assertIn(q1, queries) self.assertNotIn(q2, queries) @@ -151,7 +162,9 @@ def test_search_query_parser_parenthesis(self): q2 = self.factory.create_query(name="Testing searching") q3 = self.factory.create_query(name="Testing finding") - queries = list(Query.search('(testing search) or finding', [self.factory.default_group.id])) + queries = list( + Query.search("(testing search) or finding", [self.factory.default_group.id]) + ) self.assertIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) @@ -160,7 +173,7 @@ def test_search_query_parser_hyphen(self): q1 = self.factory.create_query(name="Testing search") q2 = self.factory.create_query(name="Testing-search") - queries = list(Query.search('testing search', [self.factory.default_group.id])) + queries = list(Query.search("testing search", [self.factory.default_group.id])) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -168,15 +181,15 @@ def test_search_query_parser_emails(self): q1 = self.factory.create_query(name="janedoe@example.com") q2 = self.factory.create_query(name="johndoe@example.com") - queries = list(Query.search('example', [self.factory.default_group.id])) + queries = list(Query.search("example", [self.factory.default_group.id])) self.assertIn(q1, queries) self.assertIn(q2, queries) - queries = list(Query.search('com', [self.factory.default_group.id])) + queries = list(Query.search("com", [self.factory.default_group.id])) self.assertIn(q1, queries) self.assertIn(q2, queries) - queries = list(Query.search('johndoe', [self.factory.default_group.id])) + queries = list(Query.search("johndoe", [self.factory.default_group.id])) self.assertNotIn(q1, queries) self.assertIn(q2, queries) @@ -184,10 +197,14 @@ def test_past_scheduled_queries(self): query = self.factory.create_query() one_day_ago = (utcnow() - datetime.timedelta(days=1)).strftime("%Y-%m-%d") one_day_later = (utcnow() + datetime.timedelta(days=1)).strftime("%Y-%m-%d") - query1 = self.factory.create_query(schedule={'interval':'3600','until':one_day_ago}) - query2 = self.factory.create_query(schedule={'interval':'3600','until':one_day_later}) + query1 = self.factory.create_query( + schedule={"interval": "3600", "until": one_day_ago} + ) + query2 = self.factory.create_query( + schedule={"interval": "3600", "until": one_day_later} + ) oq = staticmethod(lambda: [query1, query2]) - with mock.patch.object(query.query.filter(), 'order_by', oq): + with mock.patch.object(query.query.filter(), "order_by", oq): res = query.past_scheduled_queries() self.assertTrue(query1 in res) self.assertFalse(query2 in res) @@ -198,8 +215,13 @@ def test_global_recent(self): q1 = self.factory.create_query() q2 = self.factory.create_query() db.session.flush() - e = Event(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) + e = Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q1.id, + ) db.session.add(e) recent = Query.recent([self.factory.default_group.id]) self.assertIn(q1, recent) @@ -209,14 +231,24 @@ def test_recent_excludes_drafts(self): q1 = self.factory.create_query() q2 = self.factory.create_query(is_draft=True) - db.session.add_all([ - Event(org=self.factory.org, user=self.factory.user, - action="edit", object_type="query", - object_id=q1.id), - Event(org=self.factory.org, user=self.factory.user, - action="edit", object_type="query", - object_id=q2.id) - ]) + db.session.add_all( + [ + Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q1.id, + ), + Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q2.id, + ), + ] + ) recent = Query.recent([self.factory.default_group.id]) self.assertIn(q1, recent) @@ -226,15 +258,24 @@ def test_recent_for_user(self): q1 = self.factory.create_query() q2 = self.factory.create_query() db.session.flush() - e = Event(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) + e = Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q1.id, + ) db.session.add(e) - recent = Query.recent([self.factory.default_group.id], user_id=self.factory.user.id) + recent = Query.recent( + [self.factory.default_group.id], user_id=self.factory.user.id + ) self.assertIn(q1, recent) self.assertNotIn(q2, recent) - recent = Query.recent([self.factory.default_group.id], user_id=self.factory.user.id + 1) + recent = Query.recent( + [self.factory.default_group.id], user_id=self.factory.user.id + 1 + ) self.assertNotIn(q1, recent) self.assertNotIn(q2, recent) @@ -243,10 +284,20 @@ def test_respects_groups(self): ds = self.factory.create_data_source(group=self.factory.create_group()) q2 = self.factory.create_query(data_source=ds) db.session.flush() - Event(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - Event(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q2.id) + Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q1.id, + ) + Event( + org=self.factory.org, + user=self.factory.user, + action="edit", + object_type="query", + object_id=q2.id, + ) recent = Query.recent([self.factory.default_group.id]) @@ -277,7 +328,11 @@ def test_returns_drafts_by_the_user(self): def test_returns_only_queries_from_groups_the_user_is_member_in(self): q = self.factory.create_query() - q2 = self.factory.create_query(data_source=self.factory.create_data_source(group=self.factory.create_group())) + q2 = self.factory.create_query( + data_source=self.factory.create_data_source( + group=self.factory.create_group() + ) + ) queries = Query.by_user(self.factory.user) @@ -296,21 +351,25 @@ def assert_visualizations(self, origin_q, origin_v, forked_q, forked_v): def test_fork_with_visualizations(self): # prepare original query and visualizations - data_source = self.factory.create_data_source( - group=self.factory.create_group()) - query = self.factory.create_query(data_source=data_source, - description="this is description") + data_source = self.factory.create_data_source(group=self.factory.create_group()) + query = self.factory.create_query( + data_source=data_source, description="this is description" + ) # create default TABLE - query factory does not create it self.factory.create_visualization( - query_rel=query, name="Table", description='', type="TABLE", options="{}") + query_rel=query, name="Table", description="", type="TABLE", options="{}" + ) visualization_chart = self.factory.create_visualization( - query_rel=query, description="chart vis", type="CHART", - options="""{"yAxis": [{"type": "linear"}, {"type": "linear", "opposite": true}], "series": {"stacking": null}, "globalSeriesType": "line", "sortX": true, "seriesOptions": {"count": {"zIndex": 0, "index": 0, "type": "line", "yAxis": 0}}, "xAxis": {"labels": {"enabled": true}, "type": "datetime"}, "columnMapping": {"count": "y", "created_at": "x"}, "bottomMargin": 50, "legend": {"enabled": true}}""") + query_rel=query, + description="chart vis", + type="CHART", + options="""{"yAxis": [{"type": "linear"}, {"type": "linear", "opposite": true}], "series": {"stacking": null}, "globalSeriesType": "line", "sortX": true, "seriesOptions": {"count": {"zIndex": 0, "index": 0, "type": "line", "yAxis": 0}}, "xAxis": {"labels": {"enabled": true}, "type": "datetime"}, "columnMapping": {"count": "y", "created_at": "x"}, "bottomMargin": 50, "legend": {"enabled": true}}""", + ) visualization_box = self.factory.create_visualization( - query_rel=query, description="box vis", type="BOXPLOT", - options="{}") + query_rel=query, description="box vis", type="BOXPLOT", options="{}" + ) fork_user = self.factory.create_user() forked_query = query.fork(fork_user) db.session.flush() @@ -328,21 +387,22 @@ def test_fork_with_visualizations(self): count_table += 1 forked_table = v - self.assert_visualizations(query, visualization_chart, forked_query, - forked_visualization_chart) - self.assert_visualizations(query, visualization_box, forked_query, - forked_visualization_box) + self.assert_visualizations( + query, visualization_chart, forked_query, forked_visualization_chart + ) + self.assert_visualizations( + query, visualization_box, forked_query, forked_visualization_box + ) self.assertEqual(forked_query.org, query.org) self.assertEqual(forked_query.data_source, query.data_source) - self.assertEqual(forked_query.latest_query_data, - query.latest_query_data) + self.assertEqual(forked_query.latest_query_data, query.latest_query_data) self.assertEqual(forked_query.description, query.description) self.assertEqual(forked_query.query_text, query.query_text) self.assertEqual(forked_query.query_hash, query.query_hash) self.assertEqual(forked_query.user, fork_user) self.assertEqual(forked_query.description, query.description) - self.assertTrue(forked_query.name.startswith('Copy')) + self.assertTrue(forked_query.name.startswith("Copy")) # num of TABLE must be 1. default table only self.assertEqual(count_table, 1) self.assertEqual(forked_table.name, "Table") @@ -351,14 +411,15 @@ def test_fork_with_visualizations(self): def test_fork_from_query_that_has_no_visualization(self): # prepare original query and visualizations - data_source = self.factory.create_data_source( - group=self.factory.create_group()) - query = self.factory.create_query(data_source=data_source, - description="this is description") + data_source = self.factory.create_data_source(group=self.factory.create_group()) + query = self.factory.create_query( + data_source=data_source, description="this is description" + ) # create default TABLE - query factory does not create it self.factory.create_visualization( - query_rel=query, name="Table", description='', type="TABLE", options="{}") + query_rel=query, name="Table", description="", type="TABLE", options="{}" + ) fork_user = self.factory.create_user() @@ -391,8 +452,14 @@ def test_updates_existing_queries(self): query3 = self.factory.create_query(query_text=self.query) query_result = QueryResult.store_result( - self.data_source.org_id, self.data_source, self.query_hash, - self.query, self.data, self.runtime, self.utcnow) + self.data_source.org_id, + self.data_source, + self.query_hash, + self.query, + self.data, + self.runtime, + self.utcnow, + ) Query.update_latest_result(query_result) @@ -406,8 +473,14 @@ def test_doesnt_update_queries_with_different_hash(self): query3 = self.factory.create_query(query_text=self.query + "123") query_result = QueryResult.store_result( - self.data_source.org_id, self.data_source, self.query_hash, - self.query, self.data, self.runtime, self.utcnow) + self.data_source.org_id, + self.data_source, + self.query_hash, + self.query, + self.data, + self.runtime, + self.utcnow, + ) Query.update_latest_result(query_result) @@ -418,11 +491,19 @@ def test_doesnt_update_queries_with_different_hash(self): def test_doesnt_update_queries_with_different_data_source(self): query1 = self.factory.create_query(query_text=self.query) query2 = self.factory.create_query(query_text=self.query) - query3 = self.factory.create_query(query_text=self.query, data_source=self.factory.create_data_source()) + query3 = self.factory.create_query( + query_text=self.query, data_source=self.factory.create_data_source() + ) query_result = QueryResult.store_result( - self.data_source.org_id, self.data_source, self.query_hash, - self.query, self.data, self.runtime, self.utcnow) + self.data_source.org_id, + self.data_source, + self.query_hash, + self.query, + self.data, + self.runtime, + self.utcnow, + ) Query.update_latest_result(query_result) diff --git a/tests/models/test_query_results.py b/tests/models/test_query_results.py index aa6b6adacf..9695c40918 100644 --- a/tests/models/test_query_results.py +++ b/tests/models/test_query_results.py @@ -11,19 +11,25 @@ class QueryResultTest(BaseTestCase): def test_get_latest_returns_none_if_not_found(self): - found_query_result = models.QueryResult.get_latest(self.factory.data_source, "SELECT 1", 60) + found_query_result = models.QueryResult.get_latest( + self.factory.data_source, "SELECT 1", 60 + ) self.assertIsNone(found_query_result) def test_get_latest_returns_when_found(self): qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, 60) + found_query_result = models.QueryResult.get_latest( + qr.data_source, qr.query_text, 60 + ) self.assertEqual(qr, found_query_result) def test_get_latest_doesnt_return_query_from_different_data_source(self): qr = self.factory.create_query_result() data_source = self.factory.create_data_source() - found_query_result = models.QueryResult.get_latest(data_source, qr.query_text, 60) + found_query_result = models.QueryResult.get_latest( + data_source, qr.query_text, 60 + ) self.assertIsNone(found_query_result) @@ -31,7 +37,9 @@ def test_get_latest_doesnt_return_if_ttl_expired(self): yesterday = utcnow() - datetime.timedelta(days=1) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, max_age=60) + found_query_result = models.QueryResult.get_latest( + qr.data_source, qr.query_text, max_age=60 + ) self.assertIsNone(found_query_result) @@ -39,7 +47,9 @@ def test_get_latest_returns_if_ttl_not_expired(self): yesterday = utcnow() - datetime.timedelta(seconds=30) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, max_age=120) + found_query_result = models.QueryResult.get_latest( + qr.data_source, qr.query_text, max_age=120 + ) self.assertEqual(found_query_result, qr) @@ -48,7 +58,9 @@ def test_get_latest_returns_the_most_recent_result(self): self.factory.create_query_result(retrieved_at=yesterday) qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, 60) + found_query_result = models.QueryResult.get_latest( + qr.data_source, qr.query_text, 60 + ) self.assertEqual(found_query_result.id, qr.id) @@ -58,7 +70,9 @@ def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): yesterday = utcnow() + datetime.timedelta(days=-1) qr = self.factory.create_query_result(retrieved_at=yesterday) - found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, -1) + found_query_result = models.QueryResult.get_latest( + qr.data_source, qr.query_text, -1 + ) self.assertEqual(found_query_result.id, qr.id) @@ -66,7 +80,15 @@ def test_store_result_does_not_modify_query_update_at(self): original_updated_at = utcnow() - datetime.timedelta(hours=1) query = self.factory.create_query(updated_at=original_updated_at) - models.QueryResult.store_result(query.org_id, query.data_source, query.query_hash, query.query_text, "", 0, utcnow()) + models.QueryResult.store_result( + query.org_id, + query.data_source, + query.query_hash, + query.query_text, + "", + 0, + utcnow(), + ) self.assertEqual(original_updated_at, query.updated_at) @@ -79,9 +101,9 @@ def test_updating_data_removes_cached_result(self): p.data = '{"test": 2}' self.assertDictEqual(p.data, {"test": 2}) - @patch('redash.models.json_loads') + @patch("redash.models.json_loads") def test_calls_json_loads_only_once(self, json_loads_patch): - json_loads_patch.return_value = '1' + json_loads_patch.return_value = "1" p = DBPersistence() json_data = '{"test": 1}' p.data = json_data diff --git a/tests/models/test_users.py b/tests/models/test_users.py index 733bbe5afe..e9e0dffdb8 100644 --- a/tests/models/test_users.py +++ b/tests/models/test_users.py @@ -3,7 +3,12 @@ from redash import redis_connection from redash.models import User, db from redash.utils import dt_from_timestamp -from redash.models.users import sync_last_active_at, update_user_active_at, LAST_ACTIVE_KEY +from redash.models.users import ( + sync_last_active_at, + update_user_active_at, + LAST_ACTIVE_KEY, +) + class TestUserUpdateGroupAssignments(BaseTestCase): def test_default_group_always_added(self): @@ -26,29 +31,31 @@ def test_update_group_assignments(self): class TestUserFindByEmail(BaseTestCase): def test_finds_users(self): - user = self.factory.create_user(email='test@example.com') - user2 = self.factory.create_user(email='test@example.com', org=self.factory.create_org()) + user = self.factory.create_user(email="test@example.com") + user2 = self.factory.create_user( + email="test@example.com", org=self.factory.create_org() + ) users = User.find_by_email(user.email) self.assertIn(user, users) self.assertIn(user2, users) def test_finds_users_case_insensitive(self): - user = self.factory.create_user(email='test@example.com') + user = self.factory.create_user(email="test@example.com") - users = User.find_by_email('test@EXAMPLE.com') + users = User.find_by_email("test@EXAMPLE.com") self.assertIn(user, users) class TestUserGetByEmailAndOrg(BaseTestCase): def test_get_user_by_email_and_org(self): - user = self.factory.create_user(email='test@example.com') + user = self.factory.create_user(email="test@example.com") found_user = User.get_by_email_and_org(user.email, user.org) self.assertEqual(user, found_user) def test_get_user_by_email_and_org_case_insensitive(self): - user = self.factory.create_user(email='test@example.com') + user = self.factory.create_user(email="test@example.com") found_user = User.get_by_email_and_org("TEST@example.com", user.org) self.assertEqual(user, found_user) @@ -56,9 +63,9 @@ def test_get_user_by_email_and_org_case_insensitive(self): class TestUserSearch(BaseTestCase): def test_non_unicode_search_string(self): - user = self.factory.create_user(name='אריק') + user = self.factory.create_user(name="אריק") - assert user in User.search(User.all(user.org), term='א') + assert user in User.search(User.all(user.org), term="א") class TestUserRegenerateApiKey(BaseTestCase): @@ -84,24 +91,26 @@ def test_userdetail_db_default(self): def test_userdetail_db_default_save(self): with authenticated_user(self.client) as user: - user.details['test'] = 1 + user.details["test"] = 1 db.session.commit() user_reloaded = User.query.filter_by(id=user.id).first() - self.assertEqual(user.details['test'], 1) + self.assertEqual(user.details["test"], 1) self.assertEqual( user_reloaded, User.query.filter( - User.details['test'].astext.cast(db.Integer) == 1 - ).first() + User.details["test"].astext.cast(db.Integer) == 1 + ).first(), ) def test_sync(self): with authenticated_user(self.client) as user: - rv = self.client.get('/default/') - timestamp = dt_from_timestamp(redis_connection.hget(LAST_ACTIVE_KEY, user.id)) + rv = self.client.get("/default/") + timestamp = dt_from_timestamp( + redis_connection.hget(LAST_ACTIVE_KEY, user.id) + ) sync_last_active_at() - user_reloaded = User.query.filter(User.id==user.id).first() - self.assertIn('active_at', user_reloaded.details) + user_reloaded = User.query.filter(User.id == user.id).first() + self.assertIn("active_at", user_reloaded.details) self.assertEqual(user_reloaded.active_at, timestamp) diff --git a/tests/query_runner/test_athena.py b/tests/query_runner/test_athena.py index fe444de64f..9db7b52f84 100644 --- a/tests/query_runner/test_athena.py +++ b/tests/query_runner/test_athena.py @@ -14,11 +14,14 @@ class TestGlueSchema(TestCase): def setUp(self): client = botocore.session.get_session().create_client( - 'glue', region_name='mars-east-1', aws_access_key_id='foo', aws_secret_access_key='bar' + "glue", + region_name="mars-east-1", + aws_access_key_id="foo", + aws_secret_access_key="bar", ) self.stubber = Stubber(client) - self.patcher = mock.patch('boto3.client') + self.patcher = mock.patch("boto3.client") mocked_client = self.patcher.start() mocked_client.return_value = client @@ -27,130 +30,145 @@ def tearDown(self): def test_external_table(self): """Unpartitioned table crawled through a JDBC connection""" - query_runner = Athena({'glue': True, 'region': 'mars-east-1'}) + query_runner = Athena({"glue": True, "region": "mars-east-1"}) - self.stubber.add_response('get_databases', {'DatabaseList': [{'Name': 'test1'}]}, {}) self.stubber.add_response( - 'get_tables', + "get_databases", {"DatabaseList": [{"Name": "test1"}]}, {} + ) + self.stubber.add_response( + "get_tables", { - 'TableList': [ + "TableList": [ { - 'Name': 'jdbc_table', - 'StorageDescriptor': { - 'Columns': [{'Name': 'row_id', 'Type': 'int'}], - 'Location': 'Database.Schema.Table', - 'Compressed': False, - 'NumberOfBuckets': -1, - 'SerdeInfo': {'Parameters': {}}, - 'BucketColumns': [], - 'SortColumns': [], - 'Parameters': { - 'CrawlerSchemaDeserializerVersion': '1.0', - 'CrawlerSchemaSerializerVersion': '1.0', - 'UPDATED_BY_CRAWLER': 'jdbc', - 'classification': 'sqlserver', - 'compressionType': 'none', - 'connectionName': 'jdbctest', - 'typeOfData': 'view', + "Name": "jdbc_table", + "StorageDescriptor": { + "Columns": [{"Name": "row_id", "Type": "int"}], + "Location": "Database.Schema.Table", + "Compressed": False, + "NumberOfBuckets": -1, + "SerdeInfo": {"Parameters": {}}, + "BucketColumns": [], + "SortColumns": [], + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", }, - 'StoredAsSubDirectories': False, + "StoredAsSubDirectories": False, }, - 'PartitionKeys': [], - 'TableType': 'EXTERNAL_TABLE', - 'Parameters': { - 'CrawlerSchemaDeserializerVersion': '1.0', - 'CrawlerSchemaSerializerVersion': '1.0', - 'UPDATED_BY_CRAWLER': 'jdbc', - 'classification': 'sqlserver', - 'compressionType': 'none', - 'connectionName': 'jdbctest', - 'typeOfData': 'view', + "PartitionKeys": [], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "CrawlerSchemaDeserializerVersion": "1.0", + "CrawlerSchemaSerializerVersion": "1.0", + "UPDATED_BY_CRAWLER": "jdbc", + "classification": "sqlserver", + "compressionType": "none", + "connectionName": "jdbctest", + "typeOfData": "view", }, } ] }, - {'DatabaseName': 'test1'}, + {"DatabaseName": "test1"}, ) with self.stubber: - assert query_runner.get_schema() == [{'columns': ['row_id'], 'name': 'test1.jdbc_table'}] + assert query_runner.get_schema() == [ + {"columns": ["row_id"], "name": "test1.jdbc_table"} + ] def test_partitioned_table(self): """ Partitioned table as created by a GlueContext """ - query_runner = Athena({'glue': True, 'region': 'mars-east-1'}) + query_runner = Athena({"glue": True, "region": "mars-east-1"}) - self.stubber.add_response('get_databases', {'DatabaseList': [{'Name': 'test1'}]}, {}) self.stubber.add_response( - 'get_tables', + "get_databases", {"DatabaseList": [{"Name": "test1"}]}, {} + ) + self.stubber.add_response( + "get_tables", { - 'TableList': [ + "TableList": [ { - 'Name': 'partitioned_table', - 'StorageDescriptor': { - 'Columns': [{'Name': 'sk', 'Type': 'int'}], - 'Location': 's3://bucket/prefix', - 'InputFormat': 'org.apache.hadoop.mapred.TextInputFormat', - 'OutputFormat': 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat', - 'Compressed': False, - 'NumberOfBuckets': -1, - 'SerdeInfo': { - 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe', - 'Parameters': {'serialization.format': '1'}, + "Name": "partitioned_table", + "StorageDescriptor": { + "Columns": [{"Name": "sk", "Type": "int"}], + "Location": "s3://bucket/prefix", + "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", + "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", + "Compressed": False, + "NumberOfBuckets": -1, + "SerdeInfo": { + "SerializationLibrary": "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + "Parameters": {"serialization.format": "1"}, }, - 'BucketColumns': [], - 'SortColumns': [], - 'Parameters': {}, - 'SkewedInfo': { - 'SkewedColumnNames': [], - 'SkewedColumnValues': [], - 'SkewedColumnValueLocationMaps': {}, + "BucketColumns": [], + "SortColumns": [], + "Parameters": {}, + "SkewedInfo": { + "SkewedColumnNames": [], + "SkewedColumnValues": [], + "SkewedColumnValueLocationMaps": {}, }, - 'StoredAsSubDirectories': False, + "StoredAsSubDirectories": False, + }, + "PartitionKeys": [{"Name": "category", "Type": "int"}], + "TableType": "EXTERNAL_TABLE", + "Parameters": { + "EXTERNAL": "TRUE", + "transient_lastDdlTime": "1537505313", }, - 'PartitionKeys': [{'Name': 'category', 'Type': 'int'}], - 'TableType': 'EXTERNAL_TABLE', - 'Parameters': {'EXTERNAL': 'TRUE', 'transient_lastDdlTime': '1537505313'}, } ] }, - {'DatabaseName': 'test1'}, + {"DatabaseName": "test1"}, ) with self.stubber: - assert query_runner.get_schema() == [{'columns': ['sk', 'category'], 'name': 'test1.partitioned_table'}] + assert query_runner.get_schema() == [ + {"columns": ["sk", "category"], "name": "test1.partitioned_table"} + ] def test_view(self): - query_runner = Athena({'glue': True, 'region': 'mars-east-1'}) + query_runner = Athena({"glue": True, "region": "mars-east-1"}) - self.stubber.add_response('get_databases', {'DatabaseList': [{'Name': 'test1'}]}, {}) self.stubber.add_response( - 'get_tables', + "get_databases", {"DatabaseList": [{"Name": "test1"}]}, {} + ) + self.stubber.add_response( + "get_tables", { - 'TableList': [ + "TableList": [ { - 'Name': 'view', - 'StorageDescriptor': { - 'Columns': [{'Name': 'sk', 'Type': 'int'}], - 'Location': '', - 'Compressed': False, - 'NumberOfBuckets': 0, - 'SerdeInfo': {}, - 'SortColumns': [], - 'StoredAsSubDirectories': False, + "Name": "view", + "StorageDescriptor": { + "Columns": [{"Name": "sk", "Type": "int"}], + "Location": "", + "Compressed": False, + "NumberOfBuckets": 0, + "SerdeInfo": {}, + "SortColumns": [], + "StoredAsSubDirectories": False, }, - 'PartitionKeys': [], - 'ViewOriginalText': '/* Presto View: ... */', - 'ViewExpandedText': '/* Presto View */', - 'TableType': 'VIRTUAL_VIEW', - 'Parameters': {'comment': 'Presto View', 'presto_view': 'true'}, + "PartitionKeys": [], + "ViewOriginalText": "/* Presto View: ... */", + "ViewExpandedText": "/* Presto View */", + "TableType": "VIRTUAL_VIEW", + "Parameters": {"comment": "Presto View", "presto_view": "true"}, } ] }, - {'DatabaseName': 'test1'}, + {"DatabaseName": "test1"}, ) with self.stubber: - assert query_runner.get_schema() == [{'columns': ['sk'], 'name': 'test1.view'}] + assert query_runner.get_schema() == [ + {"columns": ["sk"], "name": "test1.view"} + ] def test_dodgy_table_does_not_break_schema_listing(self): """ @@ -158,33 +176,40 @@ def test_dodgy_table_does_not_break_schema_listing(self): This may be a Athena Catalog to Glue catalog migration issue. """ - query_runner = Athena({'glue': True, 'region': 'mars-east-1'}) + query_runner = Athena({"glue": True, "region": "mars-east-1"}) - self.stubber.add_response('get_databases', {'DatabaseList': [{'Name': 'test1'}]}, {}) self.stubber.add_response( - 'get_tables', + "get_databases", {"DatabaseList": [{"Name": "test1"}]}, {} + ) + self.stubber.add_response( + "get_tables", { - 'TableList': [ + "TableList": [ { - 'Name': 'csv', - 'StorageDescriptor': { - 'Columns': [{'Name': 'region', 'Type': 'string'}], - 'Location': 's3://bucket/files/', - 'InputFormat': 'org.apache.hadoop.mapred.TextInputFormat', - 'Compressed': False, - 'NumberOfBuckets': 0, - 'SerdeInfo': { - 'SerializationLibrary': 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe', - 'Parameters': {'field.delim': '|', 'skip.header.line.count': '1'}, + "Name": "csv", + "StorageDescriptor": { + "Columns": [{"Name": "region", "Type": "string"}], + "Location": "s3://bucket/files/", + "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", + "Compressed": False, + "NumberOfBuckets": 0, + "SerdeInfo": { + "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", + "Parameters": { + "field.delim": "|", + "skip.header.line.count": "1", + }, }, - 'SortColumns': [], - 'StoredAsSubDirectories': False, + "SortColumns": [], + "StoredAsSubDirectories": False, }, - 'Parameters': {'classification': 'csv'}, + "Parameters": {"classification": "csv"}, } ] }, - {'DatabaseName': 'test1'}, + {"DatabaseName": "test1"}, ) with self.stubber: - assert query_runner.get_schema() == [{'columns': ['region'], 'name': 'test1.csv'}] + assert query_runner.get_schema() == [ + {"columns": ["region"], "name": "test1.csv"} + ] diff --git a/tests/query_runner/test_drill.py b/tests/query_runner/test_drill.py index 9cfe7f1ecc..ca0b0e4d44 100644 --- a/tests/query_runner/test_drill.py +++ b/tests/query_runner/test_drill.py @@ -1,90 +1,98 @@ import datetime from unittest import TestCase -from redash.query_runner import TYPE_DATETIME, TYPE_FLOAT, TYPE_INTEGER, TYPE_BOOLEAN, TYPE_STRING +from redash.query_runner import ( + TYPE_DATETIME, + TYPE_FLOAT, + TYPE_INTEGER, + TYPE_BOOLEAN, + TYPE_STRING, +) from redash.query_runner.drill import convert_type, parse_response class TestConvertType(TestCase): def test_converts_booleans(self): - self.assertEqual(convert_type('true', TYPE_BOOLEAN), True) - self.assertEqual(convert_type('True', TYPE_BOOLEAN), True) - self.assertEqual(convert_type('TRUE', TYPE_BOOLEAN), True) - self.assertEqual(convert_type('false', TYPE_BOOLEAN), False) - self.assertEqual(convert_type('False', TYPE_BOOLEAN), False) - self.assertEqual(convert_type('FALSE', TYPE_BOOLEAN), False) + self.assertEqual(convert_type("true", TYPE_BOOLEAN), True) + self.assertEqual(convert_type("True", TYPE_BOOLEAN), True) + self.assertEqual(convert_type("TRUE", TYPE_BOOLEAN), True) + self.assertEqual(convert_type("false", TYPE_BOOLEAN), False) + self.assertEqual(convert_type("False", TYPE_BOOLEAN), False) + self.assertEqual(convert_type("FALSE", TYPE_BOOLEAN), False) def test_converts_strings(self): - self.assertEqual(convert_type('Текст', TYPE_STRING), 'Текст') - self.assertEqual(convert_type(None, TYPE_STRING), '') - self.assertEqual(convert_type('', TYPE_STRING), '') - self.assertEqual(convert_type('redash', TYPE_STRING), 'redash') + self.assertEqual(convert_type("Текст", TYPE_STRING), "Текст") + self.assertEqual(convert_type(None, TYPE_STRING), "") + self.assertEqual(convert_type("", TYPE_STRING), "") + self.assertEqual(convert_type("redash", TYPE_STRING), "redash") def test_converts_integer(self): - self.assertEqual(convert_type('42', TYPE_INTEGER), 42) + self.assertEqual(convert_type("42", TYPE_INTEGER), 42) def test_converts_float(self): - self.assertAlmostEqual(convert_type('3.14', TYPE_FLOAT), 3.14, 2) + self.assertAlmostEqual(convert_type("3.14", TYPE_FLOAT), 3.14, 2) def test_converts_date(self): - self.assertEqual(convert_type('2018-10-31', TYPE_DATETIME), datetime.datetime(2018, 10, 31, 0, 0)) + self.assertEqual( + convert_type("2018-10-31", TYPE_DATETIME), + datetime.datetime(2018, 10, 31, 0, 0), + ) -empty_response = { - 'columns': [], - 'rows': [{}] -} + +empty_response = {"columns": [], "rows": [{}]} regular_response = { - 'columns': ['key', 'date', 'count', 'avg'], - 'rows': [ - {'key': 'Alpha', 'date': '2018-01-01', 'count': '10', 'avg': '3.14'}, - {'key': 'Beta', 'date': '2018-02-01', 'count': '20', 'avg': '6.28'} - ] + "columns": ["key", "date", "count", "avg"], + "rows": [ + {"key": "Alpha", "date": "2018-01-01", "count": "10", "avg": "3.14"}, + {"key": "Beta", "date": "2018-02-01", "count": "20", "avg": "6.28"}, + ], } + class TestParseResponse(TestCase): def test_parse_empty_reponse(self): parsed = parse_response(empty_response) self.assertIsInstance(parsed, dict) - self.assertIsNotNone(parsed['columns']) - self.assertIsNotNone(parsed['rows']) - self.assertEqual(len(parsed['columns']), 0) - self.assertEqual(len(parsed['rows']), 0) + self.assertIsNotNone(parsed["columns"]) + self.assertIsNotNone(parsed["rows"]) + self.assertEqual(len(parsed["columns"]), 0) + self.assertEqual(len(parsed["rows"]), 0) def test_parse_regular_response(self): parsed = parse_response(regular_response) self.assertIsInstance(parsed, dict) - self.assertIsNotNone(parsed['columns']) - self.assertIsNotNone(parsed['rows']) - self.assertEqual(len(parsed['columns']), 4) - self.assertEqual(len(parsed['rows']), 2) - - key_col = parsed['columns'][0] - self.assertEqual(key_col['name'], 'key') - self.assertEqual(key_col['type'], TYPE_STRING) - - date_col = parsed['columns'][1] - self.assertEqual(date_col['name'], 'date') - self.assertEqual(date_col['type'], TYPE_DATETIME) - - count_col = parsed['columns'][2] - self.assertEqual(count_col['name'], 'count') - self.assertEqual(count_col['type'], TYPE_INTEGER) - - avg_col = parsed['columns'][3] - self.assertEqual(avg_col['name'], 'avg') - self.assertEqual(avg_col['type'], TYPE_FLOAT) - - row_0 = parsed['rows'][0] - self.assertEqual(row_0['key'], 'Alpha') - self.assertEqual(row_0['date'], datetime.datetime(2018, 1, 1, 0, 0)) - self.assertEqual(row_0['count'], 10) - self.assertAlmostEqual(row_0['avg'], 3.14, 2) - - row_1 = parsed['rows'][1] - self.assertEqual(row_1['key'], 'Beta') - self.assertEqual(row_1['date'], datetime.datetime(2018, 2, 1, 0, 0)) - self.assertEqual(row_1['count'], 20) - self.assertAlmostEqual(row_1['avg'], 6.28, 2) + self.assertIsNotNone(parsed["columns"]) + self.assertIsNotNone(parsed["rows"]) + self.assertEqual(len(parsed["columns"]), 4) + self.assertEqual(len(parsed["rows"]), 2) + + key_col = parsed["columns"][0] + self.assertEqual(key_col["name"], "key") + self.assertEqual(key_col["type"], TYPE_STRING) + + date_col = parsed["columns"][1] + self.assertEqual(date_col["name"], "date") + self.assertEqual(date_col["type"], TYPE_DATETIME) + + count_col = parsed["columns"][2] + self.assertEqual(count_col["name"], "count") + self.assertEqual(count_col["type"], TYPE_INTEGER) + + avg_col = parsed["columns"][3] + self.assertEqual(avg_col["name"], "avg") + self.assertEqual(avg_col["type"], TYPE_FLOAT) + + row_0 = parsed["rows"][0] + self.assertEqual(row_0["key"], "Alpha") + self.assertEqual(row_0["date"], datetime.datetime(2018, 1, 1, 0, 0)) + self.assertEqual(row_0["count"], 10) + self.assertAlmostEqual(row_0["avg"], 3.14, 2) + + row_1 = parsed["rows"][1] + self.assertEqual(row_1["key"], "Beta") + self.assertEqual(row_1["date"], datetime.datetime(2018, 2, 1, 0, 0)) + self.assertEqual(row_1["count"], 20) + self.assertAlmostEqual(row_1["avg"], 6.28, 2) diff --git a/tests/query_runner/test_google_spreadsheets.py b/tests/query_runner/test_google_spreadsheets.py index ad236b883d..abce4ddb3d 100644 --- a/tests/query_runner/test_google_spreadsheets.py +++ b/tests/query_runner/test_google_spreadsheets.py @@ -4,34 +4,56 @@ from mock import MagicMock from redash.query_runner import TYPE_DATETIME, TYPE_FLOAT -from redash.query_runner.google_spreadsheets import TYPE_BOOLEAN, TYPE_STRING, _get_columns_and_column_names, _value_eval_list, is_url_key, parse_query -from redash.query_runner.google_spreadsheets import WorksheetNotFoundError, parse_spreadsheet, parse_worksheet +from redash.query_runner.google_spreadsheets import ( + TYPE_BOOLEAN, + TYPE_STRING, + _get_columns_and_column_names, + _value_eval_list, + is_url_key, + parse_query, +) +from redash.query_runner.google_spreadsheets import ( + WorksheetNotFoundError, + parse_spreadsheet, + parse_worksheet, +) class TestValueEvalList(TestCase): def test_handles_unicode(self): - values = ['יוניקוד', 'test', 'value'] - self.assertEqual(values, _value_eval_list(values, [TYPE_STRING]*len(values))) + values = ["יוניקוד", "test", "value"] + self.assertEqual(values, _value_eval_list(values, [TYPE_STRING] * len(values))) def test_handles_boolean(self): - values = ['true', 'false', 'True', 'False', 'TRUE', 'FALSE'] + values = ["true", "false", "True", "False", "TRUE", "FALSE"] converted_values = [True, False, True, False, True, False] - self.assertEqual(converted_values, _value_eval_list(values, [TYPE_BOOLEAN]*len(values))) + self.assertEqual( + converted_values, _value_eval_list(values, [TYPE_BOOLEAN] * len(values)) + ) def test_handles_empty_values(self): - values = ['', None] + values = ["", None] converted_values = [None, None] - self.assertEqual(converted_values, _value_eval_list(values, [TYPE_STRING, TYPE_STRING])) + self.assertEqual( + converted_values, _value_eval_list(values, [TYPE_STRING, TYPE_STRING]) + ) def test_handles_float(self): - values = ['3.14', '-273.15'] + values = ["3.14", "-273.15"] converted_values = [3.14, -273.15] - self.assertEqual(converted_values, _value_eval_list(values, [TYPE_FLOAT, TYPE_FLOAT])) + self.assertEqual( + converted_values, _value_eval_list(values, [TYPE_FLOAT, TYPE_FLOAT]) + ) def test_handles_datetime(self): - values = ['2018-06-28', '2020-2-29'] - converted_values = [datetime.datetime(2018, 6, 28, 0, 0), datetime.datetime(2020, 2, 29, 0, 0)] - self.assertEqual(converted_values, _value_eval_list(values, [TYPE_DATETIME, TYPE_DATETIME])) + values = ["2018-06-28", "2020-2-29"] + converted_values = [ + datetime.datetime(2018, 6, 28, 0, 0), + datetime.datetime(2020, 2, 29, 0, 0), + ] + self.assertEqual( + converted_values, _value_eval_list(values, [TYPE_DATETIME, TYPE_DATETIME]) + ) class TestParseSpreadsheet(TestCase): @@ -46,8 +68,14 @@ def test_returns_meaningful_error_for_missing_worksheet(self): empty_worksheet = [] -only_headers_worksheet = [['Column A', 'Column B']] -regular_worksheet = [['String Column', 'Boolean Column', 'Number Column'], ['A', 'TRUE', '1'], ['B', 'FALSE', '2'], ['C', 'TRUE', '3'], ['D', 'FALSE', '4']] +only_headers_worksheet = [["Column A", "Column B"]] +regular_worksheet = [ + ["String Column", "Boolean Column", "Number Column"], + ["A", "TRUE", "1"], + ["B", "FALSE", "2"], + ["C", "TRUE", "3"], + ["D", "FALSE", "4"], +] # The following test that the parse function doesn't crash. They don't test correct output. @@ -62,49 +90,55 @@ def test_parse_regular_worksheet(self): parse_worksheet(regular_worksheet) def test_parse_worksheet_with_duplicate_column_names(self): - worksheet = [['Column', 'Another Column', 'Column'], ['A', 'TRUE', '1'], ['B', 'FALSE', '2'], ['C', 'TRUE', '3'], ['D', 'FALSE', '4']] + worksheet = [ + ["Column", "Another Column", "Column"], + ["A", "TRUE", "1"], + ["B", "FALSE", "2"], + ["C", "TRUE", "3"], + ["D", "FALSE", "4"], + ] parsed = parse_worksheet(worksheet) - columns = [column['name'] for column in parsed['columns']] - self.assertEqual('Column', columns[0]) - self.assertEqual('Another Column', columns[1]) - self.assertEqual('Column1', columns[2]) + columns = [column["name"] for column in parsed["columns"]] + self.assertEqual("Column", columns[0]) + self.assertEqual("Another Column", columns[1]) + self.assertEqual("Column1", columns[2]) - self.assertEqual('A', parsed['rows'][0]['Column']) - self.assertEqual(True, parsed['rows'][0]['Another Column']) - self.assertEqual(1, parsed['rows'][0]['Column1']) + self.assertEqual("A", parsed["rows"][0]["Column"]) + self.assertEqual(True, parsed["rows"][0]["Another Column"]) + self.assertEqual(1, parsed["rows"][0]["Column1"]) class TestParseQuery(TestCase): def test_parse_query(self): - parsed = parse_query('key|0') - self.assertEqual(('key', 0), parsed) + parsed = parse_query("key|0") + self.assertEqual(("key", 0), parsed) class TestGetColumnsAndColumnNames(TestCase): def test_get_columns(self): - _columns = ['foo', 'bar', 'baz'] + _columns = ["foo", "bar", "baz"] columns, column_names = _get_columns_and_column_names(_columns) self.assertEqual(_columns, column_names) def test_get_columns_with_duplicated(self): - _columns = ['foo', 'bar', 'baz', 'foo', 'baz'] + _columns = ["foo", "bar", "baz", "foo", "baz"] columns, column_names = _get_columns_and_column_names(_columns) - self.assertEqual(['foo', 'bar', 'baz', 'foo1', 'baz2'], column_names) + self.assertEqual(["foo", "bar", "baz", "foo1", "baz2"], column_names) def test_get_columns_with_blank(self): - _columns = ['foo', '', 'baz', ''] + _columns = ["foo", "", "baz", ""] columns, column_names = _get_columns_and_column_names(_columns) - self.assertEqual(['foo', 'column_B', 'baz', 'column_D'], column_names) + self.assertEqual(["foo", "column_B", "baz", "column_D"], column_names) class TestIsUrlKey(TestCase): def test_is_url_key(self): - _key = 'https://docs.google.com/spreadsheets/d/key/edit#gid=12345678' + _key = "https://docs.google.com/spreadsheets/d/key/edit#gid=12345678" self.assertTrue(is_url_key(_key)) - _key = 'key|0' + _key = "key|0" self.assertFalse(is_url_key(_key)) diff --git a/tests/query_runner/test_http.py b/tests/query_runner/test_http.py index c410e134f4..0ea7a1db09 100644 --- a/tests/query_runner/test_http.py +++ b/tests/query_runner/test_http.py @@ -10,24 +10,22 @@ class RequiresAuthQueryRunner(BaseHTTPQueryRunner): class TestBaseHTTPQueryRunner(TestCase): - def test_requires_authentication_default(self): self.assertFalse(BaseHTTPQueryRunner.requires_authentication) schema = BaseHTTPQueryRunner.configuration_schema() - self.assertNotIn('username', schema['required']) - self.assertNotIn('password', schema['required']) + self.assertNotIn("username", schema["required"]) + self.assertNotIn("password", schema["required"]) def test_requires_authentication_true(self): schema = RequiresAuthQueryRunner.configuration_schema() - self.assertIn('username', schema['required']) - self.assertIn('password', schema['required']) + self.assertIn("username", schema["required"]) + self.assertIn("password", schema["required"]) def test_get_auth_with_values(self): - query_runner = BaseHTTPQueryRunner({ - 'username': 'username', - 'password': 'password' - }) - self.assertEqual(query_runner.get_auth(), ('username', 'password')) + query_runner = BaseHTTPQueryRunner( + {"username": "username", "password": "password"} + ) + self.assertEqual(query_runner.get_auth(), ("username", "password")) def test_get_auth_empty(self): query_runner = BaseHTTPQueryRunner({}) @@ -36,54 +34,52 @@ def test_get_auth_empty(self): def test_get_auth_empty_requires_authentication(self): query_runner = RequiresAuthQueryRunner({}) self.assertRaisesRegex( - ValueError, - "Username and Password required", - query_runner.get_auth + ValueError, "Username and Password required", query_runner.get_auth ) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_success(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 200 mock_response.text = "Success" mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) response, error = query_runner.get_response(url) - mock_get.assert_called_once_with('get', url, auth=None) + mock_get.assert_called_once_with("get", url, auth=None) self.assertEqual(response.status_code, 200) self.assertIsNone(error) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_success_custom_auth(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 200 mock_response.text = "Success" mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) - auth = ('username', 'password') + auth = ("username", "password") response, error = query_runner.get_response(url, auth=auth) - mock_get.assert_called_once_with('get', url, auth=auth) + mock_get.assert_called_once_with("get", url, auth=auth) self.assertEqual(response.status_code, 200) self.assertIsNone(error) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_failure(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 301 mock_response.text = "Redirect" mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) response, error = query_runner.get_response(url) - mock_get.assert_called_once_with('get', url, auth=None) + mock_get.assert_called_once_with("get", url, auth=None) self.assertIn(query_runner.response_error, error) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_httperror_exception(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 500 @@ -92,14 +88,14 @@ def test_get_response_httperror_exception(self, mock_get): mock_response.raise_for_status.side_effect = http_error mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) response, error = query_runner.get_response(url) - mock_get.assert_called_once_with('get', url, auth=None) + mock_get.assert_called_once_with("get", url, auth=None) self.assertIsNotNone(error) self.assertIn("Failed to execute query", error) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_requests_exception(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 500 @@ -109,14 +105,14 @@ def test_get_response_requests_exception(self, mock_get): mock_response.raise_for_status.side_effect = requests_exception mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) response, error = query_runner.get_response(url) - mock_get.assert_called_once_with('get', url, auth=None) + mock_get.assert_called_once_with("get", url, auth=None) self.assertIsNotNone(error) self.assertEqual(exception_message, error) - @mock.patch('requests.request') + @mock.patch("requests.request") def test_get_response_generic_exception(self, mock_get): mock_response = mock.Mock() mock_response.status_code = 500 @@ -126,11 +122,8 @@ def test_get_response_generic_exception(self, mock_get): mock_response.raise_for_status.side_effect = exception mock_get.return_value = mock_response - url = 'https://example.com/' + url = "https://example.com/" query_runner = BaseHTTPQueryRunner({}) self.assertRaisesRegex( - ValueError, - exception_message, - query_runner.get_response, - url + ValueError, exception_message, query_runner.get_response, url ) diff --git a/tests/query_runner/test_jql.py b/tests/query_runner/test_jql.py index 18428d1ab2..0bb01a0f84 100644 --- a/tests/query_runner/test_jql.py +++ b/tests/query_runner/test_jql.py @@ -3,102 +3,130 @@ class TestFieldMapping(TestCase): - def test_empty(self): field_mapping = FieldMapping({}) - self.assertEqual(field_mapping.get_output_field_name('field1'), 'field1') - self.assertEqual(field_mapping.get_dict_output_field_name('field1','member1'), None) - self.assertEqual(field_mapping.get_dict_members('field1'), []) + self.assertEqual(field_mapping.get_output_field_name("field1"), "field1") + self.assertEqual( + field_mapping.get_dict_output_field_name("field1", "member1"), None + ) + self.assertEqual(field_mapping.get_dict_members("field1"), []) def test_with_mappings(self): - field_mapping = FieldMapping({ - 'field1': 'output_name_1', - 'field2.member1': 'output_name_2', - 'field2.member2': 'output_name_3' - }) - - self.assertEqual(field_mapping.get_output_field_name('field1'), 'output_name_1') - self.assertEqual(field_mapping.get_dict_output_field_name('field1','member1'), None) - self.assertEqual(field_mapping.get_dict_members('field1'), []) - - self.assertEqual(field_mapping.get_output_field_name('field2'), 'field2') - self.assertEqual(field_mapping.get_dict_output_field_name('field2','member1'), 'output_name_2') - self.assertEqual(field_mapping.get_dict_output_field_name('field2','member2'), 'output_name_3') - self.assertEqual(field_mapping.get_dict_output_field_name('field2','member3'), None) - self.assertEqual(field_mapping.get_dict_members('field2'), ['member1','member2']) + field_mapping = FieldMapping( + { + "field1": "output_name_1", + "field2.member1": "output_name_2", + "field2.member2": "output_name_3", + } + ) + + self.assertEqual(field_mapping.get_output_field_name("field1"), "output_name_1") + self.assertEqual( + field_mapping.get_dict_output_field_name("field1", "member1"), None + ) + self.assertEqual(field_mapping.get_dict_members("field1"), []) + + self.assertEqual(field_mapping.get_output_field_name("field2"), "field2") + self.assertEqual( + field_mapping.get_dict_output_field_name("field2", "member1"), + "output_name_2", + ) + self.assertEqual( + field_mapping.get_dict_output_field_name("field2", "member2"), + "output_name_3", + ) + self.assertEqual( + field_mapping.get_dict_output_field_name("field2", "member3"), None + ) + self.assertEqual( + field_mapping.get_dict_members("field2"), ["member1", "member2"] + ) class TestParseIssue(TestCase): issue = { - 'key': 'KEY-1', - 'fields': { - 'string_field': 'value1', - 'int_field': 123, - 'string_list_field': ['value1','value2'], - 'dict_field': {'member1':'value1','member2': 'value2'}, - 'dict_list_field': [ - {'member1':'value1a','member2': 'value2a'}, - {'member1':'value1b','member2': 'value2b'} + "key": "KEY-1", + "fields": { + "string_field": "value1", + "int_field": 123, + "string_list_field": ["value1", "value2"], + "dict_field": {"member1": "value1", "member2": "value2"}, + "dict_list_field": [ + {"member1": "value1a", "member2": "value2a"}, + {"member1": "value1b", "member2": "value2b"}, ], - 'dict_legacy': {'key':'legacyKey','name':'legacyName','dict_legacy':'legacyValue'}, - 'watchers': {'watchCount':10} - } + "dict_legacy": { + "key": "legacyKey", + "name": "legacyName", + "dict_legacy": "legacyValue", + }, + "watchers": {"watchCount": 10}, + }, } def test_no_mapping(self): result = parse_issue(self.issue, FieldMapping({})) - self.assertEqual(result['key'], 'KEY-1') - self.assertEqual(result['string_field'], 'value1') - self.assertEqual(result['int_field'], 123) - self.assertEqual(result['string_list_field'], 'value1,value2') - self.assertEqual('dict_field' in result, False) - self.assertEqual('dict_list_field' in result, False) - self.assertEqual(result['dict_legacy'], 'legacyValue') - self.assertEqual(result['dict_legacy_key'], 'legacyKey') - self.assertEqual(result['dict_legacy_name'], 'legacyName') - self.assertEqual(result['watchers'], 10) + self.assertEqual(result["key"], "KEY-1") + self.assertEqual(result["string_field"], "value1") + self.assertEqual(result["int_field"], 123) + self.assertEqual(result["string_list_field"], "value1,value2") + self.assertEqual("dict_field" in result, False) + self.assertEqual("dict_list_field" in result, False) + self.assertEqual(result["dict_legacy"], "legacyValue") + self.assertEqual(result["dict_legacy_key"], "legacyKey") + self.assertEqual(result["dict_legacy_name"], "legacyName") + self.assertEqual(result["watchers"], 10) def test_mapping(self): - result = parse_issue(self.issue, FieldMapping({ - 'string_field': 'string_output_field', - 'string_list_field': 'string_output_list_field', - 'dict_field.member1': 'dict_field_1', - 'dict_field.member2': 'dict_field_2', - 'dict_list_field.member1': 'dict_list_field_1', - 'dict_legacy.key': 'dict_legacy', - 'watchers.watchCount': 'watchCount', - })) - - self.assertEqual(result['key'], 'KEY-1') - self.assertEqual(result['string_output_field'], 'value1') - self.assertEqual(result['int_field'], 123) - self.assertEqual(result['string_output_list_field'], 'value1,value2') - self.assertEqual(result['dict_field_1'], 'value1') - self.assertEqual(result['dict_field_2'], 'value2') - self.assertEqual(result['dict_list_field_1'], 'value1a,value1b') - self.assertEqual(result['dict_legacy'], 'legacyKey') - self.assertEqual('dict_legacy_key' in result, False) - self.assertEqual('dict_legacy_name' in result, False) - self.assertEqual('watchers' in result, False) - self.assertEqual(result['watchCount'], 10) - + result = parse_issue( + self.issue, + FieldMapping( + { + "string_field": "string_output_field", + "string_list_field": "string_output_list_field", + "dict_field.member1": "dict_field_1", + "dict_field.member2": "dict_field_2", + "dict_list_field.member1": "dict_list_field_1", + "dict_legacy.key": "dict_legacy", + "watchers.watchCount": "watchCount", + } + ), + ) + + self.assertEqual(result["key"], "KEY-1") + self.assertEqual(result["string_output_field"], "value1") + self.assertEqual(result["int_field"], 123) + self.assertEqual(result["string_output_list_field"], "value1,value2") + self.assertEqual(result["dict_field_1"], "value1") + self.assertEqual(result["dict_field_2"], "value2") + self.assertEqual(result["dict_list_field_1"], "value1a,value1b") + self.assertEqual(result["dict_legacy"], "legacyKey") + self.assertEqual("dict_legacy_key" in result, False) + self.assertEqual("dict_legacy_name" in result, False) + self.assertEqual("watchers" in result, False) + self.assertEqual(result["watchCount"], 10) def test_mapping_nonexisting_field(self): - result = parse_issue(self.issue, FieldMapping({ - 'non_existing_field': 'output_name1', - 'dict_field.non_existing_member': 'output_name2', - 'dict_list_field.non_existing_member': 'output_name3' - })) - - self.assertEqual(result['key'], 'KEY-1') - self.assertEqual(result['string_field'], 'value1') - self.assertEqual(result['int_field'], 123) - self.assertEqual(result['string_list_field'], 'value1,value2') - self.assertEqual('dict_field' in result, False) - self.assertEqual('dict_list_field' in result, False) - self.assertEqual(result['dict_legacy'], 'legacyValue') - self.assertEqual(result['dict_legacy_key'], 'legacyKey') - self.assertEqual(result['dict_legacy_name'], 'legacyName') - self.assertEqual(result['watchers'], 10) + result = parse_issue( + self.issue, + FieldMapping( + { + "non_existing_field": "output_name1", + "dict_field.non_existing_member": "output_name2", + "dict_list_field.non_existing_member": "output_name3", + } + ), + ) + + self.assertEqual(result["key"], "KEY-1") + self.assertEqual(result["string_field"], "value1") + self.assertEqual(result["int_field"], 123) + self.assertEqual(result["string_list_field"], "value1,value2") + self.assertEqual("dict_field" in result, False) + self.assertEqual("dict_list_field" in result, False) + self.assertEqual(result["dict_legacy"], "legacyValue") + self.assertEqual(result["dict_legacy_key"], "legacyKey") + self.assertEqual(result["dict_legacy_name"], "legacyName") + self.assertEqual(result["watchers"], 10) diff --git a/tests/query_runner/test_mongodb.py b/tests/query_runner/test_mongodb.py index a4dd3b09c2..bb5d2ce5ea 100644 --- a/tests/query_runner/test_mongodb.py +++ b/tests/query_runner/test_mongodb.py @@ -4,56 +4,51 @@ from pytz import utc from freezegun import freeze_time -from redash.query_runner.mongodb import parse_query_json, parse_results, _get_column_by_name +from redash.query_runner.mongodb import ( + parse_query_json, + parse_results, + _get_column_by_name, +) from redash.utils import json_dumps, parse_human_time class TestParseQueryJson(TestCase): def test_ignores_non_isodate_fields(self): - query = { - 'test': 1, - 'test_list': ['a', 'b', 'c'], - 'test_dict': { - 'a': 1, - 'b': 2 - } - } + query = {"test": 1, "test_list": ["a", "b", "c"], "test_dict": {"a": 1, "b": 2}} query_data = parse_query_json(json_dumps(query)) self.assertDictEqual(query_data, query) def test_parses_isodate_fields(self): query = { - 'test': 1, - 'test_list': ['a', 'b', 'c'], - 'test_dict': { - 'a': 1, - 'b': 2 - }, - 'testIsoDate': "ISODate(\"2014-10-03T00:00\")" + "test": 1, + "test_list": ["a", "b", "c"], + "test_dict": {"a": 1, "b": 2}, + "testIsoDate": 'ISODate("2014-10-03T00:00")', } query_data = parse_query_json(json_dumps(query)) - self.assertEqual(query_data['testIsoDate'], datetime.datetime(2014, 10, 3, 0, 0)) + self.assertEqual( + query_data["testIsoDate"], datetime.datetime(2014, 10, 3, 0, 0) + ) def test_parses_isodate_in_nested_fields(self): query = { - 'test': 1, - 'test_list': ['a', 'b', 'c'], - 'test_dict': { - 'a': 1, - 'b': { - 'date': "ISODate(\"2014-10-04T00:00\")" - } - }, - 'testIsoDate': "ISODate(\"2014-10-03T00:00\")" + "test": 1, + "test_list": ["a", "b", "c"], + "test_dict": {"a": 1, "b": {"date": 'ISODate("2014-10-04T00:00")'}}, + "testIsoDate": 'ISODate("2014-10-03T00:00")', } query_data = parse_query_json(json_dumps(query)) - self.assertEqual(query_data['testIsoDate'], datetime.datetime(2014, 10, 3, 0, 0)) - self.assertEqual(query_data['test_dict']['b']['date'], datetime.datetime(2014, 10, 4, 0, 0)) + self.assertEqual( + query_data["testIsoDate"], datetime.datetime(2014, 10, 3, 0, 0) + ) + self.assertEqual( + query_data["test_dict"]["b"]["date"], datetime.datetime(2014, 10, 4, 0, 0) + ) def test_handles_nested_fields(self): # https://github.com/getredash/redash/issues/597 @@ -62,14 +57,17 @@ def test_handles_nested_fields(self): "aggregate": [ { "$geoNear": { - "near": {"type": "Point", "coordinates": [-22.910079, -43.205161]}, + "near": { + "type": "Point", + "coordinates": [-22.910079, -43.205161], + }, "maxDistance": 100000000, "distanceField": "dist.calculated", "includeLocs": "dist.location", - "spherical": True + "spherical": True, } } - ] + ], } query_data = parse_query_json(json_dumps(query)) @@ -78,71 +76,75 @@ def test_handles_nested_fields(self): def test_supports_extended_json_types(self): query = { - 'test': 1, - 'test_list': ['a', 'b', 'c'], - 'test_dict': { - 'a': 1, - 'b': 2 - }, - 'testIsoDate': "ISODate(\"2014-10-03T00:00\")", - 'test$date': { - '$date': '2014-10-03T00:00:00.0' - }, - 'test$undefined': { - '$undefined': None - } + "test": 1, + "test_list": ["a", "b", "c"], + "test_dict": {"a": 1, "b": 2}, + "testIsoDate": 'ISODate("2014-10-03T00:00")', + "test$date": {"$date": "2014-10-03T00:00:00.0"}, + "test$undefined": {"$undefined": None}, } query_data = parse_query_json(json_dumps(query)) - self.assertEqual(query_data['test$undefined'], None) - self.assertEqual(query_data['test$date'], datetime.datetime(2014, 10, 3, 0, 0).replace(tzinfo=utc)) + self.assertEqual(query_data["test$undefined"], None) + self.assertEqual( + query_data["test$date"], + datetime.datetime(2014, 10, 3, 0, 0).replace(tzinfo=utc), + ) - @freeze_time('2019-01-01 12:00:00') + @freeze_time("2019-01-01 12:00:00") def test_supports_relative_timestamps(self): - query = { - 'ts': {'$humanTime': '1 hour ago'} - } + query = {"ts": {"$humanTime": "1 hour ago"}} one_hour_ago = parse_human_time("1 hour ago") query_data = parse_query_json(json_dumps(query)) - self.assertEqual(query_data['ts'], one_hour_ago) + self.assertEqual(query_data["ts"], one_hour_ago) class TestMongoResults(TestCase): def test_parses_regular_results(self): raw_results = [ - {'column': 1, 'column2': 'test'}, - {'column': 2, 'column2': 'test', 'column3': 'hello'} + {"column": 1, "column2": "test"}, + {"column": 2, "column2": "test", "column3": "hello"}, ] rows, columns = parse_results(raw_results) for i, row in enumerate(rows): self.assertDictEqual(row, raw_results[i]) - self.assertIsNotNone(_get_column_by_name(columns, 'column')) - self.assertIsNotNone(_get_column_by_name(columns, 'column2')) - self.assertIsNotNone(_get_column_by_name(columns, 'column3')) + self.assertIsNotNone(_get_column_by_name(columns, "column")) + self.assertIsNotNone(_get_column_by_name(columns, "column2")) + self.assertIsNotNone(_get_column_by_name(columns, "column3")) def test_parses_nested_results(self): raw_results = [ - {'column': 1, 'column2': 'test', 'nested': { - 'a': 1, - 'b': 'str' - }}, - {'column': 2, 'column2': 'test', 'column3': 'hello', 'nested': { - 'a': 2, - 'b': 'str2', - 'c': 'c' - }} + {"column": 1, "column2": "test", "nested": {"a": 1, "b": "str"}}, + { + "column": 2, + "column2": "test", + "column3": "hello", + "nested": {"a": 2, "b": "str2", "c": "c"}, + }, ] rows, columns = parse_results(raw_results) - self.assertDictEqual(rows[0], { 'column': 1, 'column2': 'test', 'nested.a': 1, 'nested.b': 'str' }) - self.assertDictEqual(rows[1], { 'column': 2, 'column2': 'test', 'column3': 'hello', 'nested.a': 2, 'nested.b': 'str2', 'nested.c': 'c' }) - - self.assertIsNotNone(_get_column_by_name(columns, 'column')) - self.assertIsNotNone(_get_column_by_name(columns, 'column2')) - self.assertIsNotNone(_get_column_by_name(columns, 'column3')) - self.assertIsNotNone(_get_column_by_name(columns, 'nested.a')) - self.assertIsNotNone(_get_column_by_name(columns, 'nested.b')) - self.assertIsNotNone(_get_column_by_name(columns, 'nested.c')) + self.assertDictEqual( + rows[0], {"column": 1, "column2": "test", "nested.a": 1, "nested.b": "str"} + ) + self.assertDictEqual( + rows[1], + { + "column": 2, + "column2": "test", + "column3": "hello", + "nested.a": 2, + "nested.b": "str2", + "nested.c": "c", + }, + ) + + self.assertIsNotNone(_get_column_by_name(columns, "column")) + self.assertIsNotNone(_get_column_by_name(columns, "column2")) + self.assertIsNotNone(_get_column_by_name(columns, "column3")) + self.assertIsNotNone(_get_column_by_name(columns, "nested.a")) + self.assertIsNotNone(_get_column_by_name(columns, "nested.b")) + self.assertIsNotNone(_get_column_by_name(columns, "nested.c")) diff --git a/tests/query_runner/test_pg.py b/tests/query_runner/test_pg.py index 6b244cfede..17fc35315b 100644 --- a/tests/query_runner/test_pg.py +++ b/tests/query_runner/test_pg.py @@ -1,4 +1,3 @@ - from unittest import TestCase from redash.query_runner.pg import build_schema @@ -6,10 +5,14 @@ class TestBuildSchema(TestCase): def test_handles_dups_between_public_and_other_schemas(self): results = { - 'rows': [ - {'table_schema': 'public', 'table_name': 'main.users', 'column_name': 'id'}, - {'table_schema': 'main', 'table_name': 'users', 'column_name': 'id'}, - {'table_schema': 'main', 'table_name': 'users', 'column_name': 'name'}, + "rows": [ + { + "table_schema": "public", + "table_name": "main.users", + "column_name": "id", + }, + {"table_schema": "main", "table_name": "users", "column_name": "id"}, + {"table_schema": "main", "table_name": "users", "column_name": "name"}, ] } @@ -17,7 +20,7 @@ def test_handles_dups_between_public_and_other_schemas(self): build_schema(results, schema) - self.assertIn('main.users', schema.keys()) - self.assertListEqual(schema['main.users']['columns'], ['id', 'name']) + self.assertIn("main.users", schema.keys()) + self.assertListEqual(schema["main.users"]["columns"], ["id", "name"]) self.assertIn('public."main.users"', schema.keys()) - self.assertListEqual(schema['public."main.users"']['columns'], ['id']) \ No newline at end of file + self.assertListEqual(schema['public."main.users"']["columns"], ["id"]) diff --git a/tests/query_runner/test_prometheus.py b/tests/query_runner/test_prometheus.py index 21c8dcdb92..302124e1ce 100644 --- a/tests/query_runner/test_prometheus.py +++ b/tests/query_runner/test_prometheus.py @@ -8,42 +8,24 @@ class TestPrometheus(TestCase): def setUp(self): self.instant_query_result = [ { - "metric": { - "name": "example_metric_name", - "foo_bar": "foo", - }, - "value": [1516937400.781, "7400_foo"] + "metric": {"name": "example_metric_name", "foo_bar": "foo"}, + "value": [1516937400.781, "7400_foo"], }, { - "metric": { - "name": "example_metric_name", - "foo_bar": "bar", - }, - "value": [1516937400.781, "7400_bar"] - } + "metric": {"name": "example_metric_name", "foo_bar": "bar"}, + "value": [1516937400.781, "7400_bar"], + }, ] self.range_query_result = [ { - "metric": { - "name": "example_metric_name", - "foo_bar": "foo", - }, - "values": [ - [1516937400.781, "7400_foo"], - [1516938000.781, "8000_foo"], - ] + "metric": {"name": "example_metric_name", "foo_bar": "foo"}, + "values": [[1516937400.781, "7400_foo"], [1516938000.781, "8000_foo"]], }, { - "metric": { - "name": "example_metric_name", - "foo_bar": "bar", - }, - "values": [ - [1516937400.781, "7400_bar"], - [1516938000.781, "8000_bar"], - ] - } + "metric": {"name": "example_metric_name", "foo_bar": "bar"}, + "values": [[1516937400.781, "7400_bar"], [1516938000.781, "8000_bar"]], + }, ] def test_get_instant_rows(self): @@ -52,13 +34,13 @@ def test_get_instant_rows(self): "name": "example_metric_name", "foo_bar": "foo", "timestamp": datetime.datetime.fromtimestamp(1516937400.781), - "value": "7400_foo" + "value": "7400_foo", }, { "name": "example_metric_name", "foo_bar": "bar", "timestamp": datetime.datetime.fromtimestamp(1516937400.781), - "value": "7400_bar" + "value": "7400_bar", }, ] @@ -72,25 +54,25 @@ def test_get_range_rows(self): "name": "example_metric_name", "foo_bar": "foo", "timestamp": datetime.datetime.fromtimestamp(1516937400.781), - "value": "7400_foo" + "value": "7400_foo", }, { "name": "example_metric_name", "foo_bar": "foo", "timestamp": datetime.datetime.fromtimestamp(1516938000.781), - "value": "8000_foo" + "value": "8000_foo", }, { "name": "example_metric_name", "foo_bar": "bar", "timestamp": datetime.datetime.fromtimestamp(1516937400.781), - "value": "7400_bar" + "value": "7400_bar", }, { "name": "example_metric_name", "foo_bar": "bar", "timestamp": datetime.datetime.fromtimestamp(1516938000.781), - "value": "8000_bar" + "value": "8000_bar", }, ] diff --git a/tests/query_runner/test_query_results.py b/tests/query_runner/test_query_results.py index 30662a5c7d..d64027b75c 100644 --- a/tests/query_runner/test_query_results.py +++ b/tests/query_runner/test_query_results.py @@ -4,8 +4,14 @@ import pytest from redash.query_runner.query_results import ( - CreateTableError, PermissionError, _load_query, create_table, - extract_cached_query_ids, extract_query_ids, fix_column_name) + CreateTableError, + PermissionError, + _load_query, + create_table, + extract_cached_query_ids, + extract_query_ids, + fix_column_name, +) from tests import BaseTestCase @@ -29,164 +35,95 @@ def test_finds_queries_with_whitespace_characters(self): class TestCreateTable(TestCase): def test_creates_table_with_colons_in_column_name(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': 'ga:newUsers' - }, { - 'name': 'test2' - }], - 'rows': [{ - 'ga:newUsers': 123, - 'test2': 2 - }] + "columns": [{"name": "ga:newUsers"}, {"name": "test2"}], + "rows": [{"ga:newUsers": 123, "test2": 2}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_creates_table_with_double_quotes_in_column_name(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': 'ga:newUsers' - }, { - 'name': '"test2"' - }], - 'rows': [{ - 'ga:newUsers': 123, - '"test2"': 2 - }] + "columns": [{"name": "ga:newUsers"}, {"name": '"test2"'}], + "rows": [{"ga:newUsers": 123, '"test2"': 2}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_creates_table(self): - connection = sqlite3.connect(':memory:') - results = { - 'columns': [{ - 'name': 'test1' - }, { - 'name': 'test2' - }], - 'rows': [] - } - table_name = 'query_123' + connection = sqlite3.connect(":memory:") + results = {"columns": [{"name": "test1"}, {"name": "test2"}], "rows": []} + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_creates_table_with_missing_columns(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': 'test1' - }, { - 'name': 'test2' - }], - 'rows': [{ - 'test1': 1, - 'test2': 2 - }, { - 'test1': 3 - }] + "columns": [{"name": "test1"}, {"name": "test2"}], + "rows": [{"test1": 1, "test2": 2}, {"test1": 3}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_creates_table_with_spaces_in_column_name(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': 'two words' - }, { - 'name': 'test2' - }], - 'rows': [{ - 'two words': 1, - 'test2': 2 - }, { - 'test1': 3 - }] + "columns": [{"name": "two words"}, {"name": "test2"}], + "rows": [{"two words": 1, "test2": 2}, {"test1": 3}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_creates_table_with_dashes_in_column_name(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': 'two-words' - }, { - 'name': 'test2' - }], - 'rows': [{ - 'two-words': 1, - 'test2': 2 - }] + "columns": [{"name": "two-words"}, {"name": "test2"}], + "rows": [{"two-words": 1, "test2": 2}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") connection.execute('SELECT "two-words" FROM query_123') def test_creates_table_with_non_ascii_in_column_name(self): - connection = sqlite3.connect(':memory:') + connection = sqlite3.connect(":memory:") results = { - 'columns': [{ - 'name': '\xe4' - }, { - 'name': 'test2' - }], - 'rows': [{ - '\xe4': 1, - 'test2': 2 - }] + "columns": [{"name": "\xe4"}, {"name": "test2"}], + "rows": [{"\xe4": 1, "test2": 2}], } - table_name = 'query_123' + table_name = "query_123" create_table(connection, table_name, results) - connection.execute('SELECT 1 FROM query_123') + connection.execute("SELECT 1 FROM query_123") def test_shows_meaningful_error_on_failure_to_create_table(self): - connection = sqlite3.connect(':memory:') - results = {'columns': [], 'rows': []} - table_name = 'query_123' + connection = sqlite3.connect(":memory:") + results = {"columns": [], "rows": []} + table_name = "query_123" with pytest.raises(CreateTableError): create_table(connection, table_name, results) def test_loads_results(self): - connection = sqlite3.connect(':memory:') - rows = [{'test1': 1, 'test2': 'test'}, {'test1': 2, 'test2': 'test2'}] - results = { - 'columns': [{ - 'name': 'test1' - }, { - 'name': 'test2' - }], - 'rows': rows - } - table_name = 'query_123' + connection = sqlite3.connect(":memory:") + rows = [{"test1": 1, "test2": "test"}, {"test1": 2, "test2": "test2"}] + results = {"columns": [{"name": "test1"}, {"name": "test2"}], "rows": rows} + table_name = "query_123" create_table(connection, table_name, results) - self.assertEqual( - len(list(connection.execute('SELECT * FROM query_123'))), 2) + self.assertEqual(len(list(connection.execute("SELECT * FROM query_123"))), 2) def test_loads_list_and_dict_results(self): - connection = sqlite3.connect(':memory:') - rows = [{'test1': [1, 2, 3]}, {'test2': {'a': 'b'}}] - results = { - 'columns': [{ - 'name': 'test1' - }, { - 'name': 'test2' - }], - 'rows': rows - } - table_name = 'query_123' + connection = sqlite3.connect(":memory:") + rows = [{"test1": [1, 2, 3]}, {"test2": {"a": "b"}}] + results = {"columns": [{"name": "test1"}, {"name": "test2"}], "rows": rows} + table_name = "query_123" create_table(connection, table_name, results) - self.assertEqual( - len(list(connection.execute('SELECT * FROM query_123'))), 2) + self.assertEqual(len(list(connection.execute("SELECT * FROM query_123"))), 2) class TestGetQuery(BaseTestCase): @@ -213,7 +150,8 @@ def test_returns_query(self): def test_returns_query_when_user_has_view_only_access(self): ds = self.factory.create_data_source( - group=self.factory.org.default_group, view_only=True) + group=self.factory.org.default_group, view_only=True + ) query = self.factory.create_query(data_source=ds) user = self.factory.create_user() diff --git a/tests/query_runner/test_script.py b/tests/query_runner/test_script.py index 29b31bbfb7..6de190f849 100644 --- a/tests/query_runner/test_script.py +++ b/tests/query_runner/test_script.py @@ -11,26 +11,38 @@ class TestQueryToScript(BaseTestCase): monkeypatch = MonkeyPatch() def test_unspecified(self): - self.assertEqual("/foo/bar/baz.sh", query_to_script_path("*", "/foo/bar/baz.sh")) + self.assertEqual( + "/foo/bar/baz.sh", query_to_script_path("*", "/foo/bar/baz.sh") + ) def test_specified(self): self.assertRaises(IOError, lambda: query_to_script_path("/foo/bar", "baz.sh")) self.monkeypatch.setattr(os.path, "exists", lambda x: True) - self.assertEqual(["/foo/bar/baz.sh"], query_to_script_path("/foo/bar", "baz.sh")) + self.assertEqual( + ["/foo/bar/baz.sh"], query_to_script_path("/foo/bar", "baz.sh") + ) class TestRunScript(BaseTestCase): monkeypatch = MonkeyPatch() def test_success(self): - self.monkeypatch.setattr(subprocess, "check_output", lambda script, shell: "test") + self.monkeypatch.setattr( + subprocess, "check_output", lambda script, shell: "test" + ) self.assertEqual(("test", None), run_script("/foo/bar/baz.sh", True)) def test_failure(self): self.monkeypatch.setattr(subprocess, "check_output", lambda script, shell: None) - self.assertEqual((None, "Error reading output"), run_script("/foo/bar/baz.sh", True)) + self.assertEqual( + (None, "Error reading output"), run_script("/foo/bar/baz.sh", True) + ) self.monkeypatch.setattr(subprocess, "check_output", lambda script, shell: "") - self.assertEqual((None, "Empty output from script"), run_script("/foo/bar/baz.sh", True)) + self.assertEqual( + (None, "Empty output from script"), run_script("/foo/bar/baz.sh", True) + ) self.monkeypatch.setattr(subprocess, "check_output", lambda script, shell: " ") - self.assertEqual((None, "Empty output from script"), run_script("/foo/bar/baz.sh", True)) + self.assertEqual( + (None, "Empty output from script"), run_script("/foo/bar/baz.sh", True) + ) diff --git a/tests/query_runner/test_utils.py b/tests/query_runner/test_utils.py index d2286702ec..77db0bfac9 100644 --- a/tests/query_runner/test_utils.py +++ b/tests/query_runner/test_utils.py @@ -1,33 +1,40 @@ from unittest import TestCase -from redash.query_runner import TYPE_DATETIME, TYPE_FLOAT, TYPE_INTEGER, TYPE_BOOLEAN, TYPE_STRING, guess_type +from redash.query_runner import ( + TYPE_DATETIME, + TYPE_FLOAT, + TYPE_INTEGER, + TYPE_BOOLEAN, + TYPE_STRING, + guess_type, +) class TestGuessType(TestCase): def test_handles_unicode(self): - self.assertEqual(guess_type('Текст'), TYPE_STRING) + self.assertEqual(guess_type("Текст"), TYPE_STRING) def test_detects_booleans(self): - self.assertEqual(guess_type('true'), TYPE_BOOLEAN) - self.assertEqual(guess_type('True'), TYPE_BOOLEAN) - self.assertEqual(guess_type('TRUE'), TYPE_BOOLEAN) - self.assertEqual(guess_type('false'), TYPE_BOOLEAN) - self.assertEqual(guess_type('False'), TYPE_BOOLEAN) - self.assertEqual(guess_type('FALSE'), TYPE_BOOLEAN) + self.assertEqual(guess_type("true"), TYPE_BOOLEAN) + self.assertEqual(guess_type("True"), TYPE_BOOLEAN) + self.assertEqual(guess_type("TRUE"), TYPE_BOOLEAN) + self.assertEqual(guess_type("false"), TYPE_BOOLEAN) + self.assertEqual(guess_type("False"), TYPE_BOOLEAN) + self.assertEqual(guess_type("FALSE"), TYPE_BOOLEAN) self.assertEqual(guess_type(False), TYPE_BOOLEAN) def test_detects_strings(self): self.assertEqual(guess_type(None), TYPE_STRING) - self.assertEqual(guess_type(''), TYPE_STRING) - self.assertEqual(guess_type('redash'), TYPE_STRING) + self.assertEqual(guess_type(""), TYPE_STRING) + self.assertEqual(guess_type("redash"), TYPE_STRING) def test_detects_integer(self): - self.assertEqual(guess_type('42'), TYPE_INTEGER) + self.assertEqual(guess_type("42"), TYPE_INTEGER) self.assertEqual(guess_type(42), TYPE_INTEGER) def test_detects_float(self): - self.assertEqual(guess_type('3.14'), TYPE_FLOAT) + self.assertEqual(guess_type("3.14"), TYPE_FLOAT) self.assertEqual(guess_type(3.14), TYPE_FLOAT) def test_detects_date(self): - self.assertEqual(guess_type('2018-10-31'), TYPE_DATETIME) + self.assertEqual(guess_type("2018-10-31"), TYPE_DATETIME) diff --git a/tests/serializers/test_query_results.py b/tests/serializers/test_query_results.py index 95b4f72724..49e02ce412 100644 --- a/tests/serializers/test_query_results.py +++ b/tests/serializers/test_query_results.py @@ -20,22 +20,22 @@ "columns": [ {"friendly_name": "bool", "type": "boolean", "name": "bool"}, {"friendly_name": "date", "type": "datetime", "name": "datetime"}, - {"friendly_name": "date", "type": "date", "name": "date"} - ] + {"friendly_name": "date", "type": "date", "name": "date"}, + ], } + class QueryResultSerializationTest(BaseTestCase): def test_serializes_all_keys_for_authenticated_users(self): query_result = self.factory.create_query_result(data=json_dumps({})) serialized = serialize_query_result(query_result, False) - self.assertSetEqual(set(query_result.to_dict().keys()), - set(serialized.keys())) + self.assertSetEqual(set(query_result.to_dict().keys()), set(serialized.keys())) def test_doesnt_serialize_sensitive_keys_for_unauthenticated_users(self): query_result = self.factory.create_query_result(data=json_dumps({})) serialized = serialize_query_result(query_result, True) - self.assertSetEqual(set(['data', 'retrieved_at']), - set(serialized.keys())) + self.assertSetEqual(set(["data", "retrieved_at"]), set(serialized.keys())) + class CsvSerializationTest(BaseTestCase): def get_csv_content(self): @@ -43,30 +43,30 @@ def get_csv_content(self): return serialize_query_result_to_csv(query_result) def test_serializes_booleans_correctly(self): - with self.app.test_request_context('/'): + with self.app.test_request_context("/"): parsed = csv.DictReader(io.StringIO(self.get_csv_content())) rows = list(parsed) - self.assertEqual(rows[0]['bool'], 'true') - self.assertEqual(rows[1]['bool'], 'false') - self.assertEqual(rows[2]['bool'], '') + self.assertEqual(rows[0]["bool"], "true") + self.assertEqual(rows[1]["bool"], "false") + self.assertEqual(rows[2]["bool"], "") def test_serializes_datatime_with_correct_format(self): - with self.app.test_request_context('/'): + with self.app.test_request_context("/"): parsed = csv.DictReader(io.StringIO(self.get_csv_content())) rows = list(parsed) - self.assertEqual(rows[0]['datetime'], '26/05/19 12:39') - self.assertEqual(rows[1]['datetime'], '') - self.assertEqual(rows[2]['datetime'], '') - self.assertEqual(rows[0]['date'], '26/05/19') - self.assertEqual(rows[1]['date'], '') - self.assertEqual(rows[2]['date'], '') + self.assertEqual(rows[0]["datetime"], "26/05/19 12:39") + self.assertEqual(rows[1]["datetime"], "") + self.assertEqual(rows[2]["datetime"], "") + self.assertEqual(rows[0]["date"], "26/05/19") + self.assertEqual(rows[1]["date"], "") + self.assertEqual(rows[2]["date"], "") def test_serializes_datatime_as_is_in_case_of_error(self): - with self.app.test_request_context('/'): + with self.app.test_request_context("/"): parsed = csv.DictReader(io.StringIO(self.get_csv_content())) rows = list(parsed) - self.assertEqual(rows[3]['datetime'], '459') - self.assertEqual(rows[3]['date'], '123') + self.assertEqual(rows[3]["datetime"], "459") + self.assertEqual(rows[3]["date"], "123") diff --git a/tests/tasks/test_alerts.py b/tests/tasks/test_alerts.py index 14a330d2d1..305e782436 100644 --- a/tests/tasks/test_alerts.py +++ b/tests/tasks/test_alerts.py @@ -2,7 +2,11 @@ from mock import MagicMock, ANY import redash.tasks.alerts -from redash.tasks.alerts import check_alerts_for_query, notify_subscriptions, should_notify +from redash.tasks.alerts import ( + check_alerts_for_query, + notify_subscriptions, + should_notify, +) from redash.models import Alert @@ -40,4 +44,11 @@ def test_calls_notify_for_subscribers(self): subscription = self.factory.create_alert_subscription() subscription.notify = MagicMock() notify_subscriptions(subscription.alert, Alert.OK_STATE) - subscription.notify.assert_called_with(subscription.alert, subscription.alert.query_rel, subscription.user, Alert.OK_STATE, ANY, ANY) + subscription.notify.assert_called_with( + subscription.alert, + subscription.alert.query_rel, + subscription.user, + Alert.OK_STATE, + ANY, + ANY, + ) diff --git a/tests/tasks/test_empty_schedule.py b/tests/tasks/test_empty_schedule.py index 6dc59a0fd5..5a54244057 100644 --- a/tests/tasks/test_empty_schedule.py +++ b/tests/tasks/test_empty_schedule.py @@ -9,8 +9,10 @@ class TestEmptyScheduleQuery(BaseTestCase): def test_empty_schedules(self): one_day_ago = (utcnow() - datetime.timedelta(days=1)).strftime("%Y-%m-%d") - query = self.factory.create_query(schedule={'interval':'3600','until':one_day_ago}) + query = self.factory.create_query( + schedule={"interval": "3600", "until": one_day_ago} + ) oq = staticmethod(lambda: [query]) - with patch.object(Query, 'past_scheduled_queries', oq): + with patch.object(Query, "past_scheduled_queries", oq): empty_schedules() self.assertEqual(query.schedule, None) diff --git a/tests/tasks/test_failure_report.py b/tests/tasks/test_failure_report.py index 0cefe2131d..d31386a65d 100644 --- a/tests/tasks/test_failure_report.py +++ b/tests/tasks/test_failure_report.py @@ -9,11 +9,12 @@ from redash.tasks.failure_report import notify_of_failure, send_failure_report, key from redash.utils import json_loads + class TestSendAggregatedErrorsTask(BaseTestCase): def setUp(self): super(TestSendAggregatedErrorsTask, self).setUp() redis_connection.flushall() - self.factory.org.set_setting('send_email_on_failed_scheduled_queries', True) + self.factory.org.set_setting("send_email_on_failed_scheduled_queries", True) def notify(self, message="Oh no, I failed!", query=None, **kwargs): if query is None: @@ -22,12 +23,12 @@ def notify(self, message="Oh no, I failed!", query=None, **kwargs): notify_of_failure(message, query) return key(query.user.id) - @mock.patch('redash.tasks.failure_report.render_template', return_value='') + @mock.patch("redash.tasks.failure_report.render_template", return_value="") def send_email(self, user, render_template): send_failure_report(user.id) _, context = render_template.call_args[0] - return context['failures'] + return context["failures"] def test_schedules_email_if_failure_count_is_beneath_limit(self): key = self.notify(schedule_failures=settings.MAX_FAILURE_REPORTS_PER_QUERY - 1) @@ -40,7 +41,7 @@ def test_does_not_report_if_failure_count_is_beyond_limit(self): self.assertFalse(email_pending) def test_does_not_report_if_organization_is_not_subscribed(self): - self.factory.org.set_setting('send_email_on_failed_scheduled_queries', False) + self.factory.org.set_setting("send_email_on_failed_scheduled_queries", False) key = self.notify() email_pending = redis_connection.exists(key) self.assertFalse(email_pending) @@ -55,13 +56,13 @@ def test_does_not_indicate_when_not_near_limit_for_a_query(self): self.notify(schedule_failures=settings.MAX_FAILURE_REPORTS_PER_QUERY / 2) failures = self.send_email(self.factory.user) - self.assertFalse(failures[0]['comment']) + self.assertFalse(failures[0]["comment"]) def test_indicates_when_near_limit_for_a_query(self): self.notify(schedule_failures=settings.MAX_FAILURE_REPORTS_PER_QUERY - 1) failures = self.send_email(self.factory.user) - self.assertTrue(failures[0]['comment']) + self.assertTrue(failures[0]["comment"]) def test_aggregates_different_queries_in_a_single_report(self): key1 = self.notify(message="I'm a failure") @@ -80,11 +81,19 @@ def test_counts_failures_for_each_reason(self): failures = self.send_email(query.user) f1 = next(f for f in failures if f["failure_reason"] == "I'm a failure") - self.assertEqual(2, f1['failure_count']) - f2 = next(f for f in failures if f["failure_reason"] == "I'm a different type of failure") - self.assertEqual(1, f2['failure_count']) - f3 = next(f for f in failures if f["failure_reason"] == "I'm a totally different query") - self.assertEqual(1, f3['failure_count']) + self.assertEqual(2, f1["failure_count"]) + f2 = next( + f + for f in failures + if f["failure_reason"] == "I'm a different type of failure" + ) + self.assertEqual(1, f2["failure_count"]) + f3 = next( + f + for f in failures + if f["failure_reason"] == "I'm a totally different query" + ) + self.assertEqual(1, f3["failure_count"]) def test_shows_latest_failure_time(self): query = self.factory.create_query() @@ -95,5 +104,5 @@ def test_shows_latest_failure_time(self): self.notify(query=query) failures = self.send_email(query.user) - latest_failure = dateutil.parser.parse(failures[0]['failed_at']) + latest_failure = dateutil.parser.parse(failures[0]["failed_at"]) self.assertNotEqual(2000, latest_failure.year) diff --git a/tests/tasks/test_queries.py b/tests/tasks/test_queries.py index ea8efc7126..0fb3b605f5 100644 --- a/tests/tasks/test_queries.py +++ b/tests/tasks/test_queries.py @@ -8,10 +8,14 @@ from redash import redis_connection, models from redash.utils import json_dumps from redash.query_runner.pg import PostgreSQL -from redash.tasks.queries.execution import QueryExecutionError, enqueue_query, execute_query +from redash.tasks.queries.execution import ( + QueryExecutionError, + enqueue_query, + execute_query, +) -FakeResult = namedtuple('FakeResult', 'id') +FakeResult = namedtuple("FakeResult", "id") def gen_hash(*args, **kwargs): @@ -23,40 +27,90 @@ def test_multiple_enqueue_of_same_query(self): query = self.factory.create_query() execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - enqueue_query(query.query_text, query.data_source, query.user_id, False, query, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query_text, query.data_source, query.user_id, False, query, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query_text, query.data_source, query.user_id, False, query, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) self.assertEqual(1, execute_query.apply_async.call_count) - @mock.patch('redash.settings.dynamic_settings.query_time_limit', return_value=60) + @mock.patch("redash.settings.dynamic_settings.query_time_limit", return_value=60) def test_limits_query_time(self, _): query = self.factory.create_query() execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - enqueue_query(query.query_text, query.data_source, query.user_id, False, query, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + query, + {"Username": "Arik", "Query ID": query.id}, + ) _, kwargs = execute_query.apply_async.call_args - self.assertEqual(60, kwargs.get('soft_time_limit')) + self.assertEqual(60, kwargs.get("soft_time_limit")) def test_multiple_enqueue_of_different_query(self): query = self.factory.create_query() execute_query.apply_async = mock.MagicMock(side_effect=gen_hash) - enqueue_query(query.query_text, query.data_source, query.user_id, False, None, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query_text + '2', query.data_source, query.user_id, False, None, {'Username': 'Arik', 'Query ID': query.id}) - enqueue_query(query.query_text + '3', query.data_source, query.user_id, False, None, {'Username': 'Arik', 'Query ID': query.id}) + enqueue_query( + query.query_text, + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text + "2", + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) + enqueue_query( + query.query_text + "3", + query.data_source, + query.user_id, + False, + None, + {"Username": "Arik", "Query ID": query.id}, + ) self.assertEqual(3, execute_query.apply_async.call_count) class QueryExecutorTests(BaseTestCase): - def test_success(self): """ ``execute_query`` invokes the query runner and stores a query result. """ - cm = mock.patch("celery.app.task.Context.delivery_info", {'routing_key': 'test'}) + cm = mock.patch( + "celery.app.task.Context.delivery_info", {"routing_key": "test"} + ) with cm, mock.patch.object(PostgreSQL, "run_query") as qr: query_result_data = {"columns": [], "rows": []} qr.return_value = (json_dumps(query_result_data), None) @@ -69,15 +123,17 @@ def test_success_scheduled(self): """ Scheduled queries remember their latest results. """ - cm = mock.patch("celery.app.task.Context.delivery_info", - {'routing_key': 'test'}) - q = self.factory.create_query(query_text="SELECT 1, 2", schedule={"interval": 300}) + cm = mock.patch( + "celery.app.task.Context.delivery_info", {"routing_key": "test"} + ) + q = self.factory.create_query( + query_text="SELECT 1, 2", schedule={"interval": 300} + ) with cm, mock.patch.object(PostgreSQL, "run_query") as qr: qr.return_value = ([1, 2], None) result_id = execute_query( - "SELECT 1, 2", - self.factory.data_source.id, {}, - scheduled_query_id=q.id) + "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id + ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 0) result = models.QueryResult.query.get(result_id) @@ -87,19 +143,30 @@ def test_failure_scheduled(self): """ Scheduled queries that fail have their failure recorded. """ - cm = mock.patch("celery.app.task.Context.delivery_info", - {'routing_key': 'test'}) - q = self.factory.create_query(query_text="SELECT 1, 2", schedule={"interval": 300}) + cm = mock.patch( + "celery.app.task.Context.delivery_info", {"routing_key": "test"} + ) + q = self.factory.create_query( + query_text="SELECT 1, 2", schedule={"interval": 300} + ) with cm, mock.patch.object(PostgreSQL, "run_query") as qr: qr.side_effect = ValueError("broken") with self.assertRaises(QueryExecutionError): - execute_query("SELECT 1, 2", self.factory.data_source.id, {}, - scheduled_query_id=q.id) + execute_query( + "SELECT 1, 2", + self.factory.data_source.id, + {}, + scheduled_query_id=q.id, + ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 1) with self.assertRaises(QueryExecutionError): - execute_query("SELECT 1, 2", self.factory.data_source.id, {}, - scheduled_query_id=q.id) + execute_query( + "SELECT 1, 2", + self.factory.data_source.id, + {}, + scheduled_query_id=q.id, + ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 2) @@ -107,22 +174,28 @@ def test_success_after_failure(self): """ Query execution success resets the failure counter. """ - cm = mock.patch("celery.app.task.Context.delivery_info", - {'routing_key': 'test'}) - q = self.factory.create_query(query_text="SELECT 1, 2", schedule={"interval": 300}) + cm = mock.patch( + "celery.app.task.Context.delivery_info", {"routing_key": "test"} + ) + q = self.factory.create_query( + query_text="SELECT 1, 2", schedule={"interval": 300} + ) with cm, mock.patch.object(PostgreSQL, "run_query") as qr: qr.side_effect = ValueError("broken") with self.assertRaises(QueryExecutionError): - execute_query("SELECT 1, 2", - self.factory.data_source.id, {}, - scheduled_query_id=q.id) + execute_query( + "SELECT 1, 2", + self.factory.data_source.id, + {}, + scheduled_query_id=q.id, + ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 1) with cm, mock.patch.object(PostgreSQL, "run_query") as qr: qr.return_value = ([1, 2], None) - execute_query("SELECT 1, 2", - self.factory.data_source.id, {}, - scheduled_query_id=q.id) + execute_query( + "SELECT 1, 2", self.factory.data_source.id, {}, scheduled_query_id=q.id + ) q = models.Query.get_by_id(q.id) self.assertEqual(q.schedule_failures, 0) diff --git a/tests/tasks/test_refresh_queries.py b/tests/tasks/test_refresh_queries.py index 2300fff9a0..452c1b7702 100644 --- a/tests/tasks/test_refresh_queries.py +++ b/tests/tasks/test_refresh_queries.py @@ -3,7 +3,8 @@ from redash.tasks.queries.maintenance import refresh_queries from redash.models import Query -ENQUEUE_QUERY = 'redash.tasks.queries.maintenance.enqueue_query' +ENQUEUE_QUERY = "redash.tasks.queries.maintenance.enqueue_query" + class TestRefreshQuery(BaseTestCase): def test_enqueues_outdated_queries(self): @@ -13,18 +14,33 @@ def test_enqueues_outdated_queries(self): """ query1 = self.factory.create_query() query2 = self.factory.create_query( - query_text="select 42;", - data_source=self.factory.create_data_source()) + query_text="select 42;", data_source=self.factory.create_data_source() + ) oq = staticmethod(lambda: [query1, query2]) - with patch(ENQUEUE_QUERY) as add_job_mock, \ - patch.object(Query, 'outdated_queries', oq): + with patch(ENQUEUE_QUERY) as add_job_mock, patch.object( + Query, "outdated_queries", oq + ): refresh_queries() self.assertEqual(add_job_mock.call_count, 2) - add_job_mock.assert_has_calls([ - call(query1.query_text, query1.data_source, query1.user_id, - scheduled_query=query1, metadata=ANY), - call(query2.query_text, query2.data_source, query2.user_id, - scheduled_query=query2, metadata=ANY)], any_order=True) + add_job_mock.assert_has_calls( + [ + call( + query1.query_text, + query1.data_source, + query1.user_id, + scheduled_query=query1, + metadata=ANY, + ), + call( + query2.query_text, + query2.data_source, + query2.user_id, + scheduled_query=query2, + metadata=ANY, + ), + ], + any_order=True, + ) def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self): """ @@ -34,7 +50,7 @@ def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self): query = self.factory.create_query() oq = staticmethod(lambda: [query]) query.data_source.pause() - with patch.object(Query, 'outdated_queries', oq): + with patch.object(Query, "outdated_queries", oq): with patch(ENQUEUE_QUERY) as add_job_mock: refresh_queries() add_job_mock.assert_not_called() @@ -44,8 +60,12 @@ def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self): with patch(ENQUEUE_QUERY) as add_job_mock: refresh_queries() add_job_mock.assert_called_with( - query.query_text, query.data_source, query.user_id, - scheduled_query=query, metadata=ANY) + query.query_text, + query.data_source, + query.user_id, + scheduled_query=query, + metadata=ANY, + ) def test_enqueues_parameterized_queries(self): """ @@ -53,19 +73,30 @@ def test_enqueues_parameterized_queries(self): """ query = self.factory.create_query( query_text="select {{n}}", - options={"parameters": [{ - "global": False, - "type": "text", - "name": "n", - "value": "42", - "title": "n"}]}) + options={ + "parameters": [ + { + "global": False, + "type": "text", + "name": "n", + "value": "42", + "title": "n", + } + ] + }, + ) oq = staticmethod(lambda: [query]) - with patch(ENQUEUE_QUERY) as add_job_mock, \ - patch.object(Query, 'outdated_queries', oq): + with patch(ENQUEUE_QUERY) as add_job_mock, patch.object( + Query, "outdated_queries", oq + ): refresh_queries() add_job_mock.assert_called_with( - "select 42", query.data_source, query.user_id, - scheduled_query=query, metadata=ANY) + "select 42", + query.data_source, + query.user_id, + scheduled_query=query, + metadata=ANY, + ) def test_doesnt_enqueue_parameterized_queries_with_invalid_parameters(self): """ @@ -73,35 +104,51 @@ def test_doesnt_enqueue_parameterized_queries_with_invalid_parameters(self): """ query = self.factory.create_query( query_text="select {{n}}", - options={"parameters": [{ - "global": False, - "type": "text", - "name": "n", - "value": 42, # <-- should be text! - "title": "n"}]}) + options={ + "parameters": [ + { + "global": False, + "type": "text", + "name": "n", + "value": 42, # <-- should be text! + "title": "n", + } + ] + }, + ) oq = staticmethod(lambda: [query]) - with patch(ENQUEUE_QUERY) as add_job_mock, \ - patch.object(Query, 'outdated_queries', oq): + with patch(ENQUEUE_QUERY) as add_job_mock, patch.object( + Query, "outdated_queries", oq + ): refresh_queries() add_job_mock.assert_not_called() - def test_doesnt_enqueue_parameterized_queries_with_dropdown_queries_that_are_detached_from_data_source(self): + def test_doesnt_enqueue_parameterized_queries_with_dropdown_queries_that_are_detached_from_data_source( + self + ): """ Scheduled queries with a dropdown parameter which points to a query that is detached from its data source are skipped. """ query = self.factory.create_query( query_text="select {{n}}", - options={"parameters": [{ - "global": False, - "type": "query", - "name": "n", - "queryId": 100, - "title": "n"}]}) + options={ + "parameters": [ + { + "global": False, + "type": "query", + "name": "n", + "queryId": 100, + "title": "n", + } + ] + }, + ) dropdown_query = self.factory.create_query(id=100, data_source=None) oq = staticmethod(lambda: [query]) - with patch(ENQUEUE_QUERY) as add_job_mock, \ - patch.object(Query, 'outdated_queries', oq): + with patch(ENQUEUE_QUERY) as add_job_mock, patch.object( + Query, "outdated_queries", oq + ): refresh_queries() - add_job_mock.assert_not_called() \ No newline at end of file + add_job_mock.assert_not_called() diff --git a/tests/tasks/test_refresh_schemas.py b/tests/tasks/test_refresh_schemas.py index e6426f6668..8cb1210b90 100644 --- a/tests/tasks/test_refresh_schemas.py +++ b/tests/tasks/test_refresh_schemas.py @@ -7,19 +7,25 @@ class TestRefreshSchemas(BaseTestCase): def test_calls_refresh_of_all_data_sources(self): self.factory.data_source # trigger creation - with patch('redash.tasks.queries.maintenance.refresh_schema.delay') as refresh_job: + with patch( + "redash.tasks.queries.maintenance.refresh_schema.delay" + ) as refresh_job: refresh_schemas() refresh_job.assert_called() def test_skips_paused_data_sources(self): self.factory.data_source.pause() - with patch('redash.tasks.queries.maintenance.refresh_schema.delay') as refresh_job: + with patch( + "redash.tasks.queries.maintenance.refresh_schema.delay" + ) as refresh_job: refresh_schemas() refresh_job.assert_not_called() self.factory.data_source.resume() - with patch('redash.tasks.queries.maintenance.refresh_schema.delay') as refresh_job: + with patch( + "redash.tasks.queries.maintenance.refresh_schema.delay" + ) as refresh_job: refresh_schemas() refresh_job.assert_called() diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 4023b7667c..52fc173ff5 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -6,11 +6,13 @@ from flask import request from mock import patch from redash import models, settings -from redash.authentication import (api_key_load_user_from_request, - get_login_url, hmac_load_user_from_request, - sign) -from redash.authentication.google_oauth import (create_and_login_user, - verify_profile) +from redash.authentication import ( + api_key_load_user_from_request, + get_login_url, + hmac_load_user_from_request, + sign, +) +from redash.authentication.google_oauth import create_and_login_user, verify_profile from redash.utils import utcnow from sqlalchemy.orm.exc import NoResultFound from tests import BaseTestCase @@ -22,11 +24,13 @@ class TestApiKeyAuthentication(BaseTestCase): # def setUp(self): super(TestApiKeyAuthentication, self).setUp() - self.api_key = '10' + self.api_key = "10" self.query = self.factory.create_query(api_key=self.api_key) models.db.session.flush() - self.query_url = '/{}/api/queries/{}'.format(self.factory.org.slug, self.query.id) - self.queries_url = '/{}/api/queries'.format(self.factory.org.slug) + self.query_url = "/{}/api/queries/{}".format( + self.factory.org.slug, self.query.id + ) + self.queries_url = "/{}/api/queries".format(self.factory.org.slug) def test_no_api_key(self): with self.app.test_client() as c: @@ -35,24 +39,24 @@ def test_no_api_key(self): def test_wrong_api_key(self): with self.app.test_client() as c: - rv = c.get(self.query_url, query_string={'api_key': 'whatever'}) + rv = c.get(self.query_url, query_string={"api_key": "whatever"}) self.assertIsNone(api_key_load_user_from_request(request)) def test_correct_api_key(self): with self.app.test_client() as c: - rv = c.get(self.query_url, query_string={'api_key': self.api_key}) + rv = c.get(self.query_url, query_string={"api_key": self.api_key}) self.assertIsNotNone(api_key_load_user_from_request(request)) def test_no_query_id(self): with self.app.test_client() as c: - rv = c.get(self.queries_url, query_string={'api_key': self.api_key}) + rv = c.get(self.queries_url, query_string={"api_key": self.api_key}) self.assertIsNone(api_key_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") models.db.session.flush() with self.app.test_client() as c: - rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) + rv = c.get(self.queries_url, query_string={"api_key": user.api_key}) self.assertEqual(user.id, api_key_load_user_from_request(request).id) def test_disabled_user_api_key(self): @@ -60,24 +64,29 @@ def test_disabled_user_api_key(self): user.disable() models.db.session.flush() with self.app.test_client() as c: - rv = c.get(self.queries_url, query_string={'api_key': user.api_key}) + rv = c.get(self.queries_url, query_string={"api_key": user.api_key}) self.assertEqual(None, api_key_load_user_from_request(request)) def test_api_key_header(self): with self.app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(self.api_key)}) + rv = c.get( + self.query_url, headers={"Authorization": "Key {}".format(self.api_key)} + ) self.assertIsNotNone(api_key_load_user_from_request(request)) def test_api_key_header_with_wrong_key(self): with self.app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key oops"}) + rv = c.get(self.query_url, headers={"Authorization": "Key oops"}) self.assertIsNone(api_key_load_user_from_request(request)) def test_api_key_for_wrong_org(self): other_user = self.factory.create_admin(org=self.factory.create_org()) with self.app.test_client() as c: - rv = c.get(self.query_url, headers={'Authorization': "Key {}".format(other_user.api_key)}) + rv = c.get( + self.query_url, + headers={"Authorization": "Key {}".format(other_user.api_key)}, + ) self.assertEqual(404, rv.status_code) @@ -87,10 +96,10 @@ class TestHMACAuthentication(BaseTestCase): # def setUp(self): super(TestHMACAuthentication, self).setUp() - self.api_key = '10' + self.api_key = "10" self.query = self.factory.create_query(api_key=self.api_key) models.db.session.flush() - self.path = '/{}/api/queries/{}'.format(self.query.org.slug, self.query.id) + self.path = "/{}/api/queries/{}".format(self.query.org.slug, self.query.id) self.expires = time.time() + 1800 def signature(self, expires): @@ -103,27 +112,46 @@ def test_no_signature(self): def test_wrong_signature(self): with self.app.test_client() as c: - rv = c.get(self.path, query_string={'signature': 'whatever', 'expires': self.expires}) + rv = c.get( + self.path, + query_string={"signature": "whatever", "expires": self.expires}, + ) self.assertIsNone(hmac_load_user_from_request(request)) def test_correct_signature(self): with self.app.test_client() as c: - rv = c.get(self.path, query_string={'signature': self.signature(self.expires), 'expires': self.expires}) + rv = c.get( + self.path, + query_string={ + "signature": self.signature(self.expires), + "expires": self.expires, + }, + ) self.assertIsNotNone(hmac_load_user_from_request(request)) def test_no_query_id(self): with self.app.test_client() as c: - rv = c.get('/{}/api/queries'.format(self.query.org.slug), query_string={'api_key': self.api_key}) + rv = c.get( + "/{}/api/queries".format(self.query.org.slug), + query_string={"api_key": self.api_key}, + ) self.assertIsNone(hmac_load_user_from_request(request)) def test_user_api_key(self): user = self.factory.create_user(api_key="user_key") - path = '/api/queries/' + path = "/api/queries/" models.db.session.flush() signature = sign(user.api_key, path, self.expires) with self.app.test_client() as c: - rv = c.get(path, query_string={'signature': signature, 'expires': self.expires, 'user_id': user.id}) + rv = c.get( + path, + query_string={ + "signature": signature, + "expires": self.expires, + "user_id": user.id, + }, + ) self.assertEqual(user.id, hmac_load_user_from_request(request).id) @@ -136,23 +164,27 @@ def test_prefers_api_key_over_session_user_id(self): other_user = self.factory.create_user(org=other_org) models.db.session.flush() - rv = self.make_request('get', '/api/queries/{}?api_key={}'.format(query.id, query.api_key), user=other_user) + rv = self.make_request( + "get", + "/api/queries/{}?api_key={}".format(query.id, query.api_key), + user=other_user, + ) self.assertEqual(rv.status_code, 200) class TestCreateAndLoginUser(BaseTestCase): def test_logins_valid_user(self): - user = self.factory.create_user(email='test@example.com') + user = self.factory.create_user(email="test@example.com") - with patch('redash.authentication.login_user') as login_user_mock: + with patch("redash.authentication.login_user") as login_user_mock: create_and_login_user(self.factory.org, user.name, user.email) login_user_mock.assert_called_once_with(user, remember=True) def test_creates_vaild_new_user(self): - email = 'test@example.com' - name = 'Test User' + email = "test@example.com" + name = "Test User" - with patch('redash.authentication.login_user') as login_user_mock: + with patch("redash.authentication.login_user") as login_user_mock: create_and_login_user(self.factory.org, name, email) self.assertTrue(login_user_mock.called) @@ -160,85 +192,118 @@ def test_creates_vaild_new_user(self): self.assertEqual(user.email, email) def test_updates_user_name(self): - user = self.factory.create_user(email='test@example.com') + user = self.factory.create_user(email="test@example.com") - with patch('redash.authentication.login_user') as login_user_mock: + with patch("redash.authentication.login_user") as login_user_mock: create_and_login_user(self.factory.org, "New Name", user.email) login_user_mock.assert_called_once_with(user, remember=True) class TestVerifyProfile(BaseTestCase): def test_no_domain_allowed_for_org(self): - profile = dict(email='arik@example.com') + profile = dict(email="arik@example.com") self.assertFalse(verify_profile(self.factory.org, profile)) def test_domain_not_in_org_domains_list(self): - profile = dict(email='arik@example.com') - self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] + profile = dict(email="arik@example.com") + self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ + "example.org" + ] self.assertFalse(verify_profile(self.factory.org, profile)) def test_domain_in_org_domains_list(self): - profile = dict(email='arik@example.com') - self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.com'] + profile = dict(email="arik@example.com") + self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ + "example.com" + ] self.assertTrue(verify_profile(self.factory.org, profile)) - self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org', 'example.com'] + self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ + "example.org", + "example.com", + ] self.assertTrue(verify_profile(self.factory.org, profile)) def test_org_in_public_mode_accepts_any_domain(self): - profile = dict(email='arik@example.com') + profile = dict(email="arik@example.com") self.factory.org.settings[models.Organization.SETTING_IS_PUBLIC] = True self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [] self.assertTrue(verify_profile(self.factory.org, profile)) def test_user_not_in_domain_but_account_exists(self): - profile = dict(email='arik@example.com') - self.factory.create_user(email='arik@example.com') - self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = ['example.org'] + profile = dict(email="arik@example.com") + self.factory.create_user(email="arik@example.com") + self.factory.org.settings[models.Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ + "example.org" + ] self.assertTrue(verify_profile(self.factory.org, profile)) class TestGetLoginUrl(BaseTestCase): def test_when_multi_org_enabled_and_org_exists(self): - with self.app.test_request_context('/{}/'.format(self.factory.org.slug)): - self.assertEqual(get_login_url(next=None), '/{}/login'.format(self.factory.org.slug)) + with self.app.test_request_context("/{}/".format(self.factory.org.slug)): + self.assertEqual( + get_login_url(next=None), "/{}/login".format(self.factory.org.slug) + ) def test_when_multi_org_enabled_and_org_doesnt_exist(self): - with self.app.test_request_context('/{}_notexists/'.format(self.factory.org.slug)): - self.assertEqual(get_login_url(next=None), '/') + with self.app.test_request_context( + "/{}_notexists/".format(self.factory.org.slug) + ): + self.assertEqual(get_login_url(next=None), "/") class TestRedirectToUrlAfterLoggingIn(BaseTestCase): def setUp(self): super(TestRedirectToUrlAfterLoggingIn, self).setUp() self.user = self.factory.user - self.password = 'test1234' + self.password = "test1234" def test_no_next_param(self): - response = self.post_request('/login', data={'email': self.user.email, 'password': self.password}, org=self.factory.org) - self.assertEqual(response.location, 'http://localhost/{}/'.format(self.user.org.slug)) + response = self.post_request( + "/login", + data={"email": self.user.email, "password": self.password}, + org=self.factory.org, + ) + self.assertEqual( + response.location, "http://localhost/{}/".format(self.user.org.slug) + ) def test_simple_path_in_next_param(self): - response = self.post_request('/login?next=queries', data={'email': self.user.email, 'password': self.password}, org=self.factory.org) - self.assertEqual(response.location, 'http://localhost/default/queries') + response = self.post_request( + "/login?next=queries", + data={"email": self.user.email, "password": self.password}, + org=self.factory.org, + ) + self.assertEqual(response.location, "http://localhost/default/queries") def test_starts_scheme_url_in_next_param(self): - response = self.post_request('/login?next=https://redash.io', data={'email': self.user.email, 'password': self.password}, org=self.factory.org) - self.assertEqual(response.location, 'http://localhost/default/') + response = self.post_request( + "/login?next=https://redash.io", + data={"email": self.user.email, "password": self.password}, + org=self.factory.org, + ) + self.assertEqual(response.location, "http://localhost/default/") def test_without_scheme_url_in_next_param(self): - response = self.post_request('/login?next=//redash.io', data={'email': self.user.email, 'password': self.password}, org=self.factory.org) - self.assertEqual(response.location, 'http://localhost/default/') + response = self.post_request( + "/login?next=//redash.io", + data={"email": self.user.email, "password": self.password}, + org=self.factory.org, + ) + self.assertEqual(response.location, "http://localhost/default/") def test_without_scheme_with_path_url_in_next_param(self): - response = self.post_request('/login?next=//localhost/queries', data={'email': self.user.email, 'password': self.password}, org=self.factory.org) - self.assertEqual(response.location, 'http://localhost/queries') + response = self.post_request( + "/login?next=//localhost/queries", + data={"email": self.user.email, "password": self.password}, + org=self.factory.org, + ) + self.assertEqual(response.location, "http://localhost/queries") class TestRemoteUserAuth(BaseTestCase): - DEFAULT_SETTING_OVERRIDES = { - 'REDASH_REMOTE_USER_LOGIN_ENABLED': 'true' - } + DEFAULT_SETTING_OVERRIDES = {"REDASH_REMOTE_USER_LOGIN_ENABLED": "true"} def setUp(self): # Apply default setting overrides to every test @@ -269,7 +334,14 @@ def override_settings(self, overrides): # once the test ends self.addCleanup(lambda: reload_module(settings)) - def assert_correct_user_attributes(self, user, email='test@example.com', name='test@example.com', groups=None, org=None): + def assert_correct_user_attributes( + self, + user, + email="test@example.com", + name="test@example.com", + groups=None, + org=None, + ): """Helper to assert that the user attributes are correct.""" groups = groups or [] if self.factory.org.default_group.id not in groups: @@ -281,7 +353,7 @@ def assert_correct_user_attributes(self, user, email='test@example.com', name='t self.assertEqual(user.org, org or self.factory.org) self.assertCountEqual(user.group_ids, groups) - def get_test_user(self, email='test@example.com', org=None): + def get_test_user(self, email="test@example.com", org=None): """Helper to fetch an user from the database.""" # Expire all cached objects to ensure these values are read directly @@ -291,32 +363,34 @@ def get_test_user(self, email='test@example.com', org=None): return models.User.get_by_email_and_org(email, org or self.factory.org) def test_remote_login_disabled(self): - self.override_settings({ - 'REDASH_REMOTE_USER_LOGIN_ENABLED': 'false' - }) + self.override_settings({"REDASH_REMOTE_USER_LOGIN_ENABLED": "false"}) - self.get_request('/remote_user/login', org=self.factory.org, headers={ - 'X-Forwarded-Remote-User': 'test@example.com' - }) + self.get_request( + "/remote_user/login", + org=self.factory.org, + headers={"X-Forwarded-Remote-User": "test@example.com"}, + ) with self.assertRaises(NoResultFound): self.get_test_user() def test_remote_login_default_header(self): - self.get_request('/remote_user/login', org=self.factory.org, headers={ - 'X-Forwarded-Remote-User': 'test@example.com' - }) + self.get_request( + "/remote_user/login", + org=self.factory.org, + headers={"X-Forwarded-Remote-User": "test@example.com"}, + ) self.assert_correct_user_attributes(self.get_test_user()) def test_remote_login_custom_header(self): - self.override_settings({ - 'REDASH_REMOTE_USER_HEADER': 'X-Custom-User' - }) + self.override_settings({"REDASH_REMOTE_USER_HEADER": "X-Custom-User"}) - self.get_request('/remote_user/login', org=self.factory.org, headers={ - 'X-Custom-User': 'test@example.com' - }) + self.get_request( + "/remote_user/login", + org=self.factory.org, + headers={"X-Custom-User": "test@example.com"}, + ) self.assert_correct_user_attributes(self.get_test_user()) @@ -325,8 +399,12 @@ class TestUserForgotPassword(BaseTestCase): def test_user_should_receive_password_reset_link(self): user = self.factory.create_user() - with patch('redash.handlers.authentication.send_password_reset_email') as send_password_reset_email_mock: - response = self.post_request('/forgot', org=user.org, data={'email': user.email}) + with patch( + "redash.handlers.authentication.send_password_reset_email" + ) as send_password_reset_email_mock: + response = self.post_request( + "/forgot", org=user.org, data={"email": user.email} + ) self.assertEqual(response.status_code, 200) send_password_reset_email_mock.assert_called_with(user) @@ -336,9 +414,14 @@ def test_disabled_user_should_not_receive_password_reset_link(self): self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.send_password_reset_email') as send_password_reset_email_mock,\ - patch('redash.handlers.authentication.send_user_disabled_email') as send_user_disabled_email_mock: - response = self.post_request('/forgot', org=user.org, data={'email': user.email}) + with patch( + "redash.handlers.authentication.send_password_reset_email" + ) as send_password_reset_email_mock, patch( + "redash.handlers.authentication.send_user_disabled_email" + ) as send_user_disabled_email_mock: + response = self.post_request( + "/forgot", org=user.org, data={"email": user.email} + ) self.assertEqual(response.status_code, 200) send_password_reset_email_mock.assert_not_called() send_user_disabled_email_mock.assert_called_with(user) diff --git a/tests/test_cli.py b/tests/test_cli.py index 25e7be08ee..537c83d03d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,70 +12,91 @@ class DataSourceCommandTests(BaseTestCase): def test_interactive_new(self): runner = CliRunner() - pg_i = list(query_runners.keys()).index('pg') + 1 + pg_i = list(query_runners.keys()).index("pg") + 1 result = runner.invoke( manager, - ['ds', 'new'], - input="test\n%s\n\n\nexample.com\n\n\ntestdb\n" % (pg_i,)) + ["ds", "new"], + input="test\n%s\n\n\nexample.com\n\n\ntestdb\n" % (pg_i,), + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(DataSource.query.count(), 1) ds = DataSource.query.first() - self.assertEqual(ds.name, 'test') - self.assertEqual(ds.type, 'pg') - self.assertEqual(ds.options['dbname'], 'testdb') + self.assertEqual(ds.name, "test") + self.assertEqual(ds.type, "pg") + self.assertEqual(ds.options["dbname"], "testdb") def test_options_new(self): runner = CliRunner() result = runner.invoke( manager, - ['ds', 'new', - 'test', - '--options', '{"host": "example.com", "dbname": "testdb"}', - '--type', 'pg']) + [ + "ds", + "new", + "test", + "--options", + '{"host": "example.com", "dbname": "testdb"}', + "--type", + "pg", + ], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(DataSource.query.count(), 1) ds = DataSource.query.first() - self.assertEqual(ds.name, 'test') - self.assertEqual(ds.type, 'pg') - self.assertEqual(ds.options['host'], 'example.com') - self.assertEqual(ds.options['dbname'], 'testdb') + self.assertEqual(ds.name, "test") + self.assertEqual(ds.type, "pg") + self.assertEqual(ds.options["host"], "example.com") + self.assertEqual(ds.options["dbname"], "testdb") def test_bad_type_new(self): runner = CliRunner() - result = runner.invoke( - manager, ['ds', 'new', 'test', '--type', 'wrong']) + result = runner.invoke(manager, ["ds", "new", "test", "--type", "wrong"]) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('not supported', result.output) + self.assertIn("not supported", result.output) self.assertEqual(DataSource.query.count(), 0) def test_bad_options_new(self): runner = CliRunner() result = runner.invoke( - manager, ['ds', 'new', 'test', '--options', - '{"host": 12345, "dbname": "testdb"}', - '--type', 'pg']) + manager, + [ + "ds", + "new", + "test", + "--options", + '{"host": 12345, "dbname": "testdb"}', + "--type", + "pg", + ], + ) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('invalid configuration', result.output) + self.assertIn("invalid configuration", result.output) self.assertEqual(DataSource.query.count(), 0) def test_list(self): self.factory.create_data_source( - name='test1', type='pg', - options=ConfigurationContainer({"host": "example.com", - "dbname": "testdb1"})) + name="test1", + type="pg", + options=ConfigurationContainer( + {"host": "example.com", "dbname": "testdb1"} + ), + ) self.factory.create_data_source( - name='test2', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test2", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) self.factory.create_data_source( - name='Atest', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="Atest", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() - result = runner.invoke(manager, ['ds', 'list']) + result = runner.invoke(manager, ["ds", "list"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) expected_output = """ @@ -94,46 +115,55 @@ def test_list(self): Type: sqlite Options: {"dbpath": "/tmp/test.db"} """ - self.assertMultiLineEqual(result.output, - textwrap.dedent(expected_output).lstrip()) + self.assertMultiLineEqual( + result.output, textwrap.dedent(expected_output).lstrip() + ) def test_connection_test(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() - result = runner.invoke(manager, ['ds', 'test', 'test1']) + result = runner.invoke(manager, ["ds", "test", "test1"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertIn('Success', result.output) + self.assertIn("Success", result.output) def test_connection_bad_test(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": __file__})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": __file__}), + ) runner = CliRunner() - result = runner.invoke(manager, ['ds', 'test', 'test1']) + result = runner.invoke(manager, ["ds", "test", "test1"]) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('Failure', result.output) + self.assertIn("Failure", result.output) def test_connection_delete(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() - result = runner.invoke(manager, ['ds', 'delete', 'test1']) + result = runner.invoke(manager, ["ds", "delete", "test1"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertIn('Deleting', result.output) + self.assertIn("Deleting", result.output) self.assertEqual(DataSource.query.count(), 0) def test_connection_bad_delete(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() - result = runner.invoke(manager, ['ds', 'delete', 'wrong']) + result = runner.invoke(manager, ["ds", "delete", "wrong"]) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("Couldn't find", result.output) @@ -141,59 +171,83 @@ def test_connection_bad_delete(self): def test_options_edit(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() result = runner.invoke( - manager, ['ds', 'edit', 'test1', '--options', - '{"host": "example.com", "dbname": "testdb"}', - '--name', 'test2', - '--type', 'pg']) + manager, + [ + "ds", + "edit", + "test1", + "--options", + '{"host": "example.com", "dbname": "testdb"}', + "--name", + "test2", + "--type", + "pg", + ], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(DataSource.query.count(), 1) ds = DataSource.query.first() - self.assertEqual(ds.name, 'test2') - self.assertEqual(ds.type, 'pg') - self.assertEqual(ds.options['host'], 'example.com') - self.assertEqual(ds.options['dbname'], 'testdb') + self.assertEqual(ds.name, "test2") + self.assertEqual(ds.type, "pg") + self.assertEqual(ds.options["host"], "example.com") + self.assertEqual(ds.options["dbname"], "testdb") def test_bad_type_edit(self): self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() - result = runner.invoke( - manager, ['ds', 'edit', 'test', '--type', 'wrong']) + result = runner.invoke(manager, ["ds", "edit", "test", "--type", "wrong"]) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('not supported', result.output) + self.assertIn("not supported", result.output) ds = DataSource.query.first() - self.assertEqual(ds.type, 'sqlite') + self.assertEqual(ds.type, "sqlite") def test_bad_options_edit(self): ds = self.factory.create_data_source( - name='test1', type='sqlite', - options=ConfigurationContainer({"dbpath": "/tmp/test.db"})) + name="test1", + type="sqlite", + options=ConfigurationContainer({"dbpath": "/tmp/test.db"}), + ) runner = CliRunner() result = runner.invoke( - manager, ['ds', 'new', 'test', '--options', - '{"host": 12345, "dbname": "testdb"}', - '--type', 'pg']) + manager, + [ + "ds", + "new", + "test", + "--options", + '{"host": 12345, "dbname": "testdb"}', + "--type", + "pg", + ], + ) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('invalid configuration', result.output) + self.assertIn("invalid configuration", result.output) ds = DataSource.query.first() - self.assertEqual(ds.type, 'sqlite') + self.assertEqual(ds.type, "sqlite") self.assertEqual(ds.options._config, {"dbpath": "/tmp/test.db"}) class GroupCommandTests(BaseTestCase): def test_create(self): gcount = Group.query.count() - perms = ['create_query', 'edit_query', 'view_query'] + perms = ["create_query", "edit_query", "view_query"] runner = CliRunner() - result = runner.invoke(manager, ['groups', 'create', 'test', '--permissions', ','.join(perms)]) + result = runner.invoke( + manager, ["groups", "create", "test", "--permissions", ",".join(perms)] + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertEqual(Group.query.count(), gcount + 1) @@ -203,30 +257,40 @@ def test_create(self): self.assertEqual(g.permissions, perms) def test_change_permissions(self): - g = self.factory.create_group(permissions=['list_dashboards']) + g = self.factory.create_group(permissions=["list_dashboards"]) db.session.flush() g_id = g.id - perms = ['create_query', 'edit_query', 'view_query'] + perms = ["create_query", "edit_query", "view_query"] runner = CliRunner() result = runner.invoke( - manager, ['groups', 'change_permissions', str(g_id), '--permissions', ','.join(perms)]) + manager, + [ + "groups", + "change_permissions", + str(g_id), + "--permissions", + ",".join(perms), + ], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) g = Group.query.filter(Group.id == g_id).first() self.assertEqual(g.permissions, perms) def test_list(self): - self.factory.create_group(name='test', permissions=['list_dashboards']) - self.factory.create_group(name='agroup', permissions=['list_dashboards']) - self.factory.create_group(name='bgroup', permissions=['list_dashboards']) + self.factory.create_group(name="test", permissions=["list_dashboards"]) + self.factory.create_group(name="agroup", permissions=["list_dashboards"]) + self.factory.create_group(name="bgroup", permissions=["list_dashboards"]) - self.factory.create_user(name='Fred Foobar', - email='foobar@example.com', - org=self.factory.org, - group_ids=[self.factory.default_group.id]) + self.factory.create_user( + name="Fred Foobar", + email="foobar@example.com", + org=self.factory.org, + group_ids=[self.factory.default_group.id], + ) runner = CliRunner() - result = runner.invoke(manager, ['groups', 'list']) + result = runner.invoke(manager, ["groups", "list"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -265,15 +329,16 @@ def test_list(self): Permissions: [list_dashboards] Users: """ - self.assertMultiLineEqual(result.output, - textwrap.dedent(output).lstrip()) + self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip()) class OrganizationCommandTests(BaseTestCase): def test_set_google_apps_domains(self): - domains = ['example.org', 'example.com'] + domains = ["example.org", "example.com"] runner = CliRunner() - result = runner.invoke(manager, ['org', 'set_google_apps_domains', ','.join(domains)]) + result = runner.invoke( + manager, ["org", "set_google_apps_domains", ",".join(domains)] + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) db.session.add(self.factory.org) @@ -281,25 +346,26 @@ def test_set_google_apps_domains(self): def test_show_google_apps_domains(self): self.factory.org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ - 'example.org', 'example.com'] + "example.org", + "example.com", + ] db.session.add(self.factory.org) db.session.commit() runner = CliRunner() - result = runner.invoke(manager, ['org', 'show_google_apps_domains']) + result = runner.invoke(manager, ["org", "show_google_apps_domains"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ Current list of Google Apps domains: example.org, example.com """ - self.assertMultiLineEqual(result.output, - textwrap.dedent(output).lstrip()) + self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip()) def test_list(self): - self.factory.create_org(name='test', slug='test_org') - self.factory.create_org(name='Borg', slug='B_org') - self.factory.create_org(name='Aorg', slug='A_org') + self.factory.create_org(name="test", slug="test_org") + self.factory.create_org(name="Borg", slug="B_org") + self.factory.create_org(name="Aorg", slug="A_org") runner = CliRunner() - result = runner.invoke(manager, ['org', 'list']) + result = runner.invoke(manager, ["org", "list"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -319,40 +385,51 @@ def test_list(self): Name: test Slug: test_org """ - self.assertMultiLineEqual(result.output, - textwrap.dedent(output).lstrip()) + self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip()) class UserCommandTests(BaseTestCase): def test_create_basic(self): runner = CliRunner() result = runner.invoke( - manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar'], - input="password1\npassword1\n") + manager, + ["users", "create", "foobar@example.com", "Fred Foobar"], + input="password1\npassword1\n", + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") - self.assertTrue(u.verify_password('password1')) + self.assertTrue(u.verify_password("password1")) self.assertEqual(u.group_ids, [u.org.default_group.id]) def test_create_admin(self): runner = CliRunner() result = runner.invoke( - manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar', - '--password', 'password1', '--admin']) + manager, + [ + "users", + "create", + "foobar@example.com", + "Fred Foobar", + "--password", + "password1", + "--admin", + ], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") - self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.group_ids, [u.org.default_group.id, - u.org.admin_group.id]) + self.assertTrue(u.verify_password("password1")) + self.assertEqual(u.group_ids, [u.org.default_group.id, u.org.admin_group.id]) def test_create_googleauth(self): runner = CliRunner() result = runner.invoke( - manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar', '--google']) + manager, + ["users", "create", "foobar@example.com", "Fred Foobar", "--google"], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() @@ -361,61 +438,80 @@ def test_create_googleauth(self): self.assertEqual(u.group_ids, [u.org.default_group.id]) def test_create_bad(self): - self.factory.create_user(email='foobar@example.com') + self.factory.create_user(email="foobar@example.com") runner = CliRunner() result = runner.invoke( - manager, ['users', 'create', 'foobar@example.com', 'Fred Foobar'], - input="password1\npassword1\n") + manager, + ["users", "create", "foobar@example.com", "Fred Foobar"], + input="password1\npassword1\n", + ) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('Failed', result.output) + self.assertIn("Failed", result.output) def test_delete(self): - self.factory.create_user(email='foobar@example.com') + self.factory.create_user(email="foobar@example.com") ucount = User.query.count() runner = CliRunner() - result = runner.invoke(manager, ['users', 'delete', 'foobar@example.com']) + result = runner.invoke(manager, ["users", "delete", "foobar@example.com"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(User.query.filter(User.email == - "foobar@example.com").count(), 0) + self.assertEqual( + User.query.filter(User.email == "foobar@example.com").count(), 0 + ) self.assertEqual(User.query.count(), ucount - 1) def test_delete_bad(self): ucount = User.query.count() runner = CliRunner() - result = runner.invoke(manager, ['users', 'delete', 'foobar@example.com']) - self.assertIn('Deleted 0 users', result.output) + result = runner.invoke(manager, ["users", "delete", "foobar@example.com"]) + self.assertIn("Deleted 0 users", result.output) self.assertEqual(User.query.count(), ucount) def test_password(self): - self.factory.create_user(email='foobar@example.com') + self.factory.create_user(email="foobar@example.com") runner = CliRunner() - result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy']) + result = runner.invoke( + manager, ["users", "password", "foobar@example.com", "xyzzy"] + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) u = User.query.filter(User.email == "foobar@example.com").first() - self.assertTrue(u.verify_password('xyzzy')) + self.assertTrue(u.verify_password("xyzzy")) def test_password_bad(self): runner = CliRunner() - result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy']) + result = runner.invoke( + manager, ["users", "password", "foobar@example.com", "xyzzy"] + ) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('not found', result.output) + self.assertIn("not found", result.output) def test_password_bad_org(self): runner = CliRunner() - result = runner.invoke(manager, ['users', 'password', 'foobar@example.com', 'xyzzy', '--org', 'default']) + result = runner.invoke( + manager, + ["users", "password", "foobar@example.com", "xyzzy", "--org", "default"], + ) self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) - self.assertIn('not found', result.output) + self.assertIn("not found", result.output) def test_invite(self): - admin = self.factory.create_user(email='redash-admin@example.com') + admin = self.factory.create_user(email="redash-admin@example.com") runner = CliRunner() - with mock.patch('redash.cli.users.invite_user') as iu: - result = runner.invoke(manager, ['users', 'invite', 'foobar@example.com', 'Fred Foobar', 'redash-admin@example.com']) + with mock.patch("redash.cli.users.invite_user") as iu: + result = runner.invoke( + manager, + [ + "users", + "invite", + "foobar@example.com", + "Fred Foobar", + "redash-admin@example.com", + ], + ) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertTrue(iu.called) @@ -423,23 +519,23 @@ def test_invite(self): db.session.add_all(c) self.assertEqual(c[0].id, self.factory.org.id) self.assertEqual(c[1].id, admin.id) - self.assertEqual(c[2].email, 'foobar@example.com') + self.assertEqual(c[2].email, "foobar@example.com") def test_list(self): - self.factory.create_user(name='Fred Foobar', - email='foobar@example.com', - org=self.factory.org) + self.factory.create_user( + name="Fred Foobar", email="foobar@example.com", org=self.factory.org + ) - self.factory.create_user(name='William Foobar', - email='william@example.com', - org=self.factory.org) + self.factory.create_user( + name="William Foobar", email="william@example.com", org=self.factory.org + ) - self.factory.create_user(name='Andrew Foobar', - email='andrew@example.com', - org=self.factory.org) + self.factory.create_user( + name="Andrew Foobar", email="andrew@example.com", org=self.factory.org + ) runner = CliRunner() - result = runner.invoke(manager, ['users', 'list']) + result = runner.invoke(manager, ["users", "list"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) output = """ @@ -464,18 +560,18 @@ def test_list(self): Active: True Groups: default """ - self.assertMultiLineEqual(result.output, - textwrap.dedent(output).lstrip()) + self.assertMultiLineEqual(result.output, textwrap.dedent(output).lstrip()) def test_grant_admin(self): - u = self.factory.create_user(name='Fred Foobar', - email='foobar@example.com', - org=self.factory.org, - group_ids=[self.factory.default_group.id]) + u = self.factory.create_user( + name="Fred Foobar", + email="foobar@example.com", + org=self.factory.org, + group_ids=[self.factory.default_group.id], + ) runner = CliRunner() - result = runner.invoke(manager, ['users', 'grant_admin', 'foobar@example.com']) + result = runner.invoke(manager, ["users", "grant_admin", "foobar@example.com"]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) db.session.add(u) - self.assertEqual(u.group_ids, [u.org.default_group.id, - u.org.admin_group.id]) + self.assertEqual(u.group_ids, [u.org.default_group.id, u.org.admin_group.id]) diff --git a/tests/test_configuration.py b/tests/test_configuration.py index ba625fe2e5..e3b90ed8b7 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -8,24 +8,18 @@ configuration_schema = { "type": "object", "properties": { - "a": { - "type": "integer" - }, - "e": { - "type": "integer" - }, - "b": { - "type": "string" - } + "a": {"type": "integer"}, + "e": {"type": "integer"}, + "b": {"type": "string"}, }, "required": ["a"], - "secret": ["b"] + "secret": ["b"], } class TestConfigurationToJson(TestCase): def setUp(self): - self.config = {'a': 1, 'b': 'test'} + self.config = {"a": 1, "b": "test"} self.container = ConfigurationContainer(self.config, configuration_schema) def test_returns_plain_dict(self): @@ -33,51 +27,55 @@ def test_returns_plain_dict(self): def test_raises_exception_when_no_schema_set(self): self.container.set_schema(None) - self.assertRaises(RuntimeError, lambda: self.container.to_dict(mask_secrets=True)) + self.assertRaises( + RuntimeError, lambda: self.container.to_dict(mask_secrets=True) + ) def test_returns_dict_with_masked_secrets(self): d = self.container.to_dict(mask_secrets=True) - self.assertEqual(d['a'], self.config['a']) - self.assertNotEqual(d['b'], self.config['b']) + self.assertEqual(d["a"], self.config["a"]) + self.assertNotEqual(d["b"], self.config["b"]) - self.assertEqual(self.config['b'], self.container['b']) + self.assertEqual(self.config["b"], self.container["b"]) class TestConfigurationUpdate(TestCase): def setUp(self): - self.config = {'a': 1, 'b': 'test'} + self.config = {"a": 1, "b": "test"} self.container = ConfigurationContainer(self.config, configuration_schema) def test_rejects_invalid_new_config(self): - self.assertRaises(ValidationError, lambda: self.container.update({'c': 3})) + self.assertRaises(ValidationError, lambda: self.container.update({"c": 3})) def test_fails_if_no_schema_set(self): self.container.set_schema(None) - self.assertRaises(RuntimeError, lambda: self.container.update({'c': 3})) + self.assertRaises(RuntimeError, lambda: self.container.update({"c": 3})) def test_ignores_secret_placehodler(self): self.container.update(self.container.to_dict(mask_secrets=True)) - self.assertEqual(self.container['b'], self.config['b']) + self.assertEqual(self.container["b"], self.config["b"]) def test_updates_secret(self): - new_config = {'a': 2, 'b': 'new'} + new_config = {"a": 2, "b": "new"} self.container.update(new_config) self.assertDictEqual(self.container._config, new_config) def test_doesnt_leave_leftovers(self): - container = ConfigurationContainer({'a': 1, 'b': 'test', 'e': 3}, configuration_schema) + container = ConfigurationContainer( + {"a": 1, "b": "test", "e": 3}, configuration_schema + ) new_config = container.to_dict(mask_secrets=True) - new_config.pop('e') + new_config.pop("e") container.update(new_config) - self.assertEqual(container['a'], 1) - self.assertEqual('test', container['b']) - self.assertNotIn('e', container) + self.assertEqual(container["a"], 1) + self.assertEqual("test", container["b"]) + self.assertNotIn("e", container) def test_works_for_schema_without_secret(self): secretless = configuration_schema.copy() - secretless.pop('secret') - container = ConfigurationContainer({'a': 1, 'b': 'test', 'e': 3}, secretless) - container.update({'a': 2}) - self.assertEqual(container['a'], 2) + secretless.pop("secret") + container = ConfigurationContainer({"a": 1, "b": "test", "e": 3}, secretless) + container.update({"a": 2}) + self.assertEqual(container["a"], 2) diff --git a/tests/test_handlers.py b/tests/test_handlers.py index 7d8e872617..162c9f5e87 100644 --- a/tests/test_handlers.py +++ b/tests/test_handlers.py @@ -14,7 +14,7 @@ def test_returns_404_when_not_unauthenticated(self): def test_returns_content_when_authenticated(self): for path in self.paths: - rv = self.make_request('get', path, is_json=False) + rv = self.make_request("get", path, is_json=False) self.assertEqual(200, rv.status_code) @@ -22,7 +22,7 @@ class TestAuthentication(BaseTestCase): def test_responds_with_success_for_signed_in_user(self): with self.client as c: with c.session_transaction() as sess: - sess['user_id'] = self.factory.user.get_id() + sess["user_id"] = self.factory.user.get_id() rv = self.client.get("/default/") self.assertEqual(200, rv.status_code) @@ -34,7 +34,7 @@ def test_redirects_for_nonsigned_in_user(self): def test_redirects_for_invalid_session_identifier(self): with self.client as c: with c.session_transaction() as sess: - sess['user_id'] = 100 + sess["user_id"] = 100 rv = self.client.get("/default/") self.assertEqual(302, rv.status_code) @@ -42,14 +42,19 @@ def test_redirects_for_invalid_session_identifier(self): class PingTest(BaseTestCase): def test_ping(self): - rv = self.client.get('/ping') + rv = self.client.get("/ping") self.assertEqual(200, rv.status_code) - self.assertEqual(b'PONG.', rv.data) + self.assertEqual(b"PONG.", rv.data) class IndexTest(BaseTestCase): def setUp(self): - self.paths = ['/default/', '/default/dashboard/example', '/default/queries/1', '/default/admin/status'] + self.paths = [ + "/default/", + "/default/dashboard/example", + "/default/queries/1", + "/default/admin/status", + ] super(IndexTest, self).setUp() def test_redirect_to_login_when_not_authenticated(self): @@ -59,7 +64,7 @@ def test_redirect_to_login_when_not_authenticated(self): def test_returns_content_when_authenticated(self): for path in self.paths: - rv = self.make_request('get', path, org=False, is_json=False) + rv = self.make_request("get", path, org=False, is_json=False) self.assertEqual(200, rv.status_code) @@ -67,15 +72,17 @@ class StatusTest(BaseTestCase): def test_returns_data_for_super_admin(self): admin = self.factory.create_admin() models.db.session.commit() - rv = self.make_request('get', '/status.json', org=False, user=admin, is_json=False) + rv = self.make_request( + "get", "/status.json", org=False, user=admin, is_json=False + ) self.assertEqual(rv.status_code, 200) def test_returns_403_for_non_admin(self): - rv = self.make_request('get', '/status.json', org=False, is_json=False) + rv = self.make_request("get", "/status.json", org=False, is_json=False) self.assertEqual(rv.status_code, 403) def test_redirects_non_authenticated_user(self): - rv = self.client.get('/status.json') + rv = self.client.get("/status.json") self.assertEqual(rv.status_code, 302) @@ -88,10 +95,10 @@ def setUp(self): class TestLogin(BaseTestCase): def setUp(self): super(TestLogin, self).setUp() - self.factory.org.set_setting('auth_password_login_enabled', True) + self.factory.org.set_setting("auth_password_login_enabled", True) def test_get_login_form(self): - rv = self.client.get('/default/login') + rv = self.client.get("/default/login") self.assertEqual(rv.status_code, 200) def test_get_login_form_remote_auth(self): @@ -102,100 +109,124 @@ def test_get_login_form_remote_auth(self): try: settings.REMOTE_USER_LOGIN_ENABLED = True settings.LDAP_LOGIN_ENABLED = True - rv = self.client.get('/default/login') + rv = self.client.get("/default/login") self.assertEqual(rv.status_code, 200) - self.assertIn('/{}/remote_user/login'.format(self.factory.org.slug), rv.data.decode()) - self.assertIn('/{}/ldap/login'.format(self.factory.org.slug), rv.data.decode()) + self.assertIn( + "/{}/remote_user/login".format(self.factory.org.slug), rv.data.decode() + ) + self.assertIn( + "/{}/ldap/login".format(self.factory.org.slug), rv.data.decode() + ) finally: settings.REMOTE_USER_LOGIN_ENABLED = old_remote_user_enabled settings.LDAP_LOGIN_ENABLED = old_ldap_login_enabled def test_submit_non_existing_user(self): - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': 'arik', 'password': 'password'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", data={"email": "arik", "password": "password"} + ) self.assertEqual(rv.status_code, 200) self.assertFalse(login_user_mock.called) def test_submit_correct_user_and_password(self): user = self.factory.user - user.hash_password('password') + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", data={"email": user.email, "password": "password"} + ) self.assertEqual(rv.status_code, 302) login_user_mock.assert_called_with(user, remember=False) def test_submit_case_insensitive_user_and_password(self): user = self.factory.user - user.hash_password('password') + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': user.email.upper(), 'password': 'password'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", + data={"email": user.email.upper(), "password": "password"}, + ) self.assertEqual(rv.status_code, 302) login_user_mock.assert_called_with(user, remember=False) def test_submit_correct_user_and_password_and_remember_me(self): user = self.factory.user - user.hash_password('password') + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': user.email, 'password': 'password', 'remember': True}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", + data={"email": user.email, "password": "password", "remember": True}, + ) self.assertEqual(rv.status_code, 302) login_user_mock.assert_called_with(user, remember=True) def test_submit_correct_user_and_password_with_next(self): user = self.factory.user - user.hash_password('password') + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login?next=/test', - data={'email': user.email, 'password': 'password'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login?next=/test", + data={"email": user.email, "password": "password"}, + ) self.assertEqual(rv.status_code, 302) - self.assertEqual(rv.location, 'http://localhost/test') + self.assertEqual(rv.location, "http://localhost/test") login_user_mock.assert_called_with(user, remember=False) def test_submit_incorrect_user(self): - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': 'non-existing', 'password': 'password'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", data={"email": "non-existing", "password": "password"} + ) self.assertEqual(rv.status_code, 200) self.assertFalse(login_user_mock.called) def test_submit_incorrect_password(self): user = self.factory.user - user.hash_password('password') + user.hash_password("password") self.db.session.add(user) self.db.session.commit() - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={ - 'email': user.email, 'password': 'badbadpassword'}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", + data={"email": user.email, "password": "badbadpassword"}, + ) self.assertEqual(rv.status_code, 200) self.assertFalse(login_user_mock.called) def test_submit_empty_password(self): user = self.factory.user - with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/default/login', data={'email': user.email, 'password': ''}) + with patch("redash.handlers.authentication.login_user") as login_user_mock: + rv = self.client.post( + "/default/login", data={"email": user.email, "password": ""} + ) self.assertEqual(rv.status_code, 200) self.assertFalse(login_user_mock.called) def test_user_already_loggedin(self): - with authenticated_user(self.client), patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.get('/default/login') + with authenticated_user(self.client), patch( + "redash.handlers.authentication.login_user" + ) as login_user_mock: + rv = self.client.get("/default/login") self.assertEqual(rv.status_code, 302) self.assertFalse(login_user_mock.called) @@ -203,15 +234,15 @@ def test_user_already_loggedin(self): class TestLogout(BaseTestCase): def test_logout_when_not_loggedin(self): with self.app.test_client() as c: - rv = c.get('/default/logout') + rv = c.get("/default/logout") self.assertEqual(rv.status_code, 302) self.assertFalse(current_user.is_authenticated) def test_logout_when_loggedin(self): with self.app.test_client() as c, authenticated_user(c, user=self.factory.user): - rv = c.get('/default/') + rv = c.get("/default/") self.assertTrue(current_user.is_authenticated) - rv = c.get('/default/logout') + rv = c.get("/default/logout") self.assertEqual(rv.status_code, 302) self.assertFalse(current_user.is_authenticated) @@ -219,88 +250,75 @@ def test_logout_when_loggedin(self): class TestQuerySnippet(BaseTestCase): def test_create(self): res = self.make_request( - 'post', - '/api/query_snippets', - data={'trigger': 'x', 'description': 'y', 'snippet': 'z'}, - user=self.factory.user) + "post", + "/api/query_snippets", + data={"trigger": "x", "description": "y", "snippet": "z"}, + user=self.factory.user, + ) self.assertEqual( - project(res.json, ['id', 'trigger', 'description', 'snippet']), { - 'id': 1, - 'trigger': 'x', - 'description': 'y', - 'snippet': 'z', - }) + project(res.json, ["id", "trigger", "description", "snippet"]), + {"id": 1, "trigger": "x", "description": "y", "snippet": "z"}, + ) qs = models.QuerySnippet.query.one() - self.assertEqual(qs.trigger, 'x') - self.assertEqual(qs.description, 'y') - self.assertEqual(qs.snippet, 'z') + self.assertEqual(qs.trigger, "x") + self.assertEqual(qs.description, "y") + self.assertEqual(qs.snippet, "z") def test_edit(self): qs = models.QuerySnippet( - trigger='a', - description='b', - snippet='c', + trigger="a", + description="b", + snippet="c", user=self.factory.user, - org=self.factory.org + org=self.factory.org, ) models.db.session.add(qs) models.db.session.commit() res = self.make_request( - 'post', - '/api/query_snippets/1', - data={'trigger': 'x', 'description': 'y', 'snippet': 'z'}, - user=self.factory.user) + "post", + "/api/query_snippets/1", + data={"trigger": "x", "description": "y", "snippet": "z"}, + user=self.factory.user, + ) self.assertEqual( - project(res.json, ['id', 'trigger', 'description', 'snippet']), { - 'id': 1, - 'trigger': 'x', - 'description': 'y', - 'snippet': 'z', - }) - self.assertEqual(qs.trigger, 'x') - self.assertEqual(qs.description, 'y') - self.assertEqual(qs.snippet, 'z') + project(res.json, ["id", "trigger", "description", "snippet"]), + {"id": 1, "trigger": "x", "description": "y", "snippet": "z"}, + ) + self.assertEqual(qs.trigger, "x") + self.assertEqual(qs.description, "y") + self.assertEqual(qs.snippet, "z") def test_list(self): qs = models.QuerySnippet( - trigger='x', - description='y', - snippet='z', + trigger="x", + description="y", + snippet="z", user=self.factory.user, - org=self.factory.org + org=self.factory.org, ) models.db.session.add(qs) models.db.session.commit() - res = self.make_request( - 'get', - '/api/query_snippets', - user=self.factory.user) + res = self.make_request("get", "/api/query_snippets", user=self.factory.user) self.assertEqual(res.status_code, 200) data = res.json self.assertEqual(len(data), 1) self.assertEqual( - project(data[0], ['id', 'trigger', 'description', 'snippet']), { - 'id': 1, - 'trigger': 'x', - 'description': 'y', - 'snippet': 'z', - }) - self.assertEqual(qs.trigger, 'x') - self.assertEqual(qs.description, 'y') - self.assertEqual(qs.snippet, 'z') + project(data[0], ["id", "trigger", "description", "snippet"]), + {"id": 1, "trigger": "x", "description": "y", "snippet": "z"}, + ) + self.assertEqual(qs.trigger, "x") + self.assertEqual(qs.description, "y") + self.assertEqual(qs.snippet, "z") def test_delete(self): qs = models.QuerySnippet( - trigger='a', - description='b', - snippet='c', + trigger="a", + description="b", + snippet="c", user=self.factory.user, - org=self.factory.org + org=self.factory.org, ) models.db.session.add(qs) models.db.session.commit() - self.make_request( - 'delete', - '/api/query_snippets/1', - user=self.factory.user) + self.make_request("delete", "/api/query_snippets/1", user=self.factory.user) self.assertEqual(models.QuerySnippet.query.count(), 0) diff --git a/tests/test_models.py b/tests/test_models.py index 533fd29f46..d6c35d3700 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ class DashboardTest(BaseTestCase): def test_appends_suffix_to_slug_when_duplicate(self): d1 = self.factory.create_dashboard() db.session.flush() - self.assertEqual(d1.slug, 'test') + self.assertEqual(d1.slug, "test") d2 = self.factory.create_dashboard(user=d1.user) db.session.flush() @@ -43,8 +43,9 @@ def test_exact_time_that_needs_reschedule(self): yesterday = now - datetime.timedelta(days=1) scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) - self.assertTrue(models.should_schedule_next(yesterday, now, "86400", - scheduled_time)) + self.assertTrue( + models.should_schedule_next(yesterday, now, "86400", scheduled_time) + ) def test_exact_time_that_doesnt_need_reschedule(self): now = date_parse("2015-10-16 20:10") @@ -54,8 +55,7 @@ def test_exact_time_that_doesnt_need_reschedule(self): def test_exact_time_with_day_change(self): now = utcnow().replace(hour=0, minute=1) - previous = (now - datetime.timedelta(days=2)).replace(hour=23, - minute=59) + previous = (now - datetime.timedelta(days=2)).replace(hour=23, minute=59) schedule = "23:59".format(now.hour + 3) self.assertTrue(models.should_schedule_next(previous, now, "86400", schedule)) @@ -65,8 +65,11 @@ def test_exact_time_every_x_days_that_needs_reschedule(self): three_day_interval = "259200" scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) - self.assertTrue(models.should_schedule_next(four_days_ago, now, three_day_interval, - scheduled_time)) + self.assertTrue( + models.should_schedule_next( + four_days_ago, now, three_day_interval, scheduled_time + ) + ) def test_exact_time_every_x_days_that_doesnt_need_reschedule(self): now = utcnow() @@ -74,15 +77,20 @@ def test_exact_time_every_x_days_that_doesnt_need_reschedule(self): three_day_interval = "259200" scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) - self.assertFalse(models.should_schedule_next(four_days_ago, now, three_day_interval, - scheduled_time)) + self.assertFalse( + models.should_schedule_next( + four_days_ago, now, three_day_interval, scheduled_time + ) + ) def test_exact_time_every_x_days_with_day_change(self): now = utcnow().replace(hour=23, minute=59) previous = (now - datetime.timedelta(days=2)).replace(hour=0, minute=1) schedule = "23:58" three_day_interval = "259200" - self.assertTrue(models.should_schedule_next(previous, now, three_day_interval, schedule)) + self.assertTrue( + models.should_schedule_next(previous, now, three_day_interval, schedule) + ) def test_exact_time_every_x_weeks_that_needs_reschedule(self): # Setup: @@ -94,14 +102,19 @@ def test_exact_time_every_x_weeks_that_needs_reschedule(self): # Expectation: Even though less than 3 weeks have passed since the # last run 3 weeks ago on Thursday, it's overdue since # it should be running on Tuesdays. - this_thursday = utcnow() + datetime.timedelta(days=list(calendar.day_name).index("Thursday") - utcnow().weekday()) + this_thursday = utcnow() + datetime.timedelta( + days=list(calendar.day_name).index("Thursday") - utcnow().weekday() + ) three_weeks_ago = this_thursday - datetime.timedelta(weeks=3) now = this_thursday - datetime.timedelta(days=1) three_week_interval = "1814400" scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) - self.assertTrue(models.should_schedule_next(three_weeks_ago, now, three_week_interval, - scheduled_time, "Tuesday")) + self.assertTrue( + models.should_schedule_next( + three_weeks_ago, now, three_week_interval, scheduled_time, "Tuesday" + ) + ) def test_exact_time_every_x_weeks_that_doesnt_need_reschedule(self): # Setup: @@ -113,33 +126,49 @@ def test_exact_time_every_x_weeks_that_doesnt_need_reschedule(self): # Expectation: Even though more than 3 weeks have passed since the # last run 3 weeks ago on Tuesday, it's not overdue since # it should be running on Thursdays. - this_tuesday = utcnow() + datetime.timedelta(days=list(calendar.day_name).index("Tuesday") - utcnow().weekday()) + this_tuesday = utcnow() + datetime.timedelta( + days=list(calendar.day_name).index("Tuesday") - utcnow().weekday() + ) three_weeks_ago = this_tuesday - datetime.timedelta(weeks=3) now = this_tuesday + datetime.timedelta(days=1) three_week_interval = "1814400" scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) - self.assertFalse(models.should_schedule_next(three_weeks_ago, now, three_week_interval, - scheduled_time, "Thursday")) + self.assertFalse( + models.should_schedule_next( + three_weeks_ago, now, three_week_interval, scheduled_time, "Thursday" + ) + ) def test_backoff(self): now = utcnow() two_hours_ago = now - datetime.timedelta(hours=2) - self.assertTrue(models.should_schedule_next(two_hours_ago, now, "3600", - failures=5)) - self.assertFalse(models.should_schedule_next(two_hours_ago, now, - "3600", failures=10)) + self.assertTrue( + models.should_schedule_next(two_hours_ago, now, "3600", failures=5) + ) + self.assertFalse( + models.should_schedule_next(two_hours_ago, now, "3600", failures=10) + ) def test_next_iteration_overflow(self): now = utcnow() two_hours_ago = now - datetime.timedelta(hours=2) - self.assertFalse(models.should_schedule_next(two_hours_ago, now, "3600", failures=32)) + self.assertFalse( + models.should_schedule_next(two_hours_ago, now, "3600", failures=32) + ) class QueryOutdatedQueriesTest(BaseTestCase): # TODO: this test can be refactored to use mock version of should_schedule_next to simplify it. def test_outdated_queries_skips_unscheduled_queries(self): - query = self.factory.create_query(schedule={'interval':None, 'time': None, 'until':None, 'day_of_week':None}) + query = self.factory.create_query( + schedule={ + "interval": None, + "time": None, + "until": None, + "day_of_week": None, + } + ) query_with_none = self.factory.create_query(schedule=None) queries = models.Query.outdated_queries() @@ -149,8 +178,17 @@ def test_outdated_queries_skips_unscheduled_queries(self): def test_outdated_queries_works_with_ttl_based_schedule(self): two_hours_ago = utcnow() - datetime.timedelta(hours=2) - query = self.factory.create_query(schedule={'interval':'3600', 'time': None, 'until':None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=two_hours_ago) + query = self.factory.create_query( + schedule={ + "interval": "3600", + "time": None, + "until": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query.query_text, retrieved_at=two_hours_ago + ) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -158,8 +196,17 @@ def test_outdated_queries_works_with_ttl_based_schedule(self): def test_outdated_queries_works_scheduled_queries_tracker(self): two_hours_ago = utcnow() - datetime.timedelta(hours=2) - query = self.factory.create_query(schedule={'interval':'3600', 'time': None, 'until':None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query, retrieved_at=two_hours_ago) + query = self.factory.create_query( + schedule={ + "interval": "3600", + "time": None, + "until": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query, retrieved_at=two_hours_ago + ) query.latest_query_data = query_result models.scheduled_queries_executions.update(query.id) @@ -169,8 +216,17 @@ def test_outdated_queries_works_scheduled_queries_tracker(self): def test_skips_fresh_queries(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) - query = self.factory.create_query(schedule={'interval':'3600', 'time': None, 'until':None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=half_an_hour_ago) + query = self.factory.create_query( + schedule={ + "interval": "3600", + "time": None, + "until": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query.query_text, retrieved_at=half_an_hour_ago + ) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -178,8 +234,18 @@ def test_skips_fresh_queries(self): def test_outdated_queries_works_with_specific_time_schedule(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) - query = self.factory.create_query(schedule={'interval':'86400', 'time':half_an_hour_ago.strftime('%H:%M'), 'until':None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) + query = self.factory.create_query( + schedule={ + "interval": "86400", + "time": half_an_hour_ago.strftime("%H:%M"), + "until": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query.query_text, + retrieved_at=half_an_hour_ago - datetime.timedelta(days=1), + ) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -190,14 +256,30 @@ def test_enqueues_query_only_once(self): Only one query per data source with the same text will be reported by Query.outdated_queries(). """ - query = self.factory.create_query(schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}) + query = self.factory.create_query( + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + } + ) query2 = self.factory.create_query( - schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}, query_text=query.query_text, - query_hash=query.query_hash) + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + }, + query_text=query.query_text, + query_hash=query.query_hash, + ) retrieved_at = utcnow() - datetime.timedelta(minutes=10) query_result = self.factory.create_query_result( - retrieved_at=retrieved_at, query_text=query.query_text, - query_hash=query.query_hash) + retrieved_at=retrieved_at, + query_text=query.query_text, + query_hash=query.query_hash, + ) query.latest_query_data = query_result query2.latest_query_data = query_result @@ -209,14 +291,30 @@ def test_enqueues_query_with_correct_data_source(self): Query.outdated_queries() even if they have the same query text. """ query = self.factory.create_query( - schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}, data_source=self.factory.create_data_source()) + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + }, + data_source=self.factory.create_data_source(), + ) query2 = self.factory.create_query( - schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}, query_text=query.query_text, - query_hash=query.query_hash) + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + }, + query_text=query.query_text, + query_hash=query.query_hash, + ) retrieved_at = utcnow() - datetime.timedelta(minutes=10) query_result = self.factory.create_query_result( - retrieved_at=retrieved_at, query_text=query.query_text, - query_hash=query.query_hash) + retrieved_at=retrieved_at, + query_text=query.query_text, + query_hash=query.query_hash, + ) query.latest_query_data = query_result query2.latest_query_data = query_result @@ -230,14 +328,30 @@ def test_enqueues_only_for_relevant_data_source(self): If multiple queries with the same text exist, only ones that are scheduled to be refreshed are reported by Query.outdated_queries(). """ - query = self.factory.create_query(schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}) + query = self.factory.create_query( + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + } + ) query2 = self.factory.create_query( - schedule={'interval':'3600', 'until':None, 'time': None, 'day_of_week':None}, query_text=query.query_text, - query_hash=query.query_hash) + schedule={ + "interval": "3600", + "until": None, + "time": None, + "day_of_week": None, + }, + query_text=query.query_text, + query_hash=query.query_hash, + ) retrieved_at = utcnow() - datetime.timedelta(minutes=10) query_result = self.factory.create_query_result( - retrieved_at=retrieved_at, query_text=query.query_text, - query_hash=query.query_hash) + retrieved_at=retrieved_at, + query_text=query.query_text, + query_hash=query.query_hash, + ) query.latest_query_data = query_result query2.latest_query_data = query_result @@ -248,11 +362,21 @@ def test_failure_extends_schedule(self): Execution failures recorded for a query result in exponential backoff for scheduling future execution. """ - query = self.factory.create_query(schedule={'interval':'60', 'until':None, 'time': None, 'day_of_week':None}, schedule_failures=4) + query = self.factory.create_query( + schedule={ + "interval": "60", + "until": None, + "time": None, + "day_of_week": None, + }, + schedule_failures=4, + ) retrieved_at = utcnow() - datetime.timedelta(minutes=16) query_result = self.factory.create_query_result( - retrieved_at=retrieved_at, query_text=query.query_text, - query_hash=query.query_hash) + retrieved_at=retrieved_at, + query_text=query.query_text, + query_hash=query.query_hash, + ) query.latest_query_data = query_result self.assertEqual(list(models.Query.outdated_queries()), []) @@ -267,8 +391,17 @@ def test_schedule_until_after(self): """ one_day_ago = (utcnow() - datetime.timedelta(days=1)).strftime("%Y-%m-%d") two_hours_ago = utcnow() - datetime.timedelta(hours=2) - query = self.factory.create_query(schedule={'interval':'3600', 'until':one_day_ago, 'time':None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=two_hours_ago) + query = self.factory.create_query( + schedule={ + "interval": "3600", + "until": one_day_ago, + "time": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query.query_text, retrieved_at=two_hours_ago + ) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -281,8 +414,17 @@ def test_schedule_until_before(self): """ one_day_from_now = (utcnow() + datetime.timedelta(days=1)).strftime("%Y-%m-%d") two_hours_ago = utcnow() - datetime.timedelta(hours=2) - query = self.factory.create_query(schedule={'interval':'3600', 'until':one_day_from_now, 'time': None, 'day_of_week':None}) - query_result = self.factory.create_query_result(query=query.query_text, retrieved_at=two_hours_ago) + query = self.factory.create_query( + schedule={ + "interval": "3600", + "until": one_day_from_now, + "time": None, + "day_of_week": None, + } + ) + query_result = self.factory.create_query_result( + query=query.query_text, retrieved_at=two_hours_ago + ) query.latest_query_data = query_result queries = models.Query.outdated_queries() @@ -298,11 +440,19 @@ def test_archive_query_sets_flag(self): self.assertEqual(query.is_archived, True) def test_archived_query_doesnt_return_in_all(self): - query = self.factory.create_query(schedule={'interval':'1', 'until':None, 'time': None, 'day_of_week':None}) + query = self.factory.create_query( + schedule={"interval": "1", "until": None, "time": None, "day_of_week": None} + ) yesterday = utcnow() - datetime.timedelta(days=1) query_result = models.QueryResult.store_result( - query.org_id, query.data_source, query.query_hash, query.query_text, - "1", 123, yesterday) + query.org_id, + query.data_source, + query.query_hash, + query.query_text, + "1", + 123, + yesterday, + ) query.latest_query_data = query_result groups = list(models.Group.query.filter(models.Group.id.in_(query.groups))) @@ -323,7 +473,9 @@ def test_removes_associated_widgets_from_dashboards(self): self.assertEqual(models.Widget.query.get(widget.id), None) def test_removes_scheduling(self): - query = self.factory.create_query(schedule={'interval':'1', 'until':None, 'time': None, 'day_of_week':None}) + query = self.factory.create_query( + schedule={"interval": "1", "until": None, "time": None, "day_of_week": None} + ) query.archive() @@ -364,20 +516,24 @@ def test_returns_only_queries_in_given_groups(self): ds1 = self.factory.create_data_source() ds2 = self.factory.create_data_source() - group1 = models.Group(name="g1", org=ds1.org, permissions=['create', 'view']) - group2 = models.Group(name="g2", org=ds1.org, permissions=['create', 'view']) + group1 = models.Group(name="g1", org=ds1.org, permissions=["create", "view"]) + group2 = models.Group(name="g2", org=ds1.org, permissions=["create", "view"]) q1 = self.factory.create_query(data_source=ds1) q2 = self.factory.create_query(data_source=ds2) - db.session.add_all([ - ds1, ds2, - group1, group2, - q1, q2, - models.DataSourceGroup( - group=group1, data_source=ds1), - models.DataSourceGroup(group=group2, data_source=ds2) - ]) + db.session.add_all( + [ + ds1, + ds2, + group1, + group2, + q1, + q2, + models.DataSourceGroup(group=group1, data_source=ds1), + models.DataSourceGroup(group=group2, data_source=ds2), + ] + ) db.session.flush() self.assertIn(q1, list(models.Query.all_queries([group1.id]))) self.assertNotIn(q2, list(models.Query.all_queries([group1.id]))) @@ -390,11 +546,16 @@ def test_skips_drafts(self): def test_includes_drafts_of_given_user(self): q = self.factory.create_query(is_draft=True) - self.assertIn(q, models.Query.all_queries([self.factory.default_group.id], user_id=q.user_id)) + self.assertIn( + q, + models.Query.all_queries( + [self.factory.default_group.id], user_id=q.user_id + ), + ) def test_order_by_relationship(self): - u1 = self.factory.create_user(name='alice') - u2 = self.factory.create_user(name='bob') + u1 = self.factory.create_user(name="alice") + u2 = self.factory.create_user(name="bob") self.factory.create_query(user=u1) self.factory.create_query(user=u2) db.session.commit() @@ -402,9 +563,9 @@ def test_order_by_relationship(self): # created_at by default base = models.Query.all_queries([self.factory.default_group.id]).order_by(None) qs1 = base.order_by(models.User.name) - self.assertEqual(['alice', 'bob'], [q.user.name for q in qs1]) + self.assertEqual(["alice", "bob"], [q.user.name for q in qs1]) qs2 = base.order_by(models.User.name.desc()) - self.assertEqual(['bob', 'alice'], [q.user.name for q in qs2]) + self.assertEqual(["bob", "alice"], [q.user.name for q in qs2]) class TestGroup(BaseTestCase): @@ -440,8 +601,14 @@ def setUp(self): def test_stores_the_result(self): query_result = models.QueryResult.store_result( - self.data_source.org_id, self.data_source, self.query_hash, - self.query, self.data, self.runtime, self.utcnow) + self.data_source.org_id, + self.data_source, + self.query_hash, + self.query, + self.data, + self.runtime, + self.utcnow, + ) self.assertEqual(query_result._data, self.data) self.assertEqual(query_result.runtime, self.runtime) @@ -457,12 +624,14 @@ def raw_event(self): user = self.factory.user created_at = datetime.datetime.utcfromtimestamp(timestamp) db.session.flush() - raw_event = {"action": "view", - "timestamp": timestamp, - "object_type": "dashboard", - "user_id": user.id, - "object_id": 1, - "org_id": 1} + raw_event = { + "action": "view", + "timestamp": timestamp, + "object_type": "dashboard", + "user_id": user.id, + "object_id": 1, + "org_id": 1, + } return raw_event, user, created_at @@ -479,7 +648,7 @@ def test_records_event(self): def test_records_additional_properties(self): raw_event, _, _ = self.raw_event() - additional_properties = {'test': 1, 'test2': 2, 'whatever': "abc"} + additional_properties = {"test": 1, "test2": 2, "whatever": "abc"} raw_event.update(additional_properties) event = models.Event.record(raw_event) @@ -488,17 +657,19 @@ def test_records_additional_properties(self): def _set_up_dashboard_test(d): - d.g1 = d.factory.create_group(name='First', permissions=['create', 'view']) - d.g2 = d.factory.create_group(name='Second', permissions=['create', 'view']) + d.g1 = d.factory.create_group(name="First", permissions=["create", "view"]) + d.g2 = d.factory.create_group(name="Second", permissions=["create", "view"]) d.ds1 = d.factory.create_data_source() d.ds2 = d.factory.create_data_source() db.session.flush() d.u1 = d.factory.create_user(group_ids=[d.g1.id]) d.u2 = d.factory.create_user(group_ids=[d.g2.id]) - db.session.add_all([ - models.DataSourceGroup(group=d.g1, data_source=d.ds1), - models.DataSourceGroup(group=d.g2, data_source=d.ds2) - ]) + db.session.add_all( + [ + models.DataSourceGroup(group=d.g1, data_source=d.ds1), + models.DataSourceGroup(group=d.g2, data_source=d.ds2), + ] + ) d.q1 = d.factory.create_query(data_source=d.ds1) d.q2 = d.factory.create_query(data_source=d.ds2) d.v1 = d.factory.create_visualization(query_rel=d.q1) @@ -520,46 +691,69 @@ def setUp(self): def test_requires_group_or_user_id(self): d1 = self.factory.create_dashboard() - self.assertNotIn(d1, list(models.Dashboard.all( - d1.user.org, d1.user.group_ids, None))) - l2 = list(models.Dashboard.all( - d1.user.org, [0], d1.user.id)) + self.assertNotIn( + d1, list(models.Dashboard.all(d1.user.org, d1.user.group_ids, None)) + ) + l2 = list(models.Dashboard.all(d1.user.org, [0], d1.user.id)) self.assertIn(d1, l2) def test_returns_dashboards_based_on_groups(self): - self.assertIn(self.w1.dashboard, list(models.Dashboard.all( - self.u1.org, self.u1.group_ids, None))) - self.assertIn(self.w2.dashboard, list(models.Dashboard.all( - self.u2.org, self.u2.group_ids, None))) - self.assertNotIn(self.w1.dashboard, list(models.Dashboard.all( - self.u2.org, self.u2.group_ids, None))) - self.assertNotIn(self.w2.dashboard, list(models.Dashboard.all( - self.u1.org, self.u1.group_ids, None))) + self.assertIn( + self.w1.dashboard, + list(models.Dashboard.all(self.u1.org, self.u1.group_ids, None)), + ) + self.assertIn( + self.w2.dashboard, + list(models.Dashboard.all(self.u2.org, self.u2.group_ids, None)), + ) + self.assertNotIn( + self.w1.dashboard, + list(models.Dashboard.all(self.u2.org, self.u2.group_ids, None)), + ) + self.assertNotIn( + self.w2.dashboard, + list(models.Dashboard.all(self.u1.org, self.u1.group_ids, None)), + ) def test_returns_each_dashboard_once(self): dashboards = list(models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) self.assertEqual(len(dashboards), 2) def test_returns_dashboard_you_have_partial_access_to(self): - self.assertIn(self.w5.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) + self.assertIn( + self.w5.dashboard, + models.Dashboard.all(self.u1.org, self.u1.group_ids, None), + ) def test_returns_dashboards_created_by_user(self): d1 = self.factory.create_dashboard(user=self.u1) db.session.flush() - self.assertIn(d1, list(models.Dashboard.all(self.u1.org, self.u1.group_ids, self.u1.id))) + self.assertIn( + d1, list(models.Dashboard.all(self.u1.org, self.u1.group_ids, self.u1.id)) + ) self.assertIn(d1, list(models.Dashboard.all(self.u1.org, [0], self.u1.id))) - self.assertNotIn(d1, list(models.Dashboard.all(self.u2.org, self.u2.group_ids, self.u2.id))) + self.assertNotIn( + d1, list(models.Dashboard.all(self.u2.org, self.u2.group_ids, self.u2.id)) + ) def test_returns_dashboards_with_text_widgets(self): w1 = self.factory.create_widget(visualization=None) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) + self.assertIn( + w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None) + ) + self.assertIn( + w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.group_ids, None) + ) def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) user = self.factory.create_user(org=self.factory.create_org()) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) - self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.group_ids, None)) + self.assertIn( + w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None) + ) + self.assertNotIn( + w1.dashboard, models.Dashboard.all(user.org, user.group_ids, None) + ) diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 11cffd1be4..0076af04ae 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -5,13 +5,13 @@ from redash import models -MockUser = namedtuple('MockUser', ['permissions', 'group_ids']) +MockUser = namedtuple("MockUser", ["permissions", "group_ids"]) view_only = True class TestHasAccess(BaseTestCase): def test_allows_admin_regardless_of_groups(self): - user = MockUser(['admin'], []) + user = MockUser(["admin"], []) self.assertTrue(has_access({}, user, view_only)) self.assertTrue(has_access({}, user, not view_only)) @@ -29,10 +29,14 @@ def test_allows_if_user_member_in_group_with_full_access(self): def test_allows_if_user_member_in_multiple_groups(self): user = MockUser([], [1, 2, 3]) - self.assertTrue(has_access({1: not view_only, 2: view_only}, user, not view_only)) + self.assertTrue( + has_access({1: not view_only, 2: view_only}, user, not view_only) + ) self.assertFalse(has_access({1: view_only, 2: view_only}, user, not view_only)) self.assertTrue(has_access({1: view_only, 2: view_only}, user, view_only)) - self.assertTrue(has_access({1: not view_only, 2: not view_only}, user, view_only)) + self.assertTrue( + has_access({1: not view_only, 2: not view_only}, user, view_only) + ) def test_not_allows_if_not_enough_permission(self): user = MockUser([], [1]) @@ -40,7 +44,9 @@ def test_not_allows_if_not_enough_permission(self): self.assertFalse(has_access({1: view_only}, user, not view_only)) self.assertFalse(has_access({2: view_only}, user, not view_only)) self.assertFalse(has_access({2: view_only}, user, view_only)) - self.assertFalse(has_access({2: not view_only, 1: view_only}, user, not view_only)) + self.assertFalse( + has_access({2: not view_only, 1: view_only}, user, not view_only) + ) def test_allows_access_to_query_by_query_api_key(self): query = self.factory.create_query() @@ -64,4 +70,4 @@ def test_allows_access_to_query_by_dashboard_api_key(self): api_key = self.factory.create_api_key(object=dashboard).api_key user = models.ApiUser(api_key, None, []) - self.assertTrue(has_access(query, user, view_only)) \ No newline at end of file + self.assertTrue(has_access(query, user, view_only)) diff --git a/tests/test_schedule.py b/tests/test_schedule.py index 13a38b11ac..55af4e3dad 100644 --- a/tests/test_schedule.py +++ b/tests/test_schedule.py @@ -3,6 +3,7 @@ from redash.schedule import rq_scheduler, schedule_periodic_jobs + class TestSchedule(TestCase): def setUp(self): for job in rq_scheduler.get_jobs(): @@ -18,19 +19,18 @@ def foo(): jobs = [job for job in rq_scheduler.get_jobs()] self.assertEqual(len(jobs), 1) - self.assertTrue(jobs[0].func_name.endswith('foo')) - self.assertEqual(jobs[0].meta['interval'], 60) + self.assertTrue(jobs[0].func_name.endswith("foo")) + self.assertEqual(jobs[0].meta["interval"], 60) def test_doesnt_reschedule_an_existing_job(self): def foo(): pass schedule_periodic_jobs([{"func": foo, "interval": 60}]) - with patch('redash.schedule.rq_scheduler.schedule') as schedule: + with patch("redash.schedule.rq_scheduler.schedule") as schedule: schedule_periodic_jobs([{"func": foo, "interval": 60}]) schedule.assert_not_called() - def test_reschedules_a_modified_job(self): def foo(): pass @@ -41,8 +41,8 @@ def foo(): jobs = [job for job in rq_scheduler.get_jobs()] self.assertEqual(len(jobs), 1) - self.assertTrue(jobs[0].func_name.endswith('foo')) - self.assertEqual(jobs[0].meta['interval'], 120) + self.assertTrue(jobs[0].func_name.endswith("foo")) + self.assertEqual(jobs[0].meta["interval"], 120) def test_removes_jobs_that_are_no_longer_defined(self): def foo(): @@ -51,11 +51,13 @@ def foo(): def bar(): pass - schedule_periodic_jobs([{"func": foo, "interval": 60}, {"func": bar, "interval": 90}]) + schedule_periodic_jobs( + [{"func": foo, "interval": 60}, {"func": bar, "interval": 90}] + ) schedule_periodic_jobs([{"func": foo, "interval": 60}]) jobs = [job for job in rq_scheduler.get_jobs()] self.assertEqual(len(jobs), 1) - self.assertTrue(jobs[0].func_name.endswith('foo')) - self.assertEqual(jobs[0].meta['interval'], 60) \ No newline at end of file + self.assertTrue(jobs[0].func_name.endswith("foo")) + self.assertEqual(jobs[0].meta["interval"], 60) diff --git a/tests/test_utils.py b/tests/test_utils.py index 479f26fc7c..a253402ef5 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,46 +1,72 @@ from collections import namedtuple from unittest import TestCase -from redash.utils import (build_url, collect_parameters_from_request, - filter_none, json_dumps, generate_token) +from redash.utils import ( + build_url, + collect_parameters_from_request, + filter_none, + json_dumps, + generate_token, +) -DummyRequest = namedtuple('DummyRequest', ['host', 'scheme']) +DummyRequest = namedtuple("DummyRequest", ["host", "scheme"]) class TestBuildUrl(TestCase): def test_simple_case(self): - self.assertEqual("http://example.com/test", build_url(DummyRequest("", "http"), "example.com", "/test")) + self.assertEqual( + "http://example.com/test", + build_url(DummyRequest("", "http"), "example.com", "/test"), + ) def test_uses_current_request_port(self): - self.assertEqual("http://example.com:5000/test", build_url(DummyRequest("example.com:5000", "http"), "example.com", "/test")) + self.assertEqual( + "http://example.com:5000/test", + build_url(DummyRequest("example.com:5000", "http"), "example.com", "/test"), + ) def test_uses_current_request_schema(self): - self.assertEqual("https://example.com/test", build_url(DummyRequest("example.com", "https"), "example.com", "/test")) + self.assertEqual( + "https://example.com/test", + build_url(DummyRequest("example.com", "https"), "example.com", "/test"), + ) def test_skips_port_for_default_ports(self): - self.assertEqual("https://example.com/test", build_url(DummyRequest("example.com:443", "https"), "example.com", "/test")) - self.assertEqual("http://example.com/test", build_url(DummyRequest("example.com:80", "http"), "example.com", "/test")) - self.assertEqual("https://example.com:80/test", build_url(DummyRequest("example.com:80", "https"), "example.com", "/test")) - self.assertEqual("http://example.com:443/test", build_url(DummyRequest("example.com:443", "http"), "example.com", "/test")) + self.assertEqual( + "https://example.com/test", + build_url(DummyRequest("example.com:443", "https"), "example.com", "/test"), + ) + self.assertEqual( + "http://example.com/test", + build_url(DummyRequest("example.com:80", "http"), "example.com", "/test"), + ) + self.assertEqual( + "https://example.com:80/test", + build_url(DummyRequest("example.com:80", "https"), "example.com", "/test"), + ) + self.assertEqual( + "http://example.com:443/test", + build_url(DummyRequest("example.com:443", "http"), "example.com", "/test"), + ) class TestCollectParametersFromRequest(TestCase): def test_ignores_non_prefixed_values(self): - self.assertEqual({}, collect_parameters_from_request({'test': 1})) + self.assertEqual({}, collect_parameters_from_request({"test": 1})) def test_takes_prefixed_values(self): - self.assertDictEqual({'test': 1, 'something_else': 'test'}, collect_parameters_from_request({'p_test': 1, 'p_something_else': 'test'})) + self.assertDictEqual( + {"test": 1, "something_else": "test"}, + collect_parameters_from_request({"p_test": 1, "p_something_else": "test"}), + ) class TestSkipNones(TestCase): def test_skips_nones(self): - d = { - 'a': 1, - 'b': None - } + d = {"a": 1, "b": None} - self.assertDictEqual(filter_none(d), {'a': 1}) + self.assertDictEqual(filter_none(d), {"a": 1}) class TestJsonDumps(TestCase):