Skip to content

Commit 8fa1dbc

Browse files
committed
Add support for stream usage in Azure OpenAi
Signed-off-by: Andres da Silva Santos <[email protected]>
1 parent 81b715b commit 8fa1dbc

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

+15
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti
259259
fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null)
260260
.toolNames(fromOptions.getToolNames() != null ? new HashSet<>(fromOptions.getToolNames()) : null)
261261
.responseFormat(fromOptions.getResponseFormat())
262+
.streamUsage(fromOptions.getStreamUsage())
262263
.seed(fromOptions.getSeed())
263264
.logprobs(fromOptions.isLogprobs())
264265
.topLogprobs(fromOptions.getTopLogProbs())
@@ -391,6 +392,14 @@ public void setResponseFormat(AzureOpenAiResponseFormat responseFormat) {
391392
this.responseFormat = responseFormat;
392393
}
393394

395+
public Boolean getStreamUsage() {
396+
return this.streamOptions != null;
397+
}
398+
399+
public void setStreamUsage(Boolean enableStreamUsage) {
400+
this.streamOptions = (enableStreamUsage) ? new ChatCompletionStreamOptions().setIncludeUsage(true) : null;
401+
}
402+
394403
@Override
395404
@JsonIgnore
396405
public Integer getTopK() {
@@ -553,6 +562,12 @@ public Builder responseFormat(AzureOpenAiResponseFormat responseFormat) {
553562
return this;
554563
}
555564

565+
public Builder streamUsage(boolean enableStreamUsage) {
566+
this.options.streamOptions = (enableStreamUsage) ? new ChatCompletionStreamOptions().setIncludeUsage(true)
567+
: null;
568+
return this;
569+
}
570+
556571
public Builder seed(Long seed) {
557572
this.options.seed = seed;
558573
return this;

models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java

+9-4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ void testBuilderWithAllFields() {
5656
.topP(0.9)
5757
.user("test-user")
5858
.responseFormat(responseFormat)
59+
.streamUsage(true)
5960
.seed(12345L)
6061
.logprobs(true)
6162
.topLogprobs(5)
@@ -65,11 +66,11 @@ void testBuilderWithAllFields() {
6566

6667
assertThat(options)
6768
.extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop",
68-
"temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements",
69-
"streamOptions")
69+
"temperature", "topP", "user", "responseFormat", "streamUsage", "seed", "logprobs", "topLogProbs",
70+
"enhancements", "streamOptions")
7071
.containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8,
71-
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements,
72-
streamOptions);
72+
List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, true, 12345L, true, 5,
73+
enhancements, streamOptions);
7374
}
7475

7576
@Test
@@ -94,6 +95,7 @@ void testCopy() {
9495
.topP(0.9)
9596
.user("test-user")
9697
.responseFormat(responseFormat)
98+
.streamUsage(true)
9799
.seed(12345L)
98100
.logprobs(true)
99101
.topLogprobs(5)
@@ -128,6 +130,7 @@ void testSetters() {
128130
options.setTopP(0.9);
129131
options.setUser("test-user");
130132
options.setResponseFormat(responseFormat);
133+
options.setStreamUsage(true);
131134
options.setSeed(12345L);
132135
options.setLogprobs(true);
133136
options.setTopLogProbs(5);
@@ -148,6 +151,7 @@ void testSetters() {
148151
assertThat(options.getTopP()).isEqualTo(0.9);
149152
assertThat(options.getUser()).isEqualTo("test-user");
150153
assertThat(options.getResponseFormat()).isEqualTo(responseFormat);
154+
assertThat(options.getStreamUsage()).isTrue();
151155
assertThat(options.getSeed()).isEqualTo(12345L);
152156
assertThat(options.isLogprobs()).isTrue();
153157
assertThat(options.getTopLogProbs()).isEqualTo(5);
@@ -171,6 +175,7 @@ void testDefaultValues() {
171175
assertThat(options.getTopP()).isNull();
172176
assertThat(options.getUser()).isNull();
173177
assertThat(options.getResponseFormat()).isNull();
178+
assertThat(options.getStreamUsage()).isFalse();
174179
assertThat(options.getSeed()).isNull();
175180
assertThat(options.isLogprobs()).isNull();
176181
assertThat(options.getTopLogProbs()).isNull();

spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/azure-openai-chat.adoc

+1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ Deployments model name to provide as part of this completions request. | gpt-4o
185185
| spring.ai.azure.openai.chat.options.topP | An alternative to sampling with temperature called nucleus sampling. This value causes the model to consider the results of tokens with the provided probability mass. | -
186186
| spring.ai.azure.openai.chat.options.logitBias | A map between GPT token IDs and bias scores that influences the probability of specific tokens appearing in a completions response. Token IDs are computed via external tokenizer tools, while bias scores reside in the range of -100 to 100 with minimum and maximum values corresponding to a full ban or exclusive selection of a token, respectively. The exact behavior of a given bias score varies by model. | -
187187
| spring.ai.azure.openai.chat.options.user | An identifier for the caller or end user of the operation. This may be used for tracking or rate-limiting purposes. | -
188+
| spring.ai.azure.openai.chat.options.stream-usage | (For streaming only) Set to add an additional chunk with token usage statistics for the entire request. The `choices` field for this chunk is an empty array and all other chunks will also include a usage field, but with a null value. | false
188189
| spring.ai.azure.openai.chat.options.n | The number of chat completions choices that should be generated for a chat completions response. | -
189190
| spring.ai.azure.openai.chat.options.stop | A collection of textual sequences that will end completions generation. | -
190191
| spring.ai.azure.openai.chat.options.presencePenalty | A value that influences the probability of generated tokens appearing based on their existing presence in generated text. Positive values will make tokens less likely to appear when they already exist and increase the model's likelihood to output new topics. | -

0 commit comments

Comments
 (0)