Skip to content

Commit

Permalink
Introduce filter to limit search to a single file. Streamline the int…
Browse files Browse the repository at this point in the history
…erface
  • Loading branch information
DL committed Oct 4, 2024
1 parent 2ed455c commit 1eb2d90
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 19 deletions.
31 changes: 29 additions & 2 deletions src/llmsearch/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Tuple

import pandas as pd
Expand Down Expand Up @@ -75,17 +76,42 @@ def get_embedding_model(config: EmbeddingModel):

def create_embeddings(config: Config, vs: VectorStore):
splitter = DocumentSplitter(config)
all_docs, all_hash_filename_mappings, all_hash_docid_mappings = splitter.split()
all_docs, all_hash_filename_mappings, all_hash_docid_mappings, all_labels = splitter.split()

vs.create_index_from_documents(all_docs=all_docs)

splade = SparseEmbeddingsSplade(config)
splade.generate_embeddings_from_docs(docs=all_docs)

save_document_hashes(config, all_hash_filename_mappings, all_hash_docid_mappings)
update_document_labels(config, all_labels)
logger.info("ALL DONE.")


def update_document_labels(config: Config, all_labels: List[str]):
logger.info("Updating document labels...")
labels_fn = Path(os.path.join(config.embeddings.embeddings_path, "labels.txt"))
labels = load_document_labels(labels_fn)

labels = list(set(labels + all_labels))
save_document_labels(labels_fn, labels)


def load_document_labels(path: Path) -> List[str]:
try:
with open(path, 'r', encoding='utf-8') as f:
strings = [line.strip() for line in f.readlines()]
return strings
except FileNotFoundError:
logger.warning("List of labels wasn't found, returning []")
return []

def save_document_labels(path: Path, labels: List[str]) -> None:
logger.info(f"Saving document labels to {path}")
with open(path, 'w', encoding='utf-8') as f:
for string in labels:
f.write(string + '\n') # Write each string followed by a newline

def update_embeddings(config: Config, vs: VectorStore) -> dict:
splitter = DocumentSplitter(config)
new_hashes_df = splitter.get_hashes()
Expand Down Expand Up @@ -179,7 +205,7 @@ def update_embeddings(config: Config, vs: VectorStore) -> dict:
if len(changed_or_new_df) > 0:
splitter = DocumentSplitter(config)

new_docs, new_fn_hash_mappings, new_docid_hash_mappings = splitter.split(
new_docs, new_fn_hash_mappings, new_docid_hash_mappings, all_labels = splitter.split(
restrict_filenames=changed_or_new_df.loc[:, "filename"].tolist()
)

Expand All @@ -194,6 +220,7 @@ def update_embeddings(config: Config, vs: VectorStore) -> dict:
splade.add_embeddings(new_docs)
stats["scanned_files"] = len(changed_or_new_df)
stats["scanned_chunks"] = len(new_docs)
update_document_labels(config, all_labels)

stats["updated_n_files"] = len(existing_fn_hash_mappings)
# Save changed mappings
Expand Down
17 changes: 15 additions & 2 deletions src/llmsearch/parsers/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_hashes(self) -> pd.DataFrame:

def split(
self, restrict_filenames: Optional[List[str]] = None
) -> Tuple[List[Document], pd.DataFrame, pd.DataFrame]:
) -> Tuple[List[Document], pd.DataFrame, pd.DataFrame, List[str]]:
"""Splits documents based on document path settings
Returns:
Expand All @@ -77,12 +77,16 @@ def split(
# Mapping between hash and document ids
hash_docid_mappings = []

# Will collect all labels, for persisting later
all_labels = []

for setting in self.document_path_settings:
passage_prefix = setting.passage_prefix
docs_path = Path(setting.doc_path)
documents_label = setting.label
exclusion_paths = [str(e) for e in setting.exclude_paths]


for scan_extension in setting.scan_extensions:
extension = scan_extension
for chunk_size in self.chunk_sizes: # type: ignore
Expand Down Expand Up @@ -132,10 +136,12 @@ def split(
hash_filename_mappings.extend(hf_mappings)
hash_docid_mappings.extend(hd_mappings)

all_labels+=list(set([d.metadata['label'] for d in docs]))

all_hash_filename_mappings = pd.DataFrame(hash_filename_mappings)
all_hash_docid_mappings = pd.concat(hash_docid_mappings, axis=0)

return all_docs, all_hash_filename_mappings, all_hash_docid_mappings
return all_docs, all_hash_filename_mappings, all_hash_docid_mappings, all_labels

def is_exclusion(self, path: Path, exclusions: List[str]) -> bool:
"""Checks if path has parent folders in list of exclusions
Expand Down Expand Up @@ -173,6 +179,7 @@ def _get_documents_from_custom_splitter(
Examples: https://gpt-index.readthedocs.io/en/stable/guides/primer/usage_pattern.html
"""

original_label = label
all_docs = []

# Maps between file name and it's hash
Expand Down Expand Up @@ -218,6 +225,12 @@ def _get_documents_from_custom_splitter(
path = urllib.parse.quote(str(path)) # type: ignore
logger.info(path)

# If label for a set of documents doesn't exist, set it as document path
# Assign path to label

# if not original_label:
label = str(path)

docs = [
Document(
page_content=passage_prefix + d["text"],
Expand Down
42 changes: 27 additions & 15 deletions src/llmsearch/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
from io import StringIO
from pathlib import Path
from typing import List
from typing import Dict, List

import langchain
import streamlit as st
Expand All @@ -17,7 +17,7 @@
from llmsearch.chroma import VectorStoreChroma
from llmsearch.config import Config
from llmsearch.embeddings import (EmbeddingsHashNotExistError,
create_embeddings, update_embeddings)
create_embeddings, update_embeddings, load_document_labels)
from llmsearch.process import get_and_parse_response
from llmsearch.utils import get_llm_bundle, set_cache_folder

Expand Down Expand Up @@ -120,6 +120,11 @@ def load_config(doc_config, model_config) -> Config:
config_dict = {**doc_config_dict, **model_config_dict}
return Config(**config_dict)

@st.cache_data
def load_labels(embedding_path: str) -> Dict[str, str]:
labels_fn = Path(os.path.join(embedding_path, "labels.txt"))
all_labels = {Path(label).name: label for label in load_document_labels(labels_fn)}
return all_labels

@st.cache_data
def load_yaml_file(config) -> dict:
Expand Down Expand Up @@ -184,10 +189,11 @@ def generate_response(


@st.cache_data
def get_config_paths(config_dir: str) -> List[str]:
def get_config_paths(config_dir: str) -> Dict[str, str]:
root = Path(config_dir)
config_paths = sorted([str(p) for p in root.glob("*.yaml")])
return config_paths
config_paths = sorted([p for p in root.glob("*.yaml")])
config_path_names = {p.name: str(p) for p in root.glob("*.yaml")}
return config_path_names


def reload_model(doc_config_path: str, model_config_file: str):
Expand Down Expand Up @@ -247,9 +253,10 @@ def reload_model(doc_config_path: str, model_config_file: str):

if Path(args.cli_doc_config_path).is_dir():
config_paths = get_config_paths(args.cli_doc_config_path)
doc_config_path = st.sidebar.selectbox(
label="Choose config", options=config_paths, index=0
doc_config_name = st.sidebar.selectbox(
label="Choose config", options=sorted(list(config_paths.keys())), index=0
)
doc_config_path = config_paths[doc_config_name] # type: ignore
model_config_file = args.cli_model_config_path
logger.debug(f"CONFIG FILE: {doc_config_path}")

Expand Down Expand Up @@ -294,18 +301,23 @@ def reload_model(doc_config_path: str, model_config_file: str):
f"**Max char size (semantic search):** {config.semantic_search.max_char_size}"
)
label_filter = ""
if config.embeddings.labels:
document_labels = load_labels(config.embeddings.embeddings_path)

if document_labels:
label_filter = st.sidebar.selectbox(
label="Filter by label", options=["-"] + config.embeddings.labels
label="Restrict search to:", options=["-"] + sorted(list(document_labels.keys()))
)
if label_filter is None or label_filter == "-":
label_filter = ""


tables_only_filter = st.sidebar.checkbox(label="Prioritize tables")
if tables_only_filter:
source_chunk_type_filter="table"
else:
source_chunk_type_filter=""
# tables_only_filter = st.sidebar.checkbox(label="Prioritize tables")
# if tables_only_filter:
# source_chunk_type_filter="table"
# else:
# source_chunk_type_filter=""

source_chunk_type_filter=""

text = st.chat_input("Enter text", disabled=False)
is_hyde = st.sidebar.checkbox(
Expand Down Expand Up @@ -359,7 +371,7 @@ def reload_model(doc_config_path: str, model_config_file: str):
use_multiquery=st.session_state["llm_bundle"].multiquery_enabled,
_bundle=st.session_state["llm_bundle"],
_config=config,
label_filter=label_filter,
label_filter=document_labels.get(label_filter,""),
source_chunk_type_filter = source_chunk_type_filter
)

Expand Down

0 comments on commit 1eb2d90

Please sign in to comment.