Skip to content

Commit

Permalink
Merge pull request #390 from MannLabs/volcanoplot
Browse files Browse the repository at this point in the history
New backend for results display and interface for volcano plot
  • Loading branch information
JuliaS92 authored Jan 20, 2025
2 parents 567aeb9 + 9ae0325 commit 3c5ac58
Show file tree
Hide file tree
Showing 14 changed files with 1,331 additions and 164 deletions.
18 changes: 13 additions & 5 deletions alphastats/gui/pages/05_Analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import streamlit as st

from alphastats.gui.utils.analysis import PlottingOptions, StatisticOptions
from alphastats.gui.utils.analysis import (
NewAnalysisOptions,
PlottingOptions,
StatisticOptions,
)
from alphastats.gui.utils.analysis_helper import (
display_analysis_result_with_buttons,
gather_parameters_and_do_analysis,
Expand Down Expand Up @@ -47,16 +51,18 @@
with c1:
plotting_options = PlottingOptions.get_values()
statistic_options = StatisticOptions.get_values()
new_options = NewAnalysisOptions.get_values()
analysis_method = st.selectbox(
"Analysis",
options=["<select>"]
+ new_options
+ ["------- plots ------------"]
+ plotting_options
+ ["------- statistics -------"]
+ statistic_options,
)

if analysis_method in plotting_options:
if analysis_method in plotting_options or analysis_method in new_options:
analysis_result, analysis_object, parameters = (
gather_parameters_and_do_analysis(analysis_method)
)
Expand Down Expand Up @@ -85,15 +91,17 @@ def show_start_llm_button(analysis_method: str) -> None:

submitted = st.button(
f"Analyse with LLM ... {msg}",
disabled=(analysis_method != PlottingOptions.VOLCANO_PLOT),
help="Interpret the current analysis with an LLM (available for 'Volcano Plot' only).",
disabled=(
analysis_method != NewAnalysisOptions.DIFFERENTIAL_EXPRESSION_TWO_GROUPS
),
help="Interpret the current analysis with an LLM (available for 'Differential Analysis Two Groups' only).",
)
if submitted:
if StateKeys.LLM_INTEGRATION in st.session_state:
del st.session_state[StateKeys.LLM_INTEGRATION]
st.session_state[StateKeys.SELECTED_GENES_UP] = None
st.session_state[StateKeys.SELECTED_GENES_DOWN] = None
st.session_state[StateKeys.LLM_INPUT] = (analysis_object, parameters)
st.session_state[StateKeys.LLM_INPUT] = (analysis_result, parameters)

st.toast("LLM analysis created!", icon="✅")
st.page_link("pages/06_LLM.py", label="=> Go to LLM page..")
Expand Down
11 changes: 8 additions & 3 deletions alphastats/gui/pages/06_LLM.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
from typing import Dict

import pandas as pd
import streamlit as st
from openai import AuthenticationError

from alphastats.dataset.keys import Cols
from alphastats.dataset.plotting import plotly_object
from alphastats.gui.utils.analysis import ResultComponent
from alphastats.gui.utils.analysis_helper import (
display_figure,
gather_uniprot_data,
Expand Down Expand Up @@ -98,7 +100,8 @@ def llm_config():
st.info("Create a Volcano plot first using the 'Analysis' page.")
st.stop()

volcano_plot, plot_parameters = st.session_state[StateKeys.LLM_INPUT]
volcano_plot: ResultComponent = st.session_state[StateKeys.LLM_INPUT][0]
plot_parameters: Dict = st.session_state[StateKeys.LLM_INPUT][1]

st.markdown(f"Parameters used for analysis: `{plot_parameters}`")

Expand All @@ -108,9 +111,11 @@ def llm_config():
st.markdown("##### Volcano plot")
display_figure(volcano_plot.plot)

regulated_genes_df = volcano_plot.res[volcano_plot.res["label"] != ""]
regulated_genes_df = volcano_plot.annotated_dataframe[
volcano_plot.annotated_dataframe["significant"] != "non_sig"
]
regulated_genes_dict = dict(
zip(regulated_genes_df[Cols.INDEX], regulated_genes_df["color"].tolist())
zip(regulated_genes_df[Cols.INDEX], regulated_genes_df["significant"].tolist())
)

if not regulated_genes_dict:
Expand Down
1 change: 1 addition & 0 deletions alphastats/gui/pages/07_Results.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@
parameters=parameters,
show_save_button=False,
name=name,
editable_annotation=False,
)
158 changes: 131 additions & 27 deletions alphastats/gui/utils/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,18 @@

from alphastats.dataset.dataset import DataSet
from alphastats.dataset.keys import Cols, ConstantsClass
from alphastats.dataset.preprocessing import PreprocessingStateKeys
from alphastats.gui.utils.result import (
DifferentialExpressionTwoGroupsResult,
ResultComponent,
)
from alphastats.gui.utils.ui_helper import AnalysisParameters
from alphastats.plots.plot_utils import PlotlyObject
from alphastats.plots.volcano_plot import VolcanoPlot
from alphastats.tl.differential_expression_analysis import (
DeaTestTypes,
DifferentialExpressionAnalysisTTest,
)


class PlottingOptions(metaclass=ConstantsClass):
Expand All @@ -35,8 +45,13 @@ class StatisticOptions(metaclass=ConstantsClass):
ANCOVA = "ANCOVA"


# TODO rename to AnalysisComponent
class AbstractAnalysis(ABC):
class NewAnalysisOptions(metaclass=ConstantsClass):
"""Keys for the new analysis options, the order determines order in UI."""

DIFFERENTIAL_EXPRESSION_TWO_GROUPS = "Differential Expression Analysis (Two Groups)"


class AnalysisComponent(ABC):
"""Abstract class for analysis widgets."""

_works_with_nans = True
Expand Down Expand Up @@ -73,7 +88,9 @@ def do_analysis(
@abstractmethod
def _do_analysis(
self,
) -> Tuple[Union[PlotlyObject, pd.DataFrame], Optional[VolcanoPlot]]:
) -> Tuple[
Union[PlotlyObject, pd.DataFrame, ResultComponent], Optional[VolcanoPlot]
]:
pass

def _nan_check(self) -> None: # noqa: B027
Expand All @@ -87,7 +104,7 @@ def _pre_analysis_check(self) -> None: # noqa: B027
pass


class AbstractGroupCompareAnalysis(AbstractAnalysis, ABC):
class AbstractGroupCompareAnalysis(AnalysisComponent, ABC):
"""Abstract class for group comparison analysis widgets."""

def show_widget(self):
Expand All @@ -96,13 +113,20 @@ def show_widget(self):
metadata = self._dataset.metadata

default_option = "<select>"
metadata_groups = metadata.columns.to_list()
custom_group_option = "Custom groups from samples .."

options = [default_option] + metadata_groups + [custom_group_option]
grouping_variable = st.selectbox(
"Grouping variable",
options=[default_option]
+ metadata.columns.to_list()
+ [custom_group_option],
options=options,
index=options.index(
st.session_state.get(
AnalysisParameters.TWOGROUP_COLUMN,
default_option if len(metadata_groups) == 0 else metadata_groups[0],
)
),
key=AnalysisParameters.TWOGROUP_COLUMN,
)

column = None
Expand All @@ -114,18 +138,28 @@ def show_widget(self):
unique_values = metadata[grouping_variable].unique().tolist()

column = grouping_variable
group1 = st.selectbox("Group 1", options=unique_values)
group2 = st.selectbox("Group 2", options=list(reversed(unique_values)))
group1 = st.selectbox(
"Group 1",
options=unique_values,
key=AnalysisParameters.TWOGROUP_GROUP1,
)
group2 = st.selectbox(
"Group 2",
options=list(reversed(unique_values)),
key=AnalysisParameters.TWOGROUP_GROUP2,
)

else:
group1 = st.multiselect(
"Group 1 samples:",
options=metadata[Cols.SAMPLE].to_list(),
key=AnalysisParameters.TWOGROUP_GROUP1 + "multi",
)

group2 = st.multiselect(
"Group 2 samples:",
options=list(reversed(metadata[Cols.SAMPLE].to_list())),
key=AnalysisParameters.TWOGROUP_GROUP2 + "multi",
)

intersection_list = list(set(group1).intersection(set(group2)))
Expand All @@ -135,21 +169,29 @@ def show_widget(self):
+ str(intersection_list)
)

self._parameters.update({"group1": group1, "group2": group2})
self._parameters.update(
{
AnalysisParameters.TWOGROUP_GROUP1: group1,
AnalysisParameters.TWOGROUP_GROUP2: group2,
}
)
if column is not None:
self._parameters["column"] = column
self._parameters[AnalysisParameters.TWOGROUP_COLUMN] = column

def _pre_analysis_check(self):
"""Raise if selected groups are different."""
if self._parameters["group1"] == self._parameters["group2"]:
if (
self._parameters[AnalysisParameters.TWOGROUP_GROUP1]
== self._parameters[AnalysisParameters.TWOGROUP_GROUP2]
):
raise (
ValueError(
"Group 1 and Group 2 can not be the same. Please select different groups."
)
)


class AbstractDimensionReductionAnalysis(AbstractAnalysis, ABC):
class AbstractDimensionReductionAnalysis(AnalysisComponent, ABC):
"""Abstract class for dimension reduction analysis widgets."""

def show_widget(self):
Expand All @@ -165,7 +207,7 @@ def show_widget(self):
self._parameters.update({"circle": circle, "group": group})


class AbstractIntensityPlot(AbstractAnalysis, ABC):
class AbstractIntensityPlot(AnalysisComponent, ABC):
"""Abstract class for intensity plot analysis widgets."""

def show_widget(self):
Expand Down Expand Up @@ -339,10 +381,10 @@ def _do_analysis(self):
metadata=self._dataset.metadata,
preprocessing_info=self._dataset.preprocessing_info,
feature_to_repr_map=self._dataset._feature_to_repr_map,
group1=self._parameters["group1"],
group2=self._parameters["group2"],
column=self._parameters["column"],
method=self._parameters["method"],
group1=self._parameters[AnalysisParameters.TWOGROUP_GROUP1],
group2=self._parameters[AnalysisParameters.TWOGROUP_GROUP2],
column=self._parameters[AnalysisParameters.TWOGROUP_COLUMN],
method=self._parameters[AnalysisParameters.DEA_TWOGROUPS_METHOD],
labels=self._parameters["labels"],
min_fc=self._parameters["min_fc"],
alpha=self._parameters["alpha"],
Expand All @@ -358,7 +400,7 @@ def _do_analysis(self):
return volcano_plot.plot, volcano_plot


class ClustermapAnalysis(AbstractAnalysis):
class ClustermapAnalysis(AnalysisComponent):
"""Widget for Clustermap analysis."""

_works_with_nans = False
Expand All @@ -369,7 +411,7 @@ def _do_analysis(self):
return clustermap, None


class DendrogramAnalysis(AbstractAnalysis):
class DendrogramAnalysis(AnalysisComponent):
"""Widget for Dendrogram analysis."""

_works_with_nans = False
Expand Down Expand Up @@ -398,20 +440,20 @@ def show_widget(self):

super().show_widget()

self._parameters.update({"method": method})
self._parameters.update({AnalysisParameters.DEA_TWOGROUPS_METHOD: method})

def _do_analysis(self):
"""Perform T-test analysis."""
diff_exp_analysis = self._dataset.diff_expression_analysis(
method=self._parameters["method"],
group1=self._parameters["group1"],
group2=self._parameters["group2"],
column=self._parameters["column"],
method=self._parameters[AnalysisParameters.DEA_TWOGROUPS_METHOD],
group1=self._parameters[AnalysisParameters.TWOGROUP_GROUP1],
group2=self._parameters[AnalysisParameters.TWOGROUP_GROUP2],
column=self._parameters[AnalysisParameters.TWOGROUP_COLUMN],
)
return diff_exp_analysis, None


class TukeyTestAnalysis(AbstractAnalysis):
class TukeyTestAnalysis(AnalysisComponent):
"""Widget for Tukey-Test analysis."""

def show_widget(self):
Expand Down Expand Up @@ -467,7 +509,7 @@ def _do_analysis(self):
return anova_analysis, None


class AncovaAnalysis(AbstractAnalysis):
class AncovaAnalysis(AnalysisComponent):
"""Widget for Ancova analysis."""

def show_widget(self):
Expand Down Expand Up @@ -500,6 +542,67 @@ def _do_analysis(self):
return ancova_analysis, None


class DifferentialExpressionTwoGroupsAnalysis(AbstractGroupCompareAnalysis):
"""Widget for Differential expression analysis between two groups."""

def show_widget(self):
"""Show the widget and gather parameters."""
super().show_widget()

parameters = {}
method = st.selectbox(
"Differential Analysis using:",
options=["independent t-test", "paired t-test"],
key=AnalysisParameters.DEA_TWOGROUPS_METHOD,
)
parameters[AnalysisParameters.DEA_TWOGROUPS_METHOD] = method

fdr_method = st.selectbox(
"FDR method",
options=["fdr_bh", "bonferroni"],
index=0,
format_func=lambda x: {
"fdr_bh": "Benjamini-Hochberg",
"bonferroni": "Bonferroni",
}[x],
key=AnalysisParameters.DEA_TWOGROUPS_FDR_METHOD,
)
parameters[AnalysisParameters.DEA_TWOGROUPS_FDR_METHOD] = fdr_method

self._parameters.update(parameters)

def _do_analysis(self) -> Tuple[ResultComponent, None]:
"""Run the differential expression analysis between two groups and return the corresponding results object."""

test_type = {
"independent t-test": DeaTestTypes.INDEPENDENT,
"paired t-test": DeaTestTypes.PAIRED,
}[self._parameters[AnalysisParameters.DEA_TWOGROUPS_METHOD]]

dea = DifferentialExpressionAnalysisTTest(
self._dataset.mat,
is_log2_transformed=self._dataset.preprocessing_info[
PreprocessingStateKeys.LOG2_TRANSFORMED
],
)
dea_result = dea.perform(
test_type=test_type,
group1=self._parameters[AnalysisParameters.TWOGROUP_GROUP1],
group2=self._parameters[AnalysisParameters.TWOGROUP_GROUP2],
grouping_column=self._parameters[AnalysisParameters.TWOGROUP_COLUMN],
metadata=self._dataset.metadata,
fdr_method=self._parameters[AnalysisParameters.DEA_TWOGROUPS_FDR_METHOD],
)

return DifferentialExpressionTwoGroupsResult(
dea_result,
preprocessing=self._dataset.preprocessing_info,
method=self._parameters,
feature_to_repr_map=self._dataset._feature_to_repr_map,
is_plottable=True,
), None # None is for backwards compatibility


ANALYSIS_OPTIONS = {
PlottingOptions.VOLCANO_PLOT: VolcanoPlotAnalysis,
PlottingOptions.PCA_PLOT: PCAPlotAnalysis,
Expand All @@ -513,4 +616,5 @@ def _do_analysis(self):
StatisticOptions.TUKEY_TEST: TukeyTestAnalysis,
StatisticOptions.ANOVA: AnovaAnalysis,
StatisticOptions.ANCOVA: AncovaAnalysis,
NewAnalysisOptions.DIFFERENTIAL_EXPRESSION_TWO_GROUPS: DifferentialExpressionTwoGroupsAnalysis,
}
Loading

0 comments on commit 3c5ac58

Please sign in to comment.