Skip to content

Commit 10d59d5

Browse files
authored
openai: finish_reason as tool_calls for streaming with tools (ollama#7963)
1 parent a4f69a0 commit 10d59d5

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

openai/openai.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"github.com/ollama/ollama/types/model"
2121
)
2222

23+
var finishReasonToolCalls = "tool_calls"
24+
2325
type Error struct {
2426
Message string `json:"message"`
2527
Type string `json:"type"`
@@ -266,7 +268,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
266268
}
267269
}
268270

269-
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
271+
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
270272
toolCalls := toToolCalls(r.Message.ToolCalls)
271273
return ChatCompletionChunk{
272274
Id: id,
@@ -279,6 +281,9 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
279281
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls},
280282
FinishReason: func(reason string) *string {
281283
if len(reason) > 0 {
284+
if toolCallSent {
285+
return &finishReasonToolCalls
286+
}
282287
return &reason
283288
}
284289
return nil
@@ -585,6 +590,7 @@ type ChatWriter struct {
585590
stream bool
586591
streamOptions *StreamOptions
587592
id string
593+
toolCallSent bool
588594
BaseWriter
589595
}
590596

@@ -634,11 +640,14 @@ func (w *ChatWriter) writeResponse(data []byte) (int, error) {
634640

635641
// chat chunk
636642
if w.stream {
637-
c := toChunk(w.id, chatResponse)
643+
c := toChunk(w.id, chatResponse, w.toolCallSent)
638644
d, err := json.Marshal(c)
639645
if err != nil {
640646
return 0, err
641647
}
648+
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
649+
w.toolCallSent = true
650+
}
642651

643652
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
644653
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))

0 commit comments

Comments
 (0)