Skip to content

Commit

Permalink
Merge pull request #90 from dod-advana/UOT-130237
Browse files Browse the repository at this point in the history
Feature/UOT-130237
  • Loading branch information
rha930 committed Feb 4, 2022
2 parents f651614 + 87e715a commit c79215e
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 28 deletions.
87 changes: 65 additions & 22 deletions gamechangerml/api/fastapi/routers/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import tarfile
import shutil
import time

from datetime import datetime
from gamechangerml import DATA_PATH
Expand Down Expand Up @@ -251,20 +252,25 @@ async def create_LTR_model(response: Response):
"""
number_files = 0
resp = None
try:
model = []
model = []

def ltr_process():
try:

def ltr_process():
pipeline.create_ltr()
processmanager.update_status(processmanager.ltr_creation, 1, 1)
except Exception as e:
logger.warning(e)
logger.warning(f"There is an issue with LTR creation")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
processmanager.update_status(processmanager.ltr_creation, failed=True)

ltr_thread = MlThread(ltr_process)
ltr_thread = MlThread(ltr_process)
ltr_thread.start()
processmanager.running_threads[ltr_thread.ident] = ltr_thread
processmanager.update_status(processmanager.ltr_creation, 0, 1,thread_id=ltr_thread.ident)

ltr_thread.start()

except Exception as e:
logger.warning(e)
logger.warning(f"There is an issue with LTR creation")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return response.status_code


Expand Down Expand Up @@ -313,7 +319,6 @@ async def download(response: Response):
Args:
Returns:
"""
processmanager.update_status(processmanager.s3_dependency, 0, 1)
def download_s3_thread():
try:
logger.info("Attempting to download dependencies from S3")
Expand All @@ -329,6 +334,8 @@ def download_s3_thread():

thread = MlThread(download_s3_thread)
thread.start()
processmanager.running_threads[thread.ident] = thread
processmanager.update_status(processmanager.s3_dependency, 0, 1,thread_id=thread.ident)
return await get_process_status()


Expand All @@ -339,19 +346,23 @@ async def download_s3_file(file_dict:dict, response: Response):
Args:file_dict - dict {"file":(file or folder path),"type":"whether from ml-data or models)}
Returns: process status
"""
processmanager.update_status(processmanager.s3_file_download, 0, 1)
def download_s3_thread():
logger.info(f'downloading file {file_dict["file"]}')
try:

path = "gamechangerml/models/" if file_dict['type'] == "models" else "gamechangerml/"
downloaded_files = utils.get_model_s3(file_dict['file'],f"bronze/gamechanger/{file_dict['type']}/",path)
# downloaded_files = ['gamechangerml/models/20210223.tar.gz']
processmanager.update_status(processmanager.s3_file_download, 0, len(downloaded_files))
logger.info(downloaded_files)

if len(downloaded_files) == 0:
processmanager.update_status(f's3: {file_dict["file"]}',failed=True,message="No files found")
return

processmanager.update_status(f's3: {file_dict["file"]}', 0, len(downloaded_files))
i = 0
for f in downloaded_files:
i+=1
processmanager.update_status(processmanager.s3_file_download, 0,i)
processmanager.update_status(f's3: {file_dict["file"]}', 0,i)
logger.info(f)
if '.tar' in f:
tar = tarfile.open(f)
Expand All @@ -371,6 +382,8 @@ def download_s3_thread():
tar.extractall(path=path, members=[member for member in tar.getmembers() if('.git' not in member.name and '.DS_Store' not in member.name)])
tar.close()

processmanager.update_status(f's3: {file_dict["file"]}', len(downloaded_files),len(downloaded_files))

except PermissionError:
failedExtracts = []
for member in tar.getmembers():
Expand All @@ -380,17 +393,19 @@ def download_s3_thread():
failedExtracts.append(member.name)

logger.warning(f'Could not extract {failedExtracts} with permission errors')
processmanager.update_status(f's3: {file_dict["file"]}',failed=True,message="Permission error not all files extracted")

except Exception as e:
logger.warning(e)
logger.warning(f"Could download {file_dict['file']} from S3")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
processmanager.update_status(processmanager.s3_file_download, failed=True)

processmanager.update_status(processmanager.s3_file_download, len(downloaded_files),len(downloaded_files))
processmanager.update_status(f's3: {file_dict["file"]}', failed=True,message=e)

thread = MlThread(download_s3_thread)
thread.start()
processmanager.running_threads[thread.ident] = thread
processmanager.update_status(f's3: {file_dict["file"]}', 0, 1,thread_id=thread.ident)

return await get_process_status()


Expand Down Expand Up @@ -429,9 +444,8 @@ async def reload_models(model_dict: dict, response: Response):
"""
try:
total = len(model_dict)
processmanager.update_status(processmanager.reloading, 0, total)

# put the reload process on a thread

def reload_thread(model_dict):
try:
progress = 0
Expand Down Expand Up @@ -480,6 +494,8 @@ def reload_thread(model_dict):
args = {"model_dict": model_dict}
thread = MlThread(reload_thread, args)
thread.start()
processmanager.running_threads[thread.ident] = thread
processmanager.update_status(processmanager.reloading, 0, total,thread_id=thread.ident)
except Exception as e:
logger.warning(e)

Expand All @@ -502,12 +518,16 @@ async def download_corpus(corpus_dict: dict, response: Response):
corpus_dict = S3_CORPUS_PATH
args = {"s3_corpus_dir": corpus_dict["corpus"], "output_dir": CORPUS_DIR}
logger.info(args)
processmanager.update_status(processmanager.corpus_download)
corpus_thread = MlThread(utils.get_s3_corpus, args)
corpus_thread.start()
except:
processmanager.running_threads[corpus_thread.ident] = corpus_thread
processmanager.update_status(processmanager.corpus_download,0,1,thread_id=corpus_thread.ident)
except Exception as e:
logger.warning(f"Could not get corpus from S3")
logger.warning(e)
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
processmanager.update_status(processmanager.corpus_download,failed=True,message=e)

return await get_process_status()


Expand Down Expand Up @@ -664,8 +684,31 @@ async def train_model(model_dict: dict, response: Response):
# Set the training method to be loaded onto the thread
training_thread = MlThread(training_method, args={"model_dict": model_dict})
training_thread.start()

processmanager.running_threads[training_thread.ident] = training_thread
processmanager.update_status(processmanager.training, 0, 1,thread_id=training_thread.ident)
except:
logger.warning(f"Could not train/evaluate the model")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
processmanager.update_status(processmanager.training,failed=True)

return await get_process_status()


@router.post("/stopProcess")
async def stop_process(thread_dict: dict, response: Response):
"""stop_process - endpoint for stopping a process in a thread
Args:
thread_dict: dict; {"thread_id":(int of thread id), "process":(name of the process so we can also update it in redis)}
Response: Response class; for status codes(apart of fastapi do not need to pass param)
Returns:
Stopped thread id
"""
logger.info(processmanager.running_threads)
with processmanager.thread_lock:
if thread_dict['thread_id'] in processmanager.running_threads:
processmanager.running_threads[thread_dict['thread_id']].kill()
del processmanager.running_threads[thread_dict['thread_id']]
processmanager.update_status(thread_dict['process'],failed=True,message='Killed by user')


return {'stopped':thread_dict['thread_id']}
30 changes: 24 additions & 6 deletions gamechangerml/api/utils/processmanager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import threading
from datetime import datetime
from gamechangerml.api.utils.redisdriver import CacheVariable

from gamechangerml.api.fastapi.settings import logger
# Process Keys
clear_corpus = "corpus: corpus_download"
corpus_download = "corpus: corpus_download"
Expand All @@ -14,6 +14,8 @@
ltr_creation = "models: ltr_creation"
topics_creation = "models: topics_creation"

running_threads = {}

# the dictionary that holds all the progress values
try:
PROCESS_STATUS = CacheVariable("process_status", True)
Expand All @@ -40,7 +42,8 @@
if COMPLETED_PROCESS.value == None:
COMPLETED_PROCESS.value = []

def update_status(key, progress=0, total=100, message="", failed=False, completed_max = 20):

def update_status(key, progress=0, total=100, message="", failed=False,thread_id="", completed_max = 20):

try:
if progress == total or failed:
Expand All @@ -55,21 +58,36 @@ def update_status(key, progress=0, total=100, message="", failed=False, complete
with thread_lock:
if key in PROCESS_STATUS.value:
temp = PROCESS_STATUS.value
temp.pop(key, None)
temp["flags"][key] = False
tempProcess = temp.pop(key, None)
if key in temp["flags"]:
temp["flags"][key] = False
PROCESS_STATUS.value = temp
if tempProcess['thread_id'] in running_threads:
del running_threads[tempProcess['thread_id']]
if not failed:
completed_list = COMPLETED_PROCESS.value
if len(completed_list) == completed_max :
completed_list.pop(0)
completed_list.append(completed)
COMPLETED_PROCESS.value = completed_list
else:
completed['date'] = 'Failed'
completed_list = COMPLETED_PROCESS.value
completed_list.append(completed)
COMPLETED_PROCESS.value = completed_list
else:
status = {"progress": progress, "total": total}
with thread_lock:
status_dict = PROCESS_STATUS.value
status_dict[key] = status
status_dict["flags"][key] = True

if key not in status_dict:
status['thread_id'] = thread_id
status_dict[key] = status
else:
status_dict[key].update(status)

if key in status_dict["flags"]:
status_dict["flags"][key] = True
PROCESS_STATUS.value = status_dict
except Exception as e:
print(e)
Expand Down
21 changes: 21 additions & 0 deletions gamechangerml/api/utils/threaddriver.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,41 @@
import threading
import json
import sys
from gamechangerml.api.utils.logger import logger
from gamechangerml.api.utils import processmanager

# A class that takes in a function and a dictionary of arguments.
# The keys in args have to match the parameters in the function.
class MlThread(threading.Thread):
def __init__(self, function, args = {}):
super(MlThread, self).__init__()
self.function = function
self.args = args
self.killed = False

def run(self):
try:
sys.settrace(self.globaltrace)
self.function(**self.args)
except Exception as e:
logger.error(e)
logger.info("Thread errored out attempting " + self.function.__name__ + " with parameters: " + json.dumps(self.args))

def globaltrace(self, frame, why, arg):
if why == 'call':
return self.localtrace
else:
return None

def localtrace(self, frame, why, arg):
if self.killed:
if why == 'line':
raise SystemExit()
return self.localtrace

def kill(self):
logger.info(f'killing {self.function}')
self.killed = True

# Pass in a function and args which is an array of dicts
# A way to load mulitple jobs and run them on threads.
Expand Down

0 comments on commit c79215e

Please sign in to comment.