diff --git a/lmm_tools/data/data.py b/lmm_tools/data/data.py index 8893ae72..e531629b 100644 --- a/lmm_tools/data/data.py +++ b/lmm_tools/data/data.py @@ -26,7 +26,6 @@ def __init__(self, df: pd.DataFrame): Args: df (pd.DataFrame): The DataFrame containing "image_paths" and "image_id" columns. - """ self.df = df self.lmm: Optional[LMM] = None @@ -46,7 +45,12 @@ def add_lmm(self, lmm: LMM) -> Self: return self def add_column(self, name: str, prompt: str) -> Self: - r"""Adds a new column to the DataFrame containing the generated metadata from the LMM.""" + r"""Adds a new column to the DataFrame containing the generated metadata from the LMM. + + Args: + name (str): The name of the column to be added. + prompt (str): The prompt to be used to generate the metadata. + """ if self.lmm is None: raise ValueError("LMM not set yet") @@ -57,6 +61,7 @@ def add_column(self, name: str, prompt: str) -> Self: def build_index(self, target_col: str) -> Self: r"""This will generate embeddings for the `target_col` and build a searchable index over them, so next time you run search it will search over this index. + Args: target_col (str): The column name containing the data to be indexed.""" if self.emb is None: @@ -81,6 +86,7 @@ def get_embeddings(self) -> npt.NDArray[np.float32]: def search(self, query: str, top_k: int = 10) -> List[Dict]: r"""Searches the index for the most similar images to the query and returns the top_k results. + Args: query (str): The query to search for. top_k (int, optional): The number of results to return. Defaults to 10."""