Skip to content

Commit

Permalink
Add example notebook (#7)
Browse files Browse the repository at this point in the history
* add example folder and files

* jupyter notebook examples updated

* fix linting

* rename notebook and fix mypy errors
  • Loading branch information
shankar-vision-eng authored Feb 28, 2024
1 parent d9f0670 commit 4c39999
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 4 deletions.
Binary file added examples/img/ct_scan1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/img/ct_scan2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/img/doc1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/img/doc2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/img/doc3.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
384 changes: 384 additions & 0 deletions examples/va_example.ipynb

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions vision_agent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from pathlib import Path
from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Union, cast, Callable

import faiss
import numpy as np
Expand Down Expand Up @@ -44,18 +44,25 @@ def add_lmm(self, lmm: LMM) -> Self:
self.lmm = lmm
return self

def add_column(self, name: str, prompt: str) -> Self:
def add_column(
self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None
) -> Self:
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.
Args:
name (str): The name of the column to be added.
prompt (str): The prompt to be used to generate the metadata.
func (Optional[Callable[[Any], Any]]): A Python function to be applied on the output of `lmm.generate`. Defaults to None.
"""
if self.lmm is None:
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: self.lmm.generate(prompt, image=x)
lambda x: (
func(self.lmm.generate(prompt, image=x))
if func
else self.lmm.generate(prompt, image=x)
)
)
return self

Expand Down
10 changes: 9 additions & 1 deletion vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@ class LLaVALMM(LMM):
def __init__(self, name: str):
self.name = name

def generate(self, prompt: str, image: Optional[Union[str, Path]] = None) -> str:
def generate(
self,
prompt: str,
image: Optional[Union[str, Path]] = None,
temperature: float = 0.1,
max_new_tokens: int = 1500,
) -> str:
data = {"prompt": prompt}
if image:
data["image"] = encode_image(image)
data["temperature"] = temperature # type: ignore
data["max_new_tokens"] = max_new_tokens # type: ignore
res = requests.post(
_LLAVA_ENDPOINT,
headers={"Content-Type": "application/json"},
Expand Down

0 comments on commit 4c39999

Please sign in to comment.