Skip to content

Commit

Permalink
Generate embeddings as part of model / document creation. (#4021)
Browse files Browse the repository at this point in the history
Co-authored-by: dvince <[email protected]>
  • Loading branch information
kbirk and dvince2 authored Jul 2, 2024
1 parent 5dcdf3e commit 3ef9512
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ private void uploadPDFFileToDocumentThenExtract(
try {
final byte[] fileAsBytes = DownloadService.getPDF("https://unpaywall.org/" + doi);

// if this service fails, return ok with errors
// if this service fails, return ok with errors.
if (fileAsBytes == null || fileAsBytes.length == 0) {
log.debug("Document has not data, empty bytes, exit early.");
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;
import software.uncharted.terarium.hmiserver.models.TerariumAssetEmbeddings;
import software.uncharted.terarium.hmiserver.models.dataservice.ResponseDeleted;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.model.Model;
Expand All @@ -54,6 +55,7 @@
import software.uncharted.terarium.hmiserver.service.data.ProjectAssetService;
import software.uncharted.terarium.hmiserver.service.data.ProjectService;
import software.uncharted.terarium.hmiserver.service.data.ProvenanceSearchService;
import software.uncharted.terarium.hmiserver.service.gollm.EmbeddingService;
import software.uncharted.terarium.hmiserver.utils.Messages;
import software.uncharted.terarium.hmiserver.utils.rebac.Schema;

Expand Down Expand Up @@ -86,6 +88,8 @@ public class ModelController {

final ModelConfigRepository modelConfigRepository;

final EmbeddingService embeddingService;

@GetMapping("/descriptions")
@Secured(Roles.USER)
@Operation(summary = "Gets all model descriptions")
Expand Down Expand Up @@ -342,8 +346,24 @@ ResponseEntity<Model> updateModel(
// TerariumAsset have a name field, but it's not used for the model name outside
// the front-end.
final Optional<Model> updated = modelService.updateAsset(model, permission);
return updated.map(ResponseEntity::ok)
.orElseGet(() -> ResponseEntity.notFound().build());

if (updated.isEmpty()) {
return ResponseEntity.notFound().build();
}

final Model updatedModel = updated.get();

if (updatedModel.getPublicAsset() && !updatedModel.getTemporary()) {
try {
final String amr = objectMapper.writeValueAsString(updatedModel);
final TerariumAssetEmbeddings embeddings = embeddingService.generateEmbeddings(amr);
modelService.uploadEmbeddings(model.getId(), embeddings, permission);
} catch (final Exception e) {
log.warn("Unable to update embeddings for model " + model.getId(), e);
}
}

return ResponseEntity.ok(updatedModel);
} catch (final IOException e) {
final String error = "Unable to update model";
log.error(error, e);
Expand Down Expand Up @@ -419,6 +439,17 @@ ResponseEntity<Model> createModel(
final ModelConfiguration modelConfiguration =
ModelConfigurationService.modelConfigurationFromAMR(model, null, null);
modelConfigurationService.createAsset(modelConfiguration, permission);

if (model.getPublicAsset() && !model.getTemporary()) {
try {
final String amr = objectMapper.writeValueAsString(model);
final TerariumAssetEmbeddings embeddings = embeddingService.generateEmbeddings(amr);
modelService.uploadEmbeddings(model.getId(), embeddings, permission);
} catch (final Exception e) {
log.warn("Unable to generate embeddings for model " + model.getId(), e);
}
}

return ResponseEntity.status(HttpStatus.CREATED).body(model);
} catch (final IOException e) {
final String error = "Unable to create model";
Expand All @@ -434,7 +465,7 @@ ResponseEntity<Model> createModel(
value = {
@ApiResponse(
responseCode = "200",
description = "Model configurations found.",
description = "Model configurations found",
content =
@Content(
array =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService;
import software.uncharted.terarium.hmiserver.service.data.ModelService;
import software.uncharted.terarium.hmiserver.service.data.ProjectService;
import software.uncharted.terarium.hmiserver.service.data.SummaryService;
import software.uncharted.terarium.hmiserver.service.tasks.CompareModelsResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.ConfigureFromDatasetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.ConfigureModelResponseHandler;
Expand All @@ -66,7 +65,6 @@ public class GoLLMController {
private final ModelService modelService;
private final ProjectService projectService;
private final CurrentUserService currentUserService;
private final SummaryService summaryService;

private final ModelCardResponseHandler modelCardResponseHandler;
private final ConfigureModelResponseHandler configureModelResponseHandler;
Expand Down Expand Up @@ -120,26 +118,28 @@ public ResponseEntity<TaskResponse> createModelCardTask(
projectService.checkPermissionCanRead(currentUserService.get().getId(), projectId);

// Grab the document
final Optional<DocumentAsset> document = documentAssetService.getAsset(documentId, permission);
if (document.isEmpty()) {
final Optional<DocumentAsset> documentOpt = documentAssetService.getAsset(documentId, permission);
if (documentOpt.isEmpty()) {
log.warn(String.format("Document %s not found", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.not-found"));
}

final DocumentAsset document = documentOpt.get();

// make sure there is text in the document
if (document.get().getText() == null || document.get().getText().isEmpty()) {
if (document.getText() == null || document.getText().isEmpty()) {
log.warn(String.format("Document %s has no text to send", documentId));
throw new ResponseStatusException(HttpStatus.NOT_FOUND, messages.get("document.extraction.not-done"));
}

// check for input length
if (document.get().getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) {
if (document.getText().length() > ModelCardResponseHandler.MAX_TEXT_SIZE) {
log.warn(String.format("Document %s text too long for GoLLM model card task", documentId));
throw new ResponseStatusException(HttpStatus.BAD_REQUEST, messages.get("document.text-length-exceeded"));
}

final ModelCardResponseHandler.Input input = new ModelCardResponseHandler.Input();
input.setResearchPaper(document.get().getText());
input.setResearchPaper(document.getText());

// Create the task
final TaskRequest req = new TaskRequest();
Expand All @@ -158,6 +158,7 @@ public ResponseEntity<TaskResponse> createModelCardTask(

final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties();
props.setDocumentId(documentId);
props.setUpdateEmbeddings(document.getPublicAsset() && !document.getTemporary()); // update search embeddings
req.setAdditionalProperties(props);

final TaskResponse resp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import io.swagger.v3.oas.annotations.responses.ApiResponse;
import io.swagger.v3.oas.annotations.responses.ApiResponses;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import lombok.Data;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -26,15 +28,11 @@
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;
import software.uncharted.terarium.hmiserver.configuration.ElasticsearchConfiguration;
import software.uncharted.terarium.hmiserver.models.TerariumAssetEmbeddings;
import software.uncharted.terarium.hmiserver.models.dataservice.AssetType;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest.TaskType;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskStatus;
import software.uncharted.terarium.hmiserver.security.Roles;
import software.uncharted.terarium.hmiserver.service.CurrentUserService;
import software.uncharted.terarium.hmiserver.service.elasticsearch.ElasticsearchService;
import software.uncharted.terarium.hmiserver.service.tasks.TaskService;
import software.uncharted.terarium.hmiserver.service.gollm.EmbeddingService;

@RequestMapping("/search-by-asset-type")
@RestController
Expand All @@ -43,13 +41,9 @@
public class SearchByAssetTypeController {

private final ObjectMapper objectMapper;
private final TaskService taskService;
private final ElasticsearchService esService;
private final ElasticsearchConfiguration esConfig;
private final CurrentUserService currentUserService;

private static final int REQUEST_TIMEOUT_MINUTES = 1;
private static final String EMBEDDING_MODEL = "text-embedding-ada-002";
private final EmbeddingService embeddingService;

private static final List<String> EXCLUDE_FIELDS = List.of("embeddings", "text", "topics");

Expand Down Expand Up @@ -93,7 +87,8 @@ public ResponseEntity<List<JsonNode>> searchByAssetType(
@RequestParam(value = "text", defaultValue = "") final String text,
@RequestParam(value = "k", defaultValue = "100") final int k,
@RequestParam(value = "num-candidates", defaultValue = "1000") final int numCandidates,
@RequestParam(value = "embedding-model", defaultValue = EMBEDDING_MODEL) final String embeddingModel,
@RequestParam(value = "embedding-model", defaultValue = EmbeddingService.EMBEDDING_MODEL)
final String embeddingModel,
@RequestParam(value = "index", defaultValue = "") String index) {
final AssetType assetType = AssetType.getAssetType(assetTypeName, objectMapper);
try {
Expand All @@ -114,32 +109,12 @@ public ResponseEntity<List<JsonNode>> searchByAssetType(
KnnQuery knn = null;
if (text != null && !text.isEmpty()) {

// create the embedding search request
final GoLLMSearchRequest embeddingRequest = new GoLLMSearchRequest();
embeddingRequest.setText(text);
embeddingRequest.setEmbeddingModel(EMBEDDING_MODEL);

final TaskRequest req = new TaskRequest();
req.setTimeoutMinutes(REQUEST_TIMEOUT_MINUTES);
req.setType(TaskType.GOLLM);
req.setInput(embeddingRequest);
req.setScript("gollm_task:embedding");
req.setUserId(currentUserService.get().getId());

final TaskResponse resp = taskService.runTaskSync(req);

if (resp.getStatus() != TaskStatus.SUCCESS) {
throw new ResponseStatusException(
org.springframework.http.HttpStatus.INTERNAL_SERVER_ERROR,
"Unable to generate vectors for knn search");
}

final byte[] outputBytes = resp.getOutput();
final JsonNode output = objectMapper.readTree(outputBytes);
final TerariumAssetEmbeddings embeddings = embeddingService.generateEmbeddings(text);

final EmbeddingsResponse embeddingResp = objectMapper.convertValue(output, EmbeddingsResponse.class);

final List<Float> vector = embeddingResp.getResponse();
final List<Float> vector = Arrays.stream(
embeddings.getEmbeddings().get(0).getVector())
.mapToObj(d -> (float) d)
.collect(Collectors.toList());

knn = new KnnQuery.Builder()
.field("embeddings.vector")
Expand All @@ -165,7 +140,6 @@ public ResponseEntity<List<JsonNode>> searchByAssetType(
docs.add(source);
}
}

return ResponseEntity.ok(docs);

} catch (final Exception e) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package software.uncharted.terarium.hmiserver.models;

import java.util.ArrayList;
import java.util.List;
import lombok.Data;

@Data
public class TerariumAssetEmbeddings {

@Data
public static class Embeddings {
private String embeddingId;
private double[] vector;
private long[] spans;
}

private List<Embeddings> embeddings = new ArrayList<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.springframework.stereotype.Service;
import org.springframework.web.server.ResponseStatusException;
import software.uncharted.terarium.hmiserver.models.ClientEventType;
import software.uncharted.terarium.hmiserver.models.TerariumAssetEmbeddings;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentAsset;
import software.uncharted.terarium.hmiserver.models.dataservice.document.DocumentExtraction;
import software.uncharted.terarium.hmiserver.models.dataservice.document.ExtractionAssetType;
Expand All @@ -39,14 +40,13 @@
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceRelationType;
import software.uncharted.terarium.hmiserver.models.dataservice.provenance.ProvenanceType;
import software.uncharted.terarium.hmiserver.models.task.TaskRequest;
import software.uncharted.terarium.hmiserver.models.task.TaskResponse;
import software.uncharted.terarium.hmiserver.models.task.TaskStatus;
import software.uncharted.terarium.hmiserver.proxies.documentservice.ExtractionProxy;
import software.uncharted.terarium.hmiserver.proxies.skema.SkemaUnifiedProxy;
import software.uncharted.terarium.hmiserver.proxies.skema.SkemaUnifiedProxy.IntegratedTextExtractionsBody;
import software.uncharted.terarium.hmiserver.service.data.DocumentAssetService;
import software.uncharted.terarium.hmiserver.service.data.ModelService;
import software.uncharted.terarium.hmiserver.service.data.ProvenanceService;
import software.uncharted.terarium.hmiserver.service.gollm.EmbeddingService;
import software.uncharted.terarium.hmiserver.service.notification.NotificationGroupInstance;
import software.uncharted.terarium.hmiserver.service.notification.NotificationService;
import software.uncharted.terarium.hmiserver.service.tasks.ModelCardResponseHandler;
Expand All @@ -69,12 +69,13 @@ public class ExtractionService {
private final NotificationService notificationService;
private final TaskService taskService;
private final ProvenanceService provenanceService;
private final EmbeddingService embeddingService;
private final CurrentUserService currentUserService;

// Used to get the Abstract text from PDF
private static final String NODE_CONTENT = "content";

// time the progress takes to reach each subsequent half
// time the progress takes to reach each subsequent half.
final Double HALFTIME_SECONDS = 2.0;

@Value("${terarium.extractionService.poolSize:10}")
Expand Down Expand Up @@ -109,7 +110,7 @@ public Future<DocumentAsset> extractPDF(
final UUID projectId,
final Schema.Permission hasWritePermission) {

final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<Properties>(
final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<>(
clientEventService,
notificationService,
ClientEventType.EXTRACTION_PDF,
Expand Down Expand Up @@ -340,13 +341,13 @@ public Future<DocumentAsset> extractPDF(

final ModelCardResponseHandler.Properties props = new ModelCardResponseHandler.Properties();
props.setDocumentId(documentId);
props.setUpdateEmbeddings(true); // update the embeddings using the card
req.setAdditionalProperties(props);

notificationInterface.sendMessage("Sending GoLLM model card request");
final TaskResponse resp = taskService.runTaskSync(req);
if (resp.getStatus() != TaskStatus.SUCCESS) {
throw new RuntimeException("GoLLM model card task failed");
}

taskService.runTaskSync(req);

notificationInterface.sendMessage("Model Card created");
}
}
Expand Down Expand Up @@ -471,7 +472,7 @@ public Future<DocumentAsset> extractVariables(
final String domain,
final Schema.Permission hasWritePermission) {
// Set up the client interface
final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<Properties>(
final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<>(
clientEventService,
notificationService,
ClientEventType.EXTRACTION,
Expand All @@ -491,7 +492,7 @@ public Future<DocumentAsset> extractVariables(
public Future<Model> alignAMR(
final UUID documentId, final UUID modelId, final Schema.Permission hasWritePermission) {

final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<Properties>(
final NotificationGroupInstance<Properties> notificationInterface = new NotificationGroupInstance<>(
clientEventService,
notificationService,
ClientEventType.EXTRACTION,
Expand Down Expand Up @@ -575,6 +576,25 @@ public Model runAlignAMR(
ProvenanceType.DOCUMENT);
provenanceService.createProvenance(provenance);

// update model embeddings
if (model.getPublicAsset() && !model.getTemporary()) {

final JsonNode card = document.getMetadata().get("gollmCard");
if (card != null) {

final String cardText = objectMapper.writeValueAsString(card);
try {
final TerariumAssetEmbeddings embeddings = embeddingService.generateEmbeddings(cardText);

modelService.uploadEmbeddings(modelId, embeddings, hasWritePermission);
notificationInterface.sendMessage("Embeddings created");

} catch (final Exception e) {
log.warn("Unable to generate embedding vectors for model");
}
}
}

return model;

} catch (final FeignException e) {
Expand Down
Loading

0 comments on commit 3ef9512

Please sign in to comment.