Skip to content

Commit b7190c6

Browse files
committed
up
1 parent d5d87a9 commit b7190c6

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

pkg/openai/chat.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -273,16 +273,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
273273
return
274274
}
275275
defer resp.Body.Close()
276-
c.Writer.WriteHeader(resp.StatusCode)
277-
for key, value := range resp.Header {
278-
for _, v := range value {
279-
c.Writer.Header().Add(key, v)
280-
}
281-
}
282-
teeReader := io.TeeReader(resp.Body, c.Writer)
283276

284277
var result string
285278
if chatReq.Stream {
279+
for key, value := range resp.Header {
280+
for _, v := range value {
281+
c.Writer.Header().Add(key, v)
282+
}
283+
}
284+
c.Writer.WriteHeader(resp.StatusCode)
285+
teeReader := io.TeeReader(resp.Body, c.Writer)
286286
// 流式响应
287287
scanner := bufio.NewScanner(teeReader)
288288

@@ -318,15 +318,18 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
318318

319319
}
320320
} else {
321+
321322
// 处理非流式响应
322-
body, err := io.ReadAll(teeReader)
323+
body, err := io.ReadAll(resp.Body)
323324
if err != nil {
324325
fmt.Println("Error reading response body:", err)
326+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
325327
return
326328
}
327329
var opiResp ChatCompletionResponse
328330
if err := json.Unmarshal(body, &opiResp); err != nil {
329331
log.Println("Error parsing JSON:", err)
332+
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
330333
return
331334
}
332335
if opiResp.Choices != nil && len(opiResp.Choices) > 0 {
@@ -343,6 +346,16 @@ func ChatProxy(c *gin.Context, chatReq *ChatCompletionRequest) {
343346
}
344347

345348
}
349+
resp.Body = io.NopCloser(bytes.NewBuffer(body))
350+
351+
for k, v := range resp.Header {
352+
c.Writer.Header().Set(k, v[0])
353+
}
354+
c.Writer.WriteHeader(resp.StatusCode)
355+
_, err = io.Copy(c.Writer, resp.Body)
356+
if err != nil {
357+
log.Println(err)
358+
}
346359
}
347360
usagelog.CompletionCount = tokenizer.NumTokensFromStr(result, chatReq.Model)
348361
usagelog.Cost = fmt.Sprintf("%.6f", tokenizer.Cost(usagelog.Model, usagelog.PromptCount, usagelog.CompletionCount))

0 commit comments

Comments
 (0)