diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..12900c1
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+max-line-length = 88
+extend-ignore = E203,D10,D415,E501,E712
+docstring-convention = google
diff --git a/README.md b/README.md
index 7852268..7de2054 100644
--- a/README.md
+++ b/README.md
@@ -4,3 +4,24 @@ A tool for bulk labeling, built in Solara!
I'm trying to rebuild my original [bulk-labeling](https://github.com/rungalileo/bulk-labeling/) app, which was Streamlit, in [Solara](https://github.com/widgetti/solara) so it can be a bit more scalable, customizable, and robust to new features!
I also want to learn how to use solara :)
+
+
+## Development
+1. Setup a virtual env: `python -m venv .venv && source .venv/bin/activate`
+2. Install the package: `pip install -e . && pyenv rehash`
+3. Run: `solara run bulk_labeling/main.py`
+
+Any changes you make to the app should reflect in realtime
+
+### Note: `SentenceTransformers` doesn't play nicely with solara
+If you are going to be developing, I strongly recommend commenting out
+the few lines in [ml.py](bulk_labeling/utils/ml.py):
+https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L5
+https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L9
+https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L13
+
+and uncomment
+https://github.com/Ben-Epstein/bulk-labeling-solara/blob/8281f618c33e298a0bb5de373b0087a49d58e938/bulk_labeling/utils/ml.py#L15
+
+For some reason, on a page reload, solara breaks if these lines are running.
+It will also make prototyping faster because you won't be actually encoding strings.
diff --git a/bulk-labeling/main.py b/bulk-labeling/main.py
deleted file mode 100644
index 4332e14..0000000
--- a/bulk-labeling/main.py
+++ /dev/null
@@ -1,332 +0,0 @@
-import io
-import itertools
-import os
-from collections import defaultdict
-from time import sleep
-from typing import Callable, Dict, List, Optional, Set, cast
-
-import numpy as np
-import pandas as pd
-import plotly.express as px
-import solara
-from reacton import ipyvuetify as V
-from sentence_transformers import SentenceTransformer
-from solara.components.file_drop import FileInfo
-from solara.lab import Reactive
-from umap import UMAP
-
-DIR = f"{os.getcwd()}/bulk-labeling"
-PATH = f"{DIR}/conv_intent.csv"
-
-UMAP_MODEL = UMAP(n_neighbors=15, random_state=42, verbose=True)
-ENCODER = SentenceTransformer("paraphrase-MiniLM-L3-v2")
-
-INTERNAL_COLS = ["x", "y", "hovertext", "id"]
-NO_COLOR_COLS = INTERNAL_COLS + ["text"]
-BUTTON_KWARGS = dict(color="primary", text=True, outlined=True)
-
-
-class State:
- available_labels = Reactive[Set[str]](set())
- labeled_ids = Reactive[Dict[str, List[int]]](defaultdict(list))
- filtered_ids = Reactive[List[int]]([])
- chosen_label = Reactive[Optional[str]](None)
- assigned_new_label = Reactive[bool](False)
- filter_text = Reactive[str]("")
-
-
-class PlotState:
- point_size = Reactive[int](2)
- color = Reactive[str]("")
- # While we calculate embeddings and UMAP, we can manage the loading state
- loading = Reactive[bool](False)
-
-
-def has_df(df: pd.DataFrame) -> bool:
- return (len(df) != 0) and (len(df.columns) != 0)
-
-
-def filtered_df(df: pd.DataFrame) -> pd.DataFrame:
- dfc = df.copy()
- if not has_df(dfc):
- return dfc
- if State.filtered_ids.value:
- dfc = dfc[dfc["id"].isin(State.filtered_ids.value)]
- if State.filter_text.value:
- dfc = dfc[dfc["text"].str.contains(State.filter_text.value)]
- return dfc
-
-
-def add_new_label(new_label: str) -> None:
- if not new_label:
- return
- all_labels = State.available_labels.value.copy()
- all_labels.add(new_label)
- State.available_labels.set(all_labels)
- # So the "assign points" button is already pre-populated with your new label =]
- State.chosen_label.set(new_label)
-
-
-def _set_default_cols(df: pd.DataFrame) -> pd.DataFrame:
- df["text_length"] = df.text.str.len()
- df["id"] = list(range(len(df)))
- df["hovertext"] = df.text.str.wrap(30).str.replace("\n", "
")
- return df
-
-
-def apply_df_edits(df: pd.DataFrame) -> pd.DataFrame:
- print("Should be downloading!")
- df2 = df.copy()
- labeled_ids = State.labeled_ids.value
- # Map every ID to it's assigned labels
- # TODO: We can be smarter with conflicts and pick the label that an ID is
- # assigned to most frequently
- id_label = {id_: label for label, ids in labeled_ids.items() for id_ in ids}
- df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, "-1"))
- df2 = df2[df2["label"] != "-1"]
- cols = [c for c in df2.columns if c not in INTERNAL_COLS]
- return df2[cols]
-
-
-def encode_inputs(samples: List[str]) -> np.ndarray:
- return ENCODER.encode(samples)
- # When doing rapid development, it's faster to return a numpy array
- # return np.random.rand(len(samples), 20)
-
-
-def get_xy(embs: np.ndarray) -> np.ndarray:
- return UMAP_MODEL.fit_transform(embs)
-
-
-def get_text_embeddings(samples: List[str]) -> np.ndarray:
- return get_xy(encode_inputs(samples))
-
-
-def find_row_ids(fig, click_data) -> List[int]:
- """A very annoying function to get row IDs because Plotly is unhelpful
-
- Solara is going to do this for us in the future!
- """
- # goes from trace index and point index to row index in a dataframe
- # requires passing df.index as to custom_data
- trace_index = click_data["points"]["trace_indexes"]
- point_index = click_data["points"]["point_indexes"]
- point_ids = []
- for t, p in zip(trace_index, point_index):
- point_trace = fig.data[t]
- point_ids.append(point_trace.customdata[p][0])
- return point_ids
-
-
-def reset():
- """Removes any filters applied to the data"""
- State.filtered_ids.set([])
- State.filter_text.set("")
- State.chosen_label.set(None)
-
-
-@solara.component
-def assigned_label_view() -> None:
- State.assigned_new_label.use()
- if State.assigned_new_label.value:
- solara.Info(f"{len(State.filtered_ids.value)} labeled!")
- sleep(2)
- State.assigned_new_label.set(False)
-
-
-@solara.component
-def table(df: pd.DataFrame):
- solara.Markdown(f"## Data ({len(df):,} points)")
- solara.DataFrame(df[[i for i in df.columns if i not in INTERNAL_COLS]])
-
-
-@solara.component
-def embeddings(df: pd.DataFrame, color: str, point_size: int):
- solara.Markdown("## Embeddings")
- # We pass in df.id to custom_data so we can get back the correct points on a
- # lasso selection. Plotly makes this difficult
- # TODO: Solara will wrap and handle all of this logic for us in the future
- p = px.scatter(
- df,
- x="x",
- y="y",
- color=color or None,
- custom_data=[df["id"]],
- hover_data=["hovertext"],
- )
- p.update_layout(showlegend=False)
- p.update_xaxes(visible=False)
- p.update_yaxes(visible=False)
- p.update_traces(marker_size=point_size)
-
- # Plotly returns data in a weird way, we just want the ids
- # TODO: Solara to handle :)
- set_point_ids = lambda data: State.filtered_ids.set(find_row_ids(p, data))
- solara.FigurePlotly(p, on_selection=set_point_ids)
-
-
-@solara.component
-def df_view(df: pd.DataFrame) -> None:
- # TODO: Remove when solara updates
- PlotState.point_size.use()
- PlotState.color.use()
- State.filtered_ids.use()
- State.filter_text.use()
-
- fdf = filtered_df(df)
-
- with solara.Columns([1, 1]):
- table(fdf)
- embeddings(fdf, PlotState.color.value, PlotState.point_size.value)
-
-
-@solara.component
-def no_df() -> None:
- with solara.Columns([1, 1]):
- solara.Markdown("## DataFrame (Load Data)")
- solara.Markdown("## Embeddings (Load Data)")
-
-
-@solara.component
-def _emb_loading_state() -> None:
- solara.Markdown("## Embeddings")
- solara.Markdown("Loading your embeddings. Enjoy this fun animation for now")
- V.ProgressLinear(indeterminate=True)
-
-
-@solara.component
-def no_embs(df: pd.DataFrame) -> None:
- with solara.Columns([1, 1]):
- table(filtered_df(df))
- _emb_loading_state()
-
-
-@solara.component()
-def label_manager(df: pd.DataFrame) -> None:
- State.chosen_label.use()
- State.filtered_ids.use()
- State.filter_text.use()
- State.labeled_ids.use()
-
- def assign_labels() -> None:
- labeled_ids = State.labeled_ids.value.copy()
- new_ids = filtered_df(df)["id"].tolist()
- print(f"Setting {State.chosen_label.value} for " f"{len(new_ids)} points")
- labeled_ids[State.chosen_label.value].extend(new_ids)
- State.labeled_ids.set(labeled_ids)
- # State.assigned_new_label.set(True)
- # Reset the view so no points are selected
- reset()
-
- def export_edited_df() -> None:
- """Assigns the label and downloads the df to the user"""
- # TODO: Last thing! Allow the user to download the df
- exp_df = apply_df_edits(df)
- print(f"{len(exp_df)} rows edited")
-
- # TODO: Make a State.available_labels.append
- solara.InputText("Register new label", on_value=add_new_label)
- if State.available_labels.value:
- solara.Select("Available labels", list(State.available_labels.value)).connect(
- State.chosen_label
- )
- button_enabled = bool(State.chosen_label.value) and has_df(df)
- fdf = filtered_df(df)
- num = len(fdf) if len(fdf) != len(df) else "all"
- if button_enabled:
- btn_label = f"Assign {num} points to label {State.chosen_label.value}"
- elif not has_df(df):
- btn_label = "Add data"
- elif not State.available_labels.value:
- btn_label = "Create a label"
- else:
- btn_label = "Choose a label"
- solara.Button(
- btn_label, on_click=assign_labels, disabled=not button_enabled, **BUTTON_KWARGS
- )
- if State.labeled_ids.value:
- # Flatten all of the edits into a single set, so we know how many were edited
- num_edited = len(set(itertools.chain(*State.labeled_ids.value.values())))
- solara.Button(
- f"Export {num_edited} labeled points",
- on_click=export_edited_df,
- **BUTTON_KWARGS,
- )
-
-
-@solara.component()
-def menu(df: pd.DataFrame, set_df: Callable) -> None:
- # TODO: Remove when solara updates
- PlotState.point_size.use()
- PlotState.color.use()
- State.filter_text.use()
-
- # avl_cols is dependent on df, so any time it changes,
- # this will automatically update
- set_cols = lambda: [i for i in df.columns if i not in NO_COLOR_COLS]
- avl_cols = solara.use_memo(set_cols, [df])
-
- def load_demo_df():
- new_df = pd.read_csv(PATH)
- new_df = _set_default_cols(new_df)
- set_df(new_df)
-
- def load_file_df(file: FileInfo):
- data = io.BytesIO(file["data"])
- new_df = pd.read_csv(data)
- new_df = _set_default_cols(new_df)
- # Set it before embeddings so the user can see the df while embeddings load
- PlotState.loading.set(True)
- set_df(new_df)
- embs = get_text_embeddings(new_df["text"].tolist())
- new_df["x"] = embs[:, 0]
- new_df["y"] = embs[:, 1]
- # Set it again after embeddings so we can render the plotly graph
- set_df(new_df)
- PlotState.loading.set(False)
-
- solara.FileDrop(
- label="Drop CSV here (`text` col required)!",
- on_file=load_file_df,
- lazy=False,
- )
- with solara.Column():
- solara.Button(label="Load demo dataset", on_click=load_demo_df, **BUTTON_KWARGS)
- solara.Button(label="Reset view", on_click=reset, **BUTTON_KWARGS)
- label_manager(df)
- solara.InputText(
- "Filter by search", State.filter_text.value, on_value=State.filter_text.set
- )
- solara.Markdown("**Set point size**")
- solara.SliderInt("", PlotState.point_size.value, on_value=PlotState.point_size.set)
- # TODO: A drop down should have "remove selection" option
- # (esp if default state is None)
- solara.Select(
- "Color by",
- [None] + avl_cols,
- PlotState.color.value,
- on_value=PlotState.color.set,
- )
-
-
-@solara.component
-def Page():
- # TODO: Remove when solara updates
- PlotState.loading.use()
- State.filter_text.use()
- State.filtered_ids.use()
- # This `eq` makes it so every time we set the dataframe, solara thinks it's new
- df, set_df = solara.use_state(
- cast(pd.DataFrame, pd.DataFrame({})), eq=lambda *args: False
- )
- solara.Title("Bulk Labeling!")
- # TODO: Why cant i get this view to render?
- assigned_label_view()
- with solara.Sidebar():
- menu(df, set_df)
- if has_df(df) and PlotState.loading.value:
- no_embs(df)
- elif has_df(df):
- df_view(df)
- else:
- no_df()
diff --git a/bulk-labeling/requirements-dev.txt b/bulk-labeling/requirements-dev.txt
deleted file mode 100644
index aa7d98b..0000000
--- a/bulk-labeling/requirements-dev.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-jupyter
-black
-isort
diff --git a/bulk-labeling/requirements.txt b/bulk-labeling/requirements.txt
deleted file mode 100644
index 1c50096..0000000
--- a/bulk-labeling/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-solara
-pandas
-numpy
-plotly<=5.9.0
-sentence-transformers
-umap-learn
-numerize
diff --git a/bulk_labeling/__init__.py b/bulk_labeling/__init__.py
new file mode 100644
index 0000000..f102a9c
--- /dev/null
+++ b/bulk_labeling/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.0.1"
diff --git a/bulk_labeling/components/__init__.py b/bulk_labeling/components/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bulk_labeling/components/df.py b/bulk_labeling/components/df.py
new file mode 100644
index 0000000..05b4d00
--- /dev/null
+++ b/bulk_labeling/components/df.py
@@ -0,0 +1,58 @@
+import pandas as pd
+import solara
+from reacton import ipyvuetify as V
+
+from bulk_labeling.state import PlotState, State
+from bulk_labeling.utils.df import INTERNAL_COLS, filtered_df
+from bulk_labeling.utils.plotly import create_plotly_figure, find_row_ids
+
+
+@solara.component
+def _emb_loading_state() -> None:
+ solara.Markdown("## Embeddings")
+ solara.Markdown("Loading your embeddings. Enjoy this fun animation for now")
+ V.ProgressLinear(indeterminate=True)
+
+
+@solara.component
+def no_embs(df: pd.DataFrame) -> None:
+ with solara.Columns([1, 1]):
+ table(filtered_df(df))
+ _emb_loading_state()
+
+
+@solara.component
+def embeddings(df: pd.DataFrame, color: str, point_size: int) -> None:
+ State.filtered_ids.use()
+
+ solara.Markdown("## Embeddings")
+ fig = create_plotly_figure(df, color, point_size)
+
+ # Plotly returns data in a weird way, we just want the ids
+ # TODO: Solara to handle :)
+ set_point_ids = lambda data: State.filtered_ids.set(find_row_ids(fig, data))
+ solara.FigurePlotly(fig, on_selection=set_point_ids)
+
+
+@solara.component
+def table(df: pd.DataFrame) -> None:
+ solara.Markdown(f"## Data ({len(df):,} points)")
+ solara.DataFrame(df[[i for i in df.columns if i not in INTERNAL_COLS]])
+
+
+@solara.component
+def df_view(df: pd.DataFrame) -> None:
+ # TODO: Remove when solara updates
+ PlotState.point_size.use()
+ PlotState.color.use()
+ # We need these because of the `filtered_df` call.
+ # next itertation we will move the df into global state (solara blocking)
+ # and then filtered_df becomes always up to date
+ State.filtered_ids.use()
+ State.filter_text.use()
+
+ fdf = filtered_df(df)
+
+ with solara.Columns([1, 1]):
+ table(fdf)
+ embeddings(fdf, PlotState.color.value, PlotState.point_size.value)
diff --git a/bulk_labeling/components/menu.py b/bulk_labeling/components/menu.py
new file mode 100644
index 0000000..a415707
--- /dev/null
+++ b/bulk_labeling/components/menu.py
@@ -0,0 +1,167 @@
+import io
+import itertools
+from functools import partial
+from time import sleep
+from typing import Callable, List
+
+import pandas as pd
+import solara
+from solara.components.file_drop import FileInfo
+
+from bulk_labeling.state import PlotState, State, reset
+from bulk_labeling.utils.common import BUTTON_KWARGS
+from bulk_labeling.utils.df import INTERNAL_COLS, apply_df_edits, filtered_df, load_df
+from bulk_labeling.utils.menu import (
+ PATH,
+ add_new_label,
+ assign_labels,
+ get_assign_label_button_text,
+)
+from bulk_labeling.utils.ml import add_embeddings_to_df
+
+NO_COLOR_COLS = INTERNAL_COLS + ["text"]
+
+
+@solara.component
+def assigned_label_view() -> None:
+ State.assigned_new_label.use()
+ State.filtered_ids.use()
+
+ if State.assigned_new_label.value:
+ solara.Info(f"{len(State.filtered_ids.value)} labeled!")
+ sleep(2)
+ State.assigned_new_label.set(False)
+
+
+@solara.component
+def assign_label_button(df: pd.DataFrame) -> None:
+ fdf = filtered_df(df)
+ btn_label, button_enabled = get_assign_label_button_text(df)
+ solara.Button(
+ btn_label,
+ on_click=partial(assign_labels, fdf),
+ disabled=not button_enabled,
+ **BUTTON_KWARGS,
+ )
+
+
+@solara.component
+def register_new_label_button() -> None:
+ # TODO: Remove when solara updates
+ State.available_labels.use()
+ State.chosen_label.use()
+
+ # TODO: Make a State.available_labels.append
+ solara.InputText("Register new label", on_value=add_new_label)
+ if State.available_labels.value:
+ solara.Select("Available labels", list(State.available_labels.value)).connect(
+ State.chosen_label
+ )
+
+
+@solara.component
+def export_edits_button(df: pd.DataFrame) -> None:
+ # TODO: Remove when solara updates
+ State.labeled_ids.use()
+
+ def export_edited_df() -> None:
+ """Assigns the label and downloads the df to the user"""
+ # TODO: Last thing! Allow the user to download the df
+ exp_df = apply_df_edits(df)
+ print(f"{len(exp_df)} rows edited")
+
+ if State.labeled_ids.value:
+ # Flatten all of the edits into a single set, so we know how many were edited
+ num_edited = len(set(itertools.chain(*State.labeled_ids.value.values())))
+ solara.Button(
+ f"Export {num_edited} labeled points",
+ on_click=export_edited_df,
+ **BUTTON_KWARGS,
+ )
+
+
+@solara.component
+def label_manager(df: pd.DataFrame) -> None:
+ register_new_label_button()
+ assign_label_button(df)
+ export_edits_button(df)
+
+
+@solara.component
+def file_manager(set_df: Callable) -> None:
+ PlotState.color.use()
+ PlotState.loading.use()
+
+ def load_demo_df() -> None:
+ new_df = load_df(PATH)
+ set_df(new_df)
+ set_df(new_df)
+
+ def load_file_df(file: FileInfo) -> None:
+ if not file["data"]:
+ return
+ new_df = load_df(io.BytesIO(file["data"]))
+ # Set it before embeddings so the user can see the df while embeddings load
+ PlotState.loading.set(True)
+ set_df(new_df)
+ new_df = add_embeddings_to_df(new_df)
+ # Set it again after embeddings so we can render the plotly graph
+ set_df(new_df)
+ PlotState.loading.set(False)
+
+ solara.FileDrop(
+ label="Drop CSV here (`text` col required)!",
+ on_file=load_file_df,
+ lazy=False,
+ )
+ with solara.Column():
+ solara.Button(label="Load demo dataset", on_click=load_demo_df, **BUTTON_KWARGS)
+ solara.Button(label="Reset view", on_click=reset, **BUTTON_KWARGS)
+
+
+@solara.component
+def view_controller(avl_cols: List[str]) -> None:
+ # TODO: Remove when solara updates
+ PlotState.color.use()
+ PlotState.point_size.use()
+ State.filter_text.use()
+
+ solara.InputText(
+ "Filter by search", State.filter_text.value, on_value=State.filter_text.set
+ )
+ solara.Markdown("**Set point size**")
+ solara.SliderInt("", PlotState.point_size.value, on_value=PlotState.point_size.set)
+ # TODO: A drop down should have "remove selection" option
+ # (esp if default state is None)
+ solara.Select(
+ "Color by",
+ [None] + avl_cols,
+ PlotState.color.value,
+ on_value=PlotState.color.set,
+ )
+
+
+@solara.component
+def menu(df: pd.DataFrame, set_df: Callable) -> None:
+ State.reset_on_assignment.use()
+
+ # avl_cols is dependent on df, so any time it changes,
+ # this will automatically update
+ set_cols = lambda: [i for i in df.columns if i not in NO_COLOR_COLS]
+ avl_cols = solara.use_memo(set_cols, [df])
+
+ file_manager(set_df)
+ label_manager(df)
+ view_controller(avl_cols)
+ solara.Markdown(f"**Reset view on label assignment?**")
+ if State.reset_on_assignment.value:
+ label = "Reset"
+ else:
+ label = "Keep state"
+ # solara.Checkbox(
+ # label=label,
+ # value=State.reset_on_assignment.value,
+ # on_value=State.reset_on_assignment.set
+ # )
+ solara.Checkbox(label=label).connect(State.reset_on_assignment)
+ # solara.ToggleButtonsSingle(State.reset_on_assignment.value, values=[True, False], on_value=State.reset_on_assignment.set)
diff --git a/bulk-labeling/conv_intent.csv b/bulk_labeling/conv_intent.csv
similarity index 100%
rename from bulk-labeling/conv_intent.csv
rename to bulk_labeling/conv_intent.csv
diff --git a/bulk_labeling/main.py b/bulk_labeling/main.py
new file mode 100644
index 0000000..6c06c36
--- /dev/null
+++ b/bulk_labeling/main.py
@@ -0,0 +1,42 @@
+from typing import cast
+
+import pandas as pd
+import solara
+
+from bulk_labeling.components.df import df_view, no_embs
+from bulk_labeling.components.menu import assigned_label_view, menu
+from bulk_labeling.state import PlotState
+from bulk_labeling.utils.df import has_df
+
+
+@solara.component
+def no_df() -> None:
+ with solara.Columns([1, 1]):
+ solara.Markdown("## DataFrame (Load Data)")
+ solara.Markdown("## Embeddings (Load Data)")
+
+
+@solara.component
+def Page() -> None:
+ # TODO: Remove when solara updates
+ # PlotState.loading.use()
+
+ # This `eq` makes it so every time we set the dataframe, solara thinks it's new
+ df, set_df = solara.use_state(
+ cast(pd.DataFrame, pd.DataFrame({})), eq=lambda *args: False
+ )
+ solara.Title("Bulk Labeling!")
+ # TODO: Why cant i get this view to render?
+ assigned_label_view()
+ with solara.Sidebar():
+ menu(df, set_df)
+ if has_df(df) and PlotState.loading.value:
+ no_embs(df)
+ elif has_df(df):
+ df_view(df)
+ else:
+ no_df()
+
+
+if __name__ == "__main__":
+ Page()
diff --git a/bulk_labeling/state.py b/bulk_labeling/state.py
new file mode 100644
index 0000000..12a09e9
--- /dev/null
+++ b/bulk_labeling/state.py
@@ -0,0 +1,32 @@
+from collections import defaultdict
+from typing import Dict, List, Optional, Set
+
+from solara.lab import Reactive
+
+DEFAULT_POINT_SIZE = 2
+
+
+class State:
+ available_labels = Reactive[Set[str]](set())
+ labeled_ids = Reactive[Dict[str, List[int]]](defaultdict(list))
+ filtered_ids = Reactive[List[int]]([])
+ chosen_label = Reactive[Optional[str]](None)
+ assigned_new_label = Reactive[bool](False)
+ filter_text = Reactive[str]("")
+ reset_on_assignment = Reactive[bool](True)
+
+
+class PlotState:
+ point_size = Reactive[int](DEFAULT_POINT_SIZE)
+ color = Reactive[str]("")
+ # While we calculate embeddings and UMAP, we can manage the loading state
+ loading = Reactive[bool](False)
+
+
+def reset() -> None:
+ """Removes any filters applied to the data"""
+ State.filtered_ids.set([])
+ State.filter_text.set("")
+ State.chosen_label.set(None)
+ PlotState.point_size.set(DEFAULT_POINT_SIZE)
+ PlotState.color.set("")
diff --git a/bulk_labeling/utils/__init__.py b/bulk_labeling/utils/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/bulk_labeling/utils/common.py b/bulk_labeling/utils/common.py
new file mode 100644
index 0000000..26df82b
--- /dev/null
+++ b/bulk_labeling/utils/common.py
@@ -0,0 +1 @@
+BUTTON_KWARGS = dict(color="primary", text=True, outlined=True)
diff --git a/bulk_labeling/utils/df.py b/bulk_labeling/utils/df.py
new file mode 100644
index 0000000..d0991b3
--- /dev/null
+++ b/bulk_labeling/utils/df.py
@@ -0,0 +1,50 @@
+from io import BytesIO
+from typing import Union
+
+import pandas as pd
+
+from bulk_labeling.state import State
+
+INTERNAL_COLS = ["x", "y", "hovertext", "id"]
+
+
+def has_df(df: pd.DataFrame) -> bool:
+ return (len(df) != 0) and (len(df.columns) != 0)
+
+
+def filtered_df(df: pd.DataFrame) -> pd.DataFrame:
+ dfc = df.copy()
+ if not has_df(dfc):
+ return dfc
+ if State.filtered_ids.value:
+ dfc = dfc[dfc["id"].isin(State.filtered_ids.value)]
+ if State.filter_text.value:
+ dfc = dfc[dfc["text"].str.contains(State.filter_text.value)]
+ return dfc
+
+
+def set_default_cols(df: pd.DataFrame) -> pd.DataFrame:
+ df["text_length"] = df.text.str.len()
+ df["id"] = list(range(len(df)))
+ df["hovertext"] = df.text.str.wrap(30).str.replace("\n", "
")
+ return df
+
+
+def load_df(data: Union[str, BytesIO]) -> pd.DataFrame:
+ new_df = pd.read_csv(data)
+ new_df = set_default_cols(new_df)
+ return new_df
+
+
+def apply_df_edits(df: pd.DataFrame) -> pd.DataFrame:
+ print("Should be downloading!")
+ df2 = df.copy()
+ labeled_ids = State.labeled_ids.value
+ # Map every ID to it's assigned labels
+ # TODO: We can be smarter with conflicts and pick the label that an ID is
+ # assigned to most frequently
+ id_label = {id_: label for label, ids in labeled_ids.items() for id_ in ids}
+ df2["label"] = df2["id"].apply(lambda id_: id_label.get(id_, "-1"))
+ df2 = df2[df2["label"] != "-1"]
+ cols = [c for c in df2.columns if c not in INTERNAL_COLS]
+ return df2[cols]
diff --git a/bulk_labeling/utils/menu.py b/bulk_labeling/utils/menu.py
new file mode 100644
index 0000000..b0ea4d8
--- /dev/null
+++ b/bulk_labeling/utils/menu.py
@@ -0,0 +1,49 @@
+import os
+from typing import Tuple
+
+import pandas as pd
+
+from bulk_labeling.state import State, reset
+from bulk_labeling.utils.df import filtered_df, has_df
+
+DIR = f"{os.getcwd()}/bulk_labeling"
+PATH = f"{DIR}/conv_intent.csv"
+
+
+def add_new_label(new_label: str) -> None:
+ if not new_label:
+ return
+ all_labels = State.available_labels.value.copy()
+ all_labels.add(new_label)
+ State.available_labels.set(all_labels)
+ # So the "assign points" button is already pre-populated with your new label =]
+ State.chosen_label.set(new_label)
+
+
+def assign_labels(df: pd.DataFrame) -> None:
+ labeled_ids = State.labeled_ids.value.copy()
+ new_ids = df["id"].tolist()
+ if State.chosen_label.value:
+ print(f"Setting {State.chosen_label.value} for {len(new_ids)} points")
+ labeled_ids[State.chosen_label.value].extend(new_ids)
+ State.labeled_ids.set(labeled_ids)
+ # State.assigned_new_label.set(True)
+ # Reset the view so no points are selected
+ if State.reset_on_assignment.value:
+ reset()
+
+
+def get_assign_label_button_text(df: pd.DataFrame) -> Tuple[str, bool]:
+ """Gets the label for the "assign label" button, and whether it's enabled"""
+ button_enabled = bool(State.chosen_label.value) and has_df(df)
+ fdf = filtered_df(df)
+ num = len(fdf) if len(fdf) != len(df) else "all"
+ if button_enabled:
+ btn_label = f"Assign {num} points to label {State.chosen_label.value}"
+ elif not has_df(df):
+ btn_label = "Add data"
+ elif not State.available_labels.value:
+ btn_label = "Create a label"
+ else:
+ btn_label = "Choose a label"
+ return btn_label, button_enabled
diff --git a/bulk_labeling/utils/ml.py b/bulk_labeling/utils/ml.py
new file mode 100644
index 0000000..8029025
--- /dev/null
+++ b/bulk_labeling/utils/ml.py
@@ -0,0 +1,30 @@
+from typing import List
+
+import numpy as np
+import pandas as pd
+from sentence_transformers import SentenceTransformer
+from umap import UMAP
+
+UMAP_MODEL = UMAP(n_neighbors=15, random_state=42, verbose=True)
+ENCODER = SentenceTransformer("paraphrase-MiniLM-L3-v2")
+
+
+def encode_inputs(samples: List[str]) -> np.ndarray:
+ return ENCODER.encode(samples)
+ # When doing rapid development, it's faster to return a numpy array
+ # return np.random.rand(len(samples), 20)
+
+
+def get_xy(embs: np.ndarray) -> np.ndarray:
+ return UMAP_MODEL.fit_transform(embs)
+
+
+def get_text_embeddings(samples: List[str]) -> np.ndarray:
+ return get_xy(encode_inputs(samples))
+
+
+def add_embeddings_to_df(df: pd.DataFrame) -> pd.DataFrame:
+ embs = get_text_embeddings(df["text"].tolist())
+ df["x"] = embs[:, 0]
+ df["y"] = embs[:, 1]
+ return df
diff --git a/bulk_labeling/utils/plotly.py b/bulk_labeling/utils/plotly.py
new file mode 100644
index 0000000..77a6870
--- /dev/null
+++ b/bulk_labeling/utils/plotly.py
@@ -0,0 +1,40 @@
+from typing import Dict, List
+
+import pandas as pd
+import plotly.express as px
+from plotly.graph_objs._figure import Figure
+
+
+def find_row_ids(fig: Figure, click_data: Dict) -> List[int]:
+ """A very annoying function to get row IDs because Plotly is unhelpful
+
+ Solara is going to do this for us in the future!
+ """
+ # goes from trace index and point index to row index in a dataframe
+ # requires passing df.index as to custom_data
+ trace_index = click_data["points"]["trace_indexes"]
+ point_index = click_data["points"]["point_indexes"]
+ point_ids = []
+ for t, p in zip(trace_index, point_index):
+ point_trace = fig.data[t]
+ point_ids.append(point_trace.customdata[p][0])
+ return point_ids
+
+
+def create_plotly_figure(df: pd.DataFrame, color: str, point_size: int) -> Figure:
+ # We pass in df.id to custom_data so we can get back the correct points on a
+ # lasso selection. Plotly makes this difficult
+ # TODO: Solara will wrap and handle all of this logic for us in the future
+ fig = px.scatter(
+ df,
+ x="x",
+ y="y",
+ color=color or None,
+ custom_data=[df["id"]],
+ hover_data=["hovertext"],
+ )
+ fig.update_layout(showlegend=False)
+ fig.update_xaxes(visible=False)
+ fig.update_yaxes(visible=False)
+ fig.update_traces(marker_size=point_size)
+ return fig
diff --git a/mypy.ini b/mypy.ini
new file mode 100644
index 0000000..056621a
--- /dev/null
+++ b/mypy.ini
@@ -0,0 +1,3 @@
+[mypy]
+ignore_missing_imports = True
+disallow_untyped_defs = True
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..0867627
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,43 @@
+[build-system]
+requires = [
+ "setuptools >=65.0",
+ "wheel >=0.37",
+]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "bulk-labeling-solara"
+authors = [{name = "Ben Epstein", email = "ben.epstein97@gmail.com"}]
+license = {file = "LICENSE"}
+dynamic = ["version", "readme"]
+classifiers = ["License :: OSI Approved :: Apache License"]
+dependencies = [
+ "solara",
+ "pandas",
+ "numpy",
+ "plotly<=5.9.0",
+ "sentence-transformers",
+ "umap-learn",
+ "numerize",
+]
+
+[project.optional-dependencies]
+dev = [
+ "jupyter",
+ "black",
+ "isort",
+ "flake8",
+]
+
+[project.urls]
+Home = "https://www.github.com/ben-epstein/bulk-labeling-solara"
+
+[tool.isort]
+profile = "black"
+
+
+[tool.setuptools.dynamic]
+version = {attr = "bulk_labeling.__version__"}
+
+[tool.setuptools]
+py-modules = []
diff --git a/scripts/format.sh b/scripts/format.sh
new file mode 100755
index 0000000..73b82a5
--- /dev/null
+++ b/scripts/format.sh
@@ -0,0 +1,9 @@
+#!/bin/sh -ex
+
+# Sort imports one per line, so autoflake can remove unused imports
+isort --force-single-line-imports bulk_labeling
+
+autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place bulk_labeling --exclude=__init__.py
+# For some reason we need to run this again so that black can get it into the format we want
+isort bulk_labeling
+black bulk_labeling
diff --git a/scripts/lint.sh b/scripts/lint.sh
new file mode 100755
index 0000000..1e21c81
--- /dev/null
+++ b/scripts/lint.sh
@@ -0,0 +1,6 @@
+#!/bin/sh -ex
+
+mypy bulk_labeling
+flake8 bulk_labeling
+black bulk_labeling --check
+isort bulk_labeling --check-only
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..4bf19b8
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,6 @@
+[metadata]
+long_description = file: README.md
+long_description_content_type = text/markdown
+
+[options]
+packages = find:
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..316ca78
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,6 @@
+from setuptools import setup
+
+setup(
+ package_data={"": ["**conv_intent.csv"]},
+ include_package_data=True
+)