From d31c94bed4a8f3681ab642abef74d594bebe9732 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 26 Jun 2024 09:51:06 +0530 Subject: [PATCH] Add Validation logic to save PaintModel. Use API key from Paint Model Rename Paint Model, Adapters to TextToImage for consistency --- src/khoj/database/adapters/__init__.py | 16 ++++++----- ...exttoimagemodelconfig_api_key_and_more.py} | 17 +++++++++-- src/khoj/database/models/__init__.py | 28 ++++++++++++++++++- src/khoj/routers/api_config.py | 2 +- src/khoj/routers/helpers.py | 14 ++++++++-- src/khoj/routers/web_client.py | 2 +- 6 files changed, 64 insertions(+), 15 deletions(-) rename src/khoj/database/migrations/{0048_texttoimagemodelconfig_api_key_and_more.py => 0049_texttoimagemodelconfig_api_key_and_more.py} (73%) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index f00fe1794..1e43887a9 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -45,9 +45,9 @@ Subscription, TextToImageModelConfig, UserConversationConfig, - UserPaintModelConfig, UserRequests, UserSearchModelConfig, + UserTextToImageModelConfig, UserVoiceModelConfig, VoiceModelOption, ) @@ -897,25 +897,27 @@ def get_text_to_image_model_options(): return TextToImageModelConfig.objects.all() @staticmethod - def get_user_paint_model_config(user: KhojUser): - config = UserPaintModelConfig.objects.filter(user=user).first() + def get_user_text_to_image_model_config(user: KhojUser): + config = UserTextToImageModelConfig.objects.filter(user=user).first() if not config: return None return config.setting @staticmethod - async def aget_user_paint_model(user: KhojUser): - config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() + async def aget_user_text_to_image_model(user: KhojUser): + config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() if not config: return None return config.setting @staticmethod - async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int): + async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int): config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst() if not config: return None - new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create( + user=user, defaults={"setting": config} + ) return new_config diff --git a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py similarity index 73% rename from src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py rename to src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py index d5d3e9729..c7ff9e81e 100644 --- a/src/khoj/database/migrations/0048_texttoimagemodelconfig_api_key_and_more.py +++ b/src/khoj/database/migrations/0049_texttoimagemodelconfig_api_key_and_more.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.11 on 2024-06-20 19:48 +# Generated by Django 4.2.11 on 2024-06-26 03:27 import django.db.models.deletion from django.conf import settings @@ -7,7 +7,7 @@ class Migration(migrations.Migration): dependencies = [ - ("database", "0047_alter_entry_file_type"), + ("database", "0048_voicemodeloption_uservoicemodelconfig"), ] operations = [ @@ -16,6 +16,17 @@ class Migration(migrations.Migration): name="api_key", field=models.CharField(blank=True, default=None, max_length=200, null=True), ), + migrations.AddField( + model_name="texttoimagemodelconfig", + name="openai_config", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to="database.openaiprocessorconversationconfig", + ), + ), migrations.AlterField( model_name="texttoimagemodelconfig", name="model_type", @@ -24,7 +35,7 @@ class Migration(migrations.Migration): ), ), migrations.CreateModel( - name="UserPaintModelConfig", + name="UserTextToImageModelConfig", fields=[ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ("created_at", models.DateTimeField(auto_now_add=True)), diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 6f91b5066..37ee911c2 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -239,6 +239,32 @@ class ModelType(models.TextChoices): model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) api_key = models.CharField(max_length=200, default=None, null=True, blank=True) + openai_config = models.ForeignKey( + OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True + ) + + def clean(self): + # Custom validation logic + error = {} + if self.model_type == self.ModelType.OPENAI: + if self.api_key and self.openai_config: + error[ + "api_key" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + error[ + "openai_config" + ] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them." + if self.model_type != self.ModelType.OPENAI: + if not self.api_key: + error["api_key"] = "The API key field must be set for non OpenAI models." + if self.openai_config: + error["openai_config"] = "OpenAI config cannot be set for non OpenAI models." + if error: + raise ValidationError(error) + + def save(self, *args, **kwargs): + self.clean() + super().save(*args, **kwargs) class SpeechToTextModelOptions(BaseModel): @@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel): setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) -class UserPaintModelConfig(BaseModel): +class UserTextToImageModelConfig(BaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE) diff --git a/src/khoj/routers/api_config.py b/src/khoj/routers/api_config.py index f25ecacdb..65faf09c4 100644 --- a/src/khoj/routers/api_config.py +++ b/src/khoj/routers/api_config.py @@ -328,7 +328,7 @@ async def update_paint_model( if not subscribed: raise HTTPException(status_code=403, detail="User is not subscribed to premium") - new_config = await ConversationAdapters.aset_user_paint_model(user, int(id)) + new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id)) update_telemetry_state( request=request, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 79bb2365b..835bb8c16 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -762,7 +762,7 @@ async def text_to_image( image_url = None intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - text_to_image_config = await ConversationAdapters.aget_user_paint_model(user) + text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 @@ -796,9 +796,19 @@ async def text_to_image( if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: with timer("Generate image with OpenAI", logger): + if text_to_image_config.api_key: + api_key = text_to_image_config.api_key + elif text_to_image_config.openai_config: + api_key = text_to_image_config.openai_config.api_key + elif state.openai_client: + api_key = state.openai_client.api_key + auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} try: response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + prompt=improved_image_prompt, + model=text2image_model, + response_format="b64_json", + extra_headers=auth_header, ) image = response.data[0].b64_json decoded_image = base64.b64decode(image) diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 25dd43dba..9550d8e71 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -251,7 +251,7 @@ def config_page(request: Request): current_search_model_option = adapters.get_user_search_model_or_default(user) - selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user) + selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() for paint_model in paint_model_options: