Skip to content

Commit

Permalink
feat: init support for RAG
Browse files Browse the repository at this point in the history
  • Loading branch information
medcl committed Feb 16, 2025
1 parent 165be8a commit f566781
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/content.en/docs/release-notes/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Information about release notes of Coco Server is provided here.
- Google Drive Connector
- Yuque Connector
- Notion Connector
- RAG based Chat

### Breaking changes

Expand Down
152 changes: 139 additions & 13 deletions modules/assistant/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ package assistant

import (
"context"
"fmt"
log "github.com/cihub/seelog"
"infini.sh/coco/lib/langchaingo/llms"
"infini.sh/coco/lib/langchaingo/llms/ollama"
"infini.sh/coco/modules/common"
"infini.sh/coco/modules/search"
httprouter "infini.sh/framework/core/api/router"
"infini.sh/framework/core/api/websocket"
"infini.sh/framework/core/orm"
Expand Down Expand Up @@ -44,19 +46,20 @@ const MessageTypeSystem string = "system"

type ChatMessage struct {
orm.ORMObjectBase
MessageType string `json:"type"` // user, assistant, system
SessionID string `json:"session_id"`
From string `json:"from"`
To string `json:"to,omitempty"`
Message string `config:"message" json:"message,omitempty" elastic_mapping:"message:{type:keyword}"`
MessageType string `json:"type"` // user, assistant, system
SessionID string `json:"session_id"`
Parameters util.MapStr `json:"parameters,omitempty"`
From string `json:"from"`
To string `json:"to,omitempty"`
Message string `config:"message" json:"message,omitempty" elastic_mapping:"message:{type:keyword}"`
}

func (h APIHandler) getChatSessions(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {

q := orm.Query{}
q.From = h.GetIntOrDefault(req, "from", 0)
q.Size = h.GetIntOrDefault(req, "size", 20)

q.AddSort("updated", orm.DESC)
err, res := orm.Search(&Session{}, &q)
if err != nil {
h.WriteError(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -130,6 +133,7 @@ func (h APIHandler) getChatHistoryBySession(w http.ResponseWriter, req *http.Req
q.Conds = orm.And(orm.Eq("session_id", ps.MustGetParameter("session_id")))
q.From = h.GetIntOrDefault(req, "from", 0)
q.Size = h.GetIntOrDefault(req, "size", 20)
q.AddSort("updated", orm.ASC)

err, res := orm.Search(&ChatMessage{}, &q)
if err != nil {
Expand Down Expand Up @@ -157,6 +161,57 @@ func (h APIHandler) cancelReplyMessage(w http.ResponseWriter, req *http.Request,
}
}

func formatDocumentReferences(docs []common.Document) string {
var sb strings.Builder
sb.WriteString("<REFERENCES>\n")
for i, doc := range docs {
sb.WriteString(fmt.Sprintf("<Doc>"))
sb.WriteString(fmt.Sprintf("ID #%d - %v\n", i+1, doc.ID))
sb.WriteString(fmt.Sprintf("Title: %s\n", doc.Title))
sb.WriteString(fmt.Sprintf("Source: %s\n", doc.Source))
sb.WriteString(fmt.Sprintf("Updated: %s\n", doc.Updated))
sb.WriteString(fmt.Sprintf("Category: %s\n", doc.GetAllCategories()))
sb.WriteString(fmt.Sprintf("Summary: %s\n", doc.Summary))
sb.WriteString(fmt.Sprintf("Content: %s\n", doc.Content))
sb.WriteString(fmt.Sprintf("</Doc>\n"))

}
sb.WriteString("</REFERENCES>")
return sb.String()
}

func formatDocumentReferencesToDisplay(docs []common.Document) string {
var sb strings.Builder
sb.WriteString(fmt.Sprintf("<Source total=%v>\n", len(docs)))
outDocs := []util.MapStr{}
for _, doc := range docs {
item := util.MapStr{}
item["id"] = doc.ID
item["title"] = doc.Title
item["source"] = doc.Source
item["updated"] = doc.Updated
item["category"] = doc.Category
item["summary"] = doc.Summary
item["icon"] = doc.Icon
item["size"] = doc.Size
item["thumbnail"] = doc.Thumbnail
item["url"] = doc.URL
outDocs = append(outDocs, item)
}
sb.WriteString(util.MustToJSON(outDocs))
sb.WriteString("</Source>")
return sb.String()
}

func fetchDocuments(query *orm.Query) ([]common.Document, error) {
var docs []common.Document
err, _ := orm.SearchWithJSONMapper(&docs, query)
if err != nil {
return nil, fmt.Errorf("failed to fetch documents: %w", err)
}
return docs, nil
}

func (h APIHandler) sendChatMessage(w http.ResponseWriter, req *http.Request, ps httprouter.Params) {

webSocketID := req.Header.Get("WEBSOCKET-SESSION-ID")
Expand All @@ -170,12 +225,33 @@ func (h APIHandler) sendChatMessage(w http.ResponseWriter, req *http.Request, ps
return
}

var (
from = h.GetIntOrDefault(req, "from", 0)
size = h.GetIntOrDefault(req, "size", 10)
datasource = h.GetParameterOrDefault(req, "datasource", "")
category = h.GetParameterOrDefault(req, "category", "")
username = h.GetParameterOrDefault(req, "username", "")
userid = h.GetParameterOrDefault(req, "userid", "")
tags = h.GetParameterOrDefault(req, "tags", "")
subcategory = h.GetParameterOrDefault(req, "subcategory", "")
richCategory = h.GetParameterOrDefault(req, "rich_category", "")
field = h.GetParameterOrDefault(req, "search_field", "title")
source = h.GetParameterOrDefault(req, "source_fields", "*")
)

searchDB := h.GetBoolOrDefault(req, "search", true)

obj := ChatMessage{
SessionID: sessionID,
MessageType: MessageTypeUser,
Message: request.Message,
}

if searchDB {
obj.Parameters = util.MapStr{}
obj.Parameters["search"] = searchDB
}

err := orm.Create(nil, &obj)
if err != nil {
h.WriteError(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -189,15 +265,23 @@ func (h APIHandler) sendChatMessage(w http.ResponseWriter, req *http.Request, ps
}}

if webSocketID != "" {
var query *orm.Query
if searchDB {
mustClauses := search.BuildMustClauses(datasource, category, subcategory, richCategory, username, userid)
query = search.BuildTemplatedQuery(from, size, mustClauses, field, obj.Message, source, tags)

}

//de-duplicate background task per-session, cancelable
taskID := task.RunWithinGroup("assistant-session", func(taskCtx context.Context) error {
//timeout for 30 seconds

log.Debugf("place a assistant background job for session: %v, websocket: %v ", sessionID, webSocketID)

//TODO
//1. retrieve related documents from background server
//2. summary previous history chat as context
//1. expand and rewrite the query
// use the title and summary to judge which document need to fetch in-depth, also the updated time to check the data is fresh or not
// pick N related documents and combine with the memory and the near chat history as the chat context
//2. summary previous history chat as context, update as memory
//3. assemble with the agent's role setting
//4. send to LLM

Expand All @@ -214,27 +298,69 @@ func (h APIHandler) sendChatMessage(w http.ResponseWriter, req *http.Request, ps
}
ctx := context.Background()

// Prepare the system message
content := []llms.MessageContent{
//llms.TextParts(llms.ChatMessageTypeSystem, "You are a company branding design wizard."),
//llms.TextParts(llms.ChatMessageTypeHuman, "What would be a good company name for a comapny that produces Go-backed LLM tools?"),
llms.TextParts(llms.ChatMessageTypeHuman, request.Message),
llms.TextParts(llms.ChatMessageTypeSystem, "You are a personal AI assistant designed by Coco AI(https://coco.rs), the backend team is behind INFINI Labs(https://infinilabs.com)."),
}

var references string
var simpliedReferences string
//Retrieve related documents from background server
if searchDB && query != nil {
docs, err := fetchDocuments(query)
if err != nil {
log.Errorf("Failed to fetch documents from DB: %v", err)
// Proceed without RAG
} else if len(docs) > 0 {
references = formatDocumentReferences(docs)
simpliedReferences = formatDocumentReferencesToDisplay(docs)
}
}

prompt := fmt.Sprintf(`You are a friendly assistant designed to help users access and understand their personal or company data. Your responses should be clear, concise, and based solely on the information provided below. If the information is insufficient, please indicate that you need more details to assist effectively.
Query: %s
Data:
%s`, request.Message, references)

// Append the user's message
content = append(content, llms.TextParts(llms.ChatMessageTypeHuman, prompt))

log.Debug(content)

chunkSeq := 0
messageID := util.GetUUID()
requestMessageID := obj.ID
messageBuffer := strings.Builder{}

if simpliedReferences != "" {
messageBuffer.WriteString(simpliedReferences)
}

sentSource := false

completion, err := llm.GenerateContent(ctx, content,
llms.WithTemperature(0.8),
llms.WithStreamingFunc(func(ctx context.Context, chunk []byte) error {

var txtMsg string
if !sentSource {
if simpliedReferences != "" {
txtMsg = simpliedReferences
}
sentSource = true
}
txtMsg += string(chunk)

chunkSeq += 1
msg := util.MustToJSON(util.MapStr{
"session_id": sessionID,
"message_id": messageID,
"message_type": MessageTypeAssistant,
"reply_to_message": requestMessageID,
"chunk_sequence": chunkSeq,
"message_chunk": string(chunk),
"message_chunk": txtMsg,
})
messageBuffer.Write(chunk)
websocket.SendPrivateMessage(webSocketID, msg)
Expand Down
35 changes: 33 additions & 2 deletions modules/common/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ type Document struct {

}

func (document *Document) GetAllCategories() string {
// Initialize a slice to hold all category strings
var allCategories []string

// Add the primary category if available
if document.Category != "" {
allCategories = append(allCategories, document.Category)
}

// Add the subcategory if available
if document.Subcategory != "" {
allCategories = append(allCategories, document.Subcategory)
}

// Add all categories if available
if len(document.Categories) > 0 {
allCategories = append(allCategories, document.Categories...)
}

// Add rich category labels if available (only the text)
if len(document.RichCategories) > 0 {
for _, richCategory := range document.RichCategories {
// Assuming RichLabel has a `Label` field to hold the category text
allCategories = append(allCategories, richCategory.Label)
}
}

// Join all the categories with a comma
return strings.Join(allCategories, ", ")
}

func (document *Document) Cleanup() {
document.TrimLastDuplicatedCategory()
}
Expand All @@ -77,7 +108,7 @@ type EditorInfo struct {

// UserInfo represents information about a user in relation to document edits or ownership.
type UserInfo struct {
UserAvatar string `json:"avatar,omitempty" elastic_mapping:"avatar:{enabled:false}"` // Username of the user
UserName string `json:"username,omitempty" elastic_mapping:"username:{type:keyword,copy_to:combined_fulltext}"` // Username of the user
UserAvatar string `json:"avatar,omitempty" elastic_mapping:"avatar:{enabled:false}"` // Login of the user
UserName string `json:"username,omitempty" elastic_mapping:"username:{type:keyword,copy_to:combined_fulltext}"` // Login of the user
UserID string `json:"userid,omitempty" elastic_mapping:"userid:{type:keyword,copy_to:combined_fulltext}"` // Unique identifier for the user
}

0 comments on commit f566781

Please sign in to comment.