Skip to content

Commit 6d57e65

Browse files
authored
Merge pull request #30 from headout/ft-sql-triggerer
feat: moving db polling to triggers
2 parents dace85f + 8b53218 commit 6d57e65

File tree

3 files changed

+106
-7
lines changed

3 files changed

+106
-7
lines changed

src/operators/deferred_job_result.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from airflow.triggers.temporal import TimeDeltaTrigger
88
from ergo.exceptions import ErgoFailedResultException
99
from ergo.models import ErgoJob, ErgoTask
10+
from ergo.triggers.task_poll import TaskPollTrigger
1011
from sqlalchemy.orm import joinedload
1112
from airflow.triggers.temporal import TimeDeltaTrigger
1213

@@ -18,6 +19,7 @@ def __init__(
1819
self,
1920
pusher_task_id: str,
2021
wait_for_state=list(State.finished),
22+
task_poll_trigger=True,
2123
*args,
2224
**kwargs
2325
):
@@ -26,12 +28,14 @@ def __init__(
2628
kwargs.pop('ti_dict', None)
2729
super().__init__(*args, **kwargs)
2830
self.pusher_task_id = pusher_task_id
31+
self.task_poll_trigger = task_poll_trigger
2932
if not isinstance(wait_for_state, (list, tuple)):
3033
wait_for_state = (wait_for_state,)
3134
self.wait_for_state = list(State.finished)
3235
if self.wait_for_state != wait_for_state:
3336
self.wait_for_state.extend(wait_for_state)
3437

38+
3539
def _get_ergo_task(self, ti_dict, session=None):
3640
return (
3741
session.query(ErgoTask)
@@ -40,17 +44,12 @@ def _get_ergo_task(self, ti_dict, session=None):
4044
).one()
4145

4246
@provide_session
43-
def execute(self, context, session=None, event=None):
44-
ti_dict = context.get('ti_dict', dict())
45-
if not ti_dict:
46-
ti = context['ti']
47-
ti_dict['dag_id'] = ti.dag_id
48-
ti_dict['run_id'] = ti.run_id
47+
def _get_task_status(self, context, ti_dict, session=None):
4948
task = self._get_ergo_task(ti_dict, session=session)
5049
job = task.job
5150

51+
# Fallback from triggerer to worker polling if triggerer fails
5252
while task.state not in self.wait_for_state:
53-
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=15)), method_name="execute")
5453
task = self._get_ergo_task(ti_dict, session=session)
5554
self.log.info('Received task - %s... STATE: %s', str(task), task.state)
5655
job = task.job
@@ -61,14 +60,42 @@ def execute(self, context, session=None, event=None):
6160
self.log.info('Waiting for task "%s" to be queued...', str(task))
6261
self.log.info('Waiting for task "%s" to reach state %s...', str(task), self.wait_for_state)
6362

63+
if self.task_poll_trigger:
64+
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute_complete")
65+
6466
if task.state == State.FAILED:
6567
if job is not None:
6668
self.log.info('Job - (%s)' + (f'responded back at {job.response_at}' if job.response_at else ''), str(job))
6769
raise ErgoFailedResultException(job.result_code, job.error_msg)
6870
else:
6971
raise ErgoFailedResultException(400, "Cron execution failed due to unknown reason")
7072

73+
self.xcom_push(context, "ergo_task_state", "success")
7174
self.log.info('Task - %s reached state %s', str(task), task.state)
75+
76+
77+
def execute(self, context, event=None):
78+
ti_dict = context.get('ti_dict', dict())
79+
self.xcom_push(context, "ergo_task_state", "started")
80+
if not ti_dict:
81+
ti = context['ti']
82+
ti_dict['dag_id'] = ti.dag_id
83+
ti_dict['run_id'] = ti.run_id
84+
if self.task_poll_trigger:
85+
self.log.info('Polling DB task status using task poll trigger. Check triggerer logs to get more state info')
86+
self.defer(trigger=TaskPollTrigger(ti_dict, self.pusher_task_id, self.wait_for_state, 20), method_name="execute_complete")
87+
else:
88+
self._get_task_status(context, ti_dict)
89+
return
90+
91+
def execute_complete(self, context, event=None):
92+
ti_dict = context.get('ti_dict', dict())
93+
if not ti_dict:
94+
ti = context['ti']
95+
ti_dict['dag_id'] = ti.dag_id
96+
ti_dict['run_id'] = ti.run_id
97+
self.log.info("Control transferred from trigger to worker for remaining operator execution")
98+
self._get_task_status(context, ti_dict)
7299
return
73100

74101

src/triggers/__init__.py

Whitespace-only changes.

src/triggers/task_poll.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import asyncio
2+
import os
3+
from concurrent.futures import ThreadPoolExecutor
4+
from airflow.triggers.base import BaseTrigger, TriggerEvent
5+
from airflow.utils.db import provide_session
6+
from airflow.utils.state import State
7+
from ergo.exceptions import ErgoFailedResultException
8+
from ergo.models import ErgoJob, ErgoTask
9+
from sqlalchemy.orm import joinedload
10+
11+
os.environ['PYTHONASYNCIODEBUG'] = '1'
12+
13+
14+
class TaskPollTrigger(BaseTrigger):
15+
16+
def __init__(
17+
self,
18+
ti_dict,
19+
pusher_task_id: str,
20+
wait_for_state=list(State.finished),
21+
poke_interval: float = 20,
22+
):
23+
super().__init__()
24+
self.ti_dict = ti_dict
25+
self.pusher_task_id = pusher_task_id
26+
self.wait_for_state = wait_for_state
27+
self.poke_interval = poke_interval
28+
29+
def serialize(self):
30+
return (
31+
"ergo.triggers.task_poll.TaskPollTrigger",
32+
{
33+
"ti_dict": self.ti_dict,
34+
"pusher_task_id": self.pusher_task_id,
35+
"wait_for_state": self.wait_for_state,
36+
"poke_interval": self.poke_interval,
37+
},
38+
)
39+
40+
async def _get_ergo_task(self, session=None):
41+
return (
42+
session.query(ErgoTask)
43+
.options(joinedload('job'))
44+
.filter_by(ti_task_id=self.pusher_task_id, ti_dag_id=self.ti_dict['dag_id'], ti_run_id=self.ti_dict['run_id'])
45+
).one()
46+
47+
48+
@provide_session
49+
async def _check_task_status(self, session=None):
50+
task = await self._get_ergo_task(session=session)
51+
job = task.job
52+
53+
if task.state not in self.wait_for_state:
54+
self.log.info('Received task - %s... STATE: %s', str(task), task.state)
55+
job = task.job
56+
if job is not None:
57+
self.log.info(
58+
'Job - (%s)' + (f'responded back at {job.response_at}' if job.response_at else ''), str(job))
59+
else:
60+
self.log.info('Waiting for task "%s" to be queued...', str(task))
61+
self.log.info('Waiting for task "%s" to reach state %s...', str(task), self.wait_for_state)
62+
return False
63+
64+
self.log.info('Task - %s reached state %s', str(task), task.state)
65+
return True
66+
67+
async def run(self):
68+
while True:
69+
task_completed = await self._check_task_status()
70+
if task_completed:
71+
yield TriggerEvent(True)
72+
await asyncio.sleep(self.poke_interval)

0 commit comments

Comments
 (0)