Skip to content

Commit

Permalink
fixup! fix LLMChoice with improved openai schema
Browse files Browse the repository at this point in the history
  • Loading branch information
kyriediculous committed Jan 10, 2025
1 parent a943e98 commit 97af593
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 117 deletions.
173 changes: 87 additions & 86 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 19 additions & 31 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,6 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq
func (w *Worker) LLM(ctx context.Context, req GenLLMJSONRequestBody) (interface{}, error) {
isStreaming := req.Stream != nil && *req.Stream
ctx, cancel := context.WithCancel(ctx)
defer cancel()
c, err := w.borrowContainer(ctx, "llm", *req.Model)
if err != nil {
return nil, err
Expand All @@ -419,8 +418,8 @@ func (w *Worker) LLM(ctx context.Context, req GenLLMJSONRequestBody) (interface{
}
return w.handleStreamingResponse(ctx, c, resp, cancel)
}
defer cancel()

defer cancel()
resp, err := c.Client.GenLLMWithResponse(ctx, req)
if err != nil {
return nil, err
Expand Down Expand Up @@ -762,52 +761,41 @@ func (w *Worker) handleNonStreamingResponse(c *RunnerContainer, resp *GenLLMResp
return resp.JSON200, nil
}

type LlmStreamChunk struct {
Chunk string `json:"chunk,omitempty"`
TokensUsed int `json:"tokens_used,omitempty"`
Done bool `json:"done,omitempty"`
}

func (w *Worker) handleStreamingResponse(ctx context.Context, c *RunnerContainer, resp *http.Response, returnContainer func()) (<-chan LlmStreamChunk, error) {
func (w *Worker) handleStreamingResponse(ctx context.Context, c *RunnerContainer, resp *http.Response, returnContainer func()) (<-chan *LLMResponse, error) {
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

outputChan := make(chan LlmStreamChunk, 10)
outputChan := make(chan *LLMResponse, 10)

go func() {
defer close(outputChan)
defer returnContainer()

scanner := bufio.NewScanner(resp.Body)
totalTokens := 0

for scanner.Scan() {
select {
case <-ctx.Done():
return
default:
line := scanner.Text()
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
if data == "[DONE]" {
outputChan <- LlmStreamChunk{Chunk: "[DONE]", Done: true, TokensUsed: totalTokens}
return
}

var streamData LlmStreamChunk
if err := json.Unmarshal([]byte(data), &streamData); err != nil {
slog.Error("Error unmarshaling stream data", slog.String("err", err.Error()))
continue
}

totalTokens += streamData.TokensUsed

select {
case outputChan <- streamData:
case <-ctx.Done():
return
}
data := strings.TrimPrefix(line, "data: ")

if data == "[DONE]" {
break
}

var llmRes *LLMResponse
if err := json.Unmarshal([]byte(data), llmRes); err != nil {
slog.Error("Error unmarshaling stream data", slog.String("err", err.Error()))
continue
}

select {
case outputChan <- llmRes:
case <-ctx.Done():
return
}
}
}
Expand Down

0 comments on commit 97af593

Please sign in to comment.