Skip to content

Commit

Permalink
Table Extracton (#4908)
Browse files Browse the repository at this point in the history
Co-authored-by: kbirk <[email protected]>
  • Loading branch information
kbirk and kbirk authored Sep 26, 2024
1 parent fc76376 commit 82eebc5
Show file tree
Hide file tree
Showing 17 changed files with 382 additions and 30 deletions.
24 changes: 24 additions & 0 deletions containers/scripts/docker-compose-taskrunner.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,27 @@ services:
- ../../packages/text_extraction:/text_extraction_task
- ../../packages/taskrunner:/taskrunner
command: /text_extraction_task/dev.sh

table_extraction-taskrunner:
build:
context: ../..
dockerfile: ./packages/table_extraction/Dockerfile
target: table_extraction_taskrunner_builder
container_name: table_extraction-taskrunner
networks:
- terarium
environment:
TERARIUM_MQ_ADDRESSES: "amqp://rabbitmq:5672"
TERARIUM_MQ_PASSWORD: "terarium123"
TERARIUM_MQ_USERNAME: "terarium"
TERARIUM_TASKRUNNER_REQUEST_TYPE: "table_extraction"
ASKEM_DOC_AI_API_KEY: "${secret_openai_key}"
depends_on:
rabbitmq:
condition: service_healthy
extra_hosts:
- "${local_host_name}:host-gateway"
volumes:
- ../../packages/table_extraction:/table_extraction_task
- ../../packages/taskrunner:/taskrunner
command: /table_extraction_task/dev.sh
4 changes: 3 additions & 1 deletion packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,9 @@ export enum ClientEventType {
Notification = "NOTIFICATION",
SimulationNotification = "SIMULATION_NOTIFICATION",
SimulationPyciemss = "SIMULATION_PYCIEMSS",
TaskEnrichAmr = "TASK_ENRICH_AMR",
TaskExtractTextPdf = "TASK_EXTRACT_TEXT_PDF",
TaskExtractTablePdf = "TASK_EXTRACT_TABLE_PDF",
TaskExtractEquationPdf = "TASK_EXTRACT_EQUATION_PDF",
TaskFunmanValidation = "TASK_FUNMAN_VALIDATION",
TaskGollmCompareModel = "TASK_GOLLM_COMPARE_MODEL",
TaskGollmConfigureModelFromDataset = "TASK_GOLLM_CONFIGURE_MODEL_FROM_DATASET",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ public enum ClientEventType {
NOTIFICATION,
SIMULATION_NOTIFICATION,
SIMULATION_PYCIEMSS,
TASK_ENRICH_AMR,
TASK_EXTRACT_TEXT_PDF,
TASK_EXTRACT_TABLE_PDF,
TASK_EXTRACT_EQUATION_PDF,
TASK_FUNMAN_VALIDATION,
TASK_GOLLM_COMPARE_MODEL,
TASK_GOLLM_CONFIGURE_MODEL_FROM_DATASET,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ public static enum TaskType {
@JsonAlias("equation_extraction_gpu")
EQUATION_EXTRACTION_GPU("equation_extraction_gpu"),
@JsonAlias("text_extraction")
TEXT_EXTRACTION("text_extraction");
TEXT_EXTRACTION("text_extraction"),
@JsonAlias("table_extraction")
TABLE_EXTRACTION("table_extraction");

private final String value;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package software.uncharted.terarium.hmiserver.service;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -11,6 +12,7 @@
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URL;
import java.security.MessageDigest;
import java.util.ArrayList;
Expand Down Expand Up @@ -63,6 +65,7 @@
import software.uncharted.terarium.hmiserver.service.notification.NotificationGroupInstance;
import software.uncharted.terarium.hmiserver.service.notification.NotificationService;
import software.uncharted.terarium.hmiserver.service.tasks.ExtractEquationsResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.ExtractTablesResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.ExtractTextResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.utils.ByteMultipartFile;
Expand Down Expand Up @@ -173,6 +176,7 @@ static class ExtractPDFResponse {
List<ExtractionFile> files = new ArrayList<>();
List<DocumentExtraction> assets = new ArrayList<>();
List<JsonNode> equations = new ArrayList<>();
List<JsonNode> tables = new ArrayList<>();
ArrayNode variableAttributes;
JsonNode gollmCard;
boolean partialFailure = true;
Expand Down Expand Up @@ -204,6 +208,14 @@ public ExtractPDFResponse runExtractPDF(
userId
);

notificationInterface.sendMessage("Starting table extraction...");
log.info("Starting table extraction for document: {}", documentName);
final Future<TableExtraction> tableExtractionFuture = extractTablesFromPDF(
notificationInterface,
userId,
documentContents
);

// wait for text extraction
final TextExtraction textExtraction = textExtractionFuture.get();
notificationInterface.sendMessage("Text extraction complete!");
Expand All @@ -225,6 +237,18 @@ public ExtractPDFResponse runExtractPDF(
extractionResponse.partialFailure = true;
}

try {
// wait for table extraction
final TableExtraction tableExtraction = tableExtractionFuture.get();
notificationInterface.sendMessage("Table extraction complete!");
log.info("Table extraction complete for document: {}", documentName);
extractionResponse.tables = tableExtraction.tables;
} catch (final Exception e) {
notificationInterface.sendMessage("Table extraction failed, continuing");
log.error("Table extraction failed for document: {}", documentName, e);
extractionResponse.partialFailure = true;
}

// if there is text, run variable extraction
if (!extractionResponse.documentText.isEmpty()) {
// run variable extraction
Expand Down Expand Up @@ -308,6 +332,13 @@ public DocumentAsset applyExtractPDFResponse(
document.getMetadata().put("equations", objectMapper.valueToTree(extractionResponse.equations));
}

if (extractionResponse.tables != null) {
if (document.getMetadata() == null) {
document.setMetadata(new HashMap<>());
}
document.getMetadata().put("tables", objectMapper.valueToTree(extractionResponse.tables));
}

log.info("Added extraction to document: {}", documentId);

return documentService.updateAsset(document, projectId, hasWritePermission).orElseThrow();
Expand Down Expand Up @@ -747,7 +778,7 @@ public Future<EquationExtraction> extractEquationsFromPDF(

int responseCode = HttpURLConnection.HTTP_BAD_GATEWAY;
if (!EQUATION_EXTRACTION_GPU_ENDPOINT.isEmpty()) {
final URL url = new URL(String.format("%s/health", EQUATION_EXTRACTION_GPU_ENDPOINT));
final URL url = URI.create(String.format("%s/health", EQUATION_EXTRACTION_GPU_ENDPOINT)).toURL();
final HttpURLConnection connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
responseCode = connection.getResponseCode();
Expand Down Expand Up @@ -836,6 +867,55 @@ public Future<TextExtraction> extractTextFromPDF(
});
}

static class TableExtraction {

List<JsonNode> tables;
}

public Future<TableExtraction> extractTablesFromPDF(
final NotificationGroupInstance<Properties> notificationInterface,
final String userId,
final byte[] pdf
) throws JsonProcessingException, TimeoutException, InterruptedException, ExecutionException, IOException {
final int REQUEST_TIMEOUT_MINUTES = 5;

final TaskRequest req = new TaskRequest();
req.setTimeoutMinutes(REQUEST_TIMEOUT_MINUTES);
req.setInput(pdf);
req.setScript(ExtractTablesResponseHandler.NAME);
req.setUserId(userId);
req.setType(TaskType.TABLE_EXTRACTION);

return executor.submit(() -> {
final TaskResponse resp = taskService.runTaskSync(req);

final byte[] outputBytes = resp.getOutput();
final ExtractTablesResponseHandler.ResponseOutput output = objectMapper.readValue(
outputBytes,
ExtractTablesResponseHandler.ResponseOutput.class
);

// Collect keys
final List<String> keys = new ArrayList<>();
final Iterator<Map.Entry<String, JsonNode>> fieldsIterator = output.getResponse().fields();
while (fieldsIterator.hasNext()) {
final Map.Entry<String, JsonNode> field = fieldsIterator.next();
keys.add(field.getKey());
}

// Sort keys
Collections.sort(keys);

final TableExtraction extraction = new TableExtraction();

for (final String key : keys) {
extraction.tables.add(output.getResponse().get(key));
}

return extraction;
});
}

public Future<TextExtraction> extractTextFromPDFCosmos(
final NotificationGroupInstance<Properties> notificationInterface,
final String documentName,
Expand Down Expand Up @@ -919,7 +999,7 @@ public Future<TextExtraction> extractTextFromPDFCosmos(
final ObjectMapper objectMapper = new ObjectMapper();

final JsonNode rootNode = objectMapper.readTree(bytes);
if (rootNode instanceof ArrayNode arrayNode) {
if (rootNode instanceof final ArrayNode arrayNode) {
for (final JsonNode record : arrayNode) {
if (record.has("detect_cls") && record.get("detect_cls").asText().equals("Abstract")) {
abstractJsonNode = record;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import com.fasterxml.jackson.databind.JsonNode;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class ExtractTablesResponseHandler extends TaskResponseHandler {

public static final String NAME = "table_extraction_task:extract_tables";

@Override
public String getName() {
return NAME;
}

@Data
public static class ResponseOutput {

private JsonNode response;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,19 @@
@Slf4j
public class TaskNotificationEventTypes {

private static final Map<String, ClientEventType> clientEventTypes = Map.of(
ModelCardResponseHandler.NAME,
ClientEventType.TASK_GOLLM_MODEL_CARD,
ConfigureModelFromDocumentResponseHandler.NAME,
ClientEventType.TASK_GOLLM_CONFIGURE_MODEL_FROM_DOCUMENT,
ConfigureModelFromDatasetResponseHandler.NAME,
ClientEventType.TASK_GOLLM_CONFIGURE_MODEL_FROM_DATASET,
CompareModelsResponseHandler.NAME,
ClientEventType.TASK_GOLLM_COMPARE_MODEL,
GenerateSummaryHandler.NAME,
ClientEventType.TASK_GOLLM_GENERATE_SUMMARY,
ValidateModelConfigHandler.NAME,
ClientEventType.TASK_FUNMAN_VALIDATION,
EnrichAmrResponseHandler.NAME,
ClientEventType.TASK_GOLLM_ENRICH_AMR,
AMRToMMTResponseHandler.NAME,
ClientEventType.TASK_MIRA_AMR_TO_MMT,
GenerateModelLatexResponseHandler.NAME,
ClientEventType.TASK_MIRA_GENERATE_MODEL_LATEX,
EquationsFromImageResponseHandler.NAME,
ClientEventType.TASK_GOLLM_EQUATIONS_FROM_IMAGE
private static final Map<String, ClientEventType> clientEventTypes = Map.ofEntries(
Map.entry(ModelCardResponseHandler.NAME, ClientEventType.TASK_GOLLM_MODEL_CARD),
Map.entry(ConfigureModelFromDocumentResponseHandler.NAME, ClientEventType.TASK_GOLLM_CONFIGURE_MODEL_FROM_DOCUMENT),
Map.entry(ConfigureModelFromDatasetResponseHandler.NAME, ClientEventType.TASK_GOLLM_CONFIGURE_MODEL_FROM_DATASET),
Map.entry(CompareModelsResponseHandler.NAME, ClientEventType.TASK_GOLLM_COMPARE_MODEL),
Map.entry(GenerateSummaryHandler.NAME, ClientEventType.TASK_GOLLM_GENERATE_SUMMARY),
Map.entry(ValidateModelConfigHandler.NAME, ClientEventType.TASK_FUNMAN_VALIDATION),
Map.entry(EnrichAmrResponseHandler.NAME, ClientEventType.TASK_GOLLM_ENRICH_AMR),
Map.entry(AMRToMMTResponseHandler.NAME, ClientEventType.TASK_MIRA_AMR_TO_MMT),
Map.entry(ExtractEquationsResponseHandler.NAME, ClientEventType.TASK_EXTRACT_EQUATION_PDF),
Map.entry(ExtractTablesResponseHandler.NAME, ClientEventType.TASK_EXTRACT_TABLE_PDF),
Map.entry(ExtractTextResponseHandler.NAME, ClientEventType.TASK_EXTRACT_TEXT_PDF),
Map.entry(EquationsFromImageResponseHandler.NAME, ClientEventType.TASK_GOLLM_EQUATIONS_FROM_IMAGE)
);

public static ClientEventType getTypeFor(final String taskName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@ private void onTaskResponseAllInstanceReceive(final Message message) {
}

log.info("Received response status {} for task {}", resp.getStatus(), resp.getId());
if (resp.getOutput() != null) {
log.info("Received response output {} for task {}", new String(resp.getOutput()), resp.getId());
}

if (
resp.getStatus() == TaskStatus.SUCCESS ||
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.http.entity.ContentType;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.security.test.context.support.WithUserDetails;
Expand Down
40 changes: 40 additions & 0 deletions packages/table_extraction/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
HELP.md
.gradle
build/
!gradle/wrapper/gradle-wrapper.jar
!**/src/main/**/build/
!**/src/test/**/build/
*.egg-info

### STS ###
.apt_generated
.classpath
.factorypath
.project
.settings
.springBeans
.sts4-cache
bin/
!**/src/main/**/bin/
!**/src/test/**/bin/

### IntelliJ IDEA ###
.idea
*.iws
*.iml
*.ipr
out/
!**/src/main/**/out/
!**/src/test/**/out/

### NetBeans ###
/nbproject/private/
/nbbuild/
/dist/
/nbdist/
/.nb-gradle/

### VS Code ###
.vscode/
mira.egg-info
__pycache__
Loading

0 comments on commit 82eebc5

Please sign in to comment.