Skip to content

Commit

Permalink
Unify task timeouts, lower default from 30min to 5min. (#3660)
Browse files Browse the repository at this point in the history
Co-authored-by: kbirk <[email protected]>
  • Loading branch information
kbirk and kbirk authored May 17, 2024
1 parent 7b32a86 commit 5923ad7
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public ResponseEntity<Simulation> createValidationRequest(@RequestBody final Jso

try {
final TaskRequest taskRequest = new TaskRequest();
taskRequest.setTimeoutMinutes(30);
taskRequest.setType(TaskType.FUNMAN);
taskRequest.setScript(ValidateModelConfigHandler.NAME);
taskRequest.setUserId(currentUserService.get().getId());
Expand All @@ -80,7 +81,7 @@ public ResponseEntity<Simulation> createValidationRequest(@RequestBody final Jso
sim.setExecutionPayload(objectMapper.convertValue(input, JsonNode.class));

// Create new simulatin object to proxy the funman validation process
Simulation newSimulation = simulationService.createAsset(sim);
final Simulation newSimulation = simulationService.createAsset(sim);

final ValidateModelConfigHandler.Properties props = new ValidateModelConfigHandler.Properties();
props.setSimulationId(newSimulation.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class SearchByAssetTypeController {
private final ElasticsearchConfiguration esConfig;
private final CurrentUserService currentUserService;

private static final long REQUEST_TIMEOUT_SECONDS = 30;
private static final int REQUEST_TIMEOUT_MINUTES = 1;
private static final String EMBEDDING_MODEL = "text-embedding-ada-002";

private static final List<String> EXCLUDE_FIELDS = List.of("embeddings", "text", "topics");
Expand Down Expand Up @@ -120,12 +120,13 @@ public ResponseEntity<List<JsonNode>> searchByAssetType(
embeddingRequest.setEmbeddingModel(EMBEDDING_MODEL);

final TaskRequest req = new TaskRequest();
req.setTimeoutMinutes(REQUEST_TIMEOUT_MINUTES);
req.setType(TaskType.GOLLM);
req.setInput(embeddingRequest);
req.setScript("gollm_task:embedding");
req.setUserId(currentUserService.get().getId());

final TaskResponse resp = taskService.runTaskSync(req, REQUEST_TIMEOUT_SECONDS);
final TaskResponse resp = taskService.runTaskSync(req);

if (resp.getStatus() != TaskStatus.SUCCESS) {
throw new ResponseStatusException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public String toString() {
protected TaskType type;
protected String script;
protected byte[] input;
protected int timeoutMinutes = 30;
protected int timeoutMinutes = 5;
protected String userId;
protected UUID projectId;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ void init() {
if (isRunningLocalProfile()) {
// sanity check for local development to clear the caches
rLock.lock();

responseCache.clear();
taskIdCache.clear();
rLock.unlock();
Expand Down Expand Up @@ -637,7 +638,7 @@ public TaskFuture runTaskAsync(final TaskRequest r) throws JsonProcessingExcepti
}
}

public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds)
public TaskResponse runTaskSync(final TaskRequest req)
throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException {

// send the request
Expand All @@ -646,7 +647,7 @@ public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds
try {
// wait for the response
log.info("Waiting for response for task id: {}", future.getId());
final TaskResponse resp = future.get(timeoutSeconds, TimeUnit.SECONDS);
final TaskResponse resp = future.get(req.getTimeoutMinutes(), TimeUnit.MINUTES);
if (resp.getStatus() == TaskStatus.CANCELLED) {
throw new InterruptedException("Task was cancelled");
}
Expand Down Expand Up @@ -677,18 +678,11 @@ public TaskResponse runTaskSync(final TaskRequest req, final long timeoutSeconds
log.warn("Failed to cancel task: {}", future.getId(), ee);
}

throw new TimeoutException(
"Task " + future.getId().toString() + " did not complete within " + timeoutSeconds + " seconds");
throw new TimeoutException("Task " + future.getId().toString() + " did not complete within "
+ req.getTimeoutMinutes() + " minutes");
}
}

public TaskResponse runTaskSync(final TaskRequest req)
throws JsonProcessingException, TimeoutException, ExecutionException, InterruptedException {

final int DEFAULT_TIMEOUT_SECONDS = 60 * 5; // 5 minutes
return runTaskSync(req, DEFAULT_TIMEOUT_SECONDS);
}

public TaskResponse runTask(final TaskMode mode, final TaskRequest req)
throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException {
if (mode == TaskMode.SYNC) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ public void testItCanSendGoLLMModelCardRequest() throws Exception {
req.setScript("gollm_task:model_card");
req.setInput(content.getBytes());

final TaskResponse resp = taskService.runTaskSync(req, 300);
final TaskResponse resp = taskService.runTaskSync(req);

log.info(new String(resp.getOutput()));
}
Expand Down Expand Up @@ -268,7 +268,7 @@ public void testItCanSendAmrToMmtRequest() throws Exception {
public void testItCanCacheSuccess() throws Exception {
final int TIMEOUT_SECONDS = 20;

final byte[] input = "{\"input\":\"This is my input string\"}".getBytes();
final byte[] input = "{\"input\":\"This is my input string\",\"include_progress\":true}".getBytes();

final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
Expand All @@ -293,7 +293,7 @@ public void testItCanCacheSuccess() throws Exception {
public void testItDoesNotCacheFailure() throws Exception {
final int TIMEOUT_SECONDS = 20;

final byte[] input = "{\"input\":\"This is my input string\", \"should_fail\": true}".getBytes();
final byte[] input = "{\"input\":\"This is my input string\",\"should_fail\": true}".getBytes();

final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
Expand All @@ -318,7 +318,7 @@ public void testItDoesNotCacheFailure() throws Exception {
public void testItDoesNotCacheFailureButCacheSuccessAfter() throws Exception {
final int TIMEOUT_SECONDS = 20;

final byte[] input = "{\"input\":\"This is my input string\"}".getBytes();
final byte[] input = "{\"input\":\"This is my input string\"},\"include_progress\":true".getBytes();

final TaskRequest req = new TaskRequest();
req.setType(TaskType.GOLLM);
Expand Down Expand Up @@ -350,15 +350,15 @@ public void testItDoesNotCacheFailureButCacheSuccessAfter() throws Exception {
@WithUserDetails(MockUser.URSULA)
public void testItCanCacheWithConcurrency() throws Exception {

final int NUM_REQUESTS = 1024;
final int NUM_UNIQUE_REQUESTS = 32;
final int NUM_REQUESTS = 1024 * 16;
final int NUM_UNIQUE_REQUESTS = 32 * 4;
final int NUM_THREADS = 24;
final int TIMEOUT_SECONDS = 20;
final int TIMEOUT_MINUTES = 1;

final List<byte[]> reqInput = new ArrayList<>();
for (int i = 0; i < NUM_UNIQUE_REQUESTS; i++) {
// success tasks
reqInput.add(("{\"input\":\"" + generateRandomString(1024) + "\"}").getBytes());
reqInput.add(("{\"input\":\"" + generateRandomString(1024) + "\"},\"include_progress\":true").getBytes());
}
for (int i = 0; i < NUM_UNIQUE_REQUESTS; i++) {
// failure tasks
Expand All @@ -376,11 +376,12 @@ public void testItCanCacheWithConcurrency() throws Exception {
final Future<?> future = executor.submit(() -> {
try {
final TaskRequest req = new TaskRequest();
req.setTimeoutMinutes(TIMEOUT_MINUTES);
req.setType(TaskType.GOLLM);
req.setScript("/echo.py");
req.setInput(reqInput.get(rand.nextInt(NUM_UNIQUE_REQUESTS * 2)));

final TaskResponse resp = taskService.runTaskSync(req, TIMEOUT_SECONDS);
final TaskResponse resp = taskService.runTaskSync(req);
successTaskIds.add(resp.getId());
} catch (final RuntimeException e) {
// expected for purposely failed tasks
Expand All @@ -394,7 +395,7 @@ public void testItCanCacheWithConcurrency() throws Exception {

// wait for all the responses to be send
for (final Future<?> future : futures) {
future.get(TIMEOUT_SECONDS * 2, TimeUnit.SECONDS);
future.get(TIMEOUT_MINUTES * 2, TimeUnit.MINUTES);
}

for (final UUID taskId : successTaskIds) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public class TaskRequest implements Serializable {
private UUID id;
private String script;
private byte[] input;
private int timeoutMinutes = 30;
private int timeoutMinutes = 5;
private Object additionalProperties;
protected String userId;
protected UUID projectId;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ spring.rabbitmq.password=${terarium.mq-password}
terarium.taskrunner.request-queue=terarium-request-queue
terarium.taskrunner.response-exchange=terarium-response-exchange
terarium.taskrunner.cancellation-exchange=terarium-cancellation-exchange
terarium.taskrunner.request-concurrency=16
terarium.taskrunner.request-concurrency=32
terarium.taskrunner.request-type=terarium
4 changes: 2 additions & 2 deletions packages/taskrunner/src/test/resources/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def write_progress_with_timeout(progress_pipe: str, output: dict, timeout_second
return future.result(timeout=timeout_seconds)
except concurrent.futures.TimeoutError:
print('Writing to progress pipe {} timed out'.format(progress_pipe), flush=True)
raise TimeoutError('Writing to output pipe timed out')
raise TimeoutError('Writing to progress pipe timed out')

def finish_progress_with_timeout(progress_pipe: str, timeout_seconds: int):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
Expand All @@ -66,7 +66,7 @@ def finish_progress_with_timeout(progress_pipe: str, timeout_seconds: int):
return future.result(timeout=timeout_seconds)
except concurrent.futures.TimeoutError:
print('Writing to progress pipe {} timed out'.format(progress_pipe), flush=True)
raise TimeoutError('Writing to output pipe timed out')
raise TimeoutError('Writing to progress pipe timed out')

def signal_handler(sig, frame):
print('Process cancelled, cleanup logic goes here', flush=True)
Expand Down
2 changes: 1 addition & 1 deletion packages/taskrunner/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def write_progress(progress_pipe: str, progress: str):
return future.result(timeout=timeout_seconds)
except concurrent.futures.TimeoutError:
print('Writing to progress pipe {} timed out'.format(self.progress_pipe), flush=True)
raise TimeoutError('Writing to output pipe timed out')
raise TimeoutError('Writing to progress pipe timed out')

def write_progress_dict_with_timeout(self, progress: dict, timeout_seconds: int):
return self.write_progress_str_with_timeout(json.dumps(progress), timeout_seconds)
Expand Down

0 comments on commit 5923ad7

Please sign in to comment.