-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtigmint-cut
executable file
·301 lines (256 loc) · 12.9 KB
/
tigmint-cut
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
#!/usr/bin/env python3
"""
Given a BED file of molecule extents, scan the input draft assembly with windows of a fixed size to
find windows with few spanning molecules. These areas are likely misassembled areas. Cut the input
assembly at the potentially misassembled regions.
@author: Lauren Coombe and Shaun Jackman
"""
import os
import sys
import argparse
import subprocess
import shlex
import pybedtools
from datetime import datetime
from multiprocessing import Queue, Process
from intervaltree import IntervalTree
class NoSpanningRun:
"""
Helper class to represent a run of windows with no spanning molecules. Keeps positions of the last end_window
before the run, and the first start_window after the run, to be used as breakpoints.
"""
def __init__(self):
self.beforeRun_bp = 0
self.afterRun_bp = 0
def tallyBreakpoints(bps_queue, contig, noSpanningRun):
"""Add the breakpoints to the Process queue."""
if noSpanningRun.beforeRun_bp == noSpanningRun.afterRun_bp:
bps_queue.put((contig, noSpanningRun.beforeRun_bp))
else:
bps_queue.put((contig, noSpanningRun.beforeRun_bp))
bps_queue.put((contig, noSpanningRun.afterRun_bp))
def checkSpanningMolecules(intervals, window, contigLengths, contig, num_spanning, bps_queue, attempted_corr_queue, min_length=3000):
"""
Given the molecule intervals on a contig, check all windows of specified lengths for flanking chromium molecules.
Cuts will be made in windows of the genome where there are no spanning molecules.
Cut 1: The end of the last spanning window before a run of windows with no spanning molecules.
Cut 2: The start of the first spanning window after a run of windows with no spanning molecules.
"""
contigLength = contigLengths[contig]
start_window = 0
end_window = window
pastStart = False
noSpanningRun = None
while end_window < contigLength:
numSpanningMolecs = 0
smallestEndPos_spanningMolecs = float("inf")
overlap_intervals = sorted(intervals[start_window], key=lambda x: x[1], reverse=True)
for interval in overlap_intervals:
if interval.end > end_window: # Last overlapping base for interval is interval.end - 1 (range is inclusive of lower limit, non-inclusive of upper limit)
numSpanningMolecs += 1
smallestEndPos_spanningMolecs = min(smallestEndPos_spanningMolecs, interval.end-1)
else:
break # Sorted by decreasing end position, so if this interval isn't spanning, can break out of loop
if numSpanningMolecs >= num_spanning:
break
if numSpanningMolecs < num_spanning:
if pastStart:
if noSpanningRun is None: # First non-spanning interval for a new run
noSpanningRun = NoSpanningRun()
noSpanningRun.beforeRun_bp = end_window
start_window += 1
end_window += 1
else:
if pastStart and noSpanningRun is not None: # We have come to the end of a valid string of non-spanning molecule windows
noSpanningRun.afterRun_bp = start_window
tallyBreakpoints(bps_queue, contig, noSpanningRun)
# Keep track of attempted corrections
if contigLength >= min_length:
attempted_corr_queue.put(1)
noSpanningRun = None
pastStart = True
end_window = max(smallestEndPos_spanningMolecs + 1, end_window + 1)
start_window = end_window - window
def findBreakpoints(bed_name, window_length, contigLengths, num_spanning, bps_queue, attempted_corr_queue, min_length=3000):
"""
Given an input sorted BED file, find all breakpoints (in regions where there are no spanning molecules),
based on the specified window size.
"""
bedfile = pybedtools.BedTool(bed_name)
contig = ""
interval_tree = IntervalTree()
for bed in bedfile:
if bed.chrom not in contigLengths:
continue
if bed.chrom != contig:
if contig != "":
checkSpanningMolecules(interval_tree, window_length, contigLengths, contig, num_spanning, bps_queue, attempted_corr_queue, min_length)
interval_tree.clear()
contig = bed.chrom
interval_tree[bed.start:bed.stop] = bed.score
else: # Same contig ID, keep reading in the BED file to collect all molecule extents for that contig
interval_tree[bed.start:bed.stop] = bed.score
# Ensuring final contig in the bed file is also checked for spanning molecules
if contig != "":
checkSpanningMolecules(interval_tree, window_length, contigLengths, contig, num_spanning, bps_queue, attempted_corr_queue, min_length)
def launchFindBreakpoints(bedfile, window, num_processes, partitioned_contigLengths, num_spanning, min_length=3000):
"""
Launch processes to find breakpoints for partitioned contigs in parallel.
Returns dictionary of identified breakpoints.
"""
processes = []
bp_queue = Queue()
breakpoints = {}
attempted_corr_queue = Queue()
total_attempted_corr = 0
for i in range(0, num_processes):
p = Process(target=findBreakpoints, args=(bedfile, window, partitioned_contigLengths[i], num_spanning, bp_queue, attempted_corr_queue, min_length))
processes.append(p)
p.start()
while True:
running = any(p.is_alive() for p in processes)
while not bp_queue.empty():
bp = bp_queue.get()
chrom = bp[0]
pos = int(bp[1])
if chrom not in breakpoints:
breakpoints[chrom] = []
breakpoints[chrom].append(pos)
while not attempted_corr_queue.empty():
attempted_corr_queue.get()
total_attempted_corr += 1
if not running:
break
for p in processes:
p.join()
print("Attempted corrections: %i" % (total_attempted_corr))
return breakpoints
def findContigLengths(fasta, num_processes):
"""
Reads through the fai index of the fasta file to partition the contigs, and tracking the contig lengths in a dictionary per partition.
Returns the list of dictionaries (length of list = number of partitions)
"""
fasta_faidx_name = fasta + ".fai"
try:
fasta_faidx = open(fasta_faidx_name, 'r')
contig_lengths = []
for i in range(0, num_processes):
contig_lengths.append({})
count = 0
for seq in fasta_faidx:
seq = seq.strip().split("\t")
ctg_name = seq[0]
ctg_length = int(seq[1])
partition = count % num_processes
contig_lengths[partition][ctg_name] = ctg_length
count += 1
fasta_faidx.close()
return contig_lengths
except:
print("Error when trying to open %s.\nGenerate the fai index file for %s with: samtools faidx %s" % (fasta_faidx_name, fasta, fasta))
sys.exit(1)
def printBreakpoints(breakpoints, partitioned_contigLengths, bedout, len_trim):
"""Given a list of breakpoints and a reference FAI file, print a BED file representing the cuts to the assembly."""
breakpoints_bedString = ""
for part in partitioned_contigLengths:
for contig in part:
contigLength = part[contig]
if contig in breakpoints:
contig_breakpoints = sorted(breakpoints[contig])
start = 0
contigNum = 1
curName = "%s-%d" % (contig, contigNum)
for bp in contig_breakpoints:
breakpoints_bedString += "%s\t%d\t%d\t%s\n" % (contig, start, bp-len_trim, curName)
start = bp + len_trim
contigNum += 1
curName = "%s-%d" % (contig, contigNum)
if start < contigLength: # Make sure to get the final bed region
breakpoints_bedString += "%s\t%d\t%d\t%s\n" % (contig, start, contigLength, curName)
else: # No breakpoints, have BED for entire length of the contig.
breakpoints_bedString += "%s\t%d\t%d\t%s\n" % (contig, 0, contigLength, contig)
breakpoints_bed = pybedtools.BedTool(breakpoints_bedString, from_string=True)
breakpoints_bed.saveas(bedout)
def cutAssembly(fasta, bedfile, out_fasta_filename):
"""
Use bedtools to cut the assembly based on the breakpoints identified in the breaktigs bed file,
stripping Ns from beginning and end of the sequences
"""
cmd = "bedtools getfasta -fi %s -bed %s -name" % (fasta, bedfile)
cmd_shlex = shlex.split(cmd)
out_fasta = open(out_fasta_filename, 'w')
cutFasta = subprocess.Popen(cmd_shlex, stdout=subprocess.PIPE, universal_newlines=True)
for line in iter(cutFasta.stdout.readline, ''):
line = line.strip()
if line[0] == ">": # Header line
out_fasta.write(line + "\n")
else: # Sequence line, strip leading or trailing "N"s
strippedNs = line.strip("Nn")
if strippedNs == "": # Just give single N if the sequence will become empty
strippedNs = "N"
out_fasta.write(strippedNs + "\n")
out_fasta.close()
def ensure_writable(f):
"""Ensure that the file f is writable."""
if os.access(f, os.W_OK):
return
d = os.path.dirname(f) if "/" in f else "."
if os.access(f, os.F_OK):
print("tigmint-cut: error: file exists and is not writable:", f, file=sys.stderr)
sys.exit(1)
if not os.access(d, os.F_OK):
print("tigmint-cut: error: parent directory does not exist:", f, file=sys.stderr)
sys.exit(1)
if not os.access(d, os.W_OK):
print("tigmint-cut: error: cannot write to parent directory:", f, file=sys.stderr)
sys.exit(1)
def get_span(filename):
"""Read calculated span parameter from a file."""
if os.access(filename, os.F_OK):
if os.access(filename, os.R_OK):
with open(filename) as param_file:
for line in param_file:
line_content = line.strip().split("\t")
if line_content[0] == "span":
return int(line_content[1])
print("tigmint-cut: error: calculated span parameter not found in parameter file '%s'" % filename, file=sys.stderr)
sys.exit(1)
else:
print("tigmint-cut: error: parameter file '%s' cannot be read" % filename, file=sys.stderr)
sys.exit(1)
else:
print("tigmint-cut: error: parameter file '%s' does not exist" % filename, file=sys.stderr)
sys.exit(1)
def main():
parser = argparse.ArgumentParser(description="Find misassembled regions in assembly using Chromium molecule extents")
parser.add_argument("--version", action="version", version="tigmint-cut 1.2.5")
parser.add_argument("fasta", type=str, help="Reference genome fasta file (must have FAI index generated)")
parser.add_argument("bed", type=str, help="Sorted bed file of molecule extents")
parser.add_argument("-o", "--fastaout", type=str, help="The output FASTA file.", required=True)
parser.add_argument("-b", "--bedout", type=str, help="The output BED file. Default is the output FASTA filename plus .bed")
parser.add_argument("-p", "--processes", type=int, help="Number of parallel processes to launch [8]", default=8)
parser.add_argument("-w", "--window", type=int, help="Window size used to check for spanning molecules (bp) [1000]", default=1000)
parser.add_argument("-n", "--spanning", type=int, help="Spanning molecules threshold (no misassembly in window if num. \
spanning molecules >= n) [2].", default=2)
parser.add_argument("-t", "--trim", type=int, help="Number of base pairs to trim at contig cuts (bp) [0]", default=0)
parser.add_argument("-f", "--param_file", help="File containing calculated parameters for tigmint-long", default=None)
parser.add_argument("-m", "--min_length", type=int, help="Minimum contig length for tallying attempted corrections (bp) [3000]", default=3000)
args = parser.parse_args()
print("Started at: %s" % datetime.now())
if args.param_file:
args.spanning = get_span(args.param_file)
if args.bedout == None:
args.bedout = args.fastaout + ".bed"
ensure_writable(args.fastaout)
ensure_writable(args.bedout)
print("Reading contig lengths...")
partitioned_contigLengths = findContigLengths(args.fasta, args.processes)
print("Finding breakpoints...")
breakpoints = launchFindBreakpoints(args.bed, args.window, args.processes, partitioned_contigLengths, args.spanning, args.min_length)
printBreakpoints(breakpoints, partitioned_contigLengths, args.bedout, args.trim)
print("Cutting assembly at breakpoints...")
cutAssembly(args.fasta, args.bedout, args.fastaout)
print("DONE!")
print("Ended at: %s" % datetime.now())
if __name__ == "__main__":
main()