Skip to content

Commit

Permalink
Add Validation logic to save PaintModel. Use API key from Paint Model
Browse files Browse the repository at this point in the history
Rename Paint Model, Adapters to TextToImage for consistency
  • Loading branch information
debanjum committed Jun 26, 2024
1 parent 5900051 commit d31c94b
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 15 deletions.
16 changes: 9 additions & 7 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserPaintModelConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
)
Expand Down Expand Up @@ -897,25 +897,27 @@ def get_text_to_image_model_options():
return TextToImageModelConfig.objects.all()

@staticmethod
def get_user_paint_model_config(user: KhojUser):
config = UserPaintModelConfig.objects.filter(user=user).first()
def get_user_text_to_image_model_config(user: KhojUser):
config = UserTextToImageModelConfig.objects.filter(user=user).first()
if not config:
return None
return config.setting

@staticmethod
async def aget_user_paint_model(user: KhojUser):
config = await UserPaintModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
async def aget_user_text_to_image_model(user: KhojUser):
config = await UserTextToImageModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting

@staticmethod
async def aset_user_paint_model(user: KhojUser, text_to_image_model_config_id: int):
async def aset_user_text_to_image_model(user: KhojUser, text_to_image_model_config_id: int):
config = await TextToImageModelConfig.objects.filter(id=text_to_image_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserPaintModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
new_config, _ = await UserTextToImageModelConfig.objects.aupdate_or_create(
user=user, defaults={"setting": config}
)
return new_config


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Generated by Django 4.2.11 on 2024-06-20 19:48
# Generated by Django 4.2.11 on 2024-06-26 03:27

import django.db.models.deletion
from django.conf import settings
Expand All @@ -7,7 +7,7 @@

class Migration(migrations.Migration):
dependencies = [
("database", "0047_alter_entry_file_type"),
("database", "0048_voicemodeloption_uservoicemodelconfig"),
]

operations = [
Expand All @@ -16,6 +16,17 @@ class Migration(migrations.Migration):
name="api_key",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="texttoimagemodelconfig",
name="openai_config",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to="database.openaiprocessorconversationconfig",
),
),
migrations.AlterField(
model_name="texttoimagemodelconfig",
name="model_type",
Expand All @@ -24,7 +35,7 @@ class Migration(migrations.Migration):
),
),
migrations.CreateModel(
name="UserPaintModelConfig",
name="UserTextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
Expand Down
28 changes: 27 additions & 1 deletion src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,32 @@ class ModelType(models.TextChoices):
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
openai_config = models.ForeignKey(
OpenAIProcessorConversationConfig, on_delete=models.CASCADE, default=None, null=True, blank=True
)

def clean(self):
# Custom validation logic
error = {}
if self.model_type == self.ModelType.OPENAI:
if self.api_key and self.openai_config:
error[
"api_key"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
error[
"openai_config"
] = "Both API key and OpenAI config cannot be set for OpenAI models. Please set only one of them."
if self.model_type != self.ModelType.OPENAI:
if not self.api_key:
error["api_key"] = "The API key field must be set for non OpenAI models."
if self.openai_config:
error["openai_config"] = "OpenAI config cannot be set for non OpenAI models."
if error:
raise ValidationError(error)

def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)


class SpeechToTextModelOptions(BaseModel):
Expand All @@ -265,7 +291,7 @@ class UserSearchModelConfig(BaseModel):
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)


class UserPaintModelConfig(BaseModel):
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)

Expand Down
2 changes: 1 addition & 1 deletion src/khoj/routers/api_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ async def update_paint_model(
if not subscribed:
raise HTTPException(status_code=403, detail="User is not subscribed to premium")

new_config = await ConversationAdapters.aset_user_paint_model(user, int(id))
new_config = await ConversationAdapters.aset_user_text_to_image_model(user, int(id))

update_telemetry_state(
request=request,
Expand Down
14 changes: 12 additions & 2 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ async def text_to_image(
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3

text_to_image_config = await ConversationAdapters.aget_user_paint_model(user)
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
Expand Down Expand Up @@ -796,9 +796,19 @@ async def text_to_image(

if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
with timer("Generate image with OpenAI", logger):
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
prompt=improved_image_prompt,
model=text2image_model,
response_format="b64_json",
extra_headers=auth_header,
)
image = response.data[0].b64_json
decoded_image = base64.b64decode(image)
Expand Down
2 changes: 1 addition & 1 deletion src/khoj/routers/web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def config_page(request: Request):

current_search_model_option = adapters.get_user_search_model_or_default(user)

selected_paint_model_config = ConversationAdapters.get_user_paint_model_config(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
for paint_model in paint_model_options:
Expand Down

0 comments on commit d31c94b

Please sign in to comment.