From e5cd64a0e847a49bcaecf52566b1886a234d9949 Mon Sep 17 00:00:00 2001 From: David Gauldie Date: Tue, 15 Oct 2024 10:22:51 -0400 Subject: [PATCH] 5070 create new gollm endpoint that takes in a list of latex equations a returns cleaned up ones (#5145) --- packages/gollm/entities.py | 4 + .../gollm_openai/prompts/equations_cleanup.py | 24 ++++++ .../gollm_openai/prompts/latex_style_guide.py | 27 +++---- packages/gollm/gollm_openai/tool_utils.py | 46 ++++++++++- packages/gollm/setup.py | 3 +- packages/gollm/tasks/equations_cleanup.py | 40 ++++++++++ .../knowledge/KnowledgeController.java | 76 ++++++++++++++++++- .../EquationsCleanupResponseHandler.java | 40 ++++++++++ 8 files changed, 240 insertions(+), 20 deletions(-) create mode 100644 packages/gollm/gollm_openai/prompts/equations_cleanup.py create mode 100644 packages/gollm/tasks/equations_cleanup.py create mode 100644 packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsCleanupResponseHandler.java diff --git a/packages/gollm/entities.py b/packages/gollm/entities.py index 9333fb7e3b..23ec51ac87 100644 --- a/packages/gollm/entities.py +++ b/packages/gollm/entities.py @@ -29,6 +29,10 @@ class ModelCompareModel(BaseModel): amrs: List[str] # expects AMRs to be a stringified JSON object +class EquationsCleanup(BaseModel): + equations: List[str] + + class EquationsFromImage(BaseModel): image: str # expects a base64 encoded image diff --git a/packages/gollm/gollm_openai/prompts/equations_cleanup.py b/packages/gollm/gollm_openai/prompts/equations_cleanup.py new file mode 100644 index 0000000000..4e92957138 --- /dev/null +++ b/packages/gollm/gollm_openai/prompts/equations_cleanup.py @@ -0,0 +1,24 @@ +EQUATIONS_CLEANUP_PROMPT = """ +You are a helpful agent designed to reformat latex equations based on a supplied style guide. +The style guide will contain a list of rules that you should follow when reformatting the equations. You should reformat the equations to match the style guide as closely as possible. + +Do not respond in full sentences; only create a JSON object that satisfies the JSON schema specified in the response format. + +Use the following style guide to ensure that your LaTeX equations are correctly formatted: + +--- STYLE GUIDE START --- + +{style_guide} + +--- STYLE GUIDE END --- + +The equations that you need to reformat are as follows: + +--- EQUATIONS START --- + +{equations} + +--- EQUATIONS END --- + +Answer: +""" diff --git a/packages/gollm/gollm_openai/prompts/latex_style_guide.py b/packages/gollm/gollm_openai/prompts/latex_style_guide.py index 7cdc39ed4f..434f8e379b 100644 --- a/packages/gollm/gollm_openai/prompts/latex_style_guide.py +++ b/packages/gollm/gollm_openai/prompts/latex_style_guide.py @@ -1,19 +1,14 @@ LATEXT_STYLE_GUIDE = """ -1) Derivatives must be written in Leibniz notation (for example, \\frac{d X}{d t}). Equations that are written in other notations, like Newton or Lagrange, should be converted. +1) Derivatives must be written in Leibniz notation (for example, '\\frac{d X}{d t}'). Equations that are written in other notations, like Newton or Lagrange, should be converted. 2) First-order derivative must be on the left of the equal sign -3) Use whitespace to indicate multiplication - a) "*" is optional but probably should be avoided -4) "(t)" is optional and probably should be avoided -5) Avoid superscripts and LaTeX superscripts "^", particularly to denote sub-variables -6) Subscripts using LaTeX "_" are permitted - a) Ensure that all characters used in the subscript are surrounded by a pair of curly brackets "{...}" -7) Avoid mathematical constants like pi or Euler's number - a) Replace them as floats with 3 decimal places of precision -8) Avoid parentheses -9) Avoid capital sigma and pi notations for summation and product -10) Avoid non-ASCII characters when possible -11) Avoid using homoglyphs -12) Avoid words or multi-character names for variables and names - a) Use camel case to express multi-word or multi-character names -13) Do not use \\cdot for multiplication. Use whitespace instead. +3) the use of '(t)' should be avoided +4) Avoid superscripts and LaTeX superscripts '^', particularly to denote sub-variables +5) Subscripts using LaTeX '_' are permitted. However, Ensure that all characters used in the subscript are surrounded by a pair of curly brackets '{...}' +6) Avoid parentheses +7) Avoid capital sigma and pi notations for summation and product +8) Avoid non-ASCII characters when possible +9) Avoid using homoglyphs +10) Avoid words or multi-character names for variables and names. Use camel case to express multi-word or multi-character names +11) Do not use '\\cdot' or '*' to indicate multiplication. Use whitespace instead. +12) Avoid using notation for mathematical constants like 'e' and 'pi'. Use their actual values up to 3 decimal places instead. """ diff --git a/packages/gollm/gollm_openai/tool_utils.py b/packages/gollm/gollm_openai/tool_utils.py index ab9b5c3ef2..7d73b5e731 100644 --- a/packages/gollm/gollm_openai/tool_utils.py +++ b/packages/gollm/gollm_openai/tool_utils.py @@ -13,6 +13,7 @@ CONFIGURE_FROM_DATASET_MATRIX_PROMPT ) from gollm_openai.prompts.config_from_document import CONFIGURE_FROM_DOCUMENT_PROMPT +from gollm_openai.prompts.equations_cleanup import EQUATIONS_CLEANUP_PROMPT from gollm_openai.prompts.equations_from_image import EQUATIONS_FROM_IMAGE_PROMPT from gollm_openai.prompts.general_instruction import GENERAL_INSTRUCTION_PROMPT from gollm_openai.prompts.interventions_from_document import INTERVENTIONS_FROM_DOCUMENT_PROMPT @@ -59,6 +60,47 @@ def get_image_format_string(image_format: str) -> str: return format_strings.get(image_format.lower()) +def equations_cleanup(equations: List[str]) -> dict: + print("Reformatting equations...") + + print("Uploading and validating equations schema...") + config_path = os.path.join(SCRIPT_DIR, 'schemas', 'equations.json') + with open(config_path, 'r') as config_file: + response_schema = json.load(config_file) + validate_schema(response_schema) + + print("Building prompt to reformat equations...") + prompt = EQUATIONS_CLEANUP_PROMPT.format( + style_guide=LATEXT_STYLE_GUIDE, + equations="\n".join(equations) + ) + + client = OpenAI() + output = client.chat.completions.create( + model="gpt-4o-mini", + top_p=1, + frequency_penalty=0, + presence_penalty=0, + temperature=0, + seed=123, + max_tokens=1024, + response_format={ + "type": "json_schema", + "json_schema": { + "name": "equations", + "schema": response_schema + } + }, + messages=[ + {"role": "user", "content": prompt}, + ] + ) + print("Received response from OpenAI API. Formatting response to work with HMI...") + output_json = json.loads(output.choices[0].message.content) + + return output_json + + def equations_from_image(image: str) -> dict: print("Translating equations from image...") @@ -340,7 +382,7 @@ def embedding_chain(text: str) -> List: return output.data[0].embedding -def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> str: +def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> dict: print("Extracting datasets...") dataset_text = os.linesep.join(dataset) @@ -392,7 +434,7 @@ def model_config_from_dataset(amr: str, dataset: List[str], matrix: str) -> str: return model_config_adapter(output_json) -def compare_models(amrs: List[str]) -> str: +def compare_models(amrs: List[str]) -> dict: print("Comparing models...") print("Building prompt to compare models...") diff --git a/packages/gollm/setup.py b/packages/gollm/setup.py index dafd1fcbe2..7d9cf01af2 100644 --- a/packages/gollm/setup.py +++ b/packages/gollm/setup.py @@ -22,11 +22,12 @@ "gollm:configure_model_from_document=tasks.configure_model_from_document:main", "gollm:embedding=tasks.embedding:main", "gollm:enrich_amr=tasks.enrich_amr:main", + "gollm:equations_cleanup=tasks.equations_cleanup:main", "gollm:equations_from_image=tasks.equations_from_image:main", "gollm:generate_response=tasks.generate_response:main", "gollm:generate_summary=tasks.generate_summary:main", "gollm:interventions_from_document=tasks.interventions_from_document:main", - "gollm:model_card=tasks.model_card:main", + "gollm:model_card=tasks.model_card:main" ], }, python_requires=">=3.11", diff --git a/packages/gollm/tasks/equations_cleanup.py b/packages/gollm/tasks/equations_cleanup.py new file mode 100644 index 0000000000..1bce1a1775 --- /dev/null +++ b/packages/gollm/tasks/equations_cleanup.py @@ -0,0 +1,40 @@ +import sys +from entities import EquationsCleanup +from gollm_openai.tool_utils import equations_cleanup + +from taskrunner import TaskRunnerInterface + + +def cleanup(): + pass + + +def main(): + exitCode = 0 + try: + taskrunner = TaskRunnerInterface(description="Reformat equations based on style guide") + taskrunner.on_cancellation(cleanup) + + input_dict = taskrunner.read_input_dict_with_timeout() + + taskrunner.log("Creating EquationsCleanup from input") + input_model = EquationsCleanup(**input_dict) + + taskrunner.log("Sending request to OpenAI API") + response = equations_cleanup(equations=input_model.equations) + taskrunner.log("Received response from OpenAI API") + + taskrunner.write_output_dict_with_timeout({"response": response}) + + except Exception as e: + sys.stderr.write(f"Error: {str(e)}\n") + sys.stderr.flush() + exitCode = 1 + + taskrunner.log("Shutting down") + taskrunner.shutdown() + sys.exit(exitCode) + + +if __name__ == "__main__": + main() diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/knowledge/KnowledgeController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/knowledge/KnowledgeController.java index f9434b6e36..5c9f46fae7 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/knowledge/KnowledgeController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/knowledge/KnowledgeController.java @@ -9,6 +9,7 @@ import io.swagger.v3.oas.annotations.media.Content; import io.swagger.v3.oas.annotations.responses.ApiResponse; import io.swagger.v3.oas.annotations.responses.ApiResponses; +import jakarta.annotation.PostConstruct; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; @@ -60,6 +61,7 @@ import software.uncharted.terarium.hmiserver.models.extractionservice.ExtractionResponse; import software.uncharted.terarium.hmiserver.models.task.TaskRequest; import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType; +import software.uncharted.terarium.hmiserver.models.task.TaskResponse; import software.uncharted.terarium.hmiserver.proxies.mit.MitProxy; import software.uncharted.terarium.hmiserver.proxies.skema.SkemaUnifiedProxy; import software.uncharted.terarium.hmiserver.security.Roles; @@ -73,6 +75,7 @@ import software.uncharted.terarium.hmiserver.service.data.ProvenanceSearchService; import software.uncharted.terarium.hmiserver.service.data.ProvenanceService; import software.uncharted.terarium.hmiserver.service.tasks.EnrichAmrResponseHandler; +import software.uncharted.terarium.hmiserver.service.tasks.EquationsCleanupResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.TaskService; import software.uncharted.terarium.hmiserver.service.tasks.TaskService.TaskMode; import software.uncharted.terarium.hmiserver.utils.ByteMultipartFile; @@ -106,11 +109,18 @@ public class KnowledgeController { final ProjectService projectService; final CurrentUserService currentUserService; + private final EquationsCleanupResponseHandler equationsCleanupResponseHandler; + final Messages messages; @Value("${openai-api-key:}") String OPENAI_API_KEY; + @PostConstruct + void init() { + taskService.addResponseHandler(equationsCleanupResponseHandler); + } + private void enrichModel( final UUID projectId, final UUID documentId, @@ -236,9 +246,47 @@ public ResponseEntity equationsToModel( } } + // Cleanup equations from the request + List equations = new ArrayList<>(); + if (req.get("equations") != null) { + for (final JsonNode equation : req.get("equations")) { + equations.add(equation.asText()); + } + } + TaskRequest cleanupReq = cleanupEquationsTaskRequest(projectId, equations); + TaskResponse cleanupResp = null; + try { + cleanupResp = taskService.runTask(TaskMode.SYNC, cleanupReq); + } catch (final JsonProcessingException e) { + log.warn("Unable to clean-up equations due to a JsonProcessingException. Reverting to original equations.", e); + } catch (final TimeoutException e) { + log.warn("Unable to clean-up equations due to a TimeoutException. Reverting to original equations.", e); + } catch (final InterruptedException e) { + log.warn("Unable to clean-up equations due to a InterruptedException. Reverting to original equations.", e); + } catch (final ExecutionException e) { + log.warn("Unable to clean-up equations due to a ExecutionException. Reverting to original equations.", e); + } + + // get the equations from the cleanup response, or use the original equations + JsonNode equationsReq = req.get("equations"); + if (cleanupResp != null && cleanupResp.getOutput() != null) { + try { + JsonNode output = mapper.readValue(cleanupResp.getOutput(), JsonNode.class); + if (output.get("response") != null && output.get("response").get("equations") != null) { + equationsReq = output.get("response").get("equations"); + } + } catch (IOException e) { + log.warn("Unable to retrive cleaned-up equations from GoLLM response. Reverting to original equations.", e); + } + } + + // Create a new request with the cleaned-up equations, so that we don't modify the original request. + JsonNode newReq = req.deepCopy(); + ((ObjectNode) newReq).set("equations", equationsReq); + // Get an AMR from Skema Unified Service try { - responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(req).getBody(); + responseAMR = skemaUnifiedProxy.consolidatedEquationsToAMR(newReq).getBody(); if (responseAMR == null) { log.warn("Skema Unified Service did not return a valid AMR based on the provided equations"); throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("skema.bad-equations")); @@ -840,4 +888,30 @@ private ResponseStatusException handleSkemaFeignException(final FeignException e ); return new ResponseStatusException(httpStatus, messages.get("generic.unknown")); } + + private TaskRequest cleanupEquationsTaskRequest(UUID projectId, List equations) { + final EquationsCleanupResponseHandler.Input input = new EquationsCleanupResponseHandler.Input(); + input.setEquations(equations); + + // Create the task + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.GOLLM); + req.setScript(EquationsCleanupResponseHandler.NAME); + req.setUserId(currentUserService.get().getId()); + + try { + req.setInput(mapper.writeValueAsBytes(input)); + } catch (final Exception e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + } + + req.setProjectId(projectId); + + final EquationsCleanupResponseHandler.Properties props = new EquationsCleanupResponseHandler.Properties(); + props.setProjectId(projectId); + req.setAdditionalProperties(props); + + return req; + } } diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsCleanupResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsCleanupResponseHandler.java new file mode 100644 index 0000000000..4b2c4607c7 --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/EquationsCleanupResponseHandler.java @@ -0,0 +1,40 @@ +package software.uncharted.terarium.hmiserver.service.tasks; + +import com.fasterxml.jackson.databind.JsonNode; +import java.util.List; +import java.util.UUID; +import lombok.Data; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@Component +@RequiredArgsConstructor +@Slf4j +public class EquationsCleanupResponseHandler extends TaskResponseHandler { + + public static final String NAME = "gollm:equations_cleanup"; + + @Override + public String getName() { + return NAME; + } + + @Data + public static class Input { + + List equations; + } + + @Data + public static class Properties { + + UUID projectId; + } + + @Data + public static class Response { + + JsonNode response; + } +}