diff --git a/vision_agent/data/data.py b/vision_agent/data/data.py index cc069fa3..bdf381a9 100644 --- a/vision_agent/data/data.py +++ b/vision_agent/data/data.py @@ -44,7 +44,9 @@ def add_lmm(self, lmm: LMM) -> Self: self.lmm = lmm return self - def add_column(self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None) -> Self: + def add_column( + self, name: str, prompt: str, func: Optional[Callable[[str], str]] = None + ) -> Self: r"""Adds a new column to the DataFrame containing the generated metadata from the LMM. Args: @@ -56,7 +58,11 @@ def add_column(self, name: str, prompt: str, func: Optional[Callable[[str], str] raise ValueError("LMM not set yet") self.df[name] = self.df["image_paths"].progress_apply( # type: ignore - lambda x: func(self.lmm.generate(prompt, image=x)) if func else self.lmm.generate(prompt, image=x) + lambda x: ( + func(self.lmm.generate(prompt, image=x)) + if func + else self.lmm.generate(prompt, image=x) + ) ) return self diff --git a/vision_agent/lmm/lmm.py b/vision_agent/lmm/lmm.py index 1ab74bba..de4d255d 100644 --- a/vision_agent/lmm/lmm.py +++ b/vision_agent/lmm/lmm.py @@ -32,11 +32,11 @@ def __init__(self, name: str): self.name = name def generate( - self, - prompt: str, - image: Optional[Union[str, Path]] = None, - temperature: float = 0.2, - max_new_tokens: int = 256, + self, + prompt: str, + image: Optional[Union[str, Path]] = None, + temperature: float = 0.2, + max_new_tokens: int = 256, ) -> str: data = {"prompt": prompt} if image: