Skip to content

Commit

Permalink
Job management (#70)
Browse files Browse the repository at this point in the history
* job management

* refactor
  • Loading branch information
maciejmaciejko-gid authored Feb 28, 2024
1 parent 43e66d2 commit f464234
Show file tree
Hide file tree
Showing 23 changed files with 187 additions and 80 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## [Unreleased]

- Flink upgrade to 1.17 (minimal required version)
- Handling execution config.
- Job management (stop job with savepoint, initialize with savepoint).
- Drop table between restart.

## [1.3.8] - 2023-02-16

- Support computed / metadata column
Expand Down
8 changes: 4 additions & 4 deletions dbt/adapters/flink/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import yaml
from dbt.adapters.base import Credentials
from dbt.adapters.sql import SQLConnectionManager # type: ignore
from dbt.contracts.connection import Connection
from dbt.contracts.connection import Connection, ConnectionState
from dbt.events import AdapterLogger

from dbt.adapters.flink.handler import FlinkHandler, FlinkCursor
Expand Down Expand Up @@ -74,12 +74,12 @@ def exception_handler(self, sql: str):
raise dbt.exceptions.RuntimeException(str(e))

@classmethod
def open(cls, connection):
def open(cls, connection: Connection):
"""
Receives a connection object and a Credentials object
and moves it to the "open" state.
"""
if connection.state == "open":
if connection.state == ConnectionState.OPEN:
logger.debug("Connection is already open, skipping open.")
return connection

Expand All @@ -95,7 +95,7 @@ def open(cls, connection):
logger.info(f"Session created: {session.session_handle}")
FlinkConnectionManager._store_session_handle(session)

connection.state = "open"
connection.state = ConnectionState.OPEN
connection.handle = FlinkHandler(session)

except Exception as e:
Expand Down
4 changes: 4 additions & 0 deletions dbt/adapters/flink/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class ExecutionConfig:
SAVEPOINT_PATH = "execution.savepoint.path"
STATE_PATH = "state.savepoints.dir"
JOB_NAME = "pipeline.name"
82 changes: 73 additions & 9 deletions dbt/adapters/flink/handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
from datetime import datetime
from time import sleep
from typing import Sequence, Tuple, Optional, Any, List
from typing import Dict, Sequence, Tuple, Optional, Any, List

from dbt.events import AdapterLogger

from dbt.adapters.flink.query_hints_parser import QueryHints, QueryHintsParser, QueryMode
from dbt.adapters.flink.constants import ExecutionConfig
from dbt.adapters.flink.query_hints_parser import (
QueryHints,
QueryHintsParser,
QueryMode,
)

from flink.sqlgateway.client import FlinkSqlGatewayClient
from flink.sqlgateway.operation import SqlGatewayOperation
from flink.sqlgateway.result_parser import SqlGatewayResult
Expand Down Expand Up @@ -95,21 +101,38 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None:
logger.debug('Preparing statement "{}"'.format(sql))
if bindings is not None:
sql = sql.format(*[self._convert_binding(binding) for binding in bindings])
logger.info('Executing statement "{}"'.format(sql))
self.last_query_hints: QueryHints = QueryHintsParser.parse(sql)
execution_config = self.last_query_hints.execution_config
start_from_savepoint = False
if execution_config:
if not self.last_query_hints.test_query:
savepoint_path = FlinkJobManager(self.session).stop_job(execution_config)
if savepoint_path:
logger.info("f: {}", savepoint_path)
execution_config[ExecutionConfig.SAVEPOINT_PATH] = savepoint_path
start_from_savepoint = True
if not start_from_savepoint:
logger.info("Job starting without savepoint")
execution_config.pop(ExecutionConfig.SAVEPOINT_PATH, None)

if self.last_query_hints.drop_statement:
logger.info("Executing drop statement: {}", self.last_query_hints.drop_statement)
FlinkCursor(self.session).execute(self.last_query_hints.drop_statement)

self._set_query_mode()
operation_handle = FlinkSqlGatewayClient.execute_statement(self.session, sql)
logger.info("Executing statement:\n{}\nExecution config:\n{}", sql, execution_config)
operation_handle = FlinkSqlGatewayClient.execute_statement(
self.session, sql, execution_config
)
status = self._wait_till_finished(operation_handle)
logger.info(
"Statement executed. Status {}, operation handle: {}".format(
status, operation_handle.operation_handle
)
"Statement executed. Status {}, operation handle: {}",
status,
operation_handle.operation_handle,
)
if status == "ERROR":
raise Exception("Statement execution failed")

self.last_query_start_time = self._get_current_timestamp()

self.last_operation = operation_handle

def _convert_binding(self, binding):
Expand Down Expand Up @@ -184,3 +207,44 @@ def __init__(self, session: SqlGatewaySession):

def cursor(self) -> FlinkCursor:
return FlinkCursor(self.session)


class FlinkJobManager:
def __init__(self, session: SqlGatewaySession):
self.session = session

def stop_job(
self, execution_config: Dict[str, str], with_savepoint: bool = True
) -> Optional[str]:
if ExecutionConfig.JOB_NAME not in execution_config:
return None
job_name = execution_config[ExecutionConfig.JOB_NAME]
logger.info("Getting job by name {}", job_name)
job_id = self._get_job_id(job_name)
if job_id:
state_path = execution_config.get(ExecutionConfig.STATE_PATH)
logger.info("Stopping job {} using path {}", job_id, state_path)
path = self._do_stop_job(job_id, with_savepoint, state_path)
logger.info("Job stopped {}", job_id)
return path
return None

def _do_stop_job(
self, job_id: str, with_savepoint: bool, path: Optional[str] = None
) -> Optional[str]:
cursor = FlinkCursor(self.session)
hints = f"/** execution_config('{ExecutionConfig.STATE_PATH}={path}') */ " if path else ""
savepoint_statement = " WITH SAVEPOINT" if with_savepoint else ""
cursor.execute(f"{hints} STOP JOB '{job_id}'{savepoint_statement}")
for result in cursor.fetchall():
return result[0]
return None

def _get_job_id(self, job_name: str) -> Optional[str]:
cursor = FlinkCursor(self.session)
cursor.execute("SHOW JOBS")
for job in cursor.fetchall():
if job_name == job[1]:
if job[2] not in ("FAILED", "FINISHED", "CANCELED"):
return job[0]
return None
11 changes: 11 additions & 0 deletions dbt/adapters/flink/query_hints_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class QueryHints:
fetch_timeout_ms: Optional[int] = None
mode: Optional[QueryMode] = None
test_query: bool = False
execution_config: Optional[Dict[str, str]] = None
drop_statement: Optional[str] = None

def __init__(self, hints=None):
if hints is None:
Expand All @@ -25,6 +27,15 @@ def __init__(self, hints=None):
self.mode = QueryMode(hints["mode"].lower())
if "test_query" in hints:
self.test_query = bool(hints["test_query"])
if "execution_config" in hints:
self.execution_config = {}
for cfg_item in hints["execution_config"].split(";"):
key_val = cfg_item.split("=")
if len(key_val) != 2:
raise RuntimeError(f"Improper format of execution config {key_val}")
self.execution_config[key_val[0]] = key_val[1]
if "drop_statement" in hints:
self.drop_statement = hints["drop_statement"]


class QueryHintsParser:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
{%- set sql_header = config.get('sql_header', none) -%}
{% set connector_properties = config.get('default_connector_properties', {}) %}
{% set _dummy = connector_properties.update(config.get('connector_properties', {})) %}
{% set execution_config = config.get('default_execution_config', {}) %}
{% set _dummy = execution_config.update(config.get('execution_config', {})) %}

{{ sql_header if sql_header is not none }}

create {% if temporary: -%}temporary{%- endif %} table
{% if execution_config %}/** execution_config('{% for cfg_name in execution_config %}{{cfg_name}}={{execution_config[cfg_name]}}{% if not loop.last %};{% endif %}{% endfor %}') */{% endif %}
/** drop_statement('drop {% if temporary: -%}temporary {%- endif %}table if exists {{ this.render() }}') */
create {% if temporary: -%}temporary {%- endif %}table
{{ this.render() }}
{% if type %}/** mode('{{type}}')*/{% endif %}
with (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
{%- set sql_header = config.get('sql_header', none) -%}

{{ sql_header if sql_header is not none }}
create view if not exists /*TODO {{ relation }}*/ {{ this.render() }} {% if type %}/** mode('{{type}}')*/{% endif %} as (
/** drop_statement('drop view if exists {{ this.render() }}') */
create view /*TODO {{ relation }}*/ {{ this.render() }} {% if type %}/** mode('{{type}}')*/{% endif %} as (
{{ sql }}
);
{%- endmacro %}
3 changes: 2 additions & 1 deletion dbt/include/flink/macros/materializations/sources/source.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
{% set watermark_properties = node.config.get('watermark') %}
{% set type = node.config.get('type', None) %}
{% set table_column_ids = node.columns.keys() %}
CREATE TABLE IF NOT EXISTS {{ node.identifier }} {% if type %}/** mode('{{type}}')*/{% endif %} (
/** drop_statement('DROP TABLE IF EXISTS {{ node.identifier }}') */
CREATE TABLE {{ node.identifier }} {% if type %}/** mode('{{type}}')*/{% endif %} (
{% for column_id in table_column_ids %}
{%- if node.columns[column_id]["column_type"] == 'metadata' %} `{{ node.columns[column_id]["name"] }}` {{ node.columns[column_id]["data_type"] }} METADATA {% if node.columns[column_id]["expression"] %} FROM '{{node.columns[column_id]["expression"]}}' {% endif %}
{%- elif node.columns[column_id]["column_type"] == 'computed' %} `{{ node.columns[column_id]["name"] }}` AS {{ node.columns[column_id]["expression"] }}
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: '2'
services:
jobmanager:
image: flink:1.16.0-scala_2.12-java11
image: flink:1.17.2-scala_2.12-java11
command: jobmanager
ports:
- "8081:8081"
Expand All @@ -11,10 +11,10 @@ services:
jobmanager.rpc.address: jobmanager
taskmanager.numberOfTaskSlots: 10
volumes:
- ./opt/flink-sql-connector-kafka-1.16.0.jar:/opt/flink/lib/flink-sql-connector-kafka-1.16.0.jar
- ./opt/flink-sql-connector-kafka-1.17.2.jar:/opt/flink/lib/flink-sql-connector-kafka-1.17.2.jar

taskmanager:
image: flink:1.16.0-scala_2.12-java11
image: flink:1.17.2-scala_2.12-java11
command: taskmanager
depends_on:
- jobmanager
Expand All @@ -24,10 +24,10 @@ services:
jobmanager.rpc.address: jobmanager
taskmanager.numberOfTaskSlots: 10
volumes:
- ./opt/flink-sql-connector-kafka-1.16.0.jar:/opt/flink/lib/flink-sql-connector-kafka-1.16.0.jar
- ./opt/flink-sql-connector-kafka-1.17.2.jar:/opt/flink/lib/flink-sql-connector-kafka-1.17.2.jar

sql-gateway:
image: flink:1.16.0-scala_2.12-java11
image: flink:1.17.2-scala_2.12-java11
ports:
- "8083:8083"
entrypoint: /bin/sh
Expand All @@ -45,7 +45,7 @@ services:
rest.address: jobmanager
sql-gateway.endpoint.rest.address: 0.0.0.0
volumes:
- ./opt/flink-sql-connector-kafka-1.16.0.jar:/opt/flink/lib/flink-sql-connector-kafka-1.16.0.jar
- ./opt/flink-sql-connector-kafka-1.17.2.jar:/opt/flink/lib/flink-sql-connector-kafka-1.17.2.jar

networks:
default:
Expand Down
Binary file not shown.
9 changes: 7 additions & 2 deletions flink/sqlgateway/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
from flink.sqlgateway.operation import SqlGatewayOperation
from flink.sqlgateway.session import SqlGatewaySession
from flink.sqlgateway.config import SqlGatewayConfig
Expand All @@ -10,10 +11,14 @@ def create_session(host: str, port: int, session_name: str) -> SqlGatewaySession
return SqlGatewaySession.create(config)

@staticmethod
def execute_statement(session: SqlGatewaySession, sql: str) -> SqlGatewayOperation:
def execute_statement(
session: SqlGatewaySession, sql: str, execution_config: Dict[str, str] = None
) -> SqlGatewayOperation:
if session.session_handle is None:
raise Exception(
f"Session '{session.config.session_name}' is not created. Call create() method first"
)

return SqlGatewayOperation.execute_statement(session=session, sql=sql)
return SqlGatewayOperation.execute_statement(
session=session, sql=sql, execution_config=execution_config
)
9 changes: 5 additions & 4 deletions flink/sqlgateway/operation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
import json
from typing import Optional
from typing import Optional, Dict
import requests
from flink.sqlgateway.result_parser import SqlGatewayResult, SqlGatewayResultParser
from flink.sqlgateway.session import SqlGatewaySession
Expand All @@ -16,8 +16,10 @@ def __init__(self, session: SqlGatewaySession, operation_handle: str):
self.operation_handle = operation_handle

@staticmethod
def execute_statement(session: SqlGatewaySession, sql: str) -> "SqlGatewayOperation":
statement_request = {"statement": sql}
def execute_statement(
session: SqlGatewaySession, sql: str, execution_config: Dict[str, str] = None
) -> "SqlGatewayOperation":
statement_request = {"executionConfig": execution_config, "statement": sql}

response = requests.post(
url=f"{session.session_endpoint_url()}/statements",
Expand All @@ -26,7 +28,6 @@ def execute_statement(session: SqlGatewaySession, sql: str) -> "SqlGatewayOperat
"Content-Type": "application/json",
},
)

print(f"SQL gateway response: {json.dumps(response.json())}")

if response.status_code == 200:
Expand Down
6 changes: 3 additions & 3 deletions flink/sqlgateway/result_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, List, Any
from typing import Dict, List, Any, Optional

from dbt.events import AdapterLogger

Expand All @@ -9,7 +9,7 @@
@dataclass
class SqlGatewayResult:
rows: List[Dict[str, Any]]
next_result_url: str
next_result_url: Optional[str]
column_names: List[str]
is_end_of_stream: bool

Expand All @@ -31,9 +31,9 @@ class SqlGatewayResultParser:
def parse_result(data: Dict[str, Any]) -> SqlGatewayResult:
columns = data["results"]["columns"]
rows: List[Dict[str, Any]] = []
next_result_url = data["nextResultUri"]
column_names: List[str] = list(map(lambda c: c["name"], columns))
is_end_of_steam = data["resultType"] == "EOS"
next_result_url = data.get("nextResultUri", None)

logger.info(f"SQL rows returned: {data['results']['data']}")
for record in data["results"]["data"]:
Expand Down
10 changes: 10 additions & 0 deletions tests/adapters/flink/test_query_hints_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_test_query(self):
hints = QueryHintsParser.parse(sql)
assert hints.test_query is True

def test_execution_config(self):
sql = "/** execution_config('key_a=value_a;key_b=value_b') */ select 1"
hints = QueryHintsParser.parse(sql)
assert hints.execution_config == {"key_a": "value_a", "key_b": "value_b"}

def test_drop_statement(self):
sql = "/** drop_statement('DROP TABLE IF EXISTS TABLE_A') */ CREATE TABLE TABLE_A (id STRING)"
hints = QueryHintsParser.parse(sql)
assert hints.drop_statement == "DROP TABLE IF EXISTS TABLE_A"

def test_multiple_hints_in_single_comment(self):
sql = "select /** fetch_max(10) fetch_timeout_ms(1000) */ from input"
hints = QueryHintsParser.parse(sql)
Expand Down
13 changes: 7 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import os

# import json

# Import the fuctional fixtures as a plugin
Expand All @@ -13,10 +14,10 @@
@pytest.fixture(scope="class")
def dbt_profile_target():
return {
'type': 'flink',
'threads': 1,
'host': os.getenv('FLINK_SQL_GATEWAY_HOST', '127.0.0.1'),
'port': int(os.getenv('FLINK_SQL_GATEWAY_PORT', '8083')),
'session_name': os.getenv('SESSION_NAME', 'test_session'),
'database': os.getenv('DATABASE_NAME', 'test_db'),
"type": "flink",
"threads": 1,
"host": os.getenv("FLINK_SQL_GATEWAY_HOST", "127.0.0.1"),
"port": int(os.getenv("FLINK_SQL_GATEWAY_PORT", "8083")),
"session_name": os.getenv("SESSION_NAME", "test_session"),
"database": os.getenv("DATABASE_NAME", "test_db"),
}
Loading

0 comments on commit f464234

Please sign in to comment.