Codebase list ariba / HEAD ariba / cluster.py
HEAD

Tree @HEAD (Download .tar.gz)

cluster.py @HEADraw · history · blame

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
import signal
import traceback
import os
import atexit
import random
import math
import sys
import pyfastaq
from ariba import assembly, assembly_compare, assembly_variants, common, external_progs, flag, mapping, report, samtools_variants

class Error (Exception): pass

unittest=False

class Cluster:
    def __init__(self,
      root_dir,
      name,
      refdata,
      all_ref_seqs_fasta=None,
      total_reads=None,
      total_reads_bases=None,
      fail_file=None,
      read_store=None,
      reference_names=None,
      logfile=None,
      assembly_coverage=50,
      assembly_kmer=21,
      assembler='fermilite',
      max_insert=1000,
      min_scaff_depth=10,
      nucmer_min_id=90,
      nucmer_min_len=20,
      nucmer_breaklen=200,
      reads_insert=500,
      sspace_k=20,
      sspace_sd=0.4,
      threads=1,
      assembled_threshold=0.95,
      min_var_read_depth=5,
      min_second_var_read_depth=2,
      max_allele_freq=0.90,
      unique_threshold=0.03,
      max_gene_nt_extend=30,
      spades_mode="rna", #["rna","wgs"]
      spades_options=None,
      clean=True,
      extern_progs=None,
      random_seed=42,
      threads_total=1
    ):
        self.root_dir = os.path.abspath(root_dir)
        self.read_store = read_store
        self.refdata = refdata
        self.name = name
        self.fail_file = fail_file
        self.reference_fa = os.path.join(self.root_dir, 'reference.fa')
        self.reference_names = reference_names
        self.all_reads1 = os.path.join(self.root_dir, 'reads_1.fq')
        self.all_reads2 = os.path.join(self.root_dir, 'reads_2.fq')
        self.references_fa = os.path.join(self.root_dir, 'references.fa')

        if os.path.exists(self.root_dir):
            self._input_files_exist()

        self.total_reads = total_reads
        self.total_reads_bases = total_reads_bases
        self.logfile = logfile
        self.assembly_coverage = assembly_coverage
        self.assembly_kmer = assembly_kmer
        self.assembler = assembler
        self.sspace_k = sspace_k
        self.sspace_sd = sspace_sd
        self.reads_insert = reads_insert
        self.spades_mode = spades_mode
        self.spades_options = spades_options

        self.reads_for_assembly1 = os.path.join(self.root_dir, 'reads_for_assembly_1.fq')
        self.reads_for_assembly2 = os.path.join(self.root_dir, 'reads_for_assembly_2.fq')

        self.ref_sequence = None

        self.max_insert = max_insert
        self.min_scaff_depth = min_scaff_depth

        self.nucmer_min_id = nucmer_min_id
        self.nucmer_min_len = nucmer_min_len
        self.nucmer_breaklen = nucmer_breaklen

        self.min_var_read_depth = min_var_read_depth
        self.min_second_var_read_depth = min_second_var_read_depth
        self.max_allele_freq = max_allele_freq

        self.threads = threads
        self.assembled_threshold = assembled_threshold
        self.unique_threshold = unique_threshold
        self.max_gene_nt_extend = max_gene_nt_extend
        self.status_flag = flag.Flag()
        self.clean = clean

        self.threads_total = threads_total
        self.remaining_clusters = None

        self.assembly_dir = os.path.join(self.root_dir, 'Assembly')
        self.final_assembly_fa = os.path.join(self.root_dir, 'assembly.fa')
        self.final_assembly_bam = os.path.join(self.root_dir, 'assembly.reads_mapped.bam')
        self.final_assembly_read_depths = os.path.join(self.root_dir, 'assembly.reads_mapped.bam.read_depths.gz')
        self.final_assembly_vcf = os.path.join(self.root_dir, 'assembly.reads_mapped.bam.vcf')
        self.samtools_vars_prefix = self.final_assembly_bam
        self.assembly_compare = None
        self.variants_from_samtools = {}
        self.assembly_compare_prefix = os.path.join(self.root_dir, 'assembly_compare')

        self.mummer_variants = {}
        self.variant_depths = {}
        self.percent_identities = {}

        # The log filehandle self.log_fh is set at the start of the run() method.
        # Lots of other methods use self.log_fh. But for unit testing, run() isn't
        # run. So we need to set this to something for unit testing.
        # On the other hand, setting it here breaks a real run of ARIBA because
        # multiprocessing complains with the error:
        # TypeError: cannot serialize '_io.TextIOWrapper' object.
        # Hence the following two lines...
        if unittest:
            self.log_fh = sys.stdout
        else:
            atexit.register(self._atexit)
            self.log_fh = None

        if extern_progs is None:
            self.extern_progs = external_progs.ExternalProgs(using_spades=self.assembler == 'spades')
        else:
            self.extern_progs = extern_progs

        if all_ref_seqs_fasta is None:
            self.all_refs_fasta = self.references_fa
        else:
            self.all_refs_fasta = os.path.abspath(all_ref_seqs_fasta)

        self.random_seed = random_seed
        wanted_signals = [signal.SIGABRT, signal.SIGINT, signal.SIGSEGV, signal.SIGTERM]
        for s in wanted_signals:
            signal.signal(s, self._receive_signal)

    def _update_threads(self):
        """Update available thread count post-construction.
        To be called any number of times from run() method"""
        if self.remaining_clusters is not None:
            self.threads = max(1,self.threads_total//self.remaining_clusters.value)
        #otherwise just keep the current (initial) value
        print("{} detected {} threads available to it".format(self.name,self.threads), file = self.log_fh)

    def _report_completion(self):
        """Update shared counters to signal that we are done with this cluster.
        Call just before exiting run() method (in a finally clause)"""
        rem_clust = self.remaining_clusters
        if rem_clust is not None:
            # -= is non-atomic, need to acquire a lock
            with self.remaining_clusters_lock:
                rem_clust.value -= 1
            # we do not need this object anymore
            self.remaining_clusters = None
            print("{} reported completion".format(self.name), file=self.log_fh)

    def _atexit(self):
        if self.log_fh is not None:
            pyfastaq.utils.close(self.log_fh)
            self.log_fh = None

    def _receive_signal(self, signum, stack):
        print('Signal', signum, 'received in cluster', self.name + '... Stopping!', file=sys.stderr, flush=True)
        if self.log_fh is not None:
            pyfastaq.utils.close(self.log_fh)
            self.log_fh = None
        if self.fail_file is not None:
            with open(self.fail_file, 'w') as f:
                pass
        sys.exit(1)


    def _input_files_exist(self):
        assert self.read_store is None
        if not os.path.exists(self.references_fa):
            raise Error('Error making cluster. References fasta not found')


    def _set_up_input_files(self):
        if self.logfile is None:
            self.logfile = os.path.join(self.root_dir, 'log.txt')

        if os.path.exists(self.root_dir):
            self._input_files_exist()
            seqreader = pyfastaq.sequences.file_reader(self.references_fa)
            self.reference_names = set([x.id for x in seqreader])
            self.log_fh = pyfastaq.utils.open_file_write(self.logfile)
        else:
            assert self.read_store is not None
            assert self.reference_names is not None
            try:
                os.mkdir(self.root_dir)
            except:
                raise Error('Error making directory ' + self.root_dir)

            self.refdata.write_seqs_to_fasta(self.references_fa, self.reference_names)
            self.log_fh = pyfastaq.utils.open_file_write(self.logfile)
            self.total_reads, self.total_reads_bases = self.read_store.get_reads(self.name, self.all_reads1, self.all_reads2, log_fh=self.log_fh)
            self.refdata.write_seqs_to_fasta(self.references_fa, self.reference_names)

        self.longest_ref_length = max([len(self.refdata.sequence(name)) for name in self.reference_names])


    def _clean_file(self, filename):
        if self.clean:
            print('Deleting file', filename, file=self.log_fh)
            try: #protect against OSError: [Errno 16] Device or resource busy: '.nfs0000000010f0f04f000003c9' and such
                os.unlink(filename)
            except:
                pass


    def _clean(self):
        if not self.clean:
            print('   ... not deleting anything because --noclean used', file=self.log_fh, flush=True)
            return


        to_delete = [
            'assembly.fa',
            'assembly.fa.fai',
            'assembly_compare.nucmer.coords',
            'assembly_compare.nucmer.coords.snps',
            'assembly.reads_mapped.bam.bai',
            'assembly.reads_mapped.bam.vcf',
            'assembly.reads_mapped.bam',
            'assembly.reads_mapped.bam.read_depths.gz',
            'assembly.reads_mapped.bam.read_depths.gz.tbi',
            'reads_1.fq',
            'reads_2.fq',
            'reference.fa',
        ]

        to_delete = [os.path.join(self.root_dir, x) for x in to_delete]

        for filename in to_delete:
            if os.path.exists(filename):
                self._clean_file(filename)


    @staticmethod
    def _number_of_reads_for_assembly(ref_length, insert_size, total_bases, total_reads, coverage):
        assert ref_length > 0
        ref_length += 2 * insert_size
        mean_read_length = total_bases / total_reads
        wanted_bases = coverage * ref_length
        wanted_reads = int(math.ceil(wanted_bases / mean_read_length))
        wanted_reads += wanted_reads % 2
        return wanted_reads


    @staticmethod
    def _make_reads_for_assembly(number_of_wanted_reads, total_reads, reads_in1, reads_in2, reads_out1, reads_out2, random_seed=None):
        '''Makes fastq files that are random subset of input files. Returns total number of reads in output files.
           If the number of wanted reads is >= total reads, then just makes symlinks instead of making
           new copies of the input files.'''
        random.seed(random_seed)

        if number_of_wanted_reads < total_reads:
            reads_written = 0
            percent_wanted = 100 * number_of_wanted_reads / total_reads
            file_reader1 = pyfastaq.sequences.file_reader(reads_in1)
            file_reader2 = pyfastaq.sequences.file_reader(reads_in2)
            out1 = pyfastaq.utils.open_file_write(reads_out1)
            out2 = pyfastaq.utils.open_file_write(reads_out2)

            for read1 in file_reader1:
                try:
                    read2 = next(file_reader2)
                except StopIteration:
                    pyfastaq.utils.close(out1)
                    pyfastaq.utils.close(out2)
                    raise Error('Error subsetting reads. No mate found for read ' + read1.id)

                if random.randint(0, 100) <= percent_wanted:
                    print(read1, file=out1)
                    print(read2, file=out2)
                    reads_written += 2

            pyfastaq.utils.close(out1)
            pyfastaq.utils.close(out2)
            return reads_written
        else:
            os.symlink(reads_in1, reads_out1)
            os.symlink(reads_in2, reads_out2)
            return total_reads


    def run(self,remaining_clusters=None,remaining_clusters_lock=None):
        try:
            self.remaining_clusters = remaining_clusters
            self.remaining_clusters_lock = remaining_clusters_lock
            self._update_threads()
            self._set_up_input_files()

            for fname in [self.all_reads1, self.all_reads2, self.references_fa]:
                if not os.path.exists(fname):
                    raise Error('File ' + fname + ' not found. Cannot continue')

            original_dir = os.getcwd()
            os.chdir(self.root_dir)

            try:
                self._run()
            except Error as err:
                os.chdir(original_dir)
                print('Error running cluster! Error was:', err, sep='\n', file=self.log_fh)
                pyfastaq.utils.close(self.log_fh)
                self.log_fh = None
                raise Error('Error running cluster ' + self.name + '!')

            os.chdir(original_dir)
            print('Finished', file=self.log_fh, flush=True)
            print('{:_^79}'.format(' LOG FILE END ' + self.name + ' '), file=self.log_fh, flush=True)

            # This stops multiprocessing complaining with the error:
            # multiprocessing.pool.MaybeEncodingError: Error sending result: '[<ariba.cluster.Cluster object at 0x7ffa50f8bcd0>]'. Reason: 'TypeError("cannot serialize '_io.TextIOWrapper' object",)'
            pyfastaq.utils.close(self.log_fh)
            self.log_fh = None
        finally:
            self._report_completion()


    def _run(self):
        print('{:_^79}'.format(' LOG FILE START ' + self.name + ' '), file=self.log_fh, flush=True)

        if self.total_reads == 0:
            print('No reads left after filtering with cdhit', file=self.log_fh, flush=True)
            self.assembled_ok = False
        else:
            wanted_reads = self._number_of_reads_for_assembly(self.longest_ref_length, self.reads_insert, self.total_reads_bases, self.total_reads, self.assembly_coverage)
            made_reads = self._make_reads_for_assembly(wanted_reads, self.total_reads, self.all_reads1, self.all_reads2, self.reads_for_assembly1, self.reads_for_assembly2, random_seed=self.random_seed)
            print('\nUsing', made_reads, 'from a total of', self.total_reads, 'for assembly.', file=self.log_fh, flush=True)
            print('Assembling reads:', file=self.log_fh, flush=True)

            self._update_threads()
            self.assembly = assembly.Assembly(
              self.reads_for_assembly1,
              self.reads_for_assembly2,
              self.reference_fa,
              self.references_fa,
              self.assembly_dir,
              self.final_assembly_fa,
              self.final_assembly_bam,
              self.log_fh,
              self.all_refs_fasta,
              contig_name_prefix=self.name,
              assembler=self.assembler,
              extern_progs=self.extern_progs,
              clean=self.clean,
              spades_mode=self.spades_mode,
              spades_options=self.spades_options,
              threads=self.threads
            )

            self.assembly.run()
            self.assembled_ok = self.assembly.assembled_ok
            self._clean_file(self.reads_for_assembly1)
            self._clean_file(self.reads_for_assembly2)
            if self.clean:
                print('Deleting Assembly directory', self.assembly_dir, file=self.log_fh, flush=True)
                common.rmtree(self.assembly_dir)


        if self.assembled_ok and self.assembly.ref_seq_name is not None:
            self.ref_sequence = self.refdata.sequence(self.assembly.ref_seq_name)
            is_gene, is_variant_only = self.refdata.sequence_type(self.ref_sequence.id)
            self.is_gene = '1' if is_gene == 'p' else '0'
            self.is_variant_only = '1' if is_variant_only else '0'

            print('\nAssembly was successful\n\nMapping reads to assembly:', file=self.log_fh, flush=True)
            self._update_threads()
            mapping.run_bowtie2(
                self.all_reads1,
                self.all_reads2,
                self.final_assembly_fa,
                self.final_assembly_bam[:-4],
                threads=self.threads,
                sort=True,
                bowtie2=self.extern_progs.exe('bowtie2'),
                bowtie2_preset='very-sensitive-local',
                bowtie2_version=self.extern_progs.version('bowtie2'),
                verbose=True,
                verbose_filehandle=self.log_fh
            )

            if self.assembly.has_contigs_on_both_strands:
                self.status_flag.add('hit_both_strands')

            print('\nMaking and checking scaffold graph', file=self.log_fh, flush=True)

            if not self.assembly.scaff_graph_ok:
                self.status_flag.add('scaffold_graph_bad')

            print('Comparing assembly against reference sequence', file=self.log_fh, flush=True)
            self.assembly_compare = assembly_compare.AssemblyCompare(
              self.final_assembly_fa,
              self.assembly.sequences,
              self.reference_fa,
              self.ref_sequence,
              self.assembly_compare_prefix,
              self.refdata,
              nucmer_min_id=self.nucmer_min_id,
              nucmer_min_len=self.nucmer_min_len,
              nucmer_breaklen=self.nucmer_breaklen,
              assembled_threshold=self.assembled_threshold,
              unique_threshold=self.unique_threshold,
              max_gene_nt_extend=self.max_gene_nt_extend,
            )
            self.assembly_compare.run()
            self.status_flag = self.assembly_compare.update_flag(self.status_flag)

            allowed_ctg_pos, allowed_ref_pos = assembly_compare.AssemblyCompare.nucmer_hits_to_ref_and_qry_coords(self.assembly_compare.nucmer_hits)
            assembly_variants_obj = assembly_variants.AssemblyVariants(self.refdata, self.assembly_compare.nucmer_snps_file)
            self.assembly_variants = assembly_variants_obj.get_variants(self.ref_sequence.id, allowed_ctg_pos, allowed_ref_pos)

            for var_list in self.assembly_variants.values():
                for var in var_list:
                    if var[3] not in ['.', 'SYN', None]:
                        self.status_flag.add('has_variant')
                        break

                if self.status_flag.has('has_variant'):
                    break


            print('\nCalling variants with samtools:', file=self.log_fh, flush=True)

            self.samtools_vars = samtools_variants.SamtoolsVariants(
                self.final_assembly_fa,
                self.final_assembly_bam,
                self.samtools_vars_prefix,
                log_fh=self.log_fh,
                min_var_read_depth=self.min_var_read_depth,
                min_second_var_read_depth=self.min_second_var_read_depth,
                max_allele_freq=self.max_allele_freq
            )
            self.samtools_vars.run()

            self.total_contig_depths = self.samtools_vars.total_depth_per_contig(self.samtools_vars.contig_depths_file)

            self.variants_from_samtools =  self.samtools_vars.variants_in_coords(self.assembly_compare.assembly_match_coords(), self.samtools_vars.vcf_file)
            if len(self.variants_from_samtools):
                self.status_flag.add('variants_suggest_collapsed_repeat')
        elif not self.assembled_ok:
            print('\nAssembly failed\n', file=self.log_fh, flush=True)
            self.status_flag.add('assembly_fail')
        elif self.assembly.ref_seq_name is None:
            print('\nCould not get closest reference sequence\n', file=self.log_fh, flush=True)
            self.status_flag.add('ref_seq_choose_fail')

        try:
            self.report_lines = report.report_lines(self)
        except:
            print('Error making report for cluster ', self.name, '... traceback:', file=sys.stderr)
            traceback.print_exc(file=sys.stderr)
            raise Error('Error making report for cluster ' + self.name)

        self._clean()
        atexit.unregister(self._atexit)