diff --git a/gamechangerml/api/fastapi/routers/controls.py b/gamechangerml/api/fastapi/routers/controls.py index 13b99dca..24fdf36c 100644 --- a/gamechangerml/api/fastapi/routers/controls.py +++ b/gamechangerml/api/fastapi/routers/controls.py @@ -4,6 +4,7 @@ import json import tarfile import shutil +import time from datetime import datetime from gamechangerml import DATA_PATH @@ -251,20 +252,25 @@ async def create_LTR_model(response: Response): """ number_files = 0 resp = None - try: - model = [] + model = [] + + def ltr_process(): + try: - def ltr_process(): pipeline.create_ltr() + processmanager.update_status(processmanager.ltr_creation, 1, 1) + except Exception as e: + logger.warning(e) + logger.warning(f"There is an issue with LTR creation") + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + processmanager.update_status(processmanager.ltr_creation, failed=True) - ltr_thread = MlThread(ltr_process) + ltr_thread = MlThread(ltr_process) + ltr_thread.start() + processmanager.running_threads[ltr_thread.ident] = ltr_thread + processmanager.update_status(processmanager.ltr_creation, 0, 1,thread_id=ltr_thread.ident) - ltr_thread.start() - except Exception as e: - logger.warning(e) - logger.warning(f"There is an issue with LTR creation") - response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR return response.status_code @@ -313,7 +319,6 @@ async def download(response: Response): Args: Returns: """ - processmanager.update_status(processmanager.s3_dependency, 0, 1) def download_s3_thread(): try: logger.info("Attempting to download dependencies from S3") @@ -329,6 +334,8 @@ def download_s3_thread(): thread = MlThread(download_s3_thread) thread.start() + processmanager.running_threads[thread.ident] = thread + processmanager.update_status(processmanager.s3_dependency, 0, 1,thread_id=thread.ident) return await get_process_status() @@ -339,19 +346,23 @@ async def download_s3_file(file_dict:dict, response: Response): Args:file_dict - dict {"file":(file or folder path),"type":"whether from ml-data or models)} Returns: process status """ - processmanager.update_status(processmanager.s3_file_download, 0, 1) def download_s3_thread(): logger.info(f'downloading file {file_dict["file"]}') try: path = "gamechangerml/models/" if file_dict['type'] == "models" else "gamechangerml/" downloaded_files = utils.get_model_s3(file_dict['file'],f"bronze/gamechanger/{file_dict['type']}/",path) - # downloaded_files = ['gamechangerml/models/20210223.tar.gz'] - processmanager.update_status(processmanager.s3_file_download, 0, len(downloaded_files)) + logger.info(downloaded_files) + + if len(downloaded_files) == 0: + processmanager.update_status(f's3: {file_dict["file"]}',failed=True,message="No files found") + return + + processmanager.update_status(f's3: {file_dict["file"]}', 0, len(downloaded_files)) i = 0 for f in downloaded_files: i+=1 - processmanager.update_status(processmanager.s3_file_download, 0,i) + processmanager.update_status(f's3: {file_dict["file"]}', 0,i) logger.info(f) if '.tar' in f: tar = tarfile.open(f) @@ -371,6 +382,8 @@ def download_s3_thread(): tar.extractall(path=path, members=[member for member in tar.getmembers() if('.git' not in member.name and '.DS_Store' not in member.name)]) tar.close() + processmanager.update_status(f's3: {file_dict["file"]}', len(downloaded_files),len(downloaded_files)) + except PermissionError: failedExtracts = [] for member in tar.getmembers(): @@ -380,17 +393,19 @@ def download_s3_thread(): failedExtracts.append(member.name) logger.warning(f'Could not extract {failedExtracts} with permission errors') + processmanager.update_status(f's3: {file_dict["file"]}',failed=True,message="Permission error not all files extracted") except Exception as e: logger.warning(e) logger.warning(f"Could download {file_dict['file']} from S3") response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - processmanager.update_status(processmanager.s3_file_download, failed=True) - - processmanager.update_status(processmanager.s3_file_download, len(downloaded_files),len(downloaded_files)) + processmanager.update_status(f's3: {file_dict["file"]}', failed=True,message=e) thread = MlThread(download_s3_thread) thread.start() + processmanager.running_threads[thread.ident] = thread + processmanager.update_status(f's3: {file_dict["file"]}', 0, 1,thread_id=thread.ident) + return await get_process_status() @@ -429,9 +444,8 @@ async def reload_models(model_dict: dict, response: Response): """ try: total = len(model_dict) - processmanager.update_status(processmanager.reloading, 0, total) + # put the reload process on a thread - def reload_thread(model_dict): try: progress = 0 @@ -480,6 +494,8 @@ def reload_thread(model_dict): args = {"model_dict": model_dict} thread = MlThread(reload_thread, args) thread.start() + processmanager.running_threads[thread.ident] = thread + processmanager.update_status(processmanager.reloading, 0, total,thread_id=thread.ident) except Exception as e: logger.warning(e) @@ -502,12 +518,16 @@ async def download_corpus(corpus_dict: dict, response: Response): corpus_dict = S3_CORPUS_PATH args = {"s3_corpus_dir": corpus_dict["corpus"], "output_dir": CORPUS_DIR} logger.info(args) - processmanager.update_status(processmanager.corpus_download) corpus_thread = MlThread(utils.get_s3_corpus, args) corpus_thread.start() - except: + processmanager.running_threads[corpus_thread.ident] = corpus_thread + processmanager.update_status(processmanager.corpus_download,0,1,thread_id=corpus_thread.ident) + except Exception as e: logger.warning(f"Could not get corpus from S3") + logger.warning(e) response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + processmanager.update_status(processmanager.corpus_download,failed=True,message=e) + return await get_process_status() @@ -664,8 +684,31 @@ async def train_model(model_dict: dict, response: Response): # Set the training method to be loaded onto the thread training_thread = MlThread(training_method, args={"model_dict": model_dict}) training_thread.start() - + processmanager.running_threads[training_thread.ident] = training_thread + processmanager.update_status(processmanager.training, 0, 1,thread_id=training_thread.ident) except: logger.warning(f"Could not train/evaluate the model") response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + processmanager.update_status(processmanager.training,failed=True) + return await get_process_status() + + +@router.post("/stopProcess") +async def stop_process(thread_dict: dict, response: Response): + """stop_process - endpoint for stopping a process in a thread + Args: + thread_dict: dict; {"thread_id":(int of thread id), "process":(name of the process so we can also update it in redis)} + Response: Response class; for status codes(apart of fastapi do not need to pass param) + Returns: + Stopped thread id + """ + logger.info(processmanager.running_threads) + with processmanager.thread_lock: + if thread_dict['thread_id'] in processmanager.running_threads: + processmanager.running_threads[thread_dict['thread_id']].kill() + del processmanager.running_threads[thread_dict['thread_id']] + processmanager.update_status(thread_dict['process'],failed=True,message='Killed by user') + + + return {'stopped':thread_dict['thread_id']} diff --git a/gamechangerml/api/utils/processmanager.py b/gamechangerml/api/utils/processmanager.py index c4488092..59e2e3c2 100644 --- a/gamechangerml/api/utils/processmanager.py +++ b/gamechangerml/api/utils/processmanager.py @@ -1,7 +1,7 @@ import threading from datetime import datetime from gamechangerml.api.utils.redisdriver import CacheVariable - +from gamechangerml.api.fastapi.settings import logger # Process Keys clear_corpus = "corpus: corpus_download" corpus_download = "corpus: corpus_download" @@ -14,6 +14,8 @@ ltr_creation = "models: ltr_creation" topics_creation = "models: topics_creation" +running_threads = {} + # the dictionary that holds all the progress values try: PROCESS_STATUS = CacheVariable("process_status", True) @@ -40,7 +42,8 @@ if COMPLETED_PROCESS.value == None: COMPLETED_PROCESS.value = [] -def update_status(key, progress=0, total=100, message="", failed=False, completed_max = 20): + +def update_status(key, progress=0, total=100, message="", failed=False,thread_id="", completed_max = 20): try: if progress == total or failed: @@ -55,21 +58,36 @@ def update_status(key, progress=0, total=100, message="", failed=False, complete with thread_lock: if key in PROCESS_STATUS.value: temp = PROCESS_STATUS.value - temp.pop(key, None) - temp["flags"][key] = False + tempProcess = temp.pop(key, None) + if key in temp["flags"]: + temp["flags"][key] = False PROCESS_STATUS.value = temp + if tempProcess['thread_id'] in running_threads: + del running_threads[tempProcess['thread_id']] if not failed: completed_list = COMPLETED_PROCESS.value if len(completed_list) == completed_max : completed_list.pop(0) completed_list.append(completed) COMPLETED_PROCESS.value = completed_list + else: + completed['date'] = 'Failed' + completed_list = COMPLETED_PROCESS.value + completed_list.append(completed) + COMPLETED_PROCESS.value = completed_list else: status = {"progress": progress, "total": total} with thread_lock: status_dict = PROCESS_STATUS.value - status_dict[key] = status - status_dict["flags"][key] = True + + if key not in status_dict: + status['thread_id'] = thread_id + status_dict[key] = status + else: + status_dict[key].update(status) + + if key in status_dict["flags"]: + status_dict["flags"][key] = True PROCESS_STATUS.value = status_dict except Exception as e: print(e) diff --git a/gamechangerml/api/utils/threaddriver.py b/gamechangerml/api/utils/threaddriver.py index f6b2e001..e2103ee7 100644 --- a/gamechangerml/api/utils/threaddriver.py +++ b/gamechangerml/api/utils/threaddriver.py @@ -1,6 +1,9 @@ import threading import json +import sys from gamechangerml.api.utils.logger import logger +from gamechangerml.api.utils import processmanager + # A class that takes in a function and a dictionary of arguments. # The keys in args have to match the parameters in the function. class MlThread(threading.Thread): @@ -8,13 +11,31 @@ def __init__(self, function, args = {}): super(MlThread, self).__init__() self.function = function self.args = args + self.killed = False + def run(self): try: + sys.settrace(self.globaltrace) self.function(**self.args) except Exception as e: logger.error(e) logger.info("Thread errored out attempting " + self.function.__name__ + " with parameters: " + json.dumps(self.args)) + def globaltrace(self, frame, why, arg): + if why == 'call': + return self.localtrace + else: + return None + + def localtrace(self, frame, why, arg): + if self.killed: + if why == 'line': + raise SystemExit() + return self.localtrace + + def kill(self): + logger.info(f'killing {self.function}') + self.killed = True # Pass in a function and args which is an array of dicts # A way to load mulitple jobs and run them on threads.