Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow multiple snp_transcripts in plot_diplotype_clustering_advanced() #703

Merged
84 changes: 69 additions & 15 deletions malariagen_data/anoph/dipclust.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
cnv_params,
)
from .snp_frq import AnophelesSnpFrequencyAnalysis
from .cnv_data import AnophelesCnvData
from .cnv_frq import AnophelesCnvFrequencyAnalysis
leehart marked this conversation as resolved.
Show resolved Hide resolved

AA_CHANGE_QUERY = (
"effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
)


class AnophelesDipClustAnalysis(AnophelesSnpFrequencyAnalysis, AnophelesCnvData):
class AnophelesDipClustAnalysis(
AnophelesCnvFrequencyAnalysis, AnophelesSnpFrequencyAnalysis
):
def __init__(
self,
**kwargs,
Expand Down Expand Up @@ -190,7 +192,7 @@ def plot_diplotype_clustering(
else:
return {
"figure": fig,
"dendro_sample_id_order": leaf_data["sample_id"].to_list(),
"dendro_sample_id_order": np.asarray(leaf_data["sample_id"].to_list()),
"n_snps": n_snps_used,
}

Expand Down Expand Up @@ -319,7 +321,7 @@ def _dipclust_het_bar_trace(
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
sample_query_options: Optional[base_params.sample_query_options],
site_mask: base_params.site_mask,
site_mask: Optional[base_params.site_mask],
cohort_size: Optional[base_params.cohort_size],
random_seed: base_params.random_seed,
color_continuous_scale: Optional[plotly_params.color_continuous_scale],
Expand Down Expand Up @@ -547,11 +549,52 @@ def _dipclust_concat_subplots(

return fig

def _insert_dipclust_snp_trace(
self,
*,
figures,
subplot_heights,
snp_row_height: plotly_params.height = 25,
transcript: base_params.transcript,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
sample_sets: Optional[base_params.sample_sets],
sample_query: Optional[base_params.sample_query],
sample_query_options: Optional[base_params.sample_query_options],
site_mask: Optional[base_params.site_mask],
dendro_sample_id_order: np.ndarray,
snp_filter_min_maf: float,
snp_colorscale: Optional[plotly_params.color_continuous_scale],
chunks: base_params.chunks = base_params.native_chunks,
inline_array: base_params.inline_array = base_params.inline_array_default,
):
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
transcript=transcript,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
)
return figures, subplot_heights

@doc(
summary="Perform diplotype clustering, annotated with heterozygosity, gene copy number and amino acid variants.",
parameters=dict(
heterozygosity="Plot heterozygosity track.",
snp_transcript="Plot amino acid variants for this transcript.",
snp_transcript="Plot amino acid variants for these transcripts.",
cnv_region="Plot gene CNV calls for this region.",
snp_filter_min_maf="Filter amino acid variants with alternate allele frequency below this threshold.",
),
Expand All @@ -561,7 +604,7 @@ def plot_diplotype_clustering_advanced(
region: base_params.regions,
heterozygosity: bool = True,
heterozygosity_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_transcript: Optional[base_params.transcript] = None,
snp_transcript: Optional[dipclust_params.snp_transcript] = None,
snp_colorscale: plotly_params.color_continuous_scale = "Greys",
snp_filter_min_maf: float = 0.05,
snp_query: Optional[base_params.snp_query] = AA_CHANGE_QUERY,
Expand Down Expand Up @@ -682,9 +725,11 @@ def plot_diplotype_clustering_advanced(
figures.append(cnv_trace)
subplot_heights.append(cnv_row_height * n_cnv_genes)

if snp_transcript:
snp_trace, n_snps_transcript = self._dipclust_snp_trace(
if isinstance(snp_transcript, str):
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=snp_transcript,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
Expand All @@ -696,13 +741,22 @@ def plot_diplotype_clustering_advanced(
chunks=chunks,
inline_array=inline_array,
)

if snp_trace:
figures.append(snp_trace)
subplot_heights.append(snp_row_height * n_snps_transcript)
else:
print(
f"No SNPs were found below {snp_filter_min_maf} allele frequency. Omitting SNP genotype plot."
elif isinstance(snp_transcript, list):
for st in snp_transcript:
figures, subplot_heights = self._insert_dipclust_snp_trace(
transcript=st,
figures=figures,
subplot_heights=subplot_heights,
sample_sets=sample_sets,
sample_query=sample_query,
sample_query_options=sample_query_options,
snp_query=snp_query,
site_mask=site_mask,
dendro_sample_id_order=dendro_sample_id_order,
snp_filter_min_maf=snp_filter_min_maf,
snp_colorscale=snp_colorscale,
chunks=chunks,
inline_array=inline_array,
)

# Calculate total height based on subplot heights, plus a fixed
Expand Down
8 changes: 8 additions & 0 deletions malariagen_data/anoph/dipclust_params.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
"""Parameters for diplotype clustering functions."""

from typing_extensions import Annotated, TypeAlias, Union, Sequence

from .distance_params import distance_metric
from .clustering_params import linkage_method
from .base_params import transcript


linkage_method_default: linkage_method = "complete"

distance_metric_default: distance_metric = "cityblock"

snp_transcript: TypeAlias = Annotated[
Union[transcript, Sequence[transcript]],
"A transcript or a list of transcripts",
]
9 changes: 0 additions & 9 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
import plotly.graph_objects as go # type: ignore
from numpydoc_decorator import doc # type: ignore

from malariagen_data.anoph.snp_frq import (
AnophelesSnpFrequencyAnalysis,
)

from .anoph.cnv_frq import AnophelesCnvFrequencyAnalysis

from .anoph import (
aim_params,
Expand All @@ -32,7 +27,6 @@
from .anoph.karyotype import AnophelesKaryotypeAnalysis
from .anoph.aim_data import AnophelesAimData
from .anoph.base import AnophelesBase
from .anoph.cnv_data import AnophelesCnvData
from .anoph.genome_features import AnophelesGenomeFeaturesData
from .anoph.genome_sequence import AnophelesGenomeSequenceData
from .anoph.hap_data import AnophelesHapData, hap_params
Expand Down Expand Up @@ -88,8 +82,6 @@ class AnophelesDataResource(
AnophelesH12Analysis,
AnophelesG123Analysis,
AnophelesFstAnalysis,
AnophelesCnvFrequencyAnalysis,
AnophelesSnpFrequencyAnalysis,
AnophelesHapFrequencyAnalysis,
AnophelesDistanceAnalysis,
AnophelesPca,
Expand All @@ -99,7 +91,6 @@ class AnophelesDataResource(
AnophelesAimData,
AnophelesHapData,
AnophelesSnpData,
AnophelesCnvData,
AnophelesSampleMetadata,
AnophelesGenomeFeaturesData,
AnophelesGenomeSequenceData,
Expand Down
Loading
Loading