Skip to content

Commit

Permalink
Remove user customized search model (#946)
Browse files Browse the repository at this point in the history
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. 
- We may want to add server-level customization for specific tasks
- Store the search model used to generate a given entry on the `Entry` object
- Remove user-facing APIs and view
- Add a management command for migrating the default search model on the server

In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
  • Loading branch information
sabaimran authored Oct 24, 2024
1 parent f3ce47b commit 5120597
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 90 deletions.
23 changes: 1 addition & 22 deletions src/interface/web/app/settings/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ export default function SettingsView() {
};

const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") {
if (!userConfig?.is_active) {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
Expand Down Expand Up @@ -1233,27 +1233,6 @@ export default function SettingsView() {
</CardFooter>
</Card>
)}
{userConfig.search_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the search model to find your documents
</p>
<DropdownComponent
items={userConfig.search_model_options}
selected={
userConfig.selected_search_model_config
}
callbackFunc={updateModel("search")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
</Card>
)}
{userConfig.paint_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
Expand Down
33 changes: 13 additions & 20 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config


def get_user_search_model_or_default(user=None):
if user and UserSearchModelConfig.objects.filter(user=user).exists():
return UserSearchModelConfig.objects.filter(user=user).first().setting
def get_default_search_model() -> SearchModelConfig:
default_search_model = SearchModelConfig.objects.filter(name="default").first()

if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
if default_search_model:
return default_search_model
else:
SearchModelConfig.objects.create()

return SearchModelConfig.objects.first()


def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
if user:
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
if user_search_model:
return user_search_model.setting

return get_default_search_model()


def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
Expand All @@ -487,21 +495,6 @@ def get_or_create_search_models():
return search_models


async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config


async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting


class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
Expand Down
2 changes: 2 additions & 0 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
"user",
"agent",
"file_source",
"file_type",
"file_name",
Expand All @@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
list_filter = (
"file_type",
"user__email",
"search_model__name",
)
ordering = ("-created_at",)

Expand Down
182 changes: 182 additions & 0 deletions src/khoj/database/management/commands/change_default_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import logging
from typing import List

from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from tqdm import tqdm

from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
Entry,
KhojUser,
SearchModelConfig,
UserSearchModelConfig,
)
from khoj.processor.embeddings import EmbeddingsModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model."

def add_arguments(self, parser):
# Pass default SearchModelConfig ID
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
required=True,
)

# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
)

def handle(self, *args, **options):
@transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all()
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)

except Exception as e:
logger.error(f"Error embedding documents: {e}")
return

for i, entry in enumerate(tqdm(entries)):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)

Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])

search_model_config_id = options.get("search_model_id")
apply = options.get("apply")

logger.info(f"SearchModelConfig ID: {search_model_config_id}")
logger.info(f"Apply: {apply}")

embeddings_model = dict()

search_models = SearchModelConfig.objects.all()
for model in search_models:
embeddings_model.update(
{
model.name: EmbeddingsModel(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)

new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
setting_id=search_model_config_id
).all()
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")

for user_config in user_search_model_configs_to_update:
affected_user = user_config.user
entry_filter = Q(user=affected_user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
user_config.setting = new_default_search_model_config
user_config.save()

logger.info(
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
)

except Exception as e:
logger.error(f"Error embedding documents: {e}")

logger.info("----")

# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
current_default = get_default_search_model()
if current_default.id != new_default_search_model_config.id:
users_without_user_search_model_config = KhojUser.objects.annotate(
user_search_model_config_count=Count("usersearchmodelconfig")
).filter(user_search_model_config_count=0)

logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
for user in users_without_user_search_model_config:
entry_filter = Q(user=user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)

UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)

logger.info(
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
else:
logger.info("Default is the same as search_model_config_id.")

all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
for agent in all_agents:
entry_filter = Q(agent=agent)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"
current_default.save()

# Update the new default SearchModelConfig object's name
new_default_search_model_config.name = "default"
new_default_search_model_config.save()
if not apply:
logger.info("Run the command with the --apply flag to apply the new default Search model.")
24 changes: 24 additions & 0 deletions src/khoj/database/migrations/0072_entry_search_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 5.0.8 on 2024-10-21 21:09

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0071_subscription_enabled_trial_at_and_more"),
]

operations = [
migrations.AddField(
model_name="entry",
name="search_model",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="database.searchmodelconfig",
),
),
]
2 changes: 2 additions & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)


# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
Expand Down Expand Up @@ -535,6 +536,7 @@ class EntrySource(models.TextChoices):
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)

def save(self, *args, **kwargs):
if self.user and self.agent:
Expand Down
6 changes: 4 additions & 2 deletions src/khoj/processor/content/text_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_user_search_model_or_default,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
Expand Down Expand Up @@ -148,10 +149,10 @@ def update_embeddings(
hashes_to_process |= hashes_for_file - existing_entry_hashes

embeddings = []
model = get_user_default_search_model(user=user)
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)

added_entries: list[DbEntry] = []
Expand All @@ -177,6 +178,7 @@ def update_embeddings(
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
search_model=model,
)
)
try:
Expand Down
5 changes: 3 additions & 2 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
get_user_photo,
get_user_search_model_or_default,
)
from khoj.database.models import (
Agent,
Expand Down Expand Up @@ -149,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_search_model_or_default)(user)
search_model = await sync_to_async(get_user_default_search_model)(user)
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand Down
Loading

0 comments on commit 5120597

Please sign in to comment.