diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/funman/FunmanController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/funman/FunmanController.java index 367335ee9b..4fa9efa51b 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/funman/FunmanController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/funman/FunmanController.java @@ -68,6 +68,7 @@ public ResponseEntity createValidationRequest(@RequestBody final Jso try { final TaskRequest taskRequest = new TaskRequest(); + taskRequest.setTimeoutMinutes(30); taskRequest.setType(TaskType.FUNMAN); taskRequest.setScript(ValidateModelConfigHandler.NAME); taskRequest.setUserId(currentUserService.get().getId()); @@ -80,7 +81,7 @@ public ResponseEntity createValidationRequest(@RequestBody final Jso sim.setExecutionPayload(objectMapper.convertValue(input, JsonNode.class)); // Create new simulatin object to proxy the funman validation process - Simulation newSimulation = simulationService.createAsset(sim); + final Simulation newSimulation = simulationService.createAsset(sim); final ValidateModelConfigHandler.Properties props = new ValidateModelConfigHandler.Properties(); props.setSimulationId(newSimulation.getId()); diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/search/SearchByAssetTypeController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/search/SearchByAssetTypeController.java index 85d9241bac..1380e6ac98 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/search/SearchByAssetTypeController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/search/SearchByAssetTypeController.java @@ -48,7 +48,7 @@ public class SearchByAssetTypeController { private final ElasticsearchConfiguration esConfig; private final CurrentUserService currentUserService; - private static final long REQUEST_TIMEOUT_SECONDS = 30; + private static final int REQUEST_TIMEOUT_MINUTES = 1; private static final String EMBEDDING_MODEL = "text-embedding-ada-002"; private static final List EXCLUDE_FIELDS = List.of("embeddings", "text", "topics"); @@ -120,12 +120,13 @@ public ResponseEntity> searchByAssetType( embeddingRequest.setEmbeddingModel(EMBEDDING_MODEL); final TaskRequest req = new TaskRequest(); + req.setTimeoutMinutes(REQUEST_TIMEOUT_MINUTES); req.setType(TaskType.GOLLM); req.setInput(embeddingRequest); req.setScript("gollm_task:embedding"); req.setUserId(currentUserService.get().getId()); - final TaskResponse resp = taskService.runTaskSync(req, REQUEST_TIMEOUT_SECONDS); + final TaskResponse resp = taskService.runTaskSync(req); if (resp.getStatus() != TaskStatus.SUCCESS) { throw new ResponseStatusException( diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/TaskRequest.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/TaskRequest.java index 3104e906be..27cbe8856a 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/TaskRequest.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/task/TaskRequest.java @@ -41,7 +41,7 @@ public String toString() { protected TaskType type; protected String script; protected byte[] input; - protected int timeoutMinutes = 30; + protected int timeoutMinutes = 5; protected String userId; protected UUID projectId; diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java index 5570e7a6e9..ed5b0ce653 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskService.java @@ -215,6 +215,7 @@ void init() { if (isRunningLocalProfile()) { // sanity check for local development to clear the caches rLock.lock(); + responseCache.clear(); taskIdCache.clear(); rLock.unlock(); @@ -637,7 +638,7 @@ public TaskFuture runTaskAsync(final TaskRequest r) throws JsonProcessingExcepti } } - public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds) + public TaskResponse runTaskSync(final TaskRequest req) throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException { // send the request @@ -646,7 +647,7 @@ public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds try { // wait for the response log.info("Waiting for response for task id: {}", future.getId()); - final TaskResponse resp = future.get(timeoutSeconds, TimeUnit.SECONDS); + final TaskResponse resp = future.get(req.getTimeoutMinutes(), TimeUnit.MINUTES); if (resp.getStatus() == TaskStatus.CANCELLED) { throw new InterruptedException("Task was cancelled"); } @@ -677,18 +678,11 @@ public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds log.warn("Failed to cancel task: {}", future.getId(), ee); } - throw new TimeoutException( - "Task " + future.getId().toString() + " did not complete within " + timeoutSeconds + " seconds"); + throw new TimeoutException("Task " + future.getId().toString() + " did not complete within " + + req.getTimeoutMinutes() + " minutes"); } } - public TaskResponse runTaskSync(final TaskRequest req) - throws JsonProcessingException, TimeoutException, ExecutionException, InterruptedException { - - final int DEFAULT_TIMEOUT_SECONDS = 60 * 5; // 5 minutes - return runTaskSync(req, DEFAULT_TIMEOUT_SECONDS); - } - public TaskResponse runTask(final TaskMode mode, final TaskRequest req) throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException { if (mode == TaskMode.SYNC) { diff --git a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java index 5c2d6ea6ff..13e8171681 100644 --- a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java +++ b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java @@ -113,7 +113,7 @@ public void testItCanSendGoLLMModelCardRequest() throws Exception { req.setScript("gollm_task:model_card"); req.setInput(content.getBytes()); - final TaskResponse resp = taskService.runTaskSync(req, 300); + final TaskResponse resp = taskService.runTaskSync(req); log.info(new String(resp.getOutput())); } @@ -268,7 +268,7 @@ public void testItCanSendAmrToMmtRequest() throws Exception { public void testItCanCacheSuccess() throws Exception { final int TIMEOUT_SECONDS = 20; - final byte[] input = "{\"input\":\"This is my input string\"}".getBytes(); + final byte[] input = "{\"input\":\"This is my input string\",\"include_progress\":true}".getBytes(); final TaskRequest req = new TaskRequest(); req.setType(TaskType.GOLLM); @@ -293,7 +293,7 @@ public void testItCanCacheSuccess() throws Exception { public void testItDoesNotCacheFailure() throws Exception { final int TIMEOUT_SECONDS = 20; - final byte[] input = "{\"input\":\"This is my input string\", \"should_fail\": true}".getBytes(); + final byte[] input = "{\"input\":\"This is my input string\",\"should_fail\": true}".getBytes(); final TaskRequest req = new TaskRequest(); req.setType(TaskType.GOLLM); @@ -318,7 +318,7 @@ public void testItDoesNotCacheFailure() throws Exception { public void testItDoesNotCacheFailureButCacheSuccessAfter() throws Exception { final int TIMEOUT_SECONDS = 20; - final byte[] input = "{\"input\":\"This is my input string\"}".getBytes(); + final byte[] input = "{\"input\":\"This is my input string\"},\"include_progress\":true".getBytes(); final TaskRequest req = new TaskRequest(); req.setType(TaskType.GOLLM); @@ -350,15 +350,15 @@ public void testItDoesNotCacheFailureButCacheSuccessAfter() throws Exception { @WithUserDetails(MockUser.URSULA) public void testItCanCacheWithConcurrency() throws Exception { - final int NUM_REQUESTS = 1024; - final int NUM_UNIQUE_REQUESTS = 32; + final int NUM_REQUESTS = 1024 * 16; + final int NUM_UNIQUE_REQUESTS = 32 * 4; final int NUM_THREADS = 24; - final int TIMEOUT_SECONDS = 20; + final int TIMEOUT_MINUTES = 1; final List reqInput = new ArrayList<>(); for (int i = 0; i < NUM_UNIQUE_REQUESTS; i++) { // success tasks - reqInput.add(("{\"input\":\"" + generateRandomString(1024) + "\"}").getBytes()); + reqInput.add(("{\"input\":\"" + generateRandomString(1024) + "\"},\"include_progress\":true").getBytes()); } for (int i = 0; i < NUM_UNIQUE_REQUESTS; i++) { // failure tasks @@ -376,11 +376,12 @@ public void testItCanCacheWithConcurrency() throws Exception { final Future future = executor.submit(() -> { try { final TaskRequest req = new TaskRequest(); + req.setTimeoutMinutes(TIMEOUT_MINUTES); req.setType(TaskType.GOLLM); req.setScript("/echo.py"); req.setInput(reqInput.get(rand.nextInt(NUM_UNIQUE_REQUESTS * 2))); - final TaskResponse resp = taskService.runTaskSync(req, TIMEOUT_SECONDS); + final TaskResponse resp = taskService.runTaskSync(req); successTaskIds.add(resp.getId()); } catch (final RuntimeException e) { // expected for purposely failed tasks @@ -394,7 +395,7 @@ public void testItCanCacheWithConcurrency() throws Exception { // wait for all the responses to be send for (final Future future : futures) { - future.get(TIMEOUT_SECONDS * 2, TimeUnit.SECONDS); + future.get(TIMEOUT_MINUTES * 2, TimeUnit.MINUTES); } for (final UUID taskId : successTaskIds) { diff --git a/packages/taskrunner/src/main/java/software/uncharted/terarium/taskrunner/models/task/TaskRequest.java b/packages/taskrunner/src/main/java/software/uncharted/terarium/taskrunner/models/task/TaskRequest.java index 1c3c4d47b5..65e3ea6665 100644 --- a/packages/taskrunner/src/main/java/software/uncharted/terarium/taskrunner/models/task/TaskRequest.java +++ b/packages/taskrunner/src/main/java/software/uncharted/terarium/taskrunner/models/task/TaskRequest.java @@ -17,7 +17,7 @@ public class TaskRequest implements Serializable { private UUID id; private String script; private byte[] input; - private int timeoutMinutes = 30; + private int timeoutMinutes = 5; private Object additionalProperties; protected String userId; protected UUID projectId; diff --git a/packages/taskrunner/src/main/resources/application.properties b/packages/taskrunner/src/main/resources/application.properties index bd961431f0..f7aa6a5127 100644 --- a/packages/taskrunner/src/main/resources/application.properties +++ b/packages/taskrunner/src/main/resources/application.properties @@ -17,5 +17,5 @@ spring.rabbitmq.password=${terarium.mq-password} terarium.taskrunner.request-queue=terarium-request-queue terarium.taskrunner.response-exchange=terarium-response-exchange terarium.taskrunner.cancellation-exchange=terarium-cancellation-exchange -terarium.taskrunner.request-concurrency=16 +terarium.taskrunner.request-concurrency=32 terarium.taskrunner.request-type=terarium diff --git a/packages/taskrunner/src/test/resources/echo.py b/packages/taskrunner/src/test/resources/echo.py index b00973334a..64e8b97c5b 100755 --- a/packages/taskrunner/src/test/resources/echo.py +++ b/packages/taskrunner/src/test/resources/echo.py @@ -57,7 +57,7 @@ def write_progress_with_timeout(progress_pipe: str, output: dict, timeout_second return future.result(timeout=timeout_seconds) except concurrent.futures.TimeoutError: print('Writing to progress pipe {} timed out'.format(progress_pipe), flush=True) - raise TimeoutError('Writing to output pipe timed out') + raise TimeoutError('Writing to progress pipe timed out') def finish_progress_with_timeout(progress_pipe: str, timeout_seconds: int): with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: @@ -66,7 +66,7 @@ def finish_progress_with_timeout(progress_pipe: str, timeout_seconds: int): return future.result(timeout=timeout_seconds) except concurrent.futures.TimeoutError: print('Writing to progress pipe {} timed out'.format(progress_pipe), flush=True) - raise TimeoutError('Writing to output pipe timed out') + raise TimeoutError('Writing to progress pipe timed out') def signal_handler(sig, frame): print('Process cancelled, cleanup logic goes here', flush=True) diff --git a/packages/taskrunner/taskrunner.py b/packages/taskrunner/taskrunner.py index 9c16460974..831676d255 100644 --- a/packages/taskrunner/taskrunner.py +++ b/packages/taskrunner/taskrunner.py @@ -131,7 +131,7 @@ def write_progress(progress_pipe: str, progress: str): return future.result(timeout=timeout_seconds) except concurrent.futures.TimeoutError: print('Writing to progress pipe {} timed out'.format(self.progress_pipe), flush=True) - raise TimeoutError('Writing to output pipe timed out') + raise TimeoutError('Writing to progress pipe timed out') def write_progress_dict_with_timeout(self, progress: dict, timeout_seconds: int): return self.write_progress_str_with_timeout(json.dumps(progress), timeout_seconds)