7
7
from airflow .triggers .temporal import TimeDeltaTrigger
8
8
from ergo .exceptions import ErgoFailedResultException
9
9
from ergo .models import ErgoJob , ErgoTask
10
+ from ergo .triggers .task_poll import TaskPollTrigger
10
11
from sqlalchemy .orm import joinedload
11
12
from airflow .triggers .temporal import TimeDeltaTrigger
12
13
@@ -18,6 +19,7 @@ def __init__(
18
19
self ,
19
20
pusher_task_id : str ,
20
21
wait_for_state = list (State .finished ),
22
+ task_poll_trigger = True ,
21
23
* args ,
22
24
** kwargs
23
25
):
@@ -26,12 +28,14 @@ def __init__(
26
28
kwargs .pop ('ti_dict' , None )
27
29
super ().__init__ (* args , ** kwargs )
28
30
self .pusher_task_id = pusher_task_id
31
+ self .task_poll_trigger = task_poll_trigger
29
32
if not isinstance (wait_for_state , (list , tuple )):
30
33
wait_for_state = (wait_for_state ,)
31
34
self .wait_for_state = list (State .finished )
32
35
if self .wait_for_state != wait_for_state :
33
36
self .wait_for_state .extend (wait_for_state )
34
37
38
+
35
39
def _get_ergo_task (self , ti_dict , session = None ):
36
40
return (
37
41
session .query (ErgoTask )
@@ -40,17 +44,12 @@ def _get_ergo_task(self, ti_dict, session=None):
40
44
).one ()
41
45
42
46
@provide_session
43
- def execute (self , context , session = None , event = None ):
44
- ti_dict = context .get ('ti_dict' , dict ())
45
- if not ti_dict :
46
- ti = context ['ti' ]
47
- ti_dict ['dag_id' ] = ti .dag_id
48
- ti_dict ['run_id' ] = ti .run_id
47
+ def _get_task_status (self , context , ti_dict , session = None ):
49
48
task = self ._get_ergo_task (ti_dict , session = session )
50
49
job = task .job
51
50
51
+ # Fallback from triggerer to worker polling if triggerer fails
52
52
while task .state not in self .wait_for_state :
53
- self .defer (trigger = TimeDeltaTrigger (timedelta (seconds = 15 )), method_name = "execute" )
54
53
task = self ._get_ergo_task (ti_dict , session = session )
55
54
self .log .info ('Received task - %s... STATE: %s' , str (task ), task .state )
56
55
job = task .job
@@ -61,14 +60,42 @@ def execute(self, context, session=None, event=None):
61
60
self .log .info ('Waiting for task "%s" to be queued...' , str (task ))
62
61
self .log .info ('Waiting for task "%s" to reach state %s...' , str (task ), self .wait_for_state )
63
62
63
+ if self .task_poll_trigger :
64
+ self .defer (trigger = TimeDeltaTrigger (timedelta (seconds = 20 )), method_name = "execute_complete" )
65
+
64
66
if task .state == State .FAILED :
65
67
if job is not None :
66
68
self .log .info ('Job - (%s)' + (f'responded back at { job .response_at } ' if job .response_at else '' ), str (job ))
67
69
raise ErgoFailedResultException (job .result_code , job .error_msg )
68
70
else :
69
71
raise ErgoFailedResultException (400 , "Cron execution failed due to unknown reason" )
70
72
73
+ self .xcom_push (context , "ergo_task_state" , "success" )
71
74
self .log .info ('Task - %s reached state %s' , str (task ), task .state )
75
+
76
+
77
+ def execute (self , context , event = None ):
78
+ ti_dict = context .get ('ti_dict' , dict ())
79
+ self .xcom_push (context , "ergo_task_state" , "started" )
80
+ if not ti_dict :
81
+ ti = context ['ti' ]
82
+ ti_dict ['dag_id' ] = ti .dag_id
83
+ ti_dict ['run_id' ] = ti .run_id
84
+ if self .task_poll_trigger :
85
+ self .log .info ('Polling DB task status using task poll trigger. Check triggerer logs to get more state info' )
86
+ self .defer (trigger = TaskPollTrigger (ti_dict , self .pusher_task_id , self .wait_for_state , 20 ), method_name = "execute_complete" )
87
+ else :
88
+ self ._get_task_status (context , ti_dict )
89
+ return
90
+
91
+ def execute_complete (self , context , event = None ):
92
+ ti_dict = context .get ('ti_dict' , dict ())
93
+ if not ti_dict :
94
+ ti = context ['ti' ]
95
+ ti_dict ['dag_id' ] = ti .dag_id
96
+ ti_dict ['run_id' ] = ti .run_id
97
+ self .log .info ("Control transferred from trigger to worker for remaining operator execution" )
98
+ self ._get_task_status (context , ti_dict )
72
99
return
73
100
74
101
0 commit comments