-
-
Notifications
You must be signed in to change notification settings - Fork 664
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove user customized search model (#946)
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. - We may want to add server-level customization for specific tasks - Store the search model used to generate a given entry on the `Entry` object - Remove user-facing APIs and view - Add a management command for migrating the default search model on the server In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
- Loading branch information
Showing
11 changed files
with
237 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
182 changes: 182 additions & 0 deletions
182
src/khoj/database/management/commands/change_default_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
import logging | ||
from typing import List | ||
|
||
from django.core.management.base import BaseCommand | ||
from django.db import transaction | ||
from django.db.models import Count, Q | ||
from tqdm import tqdm | ||
|
||
from khoj.database.adapters import get_default_search_model | ||
from khoj.database.models import ( | ||
Agent, | ||
Entry, | ||
KhojUser, | ||
SearchModelConfig, | ||
UserSearchModelConfig, | ||
) | ||
from khoj.processor.embeddings import EmbeddingsModel | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Command(BaseCommand): | ||
help = "Convert all existing Entry objects to use a new default Search model." | ||
|
||
def add_arguments(self, parser): | ||
# Pass default SearchModelConfig ID | ||
parser.add_argument( | ||
"--search_model_id", | ||
action="store", | ||
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.", | ||
required=True, | ||
) | ||
|
||
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. | ||
parser.add_argument( | ||
"--apply", | ||
action="store_true", | ||
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.", | ||
) | ||
|
||
def handle(self, *args, **options): | ||
@transaction.atomic | ||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): | ||
entries = Entry.objects.filter(entry_filter).all() | ||
compiled_entries = [entry.compiled for entry in entries] | ||
updated_entries: List[Entry] = [] | ||
try: | ||
embeddings = embeddings_model.embed_documents(compiled_entries) | ||
|
||
except Exception as e: | ||
logger.error(f"Error embedding documents: {e}") | ||
return | ||
|
||
for i, entry in enumerate(tqdm(entries)): | ||
entry.embeddings = embeddings[i] | ||
entry.search_model_id = search_model.id | ||
updated_entries.append(entry) | ||
|
||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) | ||
|
||
search_model_config_id = options.get("search_model_id") | ||
apply = options.get("apply") | ||
|
||
logger.info(f"SearchModelConfig ID: {search_model_config_id}") | ||
logger.info(f"Apply: {apply}") | ||
|
||
embeddings_model = dict() | ||
|
||
search_models = SearchModelConfig.objects.all() | ||
for model in search_models: | ||
embeddings_model.update( | ||
{ | ||
model.name: EmbeddingsModel( | ||
model.bi_encoder, | ||
model.embeddings_inference_endpoint, | ||
model.embeddings_inference_endpoint_api_key, | ||
query_encode_kwargs=model.bi_encoder_query_encode_config, | ||
docs_encode_kwargs=model.bi_encoder_docs_encode_config, | ||
model_kwargs=model.bi_encoder_model_config, | ||
) | ||
} | ||
) | ||
|
||
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id) | ||
logger.info(f"New default Search model: {new_default_search_model_config}") | ||
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude( | ||
setting_id=search_model_config_id | ||
).all() | ||
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}") | ||
|
||
for user_config in user_search_model_configs_to_update: | ||
affected_user = user_config.user | ||
entry_filter = Q(user=affected_user) | ||
relevant_entries = Entry.objects.filter(entry_filter).all() | ||
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}") | ||
|
||
if apply: | ||
try: | ||
regenerate_entries( | ||
entry_filter, | ||
embeddings_model[new_default_search_model_config.name], | ||
new_default_search_model_config, | ||
) | ||
user_config.setting = new_default_search_model_config | ||
user_config.save() | ||
|
||
logger.info( | ||
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model." | ||
) | ||
logger.info( | ||
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model." | ||
) | ||
|
||
except Exception as e: | ||
logger.error(f"Error embedding documents: {e}") | ||
|
||
logger.info("----") | ||
|
||
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id. | ||
current_default = get_default_search_model() | ||
if current_default.id != new_default_search_model_config.id: | ||
users_without_user_search_model_config = KhojUser.objects.annotate( | ||
user_search_model_config_count=Count("usersearchmodelconfig") | ||
).filter(user_search_model_config_count=0) | ||
|
||
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}") | ||
for user in users_without_user_search_model_config: | ||
entry_filter = Q(user=user) | ||
relevant_entries = Entry.objects.filter(entry_filter).all() | ||
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}") | ||
|
||
if apply: | ||
try: | ||
regenerate_entries( | ||
entry_filter, | ||
embeddings_model[new_default_search_model_config.name], | ||
new_default_search_model_config, | ||
) | ||
|
||
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config) | ||
|
||
logger.info( | ||
f"Created UserSearchModelConfig object for user {user} to use the new default Search model." | ||
) | ||
logger.info( | ||
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model." | ||
) | ||
except Exception as e: | ||
logger.error(f"Error embedding documents: {e}") | ||
else: | ||
logger.info("Default is the same as search_model_config_id.") | ||
|
||
all_agents = Agent.objects.all() | ||
logger.info(f"Number of Agent objects to update: {all_agents.count()}") | ||
for agent in all_agents: | ||
entry_filter = Q(agent=agent) | ||
relevant_entries = Entry.objects.filter(entry_filter).all() | ||
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}") | ||
|
||
if apply: | ||
try: | ||
regenerate_entries( | ||
entry_filter, | ||
embeddings_model[new_default_search_model_config.name], | ||
new_default_search_model_config, | ||
) | ||
logger.info( | ||
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model." | ||
) | ||
except Exception as e: | ||
logger.error(f"Error embedding documents: {e}") | ||
if apply and current_default.id != new_default_search_model_config.id: | ||
# Get the existing default SearchModelConfig object and update its name | ||
current_default.name = f"prev_default_{current_default.id}" | ||
current_default.save() | ||
|
||
# Update the new default SearchModelConfig object's name | ||
new_default_search_model_config.name = "default" | ||
new_default_search_model_config.save() | ||
if not apply: | ||
logger.info("Run the command with the --apply flag to apply the new default Search model.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Generated by Django 5.0.8 on 2024-10-21 21:09 | ||
|
||
import django.db.models.deletion | ||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
dependencies = [ | ||
("database", "0071_subscription_enabled_trial_at_and_more"), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name="entry", | ||
name="search_model", | ||
field=models.ForeignKey( | ||
blank=True, | ||
default=None, | ||
null=True, | ||
on_delete=django.db.models.deletion.SET_NULL, | ||
to="database.searchmodelconfig", | ||
), | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.