Skip to content

Commit 8debb7a

Browse files
committed
Converted TurbiniaTasks to hash objects in Redis
1 parent 85d4570 commit 8debb7a

File tree

1 file changed

+130
-52
lines changed

1 file changed

+130
-52
lines changed

turbinia/state_manager.py

Lines changed: 130 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,54 @@ def set_client(self, redis_client):
250250
def _validate_data(self, data):
251251
return data
252252

253+
def get_task_legacy(self, task_id: str) -> dict:
254+
"""Returns a dictionary representing a Task object given its ID. This
255+
function is used to get data of old TurbiniaTask objects stored as
256+
string in Redis.
257+
258+
Args:
259+
task_id (str): The ID of the stored task.
260+
261+
Returns:
262+
task_dict (dict): Dict containing task attributes.
263+
"""
264+
try:
265+
return json.loads(self.client.get(task_id))
266+
except redis.RedisError as exception:
267+
error_message = f'Error decoding key {task_id} in Redis'
268+
log.error(f'{error_message}: {exception}')
269+
raise TurbiniaException(error_message) from exception
270+
271+
def get_task(self, task_id: str) -> dict:
272+
"""Returns a dictionary representing a Task object given its ID.
273+
274+
Args:
275+
task_id (str): The ID of the stored task.
276+
277+
Returns:
278+
task_dict (dict): Dict containing task attributes.
279+
"""
280+
task_key = ':'.join(('TurbiniaEvidence', task_id))
281+
282+
if self.get_key_type(task_key) == 'string':
283+
task_dict = self.get_task_legacy(task_id)
284+
else:
285+
task_dict = {}
286+
for attribute_name, attribute_value in self.iterate_attributes(
287+
task_key):
288+
task_dict[attribute_name] = attribute_value
289+
290+
if task_dict.get('last_update'):
291+
task_dict['last_update'] = datetime.strptime(
292+
task_dict.get('last_update'), DATETIME_FORMAT)
293+
if task_dict.get('run_time'):
294+
task_dict['run_time'] = timedelta(seconds=task_dict['run_time'])
295+
296+
return task_dict
297+
253298
def get_task_data(
254-
self, instance, days=0, task_id=None, request_id=None, group_id=None,
255-
user=None):
299+
self, instance: str, days: int=0, task_id: str=None,
300+
request_id: str=None, group_id: str=None, user: str=None):
256301
"""Gets task data from Redis.
257302
258303
Args:
@@ -267,77 +312,91 @@ def get_task_data(
267312
Returns:
268313
List of Task dict objects.
269314
"""
270-
tasks = [
271-
json.loads(self.client.get(task))
272-
for task in self.client.scan_iter('TurbiniaTask:*')
273-
if json.loads(self.client.get(task)).get('instance') == instance or
274-
not instance
275-
]
276-
277-
# Convert relevant date attributes back into dates/timedeltas
278-
for task in tasks:
279-
if task.get('last_update'):
280-
task['last_update'] = datetime.strptime(
281-
task.get('last_update'), DATETIME_FORMAT)
282-
if task.get('run_time'):
283-
task['run_time'] = timedelta(seconds=task['run_time'])
315+
if task_id:
316+
task = self.get_task(task_id)
317+
tasks = [task] if not request_id or task.request_id == request_id else []
318+
elif request_id:
319+
request_key = ':'.join(('TurbiniaRequest', request_id))
320+
if self.key_exists(request_key):
321+
task_ids = self.get_attribute(
322+
request_key, 'task_ids', decode_json = True)
323+
tasks = [self.get_task(task_id) for task_id in task_ids]
324+
else:
325+
tasks = [
326+
self.get_data(task_key) for task_key in self.iterate_keys('Task')]
284327

285328
# pylint: disable=no-else-return
329+
if instance:
330+
tasks = [task for task in tasks if task.get('instance') == instance]
286331
if days:
287332
start_time = datetime.now() - timedelta(days=days)
288333
# Redis only supports strings; we convert to/from datetime here and below
289334
tasks = [task for task in tasks if task.get('last_update') > start_time]
290-
if task_id:
291-
tasks = [task for task in tasks if task.get('id') == task_id]
292-
if request_id:
293-
tasks = [task for task in tasks if task.get('request_id') == request_id]
294335
if group_id:
295336
tasks = [task for task in tasks if task.get('group_id') == group_id]
296337
if user:
297338
tasks = [task for task in tasks if task.get('requester') == user]
298339

299340
return tasks
300341

342+
def format_task(self, task):
343+
task_dict = self.get_task_dict(task)
344+
task_dict['last_update'] = task_dict['last_update'].strftime(
345+
DATETIME_FORMAT)
346+
task_dict['start_time'] = task_dict['start_time'].strftime(DATETIME_FORMAT)
347+
if not task_dict.get('status'):
348+
task_dict['status'] = (
349+
f'Task scheduled at {datetime.now().strftime(DATETIME_FORMAT)}')
350+
if task_dict['run_time']:
351+
task_dict['run_time'] = task_dict['run_time'].total_seconds()
352+
for key, value in task_dict.items():
353+
try:
354+
task_dict[key] = json.dumps(value)
355+
except (TypeError, ValueError) as exception:
356+
error_message = f'Error serializing task attribute for task {task.id}.'
357+
log.error(f'{error_message}: {exception}')
358+
raise TurbiniaException(error_message) from exception
359+
return task_dict
360+
361+
def write_new_task(self, task):
362+
"""Writes task into redis.
363+
364+
Args:
365+
task_dict (dict[str]): A dictionary containing the serialized
366+
request attributes that will be saved.
367+
update (bool): Allows overwriting previous key and blocks writing new
368+
ones.
369+
370+
Returns:
371+
request_key (str): The key corresponding to the evidence in Redis
372+
373+
Raises:
374+
TurbiniaException: If the attribute deserialization fails.
375+
"""
376+
log.info(f'Writing new task {task.name:s} into Redis')
377+
task_key = ':'.join(('TurbiniaTask', task.id))
378+
task_dict = self.format_task(task)
379+
self.write_hash_object(task_key, task_dict)
380+
task.state_key = task_key
381+
return task_key
382+
301383
def update_task(self, task):
302384
task.touch()
303-
key = task.state_key
304-
if not key:
385+
task_key = task.state_key
386+
if not task_key:
305387
self.write_new_task(task)
306388
return
307-
stored_task_data = json.loads(self.client.get(f'TurbiniaTask:{task.id}'))
308-
stored_evidence_size = stored_task_data.get('evidence_size')
309-
stored_evidence_id = stored_task_data.get('evidence_id')
389+
stored_task_dict = self.get_task(task_key)
390+
stored_evidence_size = stored_task_dict.get('evidence_size')
391+
stored_evidence_id = stored_task_dict.get('evidence_id')
310392
if not task.evidence_size and stored_evidence_size:
311393
task.evidence_size = stored_evidence_size
312394
if not task.evidence_id and stored_evidence_id:
313395
task.evidence_id = stored_evidence_id
314396
log.info(f'Updating task {task.name:s} in Redis')
315-
task_data = self.get_task_dict(task)
316-
task_data['last_update'] = task_data['last_update'].strftime(
317-
DATETIME_FORMAT)
318-
task_data['start_time'] = task_data['start_time'].strftime(DATETIME_FORMAT)
319-
# Need to use json.dumps, else redis returns single quoted string which
320-
# is invalid json
321-
if not self.client.set(key, json.dumps(task_data)):
322-
log.error(f'Error updating task {task.name:s} in Redis')
323-
324-
def write_new_task(self, task):
325-
key = ':'.join(['TurbiniaTask', task.id])
326-
log.info(f'Writing new task {task.name:s} into Redis')
327-
task_data = self.get_task_dict(task)
328-
task_data['last_update'] = task_data['last_update'].strftime(
329-
DATETIME_FORMAT)
330-
task_data['start_time'] = task_data['start_time'].strftime(DATETIME_FORMAT)
331-
if not task_data.get('status'):
332-
task_data['status'] = 'Task scheduled at {0:s}'.format(
333-
datetime.now().strftime(DATETIME_FORMAT))
334-
if task_data['run_time']:
335-
task_data['run_time'] = task_data['run_time'].total_seconds()
336-
# nx=True prevents overwriting (i.e. no unintentional task clobbering)
337-
if not self.client.set(key, json.dumps(task_data), nx=True):
338-
log.error(f'Error writing new task {task.name:s} into Redis')
339-
task.state_key = key
340-
return key
397+
task_dict = self.format_task(task)
398+
self.write_hash_object(task_key, task_dict)
399+
return task_key
341400

342401
def set_attribute(
343402
self, redis_key: str, attribute_name: str, json_value: str) -> bool:
@@ -459,7 +518,7 @@ def key_exists(self, redis_key) -> bool:
459518
"""Checks if the key is saved in Redis.
460519
461520
Args:
462-
key (str): The key to be checked.
521+
redis_key (str): The key to be checked.
463522
464523
Returns:
465524
exists (bool): Boolean indicating if evidence is saved.
@@ -474,6 +533,25 @@ def key_exists(self, redis_key) -> bool:
474533
log.error(f'{error_message}: {exception}')
475534
raise TurbiniaException(error_message) from exception
476535

536+
def get_key_type(self, redis_key) -> bool:
537+
"""Gets the type of the Redis key.
538+
539+
Args:
540+
redis_key (str): The key to be checked.
541+
542+
Returns:
543+
type (str): Type of the Redis key.
544+
545+
Raises:
546+
TurbiniaException: If Redis fails in getting the type of the key.
547+
"""
548+
try:
549+
return self.client.type(redis_key)
550+
except redis.RedisError as exception:
551+
error_message = f'Error getting type of {redis_key} in Redis'
552+
log.error(f'{error_message}: {exception}')
553+
raise TurbiniaException(error_message) from exception
554+
477555
def add_to_list(self, redis_key, list_name, new_item, repeated=False):
478556
"""Appends new item to a list attribute in a hashed Redis object.
479557

0 commit comments

Comments
 (0)