Skip to content

Commit

Permalink
Add scripts to plot PCA based on PSI values
Browse files Browse the repository at this point in the history
  • Loading branch information
EricKutschera committed Oct 3, 2024
1 parent fb6e269 commit 449aa83
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 0 deletions.
153 changes: 153 additions & 0 deletions rMATS_P/extract_psi_for_pca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
import os
import os.path


def parse_args():
parser = argparse.ArgumentParser(
description=('Extract PSI values from rMATS-turbo output'))
parser.add_argument('--rmats-out-dir',
required=True,
help='the --od from rMATS')
parser.add_argument('--out-tsv',
required=True,
help='where to write the PSI values')
parser.add_argument('--average-read-count',
type=int,
default=10,
help=('Filter out events with lower average read count'
' (default %(default)s)'))
parser.add_argument(
'--sample-names',
required=False,
help=('A comma separated list of sample names corresponding to'
' the order from --b1 and then --b2'))

return parser.parse_args()


def write_tsv_line(columns, handle):
handle.write('{}\n'.format('\t'.join([str(x) for x in columns])))


def read_tsv_line(line):
columns = line.rstrip('\n').split('\t')
return columns


def parse_comma_values(string, func):
if string == '':
return list()

parts = string.split(',')
result = list()
for part in parts:
try:
value = func(part)
except ValueError:
value = None

result.append(value)

return result


def finalize_sample_names(sample_names, num_samples, file_name):
if not sample_names:
sample_names = list()
highest_sample_str = str(num_samples - 1)
num_digits = len(highest_sample_str)
for i in range(num_samples):
format_str = 'sample_{:0' + str(num_digits) + '}'
sample_names.append(format_str.format(i))
elif len(sample_names) != num_samples:
raise Exception('Expected {} samples based on IncLevel columns'
' in {}, but got {}'.format(num_samples, file_name,
sample_names))

return sample_names


def write_psi_for_file(file_name, event_type, read_count_threshold,
is_first_file, out_headers, sample_names, in_handle,
out_handle):
for line_i, line in enumerate(in_handle):
in_columns = read_tsv_line(line)
if line_i == 0:
in_headers = in_columns
continue

row = dict(zip(in_headers, in_columns))
event_id = row['ID']
ijc_1_str = row['IJC_SAMPLE_1']
sjc_1_str = row['SJC_SAMPLE_1']
ijc_2_str = row['IJC_SAMPLE_2']
sjc_2_str = row['SJC_SAMPLE_2']
inc_1_str = row['IncLevel1']
inc_2_str = row['IncLevel2']

ijc_1 = parse_comma_values(ijc_1_str, int)
sjc_1 = parse_comma_values(sjc_1_str, int)
ijc_2 = parse_comma_values(ijc_2_str, int)
sjc_2 = parse_comma_values(sjc_2_str, int)
inc_1 = parse_comma_values(inc_1_str, float)
inc_2 = parse_comma_values(inc_2_str, float)

ijc = ijc_1 + ijc_2
sjc = sjc_1 + sjc_2
inc = inc_1 + inc_2
num_samples = len(ijc)
if (line_i == 1) and is_first_file:
sample_names = finalize_sample_names(sample_names, num_samples,
file_name)
out_headers.extend(sample_names)
write_tsv_line(out_headers, out_handle)

total_reads = sum(ijc) + sum(sjc)
avg_reads = total_reads / num_samples
if avg_reads < read_count_threshold:
continue

if any([x is None for x in inc]):
continue

write_tsv_line([event_type, event_id] + inc, out_handle)


def parse_sample_names(sample_names):
if not sample_names:
return None

return sample_names.split(',')


def extract_psi_for_pca(rmats_dir, out_tsv, read_count_threshold,
sample_names):
sample_names = parse_sample_names(sample_names)

suffix = '.MATS.JC.txt'
with open(out_tsv, 'wt') as out_handle:
headers = ['event_type', 'id']
is_first_file = True
file_names = sorted(os.listdir(rmats_dir))
for name in file_names:
if not name.endswith(suffix):
continue

path = os.path.join(rmats_dir, name)
event_type = name[:-len(suffix)]
with open(path, 'rt') as in_handle:
write_psi_for_file(name, event_type, read_count_threshold,
is_first_file, headers, sample_names,
in_handle, out_handle)
is_first_file = False


def main():
args = parse_args()
extract_psi_for_pca(args.rmats_out_dir, args.out_tsv,
args.average_read_count, args.sample_names)


if __name__ == '__main__':
main()
74 changes: 74 additions & 0 deletions rMATS_R/plot_psi_pca.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
library('ggplot2')

## Rscript plot_psi_pca.R pca_psi.tsv pca_psi.png
args <- base::commandArgs(trailingOnly=TRUE)
in_file_name <- args[1]
out_file_name <- args[2]

create_ggplot_theme <- function() {
return(
ggplot2::theme_bw() +
ggplot2::theme(panel.grid.major = ggplot2::element_blank(),
panel.grid.minor = ggplot2::element_blank(),
panel.background = ggplot2::element_blank(),
panel.border = ggplot2::element_blank(),
axis.line = ggplot2::element_line(color='black'),
legend.title=ggplot2::element_blank(),
legend.text=ggplot2::element_text(size=4))
)
}

calculate_pca <- function(data) {
## data: columns are samples, rows are events
## convert so that rows are samples
transposed <- base::t(data)
variances <- base::apply(transposed, 2, stats::var)
transposed <- transposed[, variances != 0]
transposed <- base::log10(transposed + 1)
pca_results <- stats::prcomp(transposed, center=TRUE, scale.=TRUE)
variance_by_component <- base::apply(pca_results$x, 2, stats::var)
total_variance <- base::sum(variance_by_component)
percent_variance <- base::round(
100 * (variance_by_component / total_variance), digits=1)
return(list(pca=pca_results, percent_variance=percent_variance))
}

plot_pc_1_2 <- function(plot_df, pca_variance, out_path) {
point_size <- 1
point_alpha <- 0.8
width <- 5
height<- 4

title <- 'PCA by PSI value'
plot <- ggplot2::ggplot(data=plot_df,
ggplot2::aes(x=pc_1, y=pc_2, color=sample)) +
ggplot2::geom_point(size=point_size, alpha=point_alpha) +
ggplot2::labs(x=base::paste0('PC 1: ', pca_variance[1], '%'),
y=base::paste0('PC 2: ', pca_variance[2], '%'),
title=title) +
create_ggplot_theme() +
ggplot2::scale_x_continuous(breaks=NULL) +
ggplot2::scale_y_continuous(breaks=NULL) +
ggplot2::guides(color=ggplot2::guide_legend(title=NULL,
label.theme=ggplot2::element_text(size=4),
keywidth=0.5,
keyheight=0.5))

ggplot2::ggsave(plot=plot, out_path, width=width, height=height)
}

main <- function(in_file_name, out_file_name) {
data_frame <- utils::read.table(file=in_file_name, sep='\t', header=TRUE)
num_columns <- base::ncol(data_frame)
## first 2 columns are event_type, id
sample_names <- base::colnames(data_frame)[3:num_columns]
num_samples <- base::length(sample_names)
psi_values <- data_frame[, 3:num_columns]
pca_data <- calculate_pca(psi_values)
pc_1 <- pca_data$pca$x[, 1]
pc_2 <- pca_data$pca$x[, 2]
plot_df <- base::data.frame(pc_1=pc_1, pc_2=pc_2, sample=sample_names)
plot_pc_1_2(plot_df, pca_data$percent_variance, out_file_name)
}

main(in_file_name, out_file_name)

0 comments on commit 449aa83

Please sign in to comment.