From 449aa835fd441c6b09069826d4cee986574c5cc4 Mon Sep 17 00:00:00 2001 From: Eric Kutschera Date: Thu, 3 Oct 2024 10:35:43 -0400 Subject: [PATCH] Add scripts to plot PCA based on PSI values --- rMATS_P/extract_psi_for_pca.py | 153 +++++++++++++++++++++++++++++++++ rMATS_R/plot_psi_pca.R | 74 ++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 rMATS_P/extract_psi_for_pca.py create mode 100644 rMATS_R/plot_psi_pca.R diff --git a/rMATS_P/extract_psi_for_pca.py b/rMATS_P/extract_psi_for_pca.py new file mode 100644 index 0000000..0ae04ba --- /dev/null +++ b/rMATS_P/extract_psi_for_pca.py @@ -0,0 +1,153 @@ +import argparse +import os +import os.path + + +def parse_args(): + parser = argparse.ArgumentParser( + description=('Extract PSI values from rMATS-turbo output')) + parser.add_argument('--rmats-out-dir', + required=True, + help='the --od from rMATS') + parser.add_argument('--out-tsv', + required=True, + help='where to write the PSI values') + parser.add_argument('--average-read-count', + type=int, + default=10, + help=('Filter out events with lower average read count' + ' (default %(default)s)')) + parser.add_argument( + '--sample-names', + required=False, + help=('A comma separated list of sample names corresponding to' + ' the order from --b1 and then --b2')) + + return parser.parse_args() + + +def write_tsv_line(columns, handle): + handle.write('{}\n'.format('\t'.join([str(x) for x in columns]))) + + +def read_tsv_line(line): + columns = line.rstrip('\n').split('\t') + return columns + + +def parse_comma_values(string, func): + if string == '': + return list() + + parts = string.split(',') + result = list() + for part in parts: + try: + value = func(part) + except ValueError: + value = None + + result.append(value) + + return result + + +def finalize_sample_names(sample_names, num_samples, file_name): + if not sample_names: + sample_names = list() + highest_sample_str = str(num_samples - 1) + num_digits = len(highest_sample_str) + for i in range(num_samples): + format_str = 'sample_{:0' + str(num_digits) + '}' + sample_names.append(format_str.format(i)) + elif len(sample_names) != num_samples: + raise Exception('Expected {} samples based on IncLevel columns' + ' in {}, but got {}'.format(num_samples, file_name, + sample_names)) + + return sample_names + + +def write_psi_for_file(file_name, event_type, read_count_threshold, + is_first_file, out_headers, sample_names, in_handle, + out_handle): + for line_i, line in enumerate(in_handle): + in_columns = read_tsv_line(line) + if line_i == 0: + in_headers = in_columns + continue + + row = dict(zip(in_headers, in_columns)) + event_id = row['ID'] + ijc_1_str = row['IJC_SAMPLE_1'] + sjc_1_str = row['SJC_SAMPLE_1'] + ijc_2_str = row['IJC_SAMPLE_2'] + sjc_2_str = row['SJC_SAMPLE_2'] + inc_1_str = row['IncLevel1'] + inc_2_str = row['IncLevel2'] + + ijc_1 = parse_comma_values(ijc_1_str, int) + sjc_1 = parse_comma_values(sjc_1_str, int) + ijc_2 = parse_comma_values(ijc_2_str, int) + sjc_2 = parse_comma_values(sjc_2_str, int) + inc_1 = parse_comma_values(inc_1_str, float) + inc_2 = parse_comma_values(inc_2_str, float) + + ijc = ijc_1 + ijc_2 + sjc = sjc_1 + sjc_2 + inc = inc_1 + inc_2 + num_samples = len(ijc) + if (line_i == 1) and is_first_file: + sample_names = finalize_sample_names(sample_names, num_samples, + file_name) + out_headers.extend(sample_names) + write_tsv_line(out_headers, out_handle) + + total_reads = sum(ijc) + sum(sjc) + avg_reads = total_reads / num_samples + if avg_reads < read_count_threshold: + continue + + if any([x is None for x in inc]): + continue + + write_tsv_line([event_type, event_id] + inc, out_handle) + + +def parse_sample_names(sample_names): + if not sample_names: + return None + + return sample_names.split(',') + + +def extract_psi_for_pca(rmats_dir, out_tsv, read_count_threshold, + sample_names): + sample_names = parse_sample_names(sample_names) + + suffix = '.MATS.JC.txt' + with open(out_tsv, 'wt') as out_handle: + headers = ['event_type', 'id'] + is_first_file = True + file_names = sorted(os.listdir(rmats_dir)) + for name in file_names: + if not name.endswith(suffix): + continue + + path = os.path.join(rmats_dir, name) + event_type = name[:-len(suffix)] + with open(path, 'rt') as in_handle: + write_psi_for_file(name, event_type, read_count_threshold, + is_first_file, headers, sample_names, + in_handle, out_handle) + is_first_file = False + + +def main(): + args = parse_args() + extract_psi_for_pca(args.rmats_out_dir, args.out_tsv, + args.average_read_count, args.sample_names) + + +if __name__ == '__main__': + main() diff --git a/rMATS_R/plot_psi_pca.R b/rMATS_R/plot_psi_pca.R new file mode 100644 index 0000000..50e59dd --- /dev/null +++ b/rMATS_R/plot_psi_pca.R @@ -0,0 +1,74 @@ +library('ggplot2') + +## Rscript plot_psi_pca.R pca_psi.tsv pca_psi.png +args <- base::commandArgs(trailingOnly=TRUE) +in_file_name <- args[1] +out_file_name <- args[2] + +create_ggplot_theme <- function() { + return( + ggplot2::theme_bw() + + ggplot2::theme(panel.grid.major = ggplot2::element_blank(), + panel.grid.minor = ggplot2::element_blank(), + panel.background = ggplot2::element_blank(), + panel.border = ggplot2::element_blank(), + axis.line = ggplot2::element_line(color='black'), + legend.title=ggplot2::element_blank(), + legend.text=ggplot2::element_text(size=4)) + ) +} + +calculate_pca <- function(data) { + ## data: columns are samples, rows are events + ## convert so that rows are samples + transposed <- base::t(data) + variances <- base::apply(transposed, 2, stats::var) + transposed <- transposed[, variances != 0] + transposed <- base::log10(transposed + 1) + pca_results <- stats::prcomp(transposed, center=TRUE, scale.=TRUE) + variance_by_component <- base::apply(pca_results$x, 2, stats::var) + total_variance <- base::sum(variance_by_component) + percent_variance <- base::round( + 100 * (variance_by_component / total_variance), digits=1) + return(list(pca=pca_results, percent_variance=percent_variance)) +} + +plot_pc_1_2 <- function(plot_df, pca_variance, out_path) { + point_size <- 1 + point_alpha <- 0.8 + width <- 5 + height<- 4 + + title <- 'PCA by PSI value' + plot <- ggplot2::ggplot(data=plot_df, + ggplot2::aes(x=pc_1, y=pc_2, color=sample)) + + ggplot2::geom_point(size=point_size, alpha=point_alpha) + + ggplot2::labs(x=base::paste0('PC 1: ', pca_variance[1], '%'), + y=base::paste0('PC 2: ', pca_variance[2], '%'), + title=title) + + create_ggplot_theme() + + ggplot2::scale_x_continuous(breaks=NULL) + + ggplot2::scale_y_continuous(breaks=NULL) + + ggplot2::guides(color=ggplot2::guide_legend(title=NULL, + label.theme=ggplot2::element_text(size=4), + keywidth=0.5, + keyheight=0.5)) + + ggplot2::ggsave(plot=plot, out_path, width=width, height=height) +} + +main <- function(in_file_name, out_file_name) { + data_frame <- utils::read.table(file=in_file_name, sep='\t', header=TRUE) + num_columns <- base::ncol(data_frame) + ## first 2 columns are event_type, id + sample_names <- base::colnames(data_frame)[3:num_columns] + num_samples <- base::length(sample_names) + psi_values <- data_frame[, 3:num_columns] + pca_data <- calculate_pca(psi_values) + pc_1 <- pca_data$pca$x[, 1] + pc_2 <- pca_data$pca$x[, 2] + plot_df <- base::data.frame(pc_1=pc_1, pc_2=pc_2, sample=sample_names) + plot_pc_1_2(plot_df, pca_data$percent_variance, out_file_name) +} + +main(in_file_name, out_file_name)