diff --git a/docker-compose.yml b/docker-compose.yml index 3c5fe9f4d..3333c0660 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,6 +14,10 @@ services: interval: 30s timeout: 10s retries: 5 + sandbox: + image: ghcr.io/khoj-ai/terrarium:latest + ports: + - "8080:8080" server: depends_on: database: diff --git a/documentation/docs/features/chat.md b/documentation/docs/features/chat.md index 86c720984..a438bf5ff 100644 --- a/documentation/docs/features/chat.md +++ b/documentation/docs/features/chat.md @@ -43,3 +43,6 @@ Slash commands allows you to change what Khoj uses to respond to your query - **/image**: Generate an image in response to your query. - **/help**: Use /help to get all available commands and general information about Khoj - **/summarize**: Can be used to summarize 1 selected file filter for that conversation. Refer to [File Summarization](summarization) for details. +- **/diagram**: Generate a diagram in response to your query. This is built on [Excalidraw](https://excalidraw.com/). +- **/code**: Generate and run very simple Python code snippets. Refer to [Code Generation](code_generation) for details. +- **/research**: Go deeper in a topic for more accurate, in-depth responses. diff --git a/documentation/docs/features/code_execution.md b/documentation/docs/features/code_execution.md new file mode 100644 index 000000000..8403d466e --- /dev/null +++ b/documentation/docs/features/code_execution.md @@ -0,0 +1,30 @@ +--- +--- + +# Code Execution + +Khoj can generate and run very simple Python code snippets as well. This is useful if you want to generate a plot, run a simple calculation, or do some basic data manipulation. LLMs by default aren't skilled at complex quantitative tasks. Code generation & execution can come in handy for such tasks. + +Just use `/code` in your chat command. + +### Setup (Self-Hosting) +Run [Cohere's Terrarium](https://github.com/cohere-ai/cohere-terrarium) on your machine to enable code generation and execution. + +Check the [instructions](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file#development) for running from source. + +For running with Docker, you can use our [docker-compose.yml](https://github.com/khoj-ai/khoj/blob/master/docker-compose.yml), or start it manually like this: + +```bash +docker pull ghcr.io/khoj-ai/terrarium:latest +docker run -d -p 8080:8080 ghcr.io/khoj-ai/terrarium:latest +``` + +#### Verify +Verify that it's running, by evaluating a simple Python expression: + +```bash +curl -X POST -H "Content-Type: application/json" \ +--url http://localhost:8080 \ +--data-raw '{"code": "1 + 1"}' \ +--no-buffer +``` diff --git a/pyproject.toml b/pyproject.toml index 12c7789c5..ed57f55ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ dependencies = [ "django_apscheduler == 0.6.2", "anthropic == 0.26.1", "docx2txt == 0.8", - "google-generativeai == 0.7.2" + "google-generativeai == 0.8.3" ] dynamic = ["version"] diff --git a/src/interface/web/app/chat/page.tsx b/src/interface/web/app/chat/page.tsx index deb261058..c3d5ff37b 100644 --- a/src/interface/web/app/chat/page.tsx +++ b/src/interface/web/app/chat/page.tsx @@ -29,6 +29,7 @@ interface ChatBodyDataProps { onConversationIdChange?: (conversationId: string) => void; setQueryToProcess: (query: string) => void; streamedMessages: StreamMessage[]; + setStreamedMessages: (messages: StreamMessage[]) => void; setUploadedFiles: (files: string[]) => void; isMobileWidth?: boolean; isLoggedIn: boolean; @@ -118,6 +119,7 @@ function ChatBodyData(props: ChatBodyDataProps) { setAgent={setAgentMetadata} pendingMessage={processingMessage ? message : ""} incomingMessages={props.streamedMessages} + setIncomingMessages={props.setStreamedMessages} customClassName={chatHistoryCustomClassName} /> @@ -351,6 +353,7 @@ export default function Chat() { void; - incomingMessages?: StreamMessage[]; pendingMessage?: string; + incomingMessages?: StreamMessage[]; + setIncomingMessages?: (incomingMessages: StreamMessage[]) => void; publicConversationSlug?: string; setAgent: (agent: AgentData) => void; customClassName?: string; @@ -45,7 +46,7 @@ interface TrainOfThoughtComponentProps { trainOfThought: string[]; lastMessage: boolean; agentColor: string; - key: string; + keyId: string; completed?: boolean; } @@ -56,7 +57,7 @@ function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) { return (
{!props.completed && } {props.completed && @@ -97,6 +98,7 @@ export default function ChatHistory(props: ChatHistoryProps) { const [data, setData] = useState(null); const [currentPage, setCurrentPage] = useState(0); const [hasMoreMessages, setHasMoreMessages] = useState(true); + const [currentTurnId, setCurrentTurnId] = useState(null); const sentinelRef = useRef(null); const scrollAreaRef = useRef(null); const latestUserMessageRef = useRef(null); @@ -177,6 +179,10 @@ export default function ChatHistory(props: ChatHistoryProps) { if (lastMessage && !lastMessage.completed) { setIncompleteIncomingMessageIndex(props.incomingMessages.length - 1); props.setTitle(lastMessage.rawQuery); + // Store the turnId when we get it + if (lastMessage.turnId) { + setCurrentTurnId(lastMessage.turnId); + } } } }, [props.incomingMessages]); @@ -278,6 +284,25 @@ export default function ChatHistory(props: ChatHistoryProps) { return data.agent?.persona; } + const handleDeleteMessage = (turnId?: string) => { + if (!turnId) return; + + setData((prevData) => { + if (!prevData || !turnId) return prevData; + return { + ...prevData, + chat: prevData.chat.filter((msg) => msg.turnId !== turnId), + }; + }); + + // Update incoming messages if they exist + if (props.incomingMessages && props.setIncomingMessages) { + props.setIncomingMessages( + props.incomingMessages.filter((msg) => msg.turnId !== turnId), + ); + } + }; + if (!props.conversationId && !props.publicConversationSlug) { return null; } @@ -293,6 +318,18 @@ export default function ChatHistory(props: ChatHistoryProps) { data.chat && data.chat.map((chatMessage, index) => ( <> + {chatMessage.trainOfThought && chatMessage.by === "khoj" && ( + train.data, + )} + lastMessage={false} + agentColor={data?.agent?.color || "orange"} + key={`${index}trainOfThought`} + keyId={`${index}trainOfThought`} + completed={true} + /> + )} - {chatMessage.trainOfThought && chatMessage.by === "khoj" && ( - train.data, - )} - lastMessage={false} - agentColor={data?.agent?.color || "orange"} - key={`${index}trainOfThought`} - completed={true} - /> - )} ))} {props.incomingMessages && props.incomingMessages.map((message, index) => { + const messageTurnId = message.turnId ?? currentTurnId ?? undefined; return ( {message.trainOfThought && ( )} @@ -373,7 +408,12 @@ export default function ChatHistory(props: ChatHistoryProps) { "memory-type": "", "inferred-queries": message.inferredQueries || [], }, + conversationId: props.conversationId, + turnId: messageTurnId, }} + conversationId={props.conversationId} + turnId={messageTurnId} + onDeleteMessage={handleDeleteMessage} customClassName="fullHistory" borderLeftColor={`${data?.agent?.color}-500`} isLastMessage={true} @@ -393,7 +433,11 @@ export default function ChatHistory(props: ChatHistoryProps) { created: new Date().getTime().toString(), by: "you", automationId: "", + conversationId: props.conversationId, + turnId: undefined, }} + conversationId={props.conversationId} + onDeleteMessage={handleDeleteMessage} customClassName="fullHistory" borderLeftColor={`${data?.agent?.color}-500`} isLastMessage={true} diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index b60228df4..74e155235 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -149,7 +149,7 @@ export const ChatInputArea = forwardRef((pr } let messageToSend = message.trim(); - if (useResearchMode) { + if (useResearchMode && !messageToSend.startsWith("/research")) { messageToSend = `/research ${messageToSend}`; } @@ -398,7 +398,7 @@ export const ChatInputArea = forwardRef((pr e.preventDefault()} className={`${props.isMobileWidth ? "w-[100vw]" : "w-full"} rounded-md`} - side="top" + side="bottom" align="center" /* Offset below text area on home page (i.e where conversationId is unset) */ sideOffset={props.conversationId ? 0 : 80} @@ -590,8 +590,8 @@ export const ChatInputArea = forwardRef((pr - Research Mode allows you to get more deeply researched, detailed - responses. Response times may be longer. + (Experimental) Research Mode allows you to get more deeply researched, + detailed responses. Response times may be longer. diff --git a/src/interface/web/app/components/chatMessage/chatMessage.tsx b/src/interface/web/app/components/chatMessage/chatMessage.tsx index 6d82fe1d9..d05d98291 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.tsx +++ b/src/interface/web/app/components/chatMessage/chatMessage.tsx @@ -29,6 +29,7 @@ import { Check, Code, Shapes, + Trash, } from "@phosphor-icons/react"; import DOMPurify from "dompurify"; @@ -146,6 +147,8 @@ export interface SingleChatMessage { intent?: Intent; agent?: AgentData; images?: string[]; + conversationId: string; + turnId?: string; } export interface StreamMessage { @@ -161,6 +164,7 @@ export interface StreamMessage { images?: string[]; intentType?: string; inferredQueries?: string[]; + turnId?: string; } export interface ChatHistoryData { @@ -242,6 +246,9 @@ interface ChatMessageProps { borderLeftColor?: string; isLastMessage?: boolean; agent?: AgentData; + onDeleteMessage: (turnId?: string) => void; + conversationId: string; + turnId?: string; } interface TrainOfThoughtProps { @@ -654,6 +661,27 @@ const ChatMessage = forwardRef((props, ref) => }); } + const deleteMessage = async (message: SingleChatMessage) => { + const turnId = message.turnId || props.turnId; + const response = await fetch("/api/chat/conversation/message", { + method: "DELETE", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + conversation_id: props.conversationId, + turn_id: turnId, + }), + }); + + if (response.ok) { + // Update the UI after successful deletion + props.onDeleteMessage(turnId); + } else { + console.error("Failed to delete message"); + } + }; + const allReferences = constructAllReferences( props.chatMessage.context, props.chatMessage.onlineContext, @@ -716,6 +744,18 @@ const ChatMessage = forwardRef((props, ref) => /> ))} + {props.chatMessage.turnId && ( + + )}
); @@ -626,7 +630,11 @@ export default function FactChecker() { created: new Date().toISOString(), onlineContext: {}, codeContext: {}, + conversationId: conversationID, + turnId: "", }} + conversationId={conversationID} + onDeleteMessage={(turnId?: string) => {}} isMobileWidth={isMobileWidth} /> diff --git a/src/interface/web/tailwind.config.ts b/src/interface/web/tailwind.config.ts index cea53e8ed..32e66d3b3 100644 --- a/src/interface/web/tailwind.config.ts +++ b/src/interface/web/tailwind.config.ts @@ -32,6 +32,11 @@ const config = { /ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/, variants: ["focus-visible", "dark"], }, + { + pattern: + /caret-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/, + variants: ["focus", "dark"], + }, ], darkMode: ["class"], content: [ diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 1454b1645..df0760fbb 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -262,7 +262,7 @@ def configure_server( initialize_content(regenerate, search_type, user) except Exception as e: - raise e + logger.error(f"Failed to load some search models: {e}", exc_info=True) def setup_default_agent(user: KhojUser): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 5db857eeb..afcaa9f04 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -476,9 +476,8 @@ def get_default_search_model() -> SearchModelConfig: if default_search_model: return default_search_model - else: + elif SearchModelConfig.objects.count() == 0: SearchModelConfig.objects.create() - return SearchModelConfig.objects.first() @@ -1319,6 +1318,8 @@ async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_conf def add_files_to_filter(user: KhojUser, conversation_id: str, files: List[str]): conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) file_list = EntryAdapters.get_all_filenames_by_source(user, "computer") + if not conversation: + return [] for filename in files: if filename in file_list and filename not in conversation.file_filters: conversation.file_filters.append(filename) @@ -1332,6 +1333,8 @@ def add_files_to_filter(user: KhojUser, conversation_id: str, files: List[str]): @staticmethod def remove_files_from_filter(user: KhojUser, conversation_id: str, files: List[str]): conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) + if not conversation: + return [] for filename in files: if filename in conversation.file_filters: conversation.file_filters.remove(filename) @@ -1343,6 +1346,17 @@ def remove_files_from_filter(user: KhojUser, conversation_id: str, files: List[s conversation.save() return conversation.file_filters + @staticmethod + def delete_message_by_turn_id(user: KhojUser, conversation_id: str, turn_id: str): + conversation = ConversationAdapters.get_conversation_by_user(user, conversation_id=conversation_id) + if not conversation or not conversation.conversation_log or not conversation.conversation_log.get("chat"): + return False + conversation_log = conversation.conversation_log + updated_log = [msg for msg in conversation_log["chat"] if msg.get("turnId") != turn_id] + conversation.conversation_log["chat"] = updated_log + conversation.save() + return True + class FileObjectAdapters: @staticmethod diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index b3e1523c1..aaaaa081a 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,5 +1,6 @@ import json import logging +import os from datetime import datetime, timedelta from threading import Thread from typing import Any, Iterator, List, Optional, Union @@ -263,8 +264,14 @@ def send_message_to_model_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] + seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None response = offline_chat_model.create_chat_completion( - messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} + messages_dict, + stop=stop, + stream=streaming, + temperature=temperature, + response_format={"type": response_type}, + seed=seed, ) if streaming: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 6e519f5ad..36ebc679d 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -1,4 +1,5 @@ import logging +import os from threading import Thread from typing import Dict @@ -60,6 +61,9 @@ def completion_with_backoff( model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) + if os.getenv("KHOJ_LLM_SEED"): + model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + chat = client.chat.completions.create( stream=stream, messages=formatted_messages, # type: ignore @@ -157,6 +161,9 @@ def llm_thread( model_kwargs.pop("stop", None) model_kwargs.pop("response_format", None) + if os.getenv("KHOJ_LLM_SEED"): + model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED")) + chat = client.chat.completions.create( stream=stream, messages=formatted_messages, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 1a02dca6b..864b864e1 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -625,25 +625,25 @@ {personality_context} # Instructions -- Ask detailed queries to the tool AIs provided below, one at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration. -- Break down your research process into independent, self-contained steps that can be executed sequentially to answer the user's query. Write your step-by-step plan in the scratchpad. -- Ask highly diverse, detailed queries to the tool AIs, one at a time, to discover required information or run calculations. -- NEVER repeat the same query across iterations. -- Ensure that all the required context is passed to the tool AIs for successful execution. -- Ensure that you go deeper when possible and try more broad, creative strategies when a path is not yielding useful results. Build on the results of the previous iterations. +- Ask highly diverse, detailed queries to the tool AIs, one tool AI at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration. +- Break down your research process into independent, self-contained steps that can be executed sequentially using the available tool AIs to answer the user's query. Write your step-by-step plan in the scratchpad. +- Always ask a new query that was not asked to the tool AI in a previous iteration. Build on the results of the previous iterations. +- Ensure that all required context is passed to the tool AIs for successful execution. They only know the context provided in your query. +- Think step by step to come up with creative strategies when the previous iteration did not yield useful results. - You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question. - Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}} # Examples Assuming you can search the user's notes and the internet. -- When they ask for the population of their hometown +- When the user asks for the population of their hometown 1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc. 2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc. 3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI. -- When user for their computer's specs +- When the user asks for their computer's specs 1. Try find their computer model in their notes. - 2. Now find webpages with their computer model's spec online and read them. -- When I ask what clothes to carry for their upcoming trip + 2. Now find webpages with their computer model's spec online. + 3. Ask the the webpage tool AI to extract the required information from the relevant webpages. +- When the user asks what clothes to carry for their upcoming trip 1. Find the itinerary of their upcoming trip in their notes. 2. Next find the weather forecast at the destination online. 3. Then find if they mentioned what clothes they own in their notes. @@ -666,7 +666,7 @@ Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else. Response format: -{{"scratchpad": "", "tool": "", "query": ""}} +{{"scratchpad": "", "query": "", "tool": ""}} """.strip() ) @@ -798,8 +798,8 @@ online_search_conversation_subqueries = PromptTemplate.from_template( """ You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question. -- You will receive the conversation history as context. -- Add as much context from the previous questions and answers as required into your search queries. +- You will receive the actual chat history as context. +- Add as much context from the chat history as required into your search queries. - Break messages into multiple search queries when required to retrieve the relevant information. - Use site: google search operator when appropriate - You have access to the the whole internet to retrieve information. @@ -812,58 +812,56 @@ {username} Here are some examples: -History: +Example Chat History: User: I like to use Hacker News to get my tech news. +Khoj: {{queries: ["what is Hacker News?", "Hacker News website for tech news"]}} AI: Hacker News is an online forum for sharing and discussing the latest tech news. It is a great place to learn about new technologies and startups. -Q: Summarize the top posts on HackerNews +User: Summarize the top posts on HackerNews Khoj: {{"queries": ["top posts on HackerNews"]}} -History: - -Q: Tell me the latest news about the farmers protest in Colombia and China on Reuters +Example Chat History: +User: Tell me the latest news about the farmers protest in Colombia and China on Reuters Khoj: {{"queries": ["site:reuters.com farmers protest Colombia", "site:reuters.com farmers protest China"]}} -History: +Example Chat History: User: I'm currently living in New York but I'm thinking about moving to San Francisco. +Khoj: {{"queries": ["New York city vs San Francisco life", "San Francisco living cost", "New York city living cost"]}} AI: New York is a great city to live in. It has a lot of great restaurants and museums. San Francisco is also a great city to live in. It has good access to nature and a great tech scene. -Q: What is the climate like in those cities? -Khoj: {{"queries": ["climate in new york city", "climate in san francisco"]}} +User: What is the climate like in those cities? +Khoj: {{"queries": ["climate in New York city", "climate in San Francisco"]}} -History: -AI: Hey, how is it going? -User: Going well. Ananya is in town tonight! +Example Chat History: +User: Hey, Ananya is in town tonight! +Khoj: {{"queries": ["events in {location} tonight", "best restaurants in {location}", "places to visit in {location}"]}} AI: Oh that's awesome! What are your plans for the evening? -Q: She wants to see a movie. Any decent sci-fi movies playing at the local theater? +User: She wants to see a movie. Any decent sci-fi movies playing at the local theater? Khoj: {{"queries": ["new sci-fi movies in theaters near {location}"]}} -History: +Example Chat History: User: Can I chat with you over WhatsApp? -AI: Yes, you can chat with me using WhatsApp. - -Q: How Khoj: {{"queries": ["site:khoj.dev chat with Khoj on Whatsapp"]}} +AI: Yes, you can chat with me using WhatsApp. -History: - - -Q: How do I share my files with you? +Example Chat History: +User: How do I share my files with Khoj? Khoj: {{"queries": ["site:khoj.dev sync files with Khoj"]}} -History: +Example Chat History: User: I need to transport a lot of oranges to the moon. Are there any rockets that can fit a lot of oranges? +Khoj: {{"queries": ["current rockets with large cargo capacity", "rocket rideshare cost by cargo capacity"]}} AI: NASA's Saturn V rocket frequently makes lunar trips and has a large cargo capacity. -Q: How many oranges would fit in NASA's Saturn V rocket? -Khoj: {{"queries": ["volume of an orange", "volume of saturn v rocket"]}} +User: How many oranges would fit in NASA's Saturn V rocket? +Khoj: {{"queries": ["volume of an orange", "volume of Saturn V rocket"]}} Now it's your turn to construct Google search queries to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else. -History: +Actual Chat History: {chat_history} -Q: {query} +User: {query} Khoj: """.strip() ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index edef014fe..aae55af54 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,9 +1,11 @@ import base64 +import json import logging import math import mimetypes import os import queue +import uuid from dataclasses import dataclass from datetime import datetime from enum import Enum @@ -134,7 +136,11 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A for chat in conversation_history.get("chat", [])[-n:]: if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['message']}\n" + + if chat["intent"].get("inferred-queries"): + chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' + + chat_history += f"{agent_name}: {chat['message']}\n\n" elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"{agent_name}: [generated image redacted for space]\n" @@ -185,6 +191,7 @@ class ChatEvent(Enum): MESSAGE = "message" REFERENCES = "references" STATUS = "status" + METADATA = "metadata" def message_to_log( @@ -232,12 +239,14 @@ def save_to_conversation_log( train_of_thought: List[Any] = [], ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + turn_id = tracer.get("mid") or str(uuid.uuid4()) updated_conversation = message_to_log( user_message=q, chat_response=chat_response, user_message_metadata={ "created": user_message_time, "images": query_images, + "turnId": turn_id, }, khoj_message_metadata={ "context": compiled_references, @@ -246,6 +255,7 @@ def save_to_conversation_log( "codeContext": code_results, "automationId": automation_id, "trainOfThought": train_of_thought, + "turnId": turn_id, }, conversation_log=meta_log.get("chat", []), train_of_thought=train_of_thought, @@ -501,15 +511,12 @@ def commit_conversation_trace( Returns the path to the repository. """ # Serialize session, system message and response to yaml - system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False) - response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False) + system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False) + response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False) formatted_session = [{"role": message.role, "content": message.content} for message in session] - session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False) + session_yaml = json.dumps(formatted_session, ensure_ascii=False, sort_keys=False) query = ( - yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False) - .strip() - .removeprefix("'") - .removesuffix("'") + json.dumps(session[-1].content, ensure_ascii=False, sort_keys=False).strip().removeprefix("'").removesuffix("'") ) # Extract serialized query from chat session # Extract chat metadata for session diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index a19d85fad..b224e7f51 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts, timer +from khoj.utils.helpers import fix_json_dict, get_device, merge_dicts, timer from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -31,9 +31,9 @@ def __init__( ): default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} - self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) - self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) - self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.query_encode_kwargs = merge_dicts(fix_json_dict(query_encode_kwargs), default_query_encode_kwargs) + self.docs_encode_kwargs = merge_dicts(fix_json_dict(docs_encode_kwargs), default_docs_encode_kwargs) + self.model_kwargs = merge_dicts(fix_json_dict(model_kwargs), {"device": get_device()}) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 329ca2eac..c6fc7c200 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -54,6 +54,7 @@ } DEFAULT_MAX_WEBPAGES_TO_READ = 1 +MAX_WEBPAGES_TO_INFER = 10 async def search_online( @@ -157,13 +158,16 @@ async def read_webpages( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") - if send_status_func: - async for event in send_status_func(f"**Inferring web pages to read**"): - yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location, user, query_images) + urls = await infer_webpage_urls( + query, conversation_history, location, user, query_images, agent=agent, tracer=tracer + ) + + # Get the top 10 web pages to read + urls = urls[:max_webpages_to_read] logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index cc9185be1..a20982ea6 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -31,6 +31,7 @@ from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.run_code import run_code from khoj.routers.api import extract_references_and_questions +from khoj.routers.email import send_query_feedback from khoj.routers.helpers import ( ApiImageRateLimiter, ApiUserRateLimiter, @@ -38,13 +39,14 @@ ChatRequestBody, CommonQueryParams, ConversationCommandRateLimiter, + DeleteMessageRequestBody, + FeedbackData, agenerate_chat_response, aget_relevant_information_sources, aget_relevant_output_modes, construct_automation_created_message, create_automation, extract_relevant_info, - extract_relevant_summary, generate_excalidraw_diagram, generate_summary_from_files, get_conversation_command, @@ -75,16 +77,12 @@ # Initialize Router logger = logging.getLogger(__name__) conversation_command_rate_limiter = ConversationCommandRateLimiter( - trial_rate_limit=100, subscribed_rate_limit=6000, slug="command" + trial_rate_limit=20, subscribed_rate_limit=75, slug="command" ) api_chat = APIRouter() -from pydantic import BaseModel - -from khoj.routers.email import send_query_feedback - @api_chat.get("/conversation/file-filters/{conversation_id}", response_class=Response) @requires(["authenticated"]) @@ -146,12 +144,6 @@ def remove_file_filter(request: Request, filter: FileFilterRequest) -> Response: return Response(content=json.dumps(file_filters), media_type="application/json", status_code=200) -class FeedbackData(BaseModel): - uquery: str - kquery: str - sentiment: str - - @api_chat.post("/feedback") @requires(["authenticated"]) async def sendfeedback(request: Request, data: FeedbackData): @@ -166,10 +158,10 @@ async def text_to_speech( common: CommonQueryParams, text: str, rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") + ApiUserRateLimiter(requests=30, subscribed_requests=30, window=60, slug="chat_minute") ), rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=50, subscribed_requests=300, window=60 * 60 * 24, slug="chat_day") + ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") ), ) -> Response: voice_model = await ConversationAdapters.aget_voice_model_config(request.user.object) @@ -534,6 +526,19 @@ async def set_conversation_title( ) +@api_chat.delete("/conversation/message", response_class=Response) +@requires(["authenticated"]) +def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response: + user = request.user.object + success = ConversationAdapters.delete_message_by_turn_id( + user, delete_request.conversation_id, delete_request.turn_id + ) + if success: + return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200) + else: + return Response(content=json.dumps({"status": "error", "message": "Message not found"}), status_code=404) + + @api_chat.post("") @requires(["authenticated"]) async def chat( @@ -541,10 +546,10 @@ async def chat( common: CommonQueryParams, body: ChatRequestBody, rate_limiter_per_minute=Depends( - ApiUserRateLimiter(requests=60, subscribed_requests=200, window=60, slug="chat_minute") + ApiUserRateLimiter(requests=20, subscribed_requests=20, window=60, slug="chat_minute") ), rate_limiter_per_day=Depends( - ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") + ApiUserRateLimiter(requests=100, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") ), image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)), ): @@ -555,6 +560,7 @@ async def chat( stream = body.stream title = body.title conversation_id = body.conversation_id + turn_id = str(body.turn_id or uuid.uuid4()) city = body.city region = body.region country = body.country or get_country_name_from_timezone(body.timezone) @@ -574,7 +580,7 @@ async def event_generator(q: str, images: list[str]): nonlocal conversation_id tracer: dict = { - "mid": f"{uuid.uuid4()}", + "mid": turn_id, "cid": conversation_id, "uid": user.id, "khoj_version": state.khoj_version, @@ -607,7 +613,7 @@ async def send_event(event_type: ChatEvent, data: str | dict): if event_type == ChatEvent.MESSAGE: yield data - elif event_type == ChatEvent.REFERENCES or stream: + elif event_type == ChatEvent.REFERENCES or ChatEvent.METADATA or stream: yield json.dumps({"type": event_type.value, "data": data}, ensure_ascii=False) except asyncio.CancelledError as e: connection_alive = False @@ -651,6 +657,11 @@ def collect_telemetry(): metadata=chat_metadata, ) + if is_query_empty(q): + async for result in send_llm_response("Please ask your query to get started."): + yield result + return + conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation = await ConversationAdapters.aget_conversation_by_user( @@ -666,6 +677,9 @@ def collect_telemetry(): return conversation_id = conversation.id + async for event in send_event(ChatEvent.METADATA, {"conversationId": str(conversation_id), "turnId": turn_id}): + yield event + agent: Agent | None = None default_agent = await AgentAdapters.aget_default_agent() if conversation.agent and conversation.agent != default_agent: @@ -677,17 +691,11 @@ def collect_telemetry(): agent = default_agent await is_ready_to_chat(user) - user_name = await aget_user_name(user) location = None if city or region or country or country_code: location = LocationData(city=city, region=region, country=country, country_code=country_code) - if is_query_empty(q): - async for result in send_llm_response("Please ask your query to get started."): - yield result - return - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") meta_log = conversation.conversation_log @@ -699,7 +707,6 @@ def collect_telemetry(): ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] - defiltered_query = defilter_query(q) if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( @@ -730,6 +737,12 @@ def collect_telemetry(): if mode not in conversation_commands: conversation_commands.append(mode) + for cmd in conversation_commands: + await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd) + q = q.replace(f"/{cmd.value}", "").strip() + + defiltered_query = defilter_query(q) + if conversation_commands == [ConversationCommand.Research]: async for research_result in execute_information_collection( request=request, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 1cb322b0c..990fa33fa 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -478,6 +478,9 @@ async def infer_webpage_urls( valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)} if is_none_or_empty(valid_unique_urls): raise ValueError(f"Invalid list of urls: {response}") + if len(valid_unique_urls) == 0: + logger.error(f"No valid URLs found in response: {response}") + return [] return list(valid_unique_urls) except Exception: raise ValueError(f"Invalid list of urls: {response}") @@ -1255,6 +1258,7 @@ class ChatRequestBody(BaseModel): stream: Optional[bool] = False title: Optional[str] = None conversation_id: Optional[str] = None + turn_id: Optional[str] = None city: Optional[str] = None region: Optional[str] = None country: Optional[str] = None @@ -1264,6 +1268,17 @@ class ChatRequestBody(BaseModel): create_new: Optional[bool] = False +class DeleteMessageRequestBody(BaseModel): + conversation_id: str + turn_id: str + + +class FeedbackData(BaseModel): + uquery: str + kquery: str + sentiment: str + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests @@ -1366,7 +1381,7 @@ def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str) self.slug = slug self.trial_rate_limit = trial_rate_limit self.subscribed_rate_limit = subscribed_rate_limit - self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] + self.restricted_commands = [ConversationCommand.Research] async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): if state.billing_enabled is False: diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 6001bdc5f..4f9c6b4ec 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -1,12 +1,11 @@ import json import logging from datetime import datetime -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Dict, List, Optional import yaml from fastapi import Request -from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( @@ -191,18 +190,18 @@ async def execute_information_collection( document_results = result[0] this_iteration.context += document_results - if not is_none_or_empty(document_results): - try: - distinct_files = {d["file"] for d in document_results} - distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]) - # Strip only leading # from headings - headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") - async for result in send_status_func( - f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" - ): - yield result - except Exception as e: - logger.error(f"Error extracting document references: {e}", exc_info=True) + if not is_none_or_empty(document_results): + try: + distinct_files = {d["file"] for d in document_results} + distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d]) + # Strip only leading # from headings + headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "") + async for result in send_status_func( + f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}" + ): + yield result + except Exception as e: + logger.error(f"Error extracting document references: {e}", exc_info=True) elif this_iteration.tool == ConversationCommand.Online: async for result in search_online( @@ -306,13 +305,13 @@ async def execute_information_collection( if document_results or online_results or code_results or summarize_files: results_data = f"**Results**:\n" if document_results: - results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if online_results: - results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if code_results: - results_data += f"**Code Results**: {yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" if summarize_files: - results_data += f"**Summarized Files**: {yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" + results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index c98016fa9..173a1a730 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -101,6 +101,15 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict +def fix_json_dict(json_dict: dict) -> dict: + for k, v in json_dict.items(): + if v == "True" or v == "False": + json_dict[k] = v == "True" + if isinstance(v, dict): + json_dict[k] = fix_json_dict(v) + return json_dict + + def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]: "Get file type from file mime type" @@ -359,9 +368,9 @@ class ConversationCommand(str, Enum): function_calling_description_for_llm = { ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.", - ConversationCommand.Online: "To search the internet for information. Provide all relevant context to ensure new searches, not previously run, are performed.", - ConversationCommand.Webpage: "To extract information from a webpage. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage link and information to extract in your query.", - ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", + ConversationCommand.Online: "To search the internet for information. Useful to get a quick, broad overview from the internet. Provide all relevant context to ensure new searches, not in previous iterations, are performed.", + ConversationCommand.Webpage: "To extract information from webpages. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage links and information to extract in your query.", + ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.", } mode_descriptions_for_llm = {