Skip to content

Commit

Permalink
fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
shankar-vision-eng committed Feb 28, 2024
1 parent 96bfeb8 commit 8df3e6b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
10 changes: 8 additions & 2 deletions vision_agent/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def add_lmm(self, lmm: LMM) -> Self:
self.lmm = lmm
return self

def add_column(self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None) -> Self:
def add_column(
self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None
) -> Self:
r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.
Args:
Expand All @@ -56,7 +58,11 @@ def add_column(self, name: str, prompt: str, func: Optional[Callable[[str], str]
raise ValueError("LMM not set yet")

self.df[name] = self.df["image_paths"].progress_apply( # type: ignore
lambda x: func(self.lmm.generate(prompt, image=x)) if func else self.lmm.generate(prompt, image=x)
lambda x: (
func(self.lmm.generate(prompt, image=x))
if func
else self.lmm.generate(prompt, image=x)
)
)
return self

Expand Down
10 changes: 5 additions & 5 deletions vision_agent/lmm/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __init__(self, name: str):
self.name = name

def generate(
self,
prompt: str,
image: Optional[Union[str, Path]] = None,
temperature: float = 0.2,
max_new_tokens: int = 256,
self,
prompt: str,
image: Optional[Union[str, Path]] = None,
temperature: float = 0.2,
max_new_tokens: int = 256,
) -> str:
data = {"prompt": prompt}
if image:
Expand Down

0 comments on commit 8df3e6b

Please sign in to comment.