diff --git a/precellar/src/align.rs b/precellar/src/align.rs index d186080..9eb23db 100644 --- a/precellar/src/align.rs +++ b/precellar/src/align.rs @@ -10,32 +10,58 @@ use log::info; use noodles::sam::alignment::record_buf::data::field::value::Value; use noodles::sam::alignment::{record::data::field::tag::Tag, record_buf::RecordBuf, Record}; use noodles::{bam, fastq, sam}; +use rayon::prelude::*; use seqspec::{Assay, Modality, Read, RegionId, RegionIndex, RegionType}; use smallvec::SmallVec; +use star_aligner::StarAligner; use std::collections::{HashMap, HashSet}; use std::io::BufRead; use std::ops::Range; use std::sync::{Arc, Mutex}; +// GenomeAligner and TranscriptomeAligner pub trait Alinger { + type AlignOutput; + type ModRecordBuf; + fn chunk_size(&self) -> usize; fn header(&self) -> sam::Header; - fn align_reads( - &mut self, - records: &mut [fastq::Record], - ) -> impl ExactSizeIterator; + fn align_reads(&mut self, records: &mut [fastq::Record]) -> Vec; fn align_read_pairs( &mut self, records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator; + ) -> Vec<(Self::AlignOutput, Self::AlignOutput)>; + + // modify the record by adding cell barcode + fn add_cell_barcode( + &self, + header: &sam::Header, + record: &Self::AlignOutput, + ori_barcode: &[u8], + ori_qual: &[u8], + correct_barcode: Option<&[u8]>, + ) -> Result; + + fn update_qc( + &self, + _qc: &Arc>, + _header: &sam::Header, + _record: &Self::ModRecordBuf, + ) { + } + + // crumple into a vector of RecordBuf to unify the output + fn crumple(&self, record: Self::ModRecordBuf) -> Vec; } pub struct DummyAligner; impl Alinger for DummyAligner { + type AlignOutput = sam::Record; + type ModRecordBuf = RecordBuf; fn chunk_size(&self) -> usize { 0 } @@ -44,22 +70,108 @@ impl Alinger for DummyAligner { sam::Header::default() } - fn align_reads( - &mut self, - _records: &mut [fastq::Record], - ) -> impl ExactSizeIterator { - std::iter::empty() + fn align_reads(&mut self, _records: &mut [fastq::Record]) -> Vec { + Vec::new() } fn align_read_pairs( &mut self, _records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator { - std::iter::empty() + ) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + Vec::new() + } + + fn add_cell_barcode( + &self, + _header: &sam::Header, + _record: &Self::AlignOutput, + _ori_barcode: &[u8], + _ori_qual: &[u8], + _correct_barcode: Option<&[u8]>, + ) -> Result { + Ok(RecordBuf::default()) + } + + fn crumple(&self, _record: Self::ModRecordBuf) -> Vec { + Vec::new() + } +} + +impl Alinger for StarAligner { + type AlignOutput = Vec; + type ModRecordBuf = Vec; + + fn chunk_size(&self) -> usize { + 0 + } + + fn header(&self) -> sam::Header { + self.get_header().clone() + } + + fn align_reads(&mut self, records: &mut [fastq::Record]) -> Vec { + StarAligner::align_reads(self, records).collect() + } + + fn align_read_pairs<'a>( + &'a mut self, + records: &'a mut [(fastq::Record, fastq::Record)], + ) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + StarAligner::align_read_pairs(self, records).collect() + } + + // add cell barcode to the all records + fn add_cell_barcode( + &self, + header: &sam::Header, + record: &Self::AlignOutput, + ori_barcode: &[u8], + ori_qual: &[u8], + correct_barcode: Option<&[u8]>, + ) -> Result { + record + .iter() + .map(|rec| { + let mut record_buf = RecordBuf::try_from_alignment_record(header, rec)?; + let data = record_buf.data_mut(); + data.insert( + Tag::CELL_BARCODE_SEQUENCE, + Value::String(ori_barcode.into()), + ); + data.insert( + Tag::CELL_BARCODE_QUALITY_SCORES, + Value::String(ori_qual.into()), + ); + if let Some(barcode) = correct_barcode { + data.insert(Tag::CELL_BARCODE_ID, Value::String(barcode.into())); + } + Ok(record_buf) + }) + .collect() + } + + // update qc for each record + fn update_qc( + &self, + align_qc: &Arc>, + header: &sam::Header, + record: &Self::ModRecordBuf, + ) { + for recbuf in record { + align_qc.lock().unwrap().update(recbuf, header); + } + } + + // flatten the record + fn crumple(&self, record: Self::ModRecordBuf) -> Vec { + record } } impl Alinger for BurrowsWheelerAligner { + type AlignOutput = sam::Record; + type ModRecordBuf = RecordBuf; + fn chunk_size(&self) -> usize { self.chunk_size() } @@ -68,18 +180,52 @@ impl Alinger for BurrowsWheelerAligner { self.get_sam_header() } - fn align_reads( - &mut self, - records: &mut [fastq::Record], - ) -> impl ExactSizeIterator { - self.align_reads(records) + fn align_reads(&mut self, records: &mut [fastq::Record]) -> Vec { + self.align_reads(records).collect() } fn align_read_pairs( &mut self, records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator { - self.align_read_pairs(records) + ) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + self.align_read_pairs(records).collect() + } + + fn add_cell_barcode( + &self, + header: &sam::Header, + record: &Self::AlignOutput, + ori_barcode: &[u8], + ori_qual: &[u8], + correct_barcode: Option<&[u8]>, + ) -> Result { + let mut record_buf = RecordBuf::try_from_alignment_record(header, record)?; + let data = record_buf.data_mut(); + data.insert( + Tag::CELL_BARCODE_SEQUENCE, + Value::String(ori_barcode.into()), + ); + data.insert( + Tag::CELL_BARCODE_QUALITY_SCORES, + Value::String(ori_qual.into()), + ); + if let Some(barcode) = correct_barcode { + data.insert(Tag::CELL_BARCODE_ID, Value::String(barcode.into())); + } + Ok(record_buf) + } + + fn update_qc( + &self, + align_qc: &Arc>, + header: &sam::Header, + record: &Self::ModRecordBuf, + ) { + align_qc.lock().unwrap().update(record, header); + } + + fn crumple(&self, record: Self::ModRecordBuf) -> Vec { + vec![record] } } @@ -174,34 +320,43 @@ impl FastqProcessor { ) }) .unzip(); - let alignments: Vec<_> = self.aligner.align_read_pairs(&mut reads).collect(); + let alignments: Vec<_> = self.aligner.align_read_pairs(&mut reads); let results = barcodes .into_iter() .zip(alignments) .map(|(barcode, (ali1, ali2))| { - let ali1_ = add_cell_barcode( - &header, - &ali1, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - let ali2_ = add_cell_barcode( - &header, - &ali2, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - { - let mut align_qc_lock = align_qc.lock().unwrap(); - align_qc_lock.update(&ali1_, &header); - align_qc_lock.update(&ali2_, &header); - } - (ali1_, ali2_) + let ali1_ = self + .aligner + .add_cell_barcode( + &header, + &ali1, + barcode.raw.sequence(), + barcode.raw.quality_scores(), + barcode.corrected.as_deref(), + ) + .unwrap(); + let ali2_ = self + .aligner + .add_cell_barcode( + &header, + &ali2, + barcode.raw.sequence(), + barcode.raw.quality_scores(), + barcode.corrected.as_deref(), + ) + .unwrap(); + + // let mut align_qc_lock = align_qc.lock().unwrap(); + // align_qc_lock.update(&ali1_, &header); + // align_qc_lock.update(&ali2_, &header); + self.aligner.update_qc(&align_qc, &header, &ali1_); + self.aligner.update_qc(&align_qc, &header, &ali2_); + let res1 = self.aligner.crumple(ali1_); + let res2 = self.aligner.crumple(ali2_); + + (res1, res2) }) + .flat_map(|(res1, res2)| res1.into_iter().zip(res2)) .collect::>(); progress_bar.update(results.len()).unwrap(); Either::Right(results) @@ -210,24 +365,30 @@ impl FastqProcessor { .into_iter() .map(|rec| (rec.barcode.unwrap(), rec.read1.unwrap())) .unzip(); - let alignments: Vec<_> = self.aligner.align_reads(&mut reads).collect(); + let alignments: Vec<_> = self.aligner.align_reads(&mut reads); let results = barcodes .into_iter() .zip(alignments) .map(|(barcode, alignment)| { - let ali = add_cell_barcode( - &header, - &alignment, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - { - align_qc.lock().unwrap().update(&ali, &header); - } - ali + let ali = self + .aligner + .add_cell_barcode( + &header, + &alignment, + barcode.raw.sequence(), + barcode.raw.quality_scores(), + barcode.corrected.as_deref(), + ) + .unwrap(); + + // update qc + // align_qc.lock().unwrap().update(&ali, &header); + self.aligner.update_qc(&align_qc, &header, &ali); + self.aligner.crumple(ali) }) + .collect::>() + .into_iter() + .flatten() .collect::>(); progress_bar.update(results.len()).unwrap(); Either::Left(results)