Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX Embedding Model Thread-Safety Issue #2152

Open
alfredogangemi opened this issue Jan 31, 2025 · 5 comments
Open

ONNX Embedding Model Thread-Safety Issue #2152

alfredogangemi opened this issue Jan 31, 2025 · 5 comments

Comments

@alfredogangemi
Copy link

Bug description

When using the default ONNX embedding model in Spring AI (all-MiniLM-L6-v2), running the embedding process asynchronously with a ThreadPoolTaskExecutor results in inconsistent behavior and occasional runtime exceptions. The issue does not occur when executing the process synchronously.

Environment

Java Version: 17
Spring Boot Version: Latest
Spring AI Version: 1.0.0-M5
Vector Store: Qdrant (though likely unrelated)
ONNX Model: Default (all-MiniLM-L6-v2)

Steps to reproduce

Configure a ThreadPoolTaskExecutor for handling embedding asynchronously

@Bean(name = "embeddingThreadPoolTaskExecutor")
public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
    ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
    executor.setCorePoolSize(2);
    executor.setMaxPoolSize(4);
    executor.setQueueCapacity(25);
    executor.setThreadNamePrefix("EmbeddingAsyncExecutor-");
    executor.setWaitForTasksToCompleteOnShutdown(true);
    executor.setAwaitTerminationSeconds(60);
    executor.initialize();
    return executor;
}

Call an embedding process asynchronously

@Async("embeddingThreadPoolTaskExecutor")
public UploadResponse load(String tenant, MultipartFile file, DocumentMetadata metadata) {
    try {
        byte[] fileBytes = file.getBytes();
        ByteArrayResource resource = new ByteArrayResource(fileBytes);
        TikaDocumentReader documentReader = new TikaDocumentReader(resource);

        List<Document> documents = documentReader.get();
        documents.forEach(document -> document.getMetadata().put("entityId", metadata.getEntityId()));
        List<Document> splitDocuments = tokenTextSplitter.apply(documents);

        vectorStoreService.getVectorStore(tenant).add(splitDocuments);
        return new UploadResponse(true, "OK");
    } catch (Exception e) {
        log.error("Error processing file: {}", file.getOriginalFilename(), e);
        return new UploadResponse(false, e.getMessage());
    }
}

Process multiple files in parallel

@Test
@SneakyThrows
public void load() {
    Resource folderResource = new ClassPathResource("foo");
    File folder = folderResource.getFile();
    for (File fileEntry : Objects.requireNonNull(folder.listFiles())) {
        InputStream inputStream = new FileInputStream(fileEntry);
        String mimeType = URLConnection.guessContentTypeFromName(fileEntry.getName());
        MultipartFile file = new MockMultipartFile("file", fileEntry.getName(), mimeType, inputStream);
        documentVectorService.loadAsync("foo", file, new DocumentMetadata());
    }
}

Expected behavior

The embedding process should run correctly across multiple threads.

Observed behavior

  • Running the method synchronously works fine.
  • Running it asynchronously causes intermittent failures.
  • Running the embedding model in a single-threaded executor gives the same error.
  • Switching to OpenAI embeddings works fine, reinforcing the idea that the problem is ONNX-related.

Logs

2025-01-31 17:06:11.001 ERROR [semantic-search-server,,] [EmbeddingAsyncExecutor-1] i.c.w.s.s.etl.DocumentVectorService     : Errore durante l'upload del file: Svizzera.pdf (load DocumentVectorService.java 70)
java.lang.RuntimeException: ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Add node. Name:'/encoder/layer.0/attention/self/Add' Status Message: D:\a\_work\1\s\include\onnxruntime\core/common/logging/logging.h:340 onnxruntime::logging::LoggingManager::DefaultLogger Attempt to use DefaultLogger but none has been registered.

	at org.springframework.ai.transformers.TransformersEmbeddingModel.lambda$call$3(TransformersEmbeddingModel.java:351)
	at io.micrometer.observation.Observation.observe(Observation.java:564)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:298)
	at org.springframework.ai.embedding.EmbeddingModel.embed(EmbeddingModel.java:91)
	at org.springframework.ai.vectorstore.qdrant.QdrantVectorStore.doAdd(QdrantVectorStore.java:220)
	at org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore.lambda$add$1(AbstractObservationVectorStore.java:91)
	at io.micrometer.observation.Observation.observe(Observation.java:498)
	at org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore.add(AbstractObservationVectorStore.java:91)
	at it.cegeka.wemaind.semantic_search_server.service.etl.DocumentVectorService.load(DocumentVectorService.java:65)
	at it.cegeka.wemaind.semantic_search_server.service.etl.DocumentVectorService.loadAsync(DocumentVectorService.java:42)
	at jdk.internal.reflect.GeneratedMethodAccessor20.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at org.springframework.aop.support.AopUtils.invokeJoinpointUsingReflection(AopUtils.java:359)
	at org.springframework.aop.framework.ReflectiveMethodInvocation.invokeJoinpoint(ReflectiveMethodInvocation.java:196)
	at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
	at org.springframework.aop.interceptor.AsyncExecutionInterceptor.lambda$invoke$0(AsyncExecutionInterceptor.java:114)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Add node. Name:'/encoder/layer.0/attention/self/Add' Status Message: D:\a\_work\1\s\include\onnxruntime\core/common/logging/logging.h:340 onnxruntime::logging::LoggingManager::DefaultLogger Attempt to use DefaultLogger but none has been registered.

	at ai.onnxruntime.OrtSession.run(Native Method)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:395)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:242)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:210)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.lambda$call$3(TransformersEmbeddingModel.java:327)
	... 20 common frames omitted

Other infos

The JVM crashed in some tests:

#
# A fatal error has been detected by the Java Runtime Environment:
#
#  EXCEPTION_ACCESS_VIOLATION (0xc0000005) at pc=0x00007fff9109bef3, pid=6540, tid=18268
#
# JRE version: OpenJDK Runtime Environment (17.0.10+13) (build 17.0.10+13-LTS)
# Java VM: OpenJDK 64-Bit Server VM (17.0.10+13-LTS, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, windows-amd64)
# Problematic frame:
# C[thread 29912 also had an error]
  [onnxruntime.dll+0x71bef3]
#
# No core dump will be written. Minidumps are not enabled by default on client versions of Windows
#
# An error report file with more information is saved as:
# C:\Users\alfredog\IdeaProjects\Microservices\semantic-search-server\hs_err_pid6540.log
#
# If you would like to submit a bug report, please visit:
#   https://bell-sw.com/support
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#
@Craigacp
Copy link

Craigacp commented Feb 3, 2025

I'm working on running this down in ONNX Runtime, and while looking at the Spring AI code I noticed it's closing the SessionOptions on construction - https://github.com/spring-projects/spring-ai/blame/main/models/spring-ai-transformers/src/main/java/org/springframework/ai/transformers/TransformersEmbeddingModel.java#L236. The session options should outlive the sessions that use it as it may be used to cache things, as noted in the documentation - https://onnxruntime.ai/docs/api/java/ai/onnxruntime/OrtSession.SessionOptions.html. I don't think that's the root cause of this error, but it could cause other errors, particularly when using GPUs.

@Craigacp
Copy link

Craigacp commented Feb 3, 2025

Link to the ONNX Runtime issue - microsoft/onnxruntime#23555.

@alfredogangemi
Copy link
Author

https://github.com/alfredogangemi/spring-ai-onnx-async

This is a Spring test project that replicates the error. The operations performed are exactly the same as in my main project, including the files used. You will also find the JVM crash logs and the error file that occurs in other cases.

@alfredogangemi
Copy link
Author

Thanks @Craigacp for analysing this issue. this issue. See microsoft/onnxruntime#23555 for more details.

@Craigacp
Copy link

Craigacp commented Feb 5, 2025

The root cause of this crash is that ONNX Runtime registers a shutdown hook to close the OrtEnvironment, however shutdown hooks run concurrently with daemon threads (as used in the async thread pool executor), so if the main thread has terminated and triggered the shutdown hook then the daemon threads can see an inconsistent state which can crash the JVM. Not sure if there's anything to be done about this on the Spring AI side, but I don't understand the async threading mechanisms very well so maybe there's a fix on your end.

microsoft/onnxruntime#10670

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants