Skip to content

Commit

Permalink
jupyter notebook examples updated
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Feb 28, 2024
1 parent 2dbbbee commit 83ba098
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 37 deletions.
Binary file added examples/img/doc3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
252 changes: 219 additions & 33 deletions examples/lmm_example.ipynb

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions vision_agent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast, Callable

import faiss
import numpy as np
Expand Down Expand Up @@ -44,18 +44,19 @@ def add_lmm(self, lmm: LMM) -> Self:
self.lmm = lmm
return self

def add_column(self, name: str, prompt: str) -> Self:
def add_column(self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None) -> Self:
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.
Args:
name (str): The name of the column to be added.
prompt (str): The prompt to be used to generate the metadata.
func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None.
"""
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: self.lmm.generate(prompt, image=x)
lambda x: func(self.lmm.generate(prompt, image=x)) if func else self.lmm.generate(prompt, image=x)
)
return self

Expand Down
10 changes: 9 additions & 1 deletion vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ class LLaVALMM(LMM):
def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
def generate(
self,
prompt: str,
image: Optional[Union[str, Path]] = None,
temperature: float = 0.2,
max_new_tokens: int = 256,
) -> str:
data = {"prompt": prompt}
if image:
data["image"] = encode_image(image)
data["temperature"] = temperature
data["max_new_tokens"] = max_new_tokens
res = requests.post(
BASETEN_URL,
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
Expand Down

0 comments on commit 83ba098

Please sign in to comment.