diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index e85f9e0334..fedd9d3abe 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -36,6 +36,7 @@ import com.azure.ai.openai.models.ChatCompletionsJsonResponseFormat; import com.azure.ai.openai.models.ChatCompletionsOptions; import com.azure.ai.openai.models.ChatCompletionsResponseFormat; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; import com.azure.ai.openai.models.ChatCompletionsTextResponseFormat; import com.azure.ai.openai.models.ChatCompletionsToolCall; import com.azure.ai.openai.models.ChatCompletionsToolDefinition; @@ -496,8 +497,9 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { options = this.merge(options, this.defaultOptions); + AzureOpenAiChatOptions updatedRuntimeOptions; + if (prompt.getOptions() != null) { - AzureOpenAiChatOptions updatedRuntimeOptions; if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class, AzureOpenAiChatOptions.class); @@ -521,6 +523,15 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) { options.setTools(tools2); } + Boolean enableStreamUsage = (prompt.getOptions() instanceof AzureOpenAiChatOptions azureOpenAiChatOptions + && azureOpenAiChatOptions.getStreamUsage() != null) ? azureOpenAiChatOptions.getStreamUsage() + : this.defaultOptions.getStreamUsage(); + + if (Boolean.TRUE.equals(enableStreamUsage) && options.getStreamOptions() == null) { + ChatCompletionsOptionsAccessHelper.setStreamOptions(options, + new ChatCompletionStreamOptions().setIncludeUsage(true)); + } + return options; } @@ -644,6 +655,8 @@ Prompt buildRequestPrompt(Prompt prompt) { requestOptions.setInternalToolExecutionEnabled( ModelOptionsUtils.mergeOption(runtimeOptions.getInternalToolExecutionEnabled(), this.defaultOptions.getInternalToolExecutionEnabled())); + requestOptions.setStreamUsage(ModelOptionsUtils.mergeOption(runtimeOptions.getStreamUsage(), + this.defaultOptions.getStreamUsage())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), @@ -653,6 +666,7 @@ Prompt buildRequestPrompt(Prompt prompt) { } else { requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.getInternalToolExecutionEnabled()); + requestOptions.setStreamUsage(this.defaultOptions.getStreamUsage()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); requestOptions.setToolContext(this.defaultOptions.getToolContext()); diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 78651bd5a6..9b116c0c44 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -199,6 +199,13 @@ public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @JsonIgnore private Boolean internalToolExecutionEnabled; + /** + * Whether to include token usage information in streaming chat completion responses. + * Only applies to streaming responses. + */ + @JsonIgnore + private Boolean enableStreamUsage; + @Override @JsonIgnore public List getToolCallbacks() { @@ -259,6 +266,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) .toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null) .responseFormat(fromOptions.getResponseFormat()) + .streamUsage(fromOptions.getStreamUsage()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) .topLogprobs(fromOptions.getTopLogProbs()) @@ -391,6 +399,14 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) { this.responseFormat = responseFormat; } + public Boolean getStreamUsage() { + return this.enableStreamUsage; + } + + public void setStreamUsage(Boolean enableStreamUsage) { + this.enableStreamUsage = enableStreamUsage; + } + @Override @JsonIgnore public Integer getTopK() { @@ -472,6 +488,7 @@ public boolean equals(Object o) { && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) && Objects.equals(this.enhancements, that.enhancements) && Objects.equals(this.streamOptions, that.streamOptions) + && Objects.equals(this.enableStreamUsage, that.enableStreamUsage) && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) && Objects.equals(this.presencePenalty, that.presencePenalty) @@ -482,8 +499,8 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, - this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens, - this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); + this.topLogProbs, this.enhancements, this.streamOptions, this.enableStreamUsage, this.toolContext, + this.maxTokens, this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); } public static class Builder { @@ -553,6 +570,11 @@ public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) { return this; } + public Builder streamUsage(Boolean enableStreamUsage) { + this.options.enableStreamUsage = enableStreamUsage; + return this; + } + public Builder seed(Long seed) { this.options.seed = seed; return this; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java index b3a8bfd6d7..5246864325 100644 --- a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -56,6 +56,7 @@ void testBuilderWithAllFields() { .topP(0.9) .user("test-user") .responseFormat(responseFormat) + .streamUsage(true) .seed(12345L) .logprobs(true) .topLogprobs(5) @@ -65,11 +66,11 @@ void testBuilderWithAllFields() { assertThat(options) .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", - "temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements", - "streamOptions") + "temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs", + "enhancements", "streamOptions") .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, - List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements, - streamOptions); + List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5, + enhancements, streamOptions); } @Test @@ -94,6 +95,7 @@ void testCopy() { .topP(0.9) .user("test-user") .responseFormat(responseFormat) + .streamUsage(true) .seed(12345L) .logprobs(true) .topLogprobs(5) @@ -128,6 +130,7 @@ void testSetters() { options.setTopP(0.9); options.setUser("test-user"); options.setResponseFormat(responseFormat); + options.setStreamUsage(true); options.setSeed(12345L); options.setLogprobs(true); options.setTopLogProbs(5); @@ -148,6 +151,7 @@ void testSetters() { assertThat(options.getTopP()).isEqualTo(0.9); assertThat(options.getUser()).isEqualTo("test-user"); assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getStreamUsage()).isTrue(); assertThat(options.getSeed()).isEqualTo(12345L); assertThat(options.isLogprobs()).isTrue(); assertThat(options.getTopLogProbs()).isEqualTo(5); @@ -171,6 +175,7 @@ void testDefaultValues() { assertThat(options.getTopP()).isNull(); assertThat(options.getUser()).isNull(); assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getStreamUsage()).isNull(); assertThat(options.getSeed()).isNull(); assertThat(options.isLogprobs()).isNull(); assertThat(options.getTopLogProbs()).isNull(); diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc index 00160fd09e..9705f31f46 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc @@ -185,6 +185,7 @@ Deployments model name to provide as part of this completions request. | gpt-4o | spring.ai.azure.openai.chat.options.topP | An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results of tokens with the provided probability mass. | - | spring.ai.azure.openai.chat.options.logitBias | A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions response. Token IDs are computed via external tokenizer tools, while bias scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection of a token, respectively. The exact behavior of a given bias score varies by model. | - | spring.ai.azure.openai.chat.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | - +| spring.ai.azure.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false | spring.ai.azure.openai.chat.options.n | The number of chat completions choices that should be generated for a chat completions response. | - | spring.ai.azure.openai.chat.options.stop | A collection of textual sequences that will end completions generation. | - | spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | -