diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..12900c1 --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] +max-line-length = 88 +extend-ignore = E203,D10,D415,E501,E712 +docstring-convention = google diff --git a/README.md b/README.md index 7852268..7de2054 100644 --- a/README.md +++ b/README.md @@ -4,3 +4,24 @@ A tool for bulk labeling, built in Solara! I'm trying to rebuild my original [bulk-labeling](https://github.com/rungalileo/bulk-labeling/) app, which was Streamlit, in [Solara](https://github.com/widgetti/solara) so it can be a bit more scalable, customizable, and robust to new features! I also want to learn how to use solara :) + + +## Development +1. Setup a virtual env: `python -m venv .venv && source .venv/bin/activate` +2. Install the package: `pip install -e . && pyenv rehash` +3. Run: `solara run bulk_labeling/main.py` + +Any changes you make to the app should reflect in realtime + +### Note: `SentenceTransformers` doesn't play nicely with solara +If you are going to be developing, I strongly recommend commenting out +the few lines in [ml.py](bulk_labeling/utils/ml.py): +https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L5 +https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L9 +https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L13 + +and uncomment +https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L15 + +For some reason, on a page reload, solara breaks if these lines are running. +It will also make prototyping faster because you won't be actually encoding strings. diff --git a/bulk-labeling/main.py b/bulk-labeling/main.py deleted file mode 100644 index 4332e14..0000000 --- a/bulk-labeling/main.py +++ /dev/null @@ -1,332 +0,0 @@ -import io -import itertools -import os -from collections import defaultdict -from time import sleep -from typing import Callable, Dict, List, Optional, Set, cast - -import numpy as np -import pandas as pd -import plotly.express as px -import solara -from reacton import ipyvuetify as V -from sentence_transformers import SentenceTransformer -from solara.components.file_drop import FileInfo -from solara.lab import Reactive -from umap import UMAP - -DIR = f"{os.getcwd()}/bulk-labeling" -PATH = f"{DIR}/conv_intent.csv" - -UMAP_MODEL = UMAP(n_neighbors=15, random_state=42, verbose=True) -ENCODER = SentenceTransformer("paraphrase-MiniLM-L3-v2") - -INTERNAL_COLS = ["x", "y", "hovertext", "id"] -NO_COLOR_COLS = INTERNAL_COLS + ["text"] -BUTTON_KWARGS = dict(color="primary", text=True, outlined=True) - - -class State: - available_labels = Reactive[Set[str]](set()) - labeled_ids = Reactive[Dict[str, List[int]]](defaultdict(list)) - filtered_ids = Reactive[List[int]]([]) - chosen_label = Reactive[Optional[str]](None) - assigned_new_label = Reactive[bool](False) - filter_text = Reactive[str]("") - - -class PlotState: - point_size = Reactive[int](2) - color = Reactive[str]("") - # While we calculate embeddings and UMAP, we can manage the loading state - loading = Reactive[bool](False) - - -def has_df(df: pd.DataFrame) -> bool: - return (len(df) != 0) and (len(df.columns) != 0) - - -def filtered_df(df: pd.DataFrame) -> pd.DataFrame: - dfc = df.copy() - if not has_df(dfc): - return dfc - if State.filtered_ids.value: - dfc = dfc[dfc["id"].isin(State.filtered_ids.value)] - if State.filter_text.value: - dfc = dfc[dfc["text"].str.contains(State.filter_text.value)] - return dfc - - -def add_new_label(new_label: str) -> None: - if not new_label: - return - all_labels = State.available_labels.value.copy() - all_labels.add(new_label) - State.available_labels.set(all_labels) - # So the "assign points" button is already pre-populated with your new label =] - State.chosen_label.set(new_label) - - -def _set_default_cols(df: pd.DataFrame) -> pd.DataFrame: - df["text_length"] = df.text.str.len() - df["id"] = list(range(len(df))) - df["hovertext"] = df.text.str.wrap(30).str.replace("\n", "
") - return df - - -def apply_df_edits(df: pd.DataFrame) -> pd.DataFrame: - print("Should be downloading!") - df2 = df.copy() - labeled_ids = State.labeled_ids.value - # Map every ID to it's assigned labels - # TODO: We can be smarter with conflicts and pick the label that an ID is - # assigned to most frequently - id_label = {id_: label for label, ids in labeled_ids.items() for id_ in ids} - df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, "-1")) - df2 = df2[df2["label"] != "-1"] - cols = [c for c in df2.columns if c not in INTERNAL_COLS] - return df2[cols] - - -def encode_inputs(samples: List[str]) -> np.ndarray: - return ENCODER.encode(samples) - # When doing rapid development, it's faster to return a numpy array - # return np.random.rand(len(samples), 20) - - -def get_xy(embs: np.ndarray) -> np.ndarray: - return UMAP_MODEL.fit_transform(embs) - - -def get_text_embeddings(samples: List[str]) -> np.ndarray: - return get_xy(encode_inputs(samples)) - - -def find_row_ids(fig, click_data) -> List[int]: - """A very annoying function to get row IDs because Plotly is unhelpful - - Solara is going to do this for us in the future! - """ - # goes from trace index and point index to row index in a dataframe - # requires passing df.index as to custom_data - trace_index = click_data["points"]["trace_indexes"] - point_index = click_data["points"]["point_indexes"] - point_ids = [] - for t, p in zip(trace_index, point_index): - point_trace = fig.data[t] - point_ids.append(point_trace.customdata[p][0]) - return point_ids - - -def reset(): - """Removes any filters applied to the data""" - State.filtered_ids.set([]) - State.filter_text.set("") - State.chosen_label.set(None) - - -@solara.component -def assigned_label_view() -> None: - State.assigned_new_label.use() - if State.assigned_new_label.value: - solara.Info(f"{len(State.filtered_ids.value)} labeled!") - sleep(2) - State.assigned_new_label.set(False) - - -@solara.component -def table(df: pd.DataFrame): - solara.Markdown(f"## Data ({len(df):,} points)") - solara.DataFrame(df[[i for i in df.columns if i not in INTERNAL_COLS]]) - - -@solara.component -def embeddings(df: pd.DataFrame, color: str, point_size: int): - solara.Markdown("## Embeddings") - # We pass in df.id to custom_data so we can get back the correct points on a - # lasso selection. Plotly makes this difficult - # TODO: Solara will wrap and handle all of this logic for us in the future - p = px.scatter( - df, - x="x", - y="y", - color=color or None, - custom_data=[df["id"]], - hover_data=["hovertext"], - ) - p.update_layout(showlegend=False) - p.update_xaxes(visible=False) - p.update_yaxes(visible=False) - p.update_traces(marker_size=point_size) - - # Plotly returns data in a weird way, we just want the ids - # TODO: Solara to handle :) - set_point_ids = lambda data: State.filtered_ids.set(find_row_ids(p, data)) - solara.FigurePlotly(p, on_selection=set_point_ids) - - -@solara.component -def df_view(df: pd.DataFrame) -> None: - # TODO: Remove when solara updates - PlotState.point_size.use() - PlotState.color.use() - State.filtered_ids.use() - State.filter_text.use() - - fdf = filtered_df(df) - - with solara.Columns([1, 1]): - table(fdf) - embeddings(fdf, PlotState.color.value, PlotState.point_size.value) - - -@solara.component -def no_df() -> None: - with solara.Columns([1, 1]): - solara.Markdown("## DataFrame (Load Data)") - solara.Markdown("## Embeddings (Load Data)") - - -@solara.component -def _emb_loading_state() -> None: - solara.Markdown("## Embeddings") - solara.Markdown("Loading your embeddings. Enjoy this fun animation for now") - V.ProgressLinear(indeterminate=True) - - -@solara.component -def no_embs(df: pd.DataFrame) -> None: - with solara.Columns([1, 1]): - table(filtered_df(df)) - _emb_loading_state() - - -@solara.component() -def label_manager(df: pd.DataFrame) -> None: - State.chosen_label.use() - State.filtered_ids.use() - State.filter_text.use() - State.labeled_ids.use() - - def assign_labels() -> None: - labeled_ids = State.labeled_ids.value.copy() - new_ids = filtered_df(df)["id"].tolist() - print(f"Setting {State.chosen_label.value} for " f"{len(new_ids)} points") - labeled_ids[State.chosen_label.value].extend(new_ids) - State.labeled_ids.set(labeled_ids) - # State.assigned_new_label.set(True) - # Reset the view so no points are selected - reset() - - def export_edited_df() -> None: - """Assigns the label and downloads the df to the user""" - # TODO: Last thing! Allow the user to download the df - exp_df = apply_df_edits(df) - print(f"{len(exp_df)} rows edited") - - # TODO: Make a State.available_labels.append - solara.InputText("Register new label", on_value=add_new_label) - if State.available_labels.value: - solara.Select("Available labels", list(State.available_labels.value)).connect( - State.chosen_label - ) - button_enabled = bool(State.chosen_label.value) and has_df(df) - fdf = filtered_df(df) - num = len(fdf) if len(fdf) != len(df) else "all" - if button_enabled: - btn_label = f"Assign {num} points to label {State.chosen_label.value}" - elif not has_df(df): - btn_label = "Add data" - elif not State.available_labels.value: - btn_label = "Create a label" - else: - btn_label = "Choose a label" - solara.Button( - btn_label, on_click=assign_labels, disabled=not button_enabled, **BUTTON_KWARGS - ) - if State.labeled_ids.value: - # Flatten all of the edits into a single set, so we know how many were edited - num_edited = len(set(itertools.chain(*State.labeled_ids.value.values()))) - solara.Button( - f"Export {num_edited} labeled points", - on_click=export_edited_df, - **BUTTON_KWARGS, - ) - - -@solara.component() -def menu(df: pd.DataFrame, set_df: Callable) -> None: - # TODO: Remove when solara updates - PlotState.point_size.use() - PlotState.color.use() - State.filter_text.use() - - # avl_cols is dependent on df, so any time it changes, - # this will automatically update - set_cols = lambda: [i for i in df.columns if i not in NO_COLOR_COLS] - avl_cols = solara.use_memo(set_cols, [df]) - - def load_demo_df(): - new_df = pd.read_csv(PATH) - new_df = _set_default_cols(new_df) - set_df(new_df) - - def load_file_df(file: FileInfo): - data = io.BytesIO(file["data"]) - new_df = pd.read_csv(data) - new_df = _set_default_cols(new_df) - # Set it before embeddings so the user can see the df while embeddings load - PlotState.loading.set(True) - set_df(new_df) - embs = get_text_embeddings(new_df["text"].tolist()) - new_df["x"] = embs[:, 0] - new_df["y"] = embs[:, 1] - # Set it again after embeddings so we can render the plotly graph - set_df(new_df) - PlotState.loading.set(False) - - solara.FileDrop( - label="Drop CSV here (`text` col required)!", - on_file=load_file_df, - lazy=False, - ) - with solara.Column(): - solara.Button(label="Load demo dataset", on_click=load_demo_df, **BUTTON_KWARGS) - solara.Button(label="Reset view", on_click=reset, **BUTTON_KWARGS) - label_manager(df) - solara.InputText( - "Filter by search", State.filter_text.value, on_value=State.filter_text.set - ) - solara.Markdown("**Set point size**") - solara.SliderInt("", PlotState.point_size.value, on_value=PlotState.point_size.set) - # TODO: A drop down should have "remove selection" option - # (esp if default state is None) - solara.Select( - "Color by", - [None] + avl_cols, - PlotState.color.value, - on_value=PlotState.color.set, - ) - - -@solara.component -def Page(): - # TODO: Remove when solara updates - PlotState.loading.use() - State.filter_text.use() - State.filtered_ids.use() - # This `eq` makes it so every time we set the dataframe, solara thinks it's new - df, set_df = solara.use_state( - cast(pd.DataFrame, pd.DataFrame({})), eq=lambda *args: False - ) - solara.Title("Bulk Labeling!") - # TODO: Why cant i get this view to render? - assigned_label_view() - with solara.Sidebar(): - menu(df, set_df) - if has_df(df) and PlotState.loading.value: - no_embs(df) - elif has_df(df): - df_view(df) - else: - no_df() diff --git a/bulk-labeling/requirements-dev.txt b/bulk-labeling/requirements-dev.txt deleted file mode 100644 index aa7d98b..0000000 --- a/bulk-labeling/requirements-dev.txt +++ /dev/null @@ -1,3 +0,0 @@ -jupyter -black -isort diff --git a/bulk-labeling/requirements.txt b/bulk-labeling/requirements.txt deleted file mode 100644 index 1c50096..0000000 --- a/bulk-labeling/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -solara -pandas -numpy -plotly<=5.9.0 -sentence-transformers -umap-learn -numerize diff --git a/bulk_labeling/__init__.py b/bulk_labeling/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/bulk_labeling/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/bulk_labeling/components/__init__.py b/bulk_labeling/components/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bulk_labeling/components/df.py b/bulk_labeling/components/df.py new file mode 100644 index 0000000..05b4d00 --- /dev/null +++ b/bulk_labeling/components/df.py @@ -0,0 +1,58 @@ +import pandas as pd +import solara +from reacton import ipyvuetify as V + +from bulk_labeling.state import PlotState, State +from bulk_labeling.utils.df import INTERNAL_COLS, filtered_df +from bulk_labeling.utils.plotly import create_plotly_figure, find_row_ids + + +@solara.component +def _emb_loading_state() -> None: + solara.Markdown("## Embeddings") + solara.Markdown("Loading your embeddings. Enjoy this fun animation for now") + V.ProgressLinear(indeterminate=True) + + +@solara.component +def no_embs(df: pd.DataFrame) -> None: + with solara.Columns([1, 1]): + table(filtered_df(df)) + _emb_loading_state() + + +@solara.component +def embeddings(df: pd.DataFrame, color: str, point_size: int) -> None: + State.filtered_ids.use() + + solara.Markdown("## Embeddings") + fig = create_plotly_figure(df, color, point_size) + + # Plotly returns data in a weird way, we just want the ids + # TODO: Solara to handle :) + set_point_ids = lambda data: State.filtered_ids.set(find_row_ids(fig, data)) + solara.FigurePlotly(fig, on_selection=set_point_ids) + + +@solara.component +def table(df: pd.DataFrame) -> None: + solara.Markdown(f"## Data ({len(df):,} points)") + solara.DataFrame(df[[i for i in df.columns if i not in INTERNAL_COLS]]) + + +@solara.component +def df_view(df: pd.DataFrame) -> None: + # TODO: Remove when solara updates + PlotState.point_size.use() + PlotState.color.use() + # We need these because of the `filtered_df` call. + # next itertation we will move the df into global state (solara blocking) + # and then filtered_df becomes always up to date + State.filtered_ids.use() + State.filter_text.use() + + fdf = filtered_df(df) + + with solara.Columns([1, 1]): + table(fdf) + embeddings(fdf, PlotState.color.value, PlotState.point_size.value) diff --git a/bulk_labeling/components/menu.py b/bulk_labeling/components/menu.py new file mode 100644 index 0000000..a415707 --- /dev/null +++ b/bulk_labeling/components/menu.py @@ -0,0 +1,167 @@ +import io +import itertools +from functools import partial +from time import sleep +from typing import Callable, List + +import pandas as pd +import solara +from solara.components.file_drop import FileInfo + +from bulk_labeling.state import PlotState, State, reset +from bulk_labeling.utils.common import BUTTON_KWARGS +from bulk_labeling.utils.df import INTERNAL_COLS, apply_df_edits, filtered_df, load_df +from bulk_labeling.utils.menu import ( + PATH, + add_new_label, + assign_labels, + get_assign_label_button_text, +) +from bulk_labeling.utils.ml import add_embeddings_to_df + +NO_COLOR_COLS = INTERNAL_COLS + ["text"] + + +@solara.component +def assigned_label_view() -> None: + State.assigned_new_label.use() + State.filtered_ids.use() + + if State.assigned_new_label.value: + solara.Info(f"{len(State.filtered_ids.value)} labeled!") + sleep(2) + State.assigned_new_label.set(False) + + +@solara.component +def assign_label_button(df: pd.DataFrame) -> None: + fdf = filtered_df(df) + btn_label, button_enabled = get_assign_label_button_text(df) + solara.Button( + btn_label, + on_click=partial(assign_labels, fdf), + disabled=not button_enabled, + **BUTTON_KWARGS, + ) + + +@solara.component +def register_new_label_button() -> None: + # TODO: Remove when solara updates + State.available_labels.use() + State.chosen_label.use() + + # TODO: Make a State.available_labels.append + solara.InputText("Register new label", on_value=add_new_label) + if State.available_labels.value: + solara.Select("Available labels", list(State.available_labels.value)).connect( + State.chosen_label + ) + + +@solara.component +def export_edits_button(df: pd.DataFrame) -> None: + # TODO: Remove when solara updates + State.labeled_ids.use() + + def export_edited_df() -> None: + """Assigns the label and downloads the df to the user""" + # TODO: Last thing! Allow the user to download the df + exp_df = apply_df_edits(df) + print(f"{len(exp_df)} rows edited") + + if State.labeled_ids.value: + # Flatten all of the edits into a single set, so we know how many were edited + num_edited = len(set(itertools.chain(*State.labeled_ids.value.values()))) + solara.Button( + f"Export {num_edited} labeled points", + on_click=export_edited_df, + **BUTTON_KWARGS, + ) + + +@solara.component +def label_manager(df: pd.DataFrame) -> None: + register_new_label_button() + assign_label_button(df) + export_edits_button(df) + + +@solara.component +def file_manager(set_df: Callable) -> None: + PlotState.color.use() + PlotState.loading.use() + + def load_demo_df() -> None: + new_df = load_df(PATH) + set_df(new_df) + set_df(new_df) + + def load_file_df(file: FileInfo) -> None: + if not file["data"]: + return + new_df = load_df(io.BytesIO(file["data"])) + # Set it before embeddings so the user can see the df while embeddings load + PlotState.loading.set(True) + set_df(new_df) + new_df = add_embeddings_to_df(new_df) + # Set it again after embeddings so we can render the plotly graph + set_df(new_df) + PlotState.loading.set(False) + + solara.FileDrop( + label="Drop CSV here (`text` col required)!", + on_file=load_file_df, + lazy=False, + ) + with solara.Column(): + solara.Button(label="Load demo dataset", on_click=load_demo_df, **BUTTON_KWARGS) + solara.Button(label="Reset view", on_click=reset, **BUTTON_KWARGS) + + +@solara.component +def view_controller(avl_cols: List[str]) -> None: + # TODO: Remove when solara updates + PlotState.color.use() + PlotState.point_size.use() + State.filter_text.use() + + solara.InputText( + "Filter by search", State.filter_text.value, on_value=State.filter_text.set + ) + solara.Markdown("**Set point size**") + solara.SliderInt("", PlotState.point_size.value, on_value=PlotState.point_size.set) + # TODO: A drop down should have "remove selection" option + # (esp if default state is None) + solara.Select( + "Color by", + [None] + avl_cols, + PlotState.color.value, + on_value=PlotState.color.set, + ) + + +@solara.component +def menu(df: pd.DataFrame, set_df: Callable) -> None: + State.reset_on_assignment.use() + + # avl_cols is dependent on df, so any time it changes, + # this will automatically update + set_cols = lambda: [i for i in df.columns if i not in NO_COLOR_COLS] + avl_cols = solara.use_memo(set_cols, [df]) + + file_manager(set_df) + label_manager(df) + view_controller(avl_cols) + solara.Markdown(f"**Reset view on label assignment?**") + if State.reset_on_assignment.value: + label = "Reset" + else: + label = "Keep state" + # solara.Checkbox( + # label=label, + # value=State.reset_on_assignment.value, + # on_value=State.reset_on_assignment.set + # ) + solara.Checkbox(label=label).connect(State.reset_on_assignment) + # solara.ToggleButtonsSingle(State.reset_on_assignment.value, values=[True, False], on_value=State.reset_on_assignment.set) diff --git a/bulk-labeling/conv_intent.csv b/bulk_labeling/conv_intent.csv similarity index 100% rename from bulk-labeling/conv_intent.csv rename to bulk_labeling/conv_intent.csv diff --git a/bulk_labeling/main.py b/bulk_labeling/main.py new file mode 100644 index 0000000..6c06c36 --- /dev/null +++ b/bulk_labeling/main.py @@ -0,0 +1,42 @@ +from typing import cast + +import pandas as pd +import solara + +from bulk_labeling.components.df import df_view, no_embs +from bulk_labeling.components.menu import assigned_label_view, menu +from bulk_labeling.state import PlotState +from bulk_labeling.utils.df import has_df + + +@solara.component +def no_df() -> None: + with solara.Columns([1, 1]): + solara.Markdown("## DataFrame (Load Data)") + solara.Markdown("## Embeddings (Load Data)") + + +@solara.component +def Page() -> None: + # TODO: Remove when solara updates + # PlotState.loading.use() + + # This `eq` makes it so every time we set the dataframe, solara thinks it's new + df, set_df = solara.use_state( + cast(pd.DataFrame, pd.DataFrame({})), eq=lambda *args: False + ) + solara.Title("Bulk Labeling!") + # TODO: Why cant i get this view to render? + assigned_label_view() + with solara.Sidebar(): + menu(df, set_df) + if has_df(df) and PlotState.loading.value: + no_embs(df) + elif has_df(df): + df_view(df) + else: + no_df() + + +if __name__ == "__main__": + Page() diff --git a/bulk_labeling/state.py b/bulk_labeling/state.py new file mode 100644 index 0000000..12a09e9 --- /dev/null +++ b/bulk_labeling/state.py @@ -0,0 +1,32 @@ +from collections import defaultdict +from typing import Dict, List, Optional, Set + +from solara.lab import Reactive + +DEFAULT_POINT_SIZE = 2 + + +class State: + available_labels = Reactive[Set[str]](set()) + labeled_ids = Reactive[Dict[str, List[int]]](defaultdict(list)) + filtered_ids = Reactive[List[int]]([]) + chosen_label = Reactive[Optional[str]](None) + assigned_new_label = Reactive[bool](False) + filter_text = Reactive[str]("") + reset_on_assignment = Reactive[bool](True) + + +class PlotState: + point_size = Reactive[int](DEFAULT_POINT_SIZE) + color = Reactive[str]("") + # While we calculate embeddings and UMAP, we can manage the loading state + loading = Reactive[bool](False) + + +def reset() -> None: + """Removes any filters applied to the data""" + State.filtered_ids.set([]) + State.filter_text.set("") + State.chosen_label.set(None) + PlotState.point_size.set(DEFAULT_POINT_SIZE) + PlotState.color.set("") diff --git a/bulk_labeling/utils/__init__.py b/bulk_labeling/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bulk_labeling/utils/common.py b/bulk_labeling/utils/common.py new file mode 100644 index 0000000..26df82b --- /dev/null +++ b/bulk_labeling/utils/common.py @@ -0,0 +1 @@ +BUTTON_KWARGS = dict(color="primary", text=True, outlined=True) diff --git a/bulk_labeling/utils/df.py b/bulk_labeling/utils/df.py new file mode 100644 index 0000000..d0991b3 --- /dev/null +++ b/bulk_labeling/utils/df.py @@ -0,0 +1,50 @@ +from io import BytesIO +from typing import Union + +import pandas as pd + +from bulk_labeling.state import State + +INTERNAL_COLS = ["x", "y", "hovertext", "id"] + + +def has_df(df: pd.DataFrame) -> bool: + return (len(df) != 0) and (len(df.columns) != 0) + + +def filtered_df(df: pd.DataFrame) -> pd.DataFrame: + dfc = df.copy() + if not has_df(dfc): + return dfc + if State.filtered_ids.value: + dfc = dfc[dfc["id"].isin(State.filtered_ids.value)] + if State.filter_text.value: + dfc = dfc[dfc["text"].str.contains(State.filter_text.value)] + return dfc + + +def set_default_cols(df: pd.DataFrame) -> pd.DataFrame: + df["text_length"] = df.text.str.len() + df["id"] = list(range(len(df))) + df["hovertext"] = df.text.str.wrap(30).str.replace("\n", "
") + return df + + +def load_df(data: Union[str, BytesIO]) -> pd.DataFrame: + new_df = pd.read_csv(data) + new_df = set_default_cols(new_df) + return new_df + + +def apply_df_edits(df: pd.DataFrame) -> pd.DataFrame: + print("Should be downloading!") + df2 = df.copy() + labeled_ids = State.labeled_ids.value + # Map every ID to it's assigned labels + # TODO: We can be smarter with conflicts and pick the label that an ID is + # assigned to most frequently + id_label = {id_: label for label, ids in labeled_ids.items() for id_ in ids} + df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, "-1")) + df2 = df2[df2["label"] != "-1"] + cols = [c for c in df2.columns if c not in INTERNAL_COLS] + return df2[cols] diff --git a/bulk_labeling/utils/menu.py b/bulk_labeling/utils/menu.py new file mode 100644 index 0000000..b0ea4d8 --- /dev/null +++ b/bulk_labeling/utils/menu.py @@ -0,0 +1,49 @@ +import os +from typing import Tuple + +import pandas as pd + +from bulk_labeling.state import State, reset +from bulk_labeling.utils.df import filtered_df, has_df + +DIR = f"{os.getcwd()}/bulk_labeling" +PATH = f"{DIR}/conv_intent.csv" + + +def add_new_label(new_label: str) -> None: + if not new_label: + return + all_labels = State.available_labels.value.copy() + all_labels.add(new_label) + State.available_labels.set(all_labels) + # So the "assign points" button is already pre-populated with your new label =] + State.chosen_label.set(new_label) + + +def assign_labels(df: pd.DataFrame) -> None: + labeled_ids = State.labeled_ids.value.copy() + new_ids = df["id"].tolist() + if State.chosen_label.value: + print(f"Setting {State.chosen_label.value} for {len(new_ids)} points") + labeled_ids[State.chosen_label.value].extend(new_ids) + State.labeled_ids.set(labeled_ids) + # State.assigned_new_label.set(True) + # Reset the view so no points are selected + if State.reset_on_assignment.value: + reset() + + +def get_assign_label_button_text(df: pd.DataFrame) -> Tuple[str, bool]: + """Gets the label for the "assign label" button, and whether it's enabled""" + button_enabled = bool(State.chosen_label.value) and has_df(df) + fdf = filtered_df(df) + num = len(fdf) if len(fdf) != len(df) else "all" + if button_enabled: + btn_label = f"Assign {num} points to label {State.chosen_label.value}" + elif not has_df(df): + btn_label = "Add data" + elif not State.available_labels.value: + btn_label = "Create a label" + else: + btn_label = "Choose a label" + return btn_label, button_enabled diff --git a/bulk_labeling/utils/ml.py b/bulk_labeling/utils/ml.py new file mode 100644 index 0000000..8029025 --- /dev/null +++ b/bulk_labeling/utils/ml.py @@ -0,0 +1,30 @@ +from typing import List + +import numpy as np +import pandas as pd +from sentence_transformers import SentenceTransformer +from umap import UMAP + +UMAP_MODEL = UMAP(n_neighbors=15, random_state=42, verbose=True) +ENCODER = SentenceTransformer("paraphrase-MiniLM-L3-v2") + + +def encode_inputs(samples: List[str]) -> np.ndarray: + return ENCODER.encode(samples) + # When doing rapid development, it's faster to return a numpy array + # return np.random.rand(len(samples), 20) + + +def get_xy(embs: np.ndarray) -> np.ndarray: + return UMAP_MODEL.fit_transform(embs) + + +def get_text_embeddings(samples: List[str]) -> np.ndarray: + return get_xy(encode_inputs(samples)) + + +def add_embeddings_to_df(df: pd.DataFrame) -> pd.DataFrame: + embs = get_text_embeddings(df["text"].tolist()) + df["x"] = embs[:, 0] + df["y"] = embs[:, 1] + return df diff --git a/bulk_labeling/utils/plotly.py b/bulk_labeling/utils/plotly.py new file mode 100644 index 0000000..77a6870 --- /dev/null +++ b/bulk_labeling/utils/plotly.py @@ -0,0 +1,40 @@ +from typing import Dict, List + +import pandas as pd +import plotly.express as px +from plotly.graph_objs._figure import Figure + + +def find_row_ids(fig: Figure, click_data: Dict) -> List[int]: + """A very annoying function to get row IDs because Plotly is unhelpful + + Solara is going to do this for us in the future! + """ + # goes from trace index and point index to row index in a dataframe + # requires passing df.index as to custom_data + trace_index = click_data["points"]["trace_indexes"] + point_index = click_data["points"]["point_indexes"] + point_ids = [] + for t, p in zip(trace_index, point_index): + point_trace = fig.data[t] + point_ids.append(point_trace.customdata[p][0]) + return point_ids + + +def create_plotly_figure(df: pd.DataFrame, color: str, point_size: int) -> Figure: + # We pass in df.id to custom_data so we can get back the correct points on a + # lasso selection. Plotly makes this difficult + # TODO: Solara will wrap and handle all of this logic for us in the future + fig = px.scatter( + df, + x="x", + y="y", + color=color or None, + custom_data=[df["id"]], + hover_data=["hovertext"], + ) + fig.update_layout(showlegend=False) + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + fig.update_traces(marker_size=point_size) + return fig diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..056621a --- /dev/null +++ b/mypy.ini @@ -0,0 +1,3 @@ +[mypy] +ignore_missing_imports = True +disallow_untyped_defs = True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0867627 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,43 @@ +[build-system] +requires = [ + "setuptools >=65.0", + "wheel >=0.37", +] +build-backend = "setuptools.build_meta" + +[project] +name = "bulk-labeling-solara" +authors = [{name = "Ben Epstein", email = "ben.epstein97@gmail.com"}] +license = {file = "LICENSE"} +dynamic = ["version", "readme"] +classifiers = ["License :: OSI Approved :: Apache License"] +dependencies = [ + "solara", + "pandas", + "numpy", + "plotly<=5.9.0", + "sentence-transformers", + "umap-learn", + "numerize", +] + +[project.optional-dependencies] +dev = [ + "jupyter", + "black", + "isort", + "flake8", +] + +[project.urls] +Home = "https://www.github.com/ben-epstein/bulk-labeling-solara" + +[tool.isort] +profile = "black" + + +[tool.setuptools.dynamic] +version = {attr = "bulk_labeling.__version__"} + +[tool.setuptools] +py-modules = [] diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100755 index 0000000..73b82a5 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,9 @@ +#!/bin/sh -ex + +# Sort imports one per line, so autoflake can remove unused imports +isort --force-single-line-imports bulk_labeling + +autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place bulk_labeling --exclude=__init__.py +# For some reason we need to run this again so that black can get it into the format we want +isort bulk_labeling +black bulk_labeling diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100755 index 0000000..1e21c81 --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,6 @@ +#!/bin/sh -ex + +mypy bulk_labeling +flake8 bulk_labeling +black bulk_labeling --check +isort bulk_labeling --check-only diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..4bf19b8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,6 @@ +[metadata] +long_description = file: README.md +long_description_content_type = text/markdown + +[options] +packages = find: diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..316ca78 --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +from setuptools import setup + +setup( + package_data={"": ["**conv_intent.csv"]}, + include_package_data=True +)