From a545e5e2eb46b6fdf84e95ff5445b6fd53f60c33 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Tue, 5 Nov 2024 22:34:07 +0800 Subject: [PATCH] add StarAligner --- precellar/Cargo.toml | 1 + precellar/src/align.rs | 265 ++++++++++++++++++++++++++--------------- python/Cargo.toml | 1 + python/src/lib.rs | 2 +- 4 files changed, 174 insertions(+), 95 deletions(-) diff --git a/precellar/Cargo.toml b/precellar/Cargo.toml index 8b7e287..1adfcca 100644 --- a/precellar/Cargo.toml +++ b/precellar/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" anyhow = "1.0" bed-utils = "0.5.1" bwa-mem2 = { git = "https://github.com/regulatory-genomics/bwa-mem2-rust.git", rev = "07eda9b9c2815ae52b3fa30b01de0e19fae31fe0" } +star-aligner = { git = "https://github.com/regulatory-genomics/star-aligner", rev = "f9915ea3afbac1e8f4773e2e7c22376f1549c3c7" } bstr = "1.0" either = "1.13" itertools = "0.13" diff --git a/precellar/src/align.rs b/precellar/src/align.rs index d186080..bb46fd1 100644 --- a/precellar/src/align.rs +++ b/precellar/src/align.rs @@ -10,32 +10,57 @@ use log::info; use noodles::sam::alignment::record_buf::data::field::value::Value; use noodles::sam::alignment::{record::data::field::tag::Tag, record_buf::RecordBuf, Record}; use noodles::{bam, fastq, sam}; +use rayon::iter::ParallelIterator; use seqspec::{Assay, Modality, Read, RegionId, RegionIndex, RegionType}; use smallvec::SmallVec; +use star_aligner::StarAligner; use std::collections::{HashMap, HashSet}; use std::io::BufRead; use std::ops::Range; use std::sync::{Arc, Mutex}; -pub trait Alinger { +pub trait AsIterator { + type Item; + type AsIter<'a>: Iterator where Self: 'a; + + fn as_iter(&self) -> Self::AsIter<'_>; +} + +impl AsIterator for RecordBuf { + type Item = RecordBuf; + type AsIter<'a> = std::iter::Once<&'a RecordBuf>; + + fn as_iter(&self) -> Self::AsIter<'_> { + std::iter::once(&self) + } +} + +impl AsIterator for Vec { + type Item = RecordBuf; + type AsIter<'a> = std::slice::Iter<'a, RecordBuf>; + + fn as_iter(&self) -> Self::AsIter<'_> { + self.iter() + } +} + +pub trait Aligner { + type AlignOutput: AsIterator; + fn chunk_size(&self) -> usize; fn header(&self) -> sam::Header; - fn align_reads( - &mut self, - records: &mut [fastq::Record], - ) -> impl ExactSizeIterator; + fn align_reads(&mut self, header: &sam::Header, records: Vec) -> Vec; - fn align_read_pairs( - &mut self, - records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator; + fn align_read_pairs(&mut self, header: &sam::Header, records: Vec) -> Vec<(Self::AlignOutput, Self::AlignOutput)>; } pub struct DummyAligner; -impl Alinger for DummyAligner { +impl Aligner for DummyAligner { + type AlignOutput = RecordBuf; + fn chunk_size(&self) -> usize { 0 } @@ -44,22 +69,18 @@ impl Alinger for DummyAligner { sam::Header::default() } - fn align_reads( - &mut self, - _records: &mut [fastq::Record], - ) -> impl ExactSizeIterator { - std::iter::empty() + fn align_reads(&mut self, _: &sam::Header, _: Vec) -> Vec { + Vec::new() } - fn align_read_pairs( - &mut self, - _records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator { - std::iter::empty() + fn align_read_pairs(&mut self, _: &sam::Header, _: Vec) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + Vec::new() } } -impl Alinger for BurrowsWheelerAligner { +impl Aligner for BurrowsWheelerAligner { + type AlignOutput = RecordBuf; + fn chunk_size(&self) -> usize { self.chunk_size() } @@ -68,18 +89,125 @@ impl Alinger for BurrowsWheelerAligner { self.get_sam_header() } - fn align_reads( - &mut self, - records: &mut [fastq::Record], - ) -> impl ExactSizeIterator { - self.align_reads(records) + fn align_reads(&mut self, header: &sam::Header, records: Vec) -> Vec { + let (info, mut reads): (Vec<_>, Vec<_>) = records + .into_iter() + .map(|rec| ((rec.barcode.unwrap(), rec.umi), rec.read1.unwrap())) + .unzip(); + + // TODO: add UMI + self.align_reads(reads.as_mut_slice()).enumerate().map(|(i, alignment)| { + let (bc, umi) = info.get(i).unwrap(); + add_cell_barcode( + header, + &alignment, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap() + }).collect() } - fn align_read_pairs( - &mut self, - records: &mut [(fastq::Record, fastq::Record)], - ) -> impl ExactSizeIterator { - self.align_read_pairs(records) + fn align_read_pairs(&mut self, header: &sam::Header, records: Vec) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + let (info, mut reads): (Vec<_>, Vec<_>) = records + .into_iter() + .map(|rec| { + ( + (rec.barcode.unwrap(), rec.umi), + (rec.read1.unwrap(), rec.read2.unwrap()), + ) + }) + .unzip(); + self.align_read_pairs(&mut reads).enumerate().map(|(i, (ali1, ali2))| { + let (bc, umi) = info.get(i).unwrap(); + let ali1_ = add_cell_barcode( + &header, + &ali1, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap(); + let ali2_ = add_cell_barcode( + &header, + &ali2, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap(); + (ali1_, ali2_) + }).collect() + } +} + +impl Aligner for StarAligner { + type AlignOutput = Vec; + + fn chunk_size(&self) -> usize { + 0 + } + + fn header(&self) -> sam::Header { + self.get_header().clone() + } + + fn align_reads(&mut self, header: &sam::Header, records: Vec) -> Vec { + let (info, mut reads): (Vec<_>, Vec<_>) = records + .into_iter() + .map(|rec| ((rec.barcode.unwrap(), rec.umi), rec.read1.unwrap())) + .unzip(); + + // TODO: StarAligner can expose a method to align a single read instead of a batch, + // so that barcode and UMI processing can be done in parallel. + StarAligner::align_reads(self, reads.as_mut_slice()) + .collect::>().into_iter().enumerate().map(|(i, alignment)| { + let (bc, umi) = info.get(i).unwrap(); + alignment.into_iter().map(|x| + add_cell_barcode( + header, + &x, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap() + ).collect() + }).collect() + } + + fn align_read_pairs(&mut self, header: &sam::Header, records: Vec) -> Vec<(Self::AlignOutput, Self::AlignOutput)> { + let (info, mut reads): (Vec<_>, Vec<_>) = records + .into_iter() + .map(|rec| { + ( + (rec.barcode.unwrap(), rec.umi), + (rec.read1.unwrap(), rec.read2.unwrap()), + ) + }) + .unzip(); + StarAligner::align_read_pairs(self, &mut reads) + .collect::>().into_iter().enumerate().map(|(i, (ali1, ali2))| { + let (bc, umi) = info.get(i).unwrap(); + let ali1_ = ali1.into_iter().map(|x| add_cell_barcode( + &header, + &x, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap()).collect(); + let ali2_ = ali2.into_iter().map(|x| add_cell_barcode( + &header, + &x, + bc.raw.sequence(), + bc.raw.quality_scores(), + bc.corrected.as_deref(), + ) + .unwrap()).collect(); + (ali1_, ali2_) + }).collect() } } @@ -96,7 +224,7 @@ pub struct FastqProcessor { mismatch_in_barcode: usize, // The number of mismatches allowed in barcode } -impl FastqProcessor { +impl FastqProcessor { pub fn new(assay: Assay, aligner: A) -> Self { Self { assay, @@ -146,7 +274,7 @@ impl FastqProcessor { pub fn gen_barcoded_alignments( &mut self, - ) -> impl Iterator, Vec<(RecordBuf, RecordBuf)>>> + '_ { + ) -> impl Iterator, Vec<(A::AlignOutput, A::AlignOutput)>>> + '_ { let fq_reader = self.gen_barcoded_fastq(true); let is_paired = fq_reader.is_paired_end(); @@ -164,71 +292,20 @@ impl FastqProcessor { let mut progress_bar = tqdm!(total = fq_reader.total_reads.unwrap_or(0)); let fq_reader = VectorChunk::new(fq_reader, self.aligner.chunk_size()); fq_reader.map(move |data| { + let mut align_qc_lock = align_qc.lock().unwrap(); if is_paired { - let (barcodes, mut reads): (Vec<_>, Vec<_>) = data - .into_iter() - .map(|rec| { - ( - rec.barcode.unwrap(), - (rec.read1.unwrap(), rec.read2.unwrap()), - ) - }) - .unzip(); - let alignments: Vec<_> = self.aligner.align_read_pairs(&mut reads).collect(); - let results = barcodes - .into_iter() - .zip(alignments) - .map(|(barcode, (ali1, ali2))| { - let ali1_ = add_cell_barcode( - &header, - &ali1, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - let ali2_ = add_cell_barcode( - &header, - &ali2, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - { - let mut align_qc_lock = align_qc.lock().unwrap(); - align_qc_lock.update(&ali1_, &header); - align_qc_lock.update(&ali2_, &header); - } - (ali1_, ali2_) - }) - .collect::>(); + let results: Vec<_> = self.aligner.align_read_pairs(&header, data); + results.iter().for_each(|(ali1, ali2)| { + ali1.as_iter().for_each(|x| align_qc_lock.update(x, &header)); + ali2.as_iter().for_each(|x| align_qc_lock.update(x, &header)); + }); progress_bar.update(results.len()).unwrap(); Either::Right(results) } else { - let (barcodes, mut reads): (Vec<_>, Vec<_>) = data - .into_iter() - .map(|rec| (rec.barcode.unwrap(), rec.read1.unwrap())) - .unzip(); - let alignments: Vec<_> = self.aligner.align_reads(&mut reads).collect(); - let results = barcodes - .into_iter() - .zip(alignments) - .map(|(barcode, alignment)| { - let ali = add_cell_barcode( - &header, - &alignment, - barcode.raw.sequence(), - barcode.raw.quality_scores(), - barcode.corrected.as_deref(), - ) - .unwrap(); - { - align_qc.lock().unwrap().update(&ali, &header); - } - ali - }) - .collect::>(); + let results: Vec<_> = self.aligner.align_reads(&header, data); + results.iter().for_each(|ali| { + ali.as_iter().for_each(|x| align_qc_lock.update(x, &header)); + }); progress_bar.update(results.len()).unwrap(); Either::Left(results) } diff --git a/python/Cargo.toml b/python/Cargo.toml index eaab163..51711d2 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0" bwa-mem2 = { git = "https://github.com/regulatory-genomics/bwa-mem2-rust.git", rev = "07eda9b9c2815ae52b3fa30b01de0e19fae31fe0" } +star-aligner = { git = "https://github.com/regulatory-genomics/star-aligner", rev = "f9915ea3afbac1e8f4773e2e7c22376f1549c3c7" } bstr = "1.0" either = "1.13" itertools = "0.13" diff --git a/python/src/lib.rs b/python/src/lib.rs index a823397..620913c 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -14,7 +14,7 @@ use std::{collections::HashMap, io::BufWriter, path::PathBuf, str::FromStr}; use ::precellar::{ align::{ - extend_fastq_record, Alinger, Barcode, DummyAligner, FastqProcessor, NameCollatedRecords, + extend_fastq_record, Aligner, Barcode, DummyAligner, FastqProcessor, NameCollatedRecords, }, fragment::FragmentGenerator, qc::{AlignQC, FragmentQC, Metrics},