Skip to content

Commit

Permalink
Deprecate the UserSearchModelConfig and remove all references
Browse files Browse the repository at this point in the history
- The server has moved to a model of standardization for the embeddings generation workflow. Remove references to the support for differentiated models.
- The migration script fo ra new model needs to be updated to accommodate full regeneration.
  • Loading branch information
sabaimran committed Nov 4, 2024
1 parent 99c1d28 commit 1e89bac
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 100 deletions.
10 changes: 0 additions & 10 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
Expand Down Expand Up @@ -481,15 +480,6 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()


def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
if user:
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
if user_search_model:
return user_search_model.setting

return get_default_search_model()


def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
Expand Down
2 changes: 0 additions & 2 deletions src/khoj/database/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
WebScraper,
Expand Down Expand Up @@ -99,7 +98,6 @@ def get_email_login_url(self, request, queryset):
admin.site.register(ProcessLock)
admin.site.register(SpeechToTextModelOptions)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)
Expand Down
79 changes: 7 additions & 72 deletions src/khoj/database/management/commands/change_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,7 @@
from tqdm import tqdm

from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
Entry,
KhojUser,
SearchModelConfig,
UserSearchModelConfig,
)
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel

logging.basicConfig(level=logging.INFO)
Expand All @@ -30,15 +24,15 @@ def add_arguments(self, parser):
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects.",
required=True,
)

# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
# Set the apply flag to apply the new default Search model to all existing Entry objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
help="Apply the new default Search model to all existing Entry objects. Otherwise, only display the number of Entry objects that will be affected.",
)

def handle(self, *args, **options):
Expand Down Expand Up @@ -88,72 +82,12 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc

new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
setting_id=search_model_config_id
).all()
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")

for user_config in user_search_model_configs_to_update:
affected_user = user_config.user
entry_filter = Q(user=affected_user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
user_config.setting = new_default_search_model_config
user_config.save()

logger.info(
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
)

except Exception as e:
logger.error(f"Error embedding documents: {e}")

logger.info("----")

# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
current_default = get_default_search_model()
if current_default.id != new_default_search_model_config.id:
users_without_user_search_model_config = KhojUser.objects.annotate(
user_search_model_config_count=Count("usersearchmodelconfig")
).filter(user_search_model_config_count=0)

logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
for user in users_without_user_search_model_config:
entry_filter = Q(user=user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")

if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)

UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)

logger.info(
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
else:
logger.info("Default is the same as search_model_config_id.")

# TODO: Migrate all Entry objects to use the new default Search model

all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
Expand All @@ -174,6 +108,7 @@ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, searc
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")

if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"
Expand Down
15 changes: 15 additions & 0 deletions src/khoj/database/migrations/0073_delete_usersearchmodelconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Generated by Django 5.0.9 on 2024-11-04 19:56

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("database", "0072_entry_search_model"),
]

operations = [
migrations.DeleteModel(
name="UserSearchModelConfig",
),
]
6 changes: 0 additions & 6 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,6 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)


# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)


class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
Expand Down
3 changes: 1 addition & 2 deletions src/khoj/processor/content/text_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
EntryAdapters,
FileObjectAdapters,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
Expand Down Expand Up @@ -149,7 +148,7 @@ def update_embeddings(
hashes_to_process |= hashes_for_file - existing_entry_hashes

embeddings = []
model = get_user_default_search_model(user=user)
model = get_default_search_model()
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
Expand Down
3 changes: 1 addition & 2 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
get_user_photo,
)
from khoj.database.models import (
Expand Down Expand Up @@ -151,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_default_search_model)(user)
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)

with concurrent.futures.ThreadPoolExecutor() as executor:
Expand Down
8 changes: 2 additions & 6 deletions src/khoj/search_type/text_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
from asgiref.sync import sync_to_async
from sentence_transformers import util

from khoj.database.adapters import (
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
Expand Down Expand Up @@ -114,7 +110,7 @@ async def query(
file_type = search_type_to_embeddings_type[type.value]

query = raw_query
search_model = await sync_to_async(get_user_default_search_model)(user)
search_model = await sync_to_async(get_default_search_model)()
if not max_distance:
if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold
Expand Down

0 comments on commit 1e89bac

Please sign in to comment.