Skip to content

ONNX Embedding Model Thread-Safety Issue #2152

Open
@alfredogangemi

Description

@alfredogangemi

Bug description

When using the default ONNX embedding model in Spring AI (all-MiniLM-L6-v2), running the embedding process asynchronously with a ThreadPoolTaskExecutor results in inconsistent behavior and occasional runtime exceptions. The issue does not occur when executing the process synchronously.

Environment

Java Version: 17
Spring Boot Version: Latest
Spring AI Version: 1.0.0-M5
Vector Store: Qdrant (though likely unrelated)
ONNX Model: Default (all-MiniLM-L6-v2)

Steps to reproduce

Configure a ThreadPoolTaskExecutor for handling embedding asynchronously

@Bean(name = "embeddingThreadPoolTaskExecutor")
public ThreadPoolTaskExecutor threadPoolTaskExecutor() {
    ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
    executor.setCorePoolSize(2);
    executor.setMaxPoolSize(4);
    executor.setQueueCapacity(25);
    executor.setThreadNamePrefix("EmbeddingAsyncExecutor-");
    executor.setWaitForTasksToCompleteOnShutdown(true);
    executor.setAwaitTerminationSeconds(60);
    executor.initialize();
    return executor;
}

Call an embedding process asynchronously

@Async("embeddingThreadPoolTaskExecutor")
public UploadResponse load(String tenant, MultipartFile file, DocumentMetadata metadata) {
    try {
        byte[] fileBytes = file.getBytes();
        ByteArrayResource resource = new ByteArrayResource(fileBytes);
        TikaDocumentReader documentReader = new TikaDocumentReader(resource);

        List<Document> documents = documentReader.get();
        documents.forEach(document -> document.getMetadata().put("entityId", metadata.getEntityId()));
        List<Document> splitDocuments = tokenTextSplitter.apply(documents);

        vectorStoreService.getVectorStore(tenant).add(splitDocuments);
        return new UploadResponse(true, "OK");
    } catch (Exception e) {
        log.error("Error processing file: {}", file.getOriginalFilename(), e);
        return new UploadResponse(false, e.getMessage());
    }
}

Process multiple files in parallel

@Test
@SneakyThrows
public void load() {
    Resource folderResource = new ClassPathResource("foo");
    File folder = folderResource.getFile();
    for (File fileEntry : Objects.requireNonNull(folder.listFiles())) {
        InputStream inputStream = new FileInputStream(fileEntry);
        String mimeType = URLConnection.guessContentTypeFromName(fileEntry.getName());
        MultipartFile file = new MockMultipartFile("file", fileEntry.getName(), mimeType, inputStream);
        documentVectorService.loadAsync("foo", file, new DocumentMetadata());
    }
}

Expected behavior

The embedding process should run correctly across multiple threads.

Observed behavior

  • Running the method synchronously works fine.
  • Running it asynchronously causes intermittent failures.
  • Running the embedding model in a single-threaded executor gives the same error.
  • Switching to OpenAI embeddings works fine, reinforcing the idea that the problem is ONNX-related.

Logs

2025-01-31 17:06:11.001 ERROR [semantic-search-server,,] [EmbeddingAsyncExecutor-1] i.c.w.s.s.etl.DocumentVectorService     : Errore durante l'upload del file: Svizzera.pdf (load DocumentVectorService.java 70)
java.lang.RuntimeException: ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Add node. Name:'/encoder/layer.0/attention/self/Add' Status Message: D:\a\_work\1\s\include\onnxruntime\core/common/logging/logging.h:340 onnxruntime::logging::LoggingManager::DefaultLogger Attempt to use DefaultLogger but none has been registered.

	at org.springframework.ai.transformers.TransformersEmbeddingModel.lambda$call$3(TransformersEmbeddingModel.java:351)
	at io.micrometer.observation.Observation.observe(Observation.java:564)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.call(TransformersEmbeddingModel.java:298)
	at org.springframework.ai.embedding.EmbeddingModel.embed(EmbeddingModel.java:91)
	at org.springframework.ai.vectorstore.qdrant.QdrantVectorStore.doAdd(QdrantVectorStore.java:220)
	at org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore.lambda$add$1(AbstractObservationVectorStore.java:91)
	at io.micrometer.observation.Observation.observe(Observation.java:498)
	at org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore.add(AbstractObservationVectorStore.java:91)
	at it.cegeka.wemaind.semantic_search_server.service.etl.DocumentVectorService.load(DocumentVectorService.java:65)
	at it.cegeka.wemaind.semantic_search_server.service.etl.DocumentVectorService.loadAsync(DocumentVectorService.java:42)
	at jdk.internal.reflect.GeneratedMethodAccessor20.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at org.springframework.aop.support.AopUtils.invokeJoinpointUsingReflection(AopUtils.java:359)
	at org.springframework.aop.framework.ReflectiveMethodInvocation.invokeJoinpoint(ReflectiveMethodInvocation.java:196)
	at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
	at org.springframework.aop.interceptor.AsyncExecutionInterceptor.lambda$invoke$0(AsyncExecutionInterceptor.java:114)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:840)
Caused by: ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Non-zero status code returned while running Add node. Name:'/encoder/layer.0/attention/self/Add' Status Message: D:\a\_work\1\s\include\onnxruntime\core/common/logging/logging.h:340 onnxruntime::logging::LoggingManager::DefaultLogger Attempt to use DefaultLogger but none has been registered.

	at ai.onnxruntime.OrtSession.run(Native Method)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:395)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:242)
	at ai.onnxruntime.OrtSession.run(OrtSession.java:210)
	at org.springframework.ai.transformers.TransformersEmbeddingModel.lambda$call$3(TransformersEmbeddingModel.java:327)
	... 20 common frames omitted

Other infos

The JVM crashed in some tests:

#
# A fatal error has been detected by the Java Runtime Environment:
#
#  EXCEPTION_ACCESS_VIOLATION (0xc0000005) at pc=0x00007fff9109bef3, pid=6540, tid=18268
#
# JRE version: OpenJDK Runtime Environment (17.0.10+13) (build 17.0.10+13-LTS)
# Java VM: OpenJDK 64-Bit Server VM (17.0.10+13-LTS, mixed mode, sharing, tiered, compressed oops, compressed class ptrs, g1 gc, windows-amd64)
# Problematic frame:
# C[thread 29912 also had an error]
  [onnxruntime.dll+0x71bef3]
#
# No core dump will be written. Minidumps are not enabled by default on client versions of Windows
#
# An error report file with more information is saved as:
# C:\Users\alfredog\IdeaProjects\Microservices\semantic-search-server\hs_err_pid6540.log
#
# If you would like to submit a bug report, please visit:
#   https://bell-sw.com/support
# The crash happened outside the Java Virtual Machine in native code.
# See problematic frame for where to report the bug.
#

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions