diff --git a/precellar/src/align.rs b/precellar/src/align.rs index 33d6f68..214f497 100644 --- a/precellar/src/align.rs +++ b/precellar/src/align.rs @@ -60,6 +60,7 @@ pub struct FastqProcessor { barcode_correct_prob: f64, // if the posterior probability of a correction // exceeds this threshold, the barcode will be corrected. // cellrange uses 0.975 for ATAC and 0.9 for multiome. + mismatch_in_barcode: usize, // The number of mismatches allowed in barcode } impl FastqProcessor { @@ -68,6 +69,7 @@ impl FastqProcessor { assay, aligner, current_modality: None, metrics: HashMap::new(), align_qc: HashMap::new(), mito_dna: HashSet::new(), barcode_correct_prob: 0.975, + mismatch_in_barcode: 1, } } @@ -127,22 +129,23 @@ impl FastqProcessor { let results = barcodes.into_par_iter().zip(alignments).map(|(barcode, (ali1, ali2))| { let corrected_barcode = corrector.correct( whitelist.get_barcode_counts(), - std::str::from_utf8(barcode.sequence()).unwrap(), - barcode.quality_scores() + barcode.sequence(), + barcode.quality_scores(), + self.mismatch_in_barcode, ).ok(); let ali1_ = add_cell_barcode( &header, &ali1, - std::str::from_utf8(barcode.sequence()).unwrap(), + barcode.sequence(), barcode.quality_scores(), - corrected_barcode.as_ref().map(|x| x.as_str()) + corrected_barcode, ).unwrap(); let ali2_ = add_cell_barcode( &header, &ali2, - std::str::from_utf8(barcode.sequence()).unwrap(), + barcode.sequence(), barcode.quality_scores(), - corrected_barcode.as_ref().map(|x| x.as_str()) + corrected_barcode, ).unwrap(); { let mut align_qc_lock = align_qc.lock().unwrap(); @@ -160,16 +163,17 @@ impl FastqProcessor { let results = barcodes.into_par_iter().zip(alignments).map(|(barcode, alignment)| { let corrected_barcode = corrector.correct( whitelist.get_barcode_counts(), - std::str::from_utf8(barcode.sequence()).unwrap(), - barcode.quality_scores() + barcode.sequence(), + barcode.quality_scores(), + self.mismatch_in_barcode, ).ok(); let ali = add_cell_barcode( &header, &alignment, - std::str::from_utf8(barcode.sequence()).unwrap(), + barcode.sequence(), barcode.quality_scores(), - corrected_barcode.as_ref().map(|x| x.as_str()) + corrected_barcode, ).unwrap(); { align_qc.lock().unwrap().update(&ali, &header); } ali @@ -226,7 +230,7 @@ impl FastqProcessor { if read.is_reverse() { record = rev_compl_fastq_record(record); } - whitelist.count_barcode(std::str::from_utf8(record.sequence()).unwrap(), record.quality_scores()); + whitelist.count_barcode(record.sequence(), record.quality_scores()); }); self.metrics.entry(modality).or_insert_with(Metrics::default) @@ -387,9 +391,9 @@ impl Iterator for FastqRecordChunk { fn add_cell_barcode( header: &sam::Header, record: &R, - ori_barcode: &str, + ori_barcode: &[u8], ori_qual: &[u8], - correct_barcode: Option<&str>, + correct_barcode: Option<&[u8]>, ) -> std::io::Result { let mut record_buf = RecordBuf::try_from_alignment_record(header, record)?; let data = record_buf.data_mut(); diff --git a/precellar/src/barcode.rs b/precellar/src/barcode.rs index e2f60d1..0073e08 100644 --- a/precellar/src/barcode.rs +++ b/precellar/src/barcode.rs @@ -1,13 +1,173 @@ use core::f64; -use std::collections::HashMap; +use std::{collections::HashMap, ops::{Deref, DerefMut}}; const BC_MAX_QV: u8 = 66; // This is the illumina quality value pub(crate) const BASE_OPTS: [u8; 4] = [b'A', b'C', b'G', b'T']; +/// A map of oligo species to their frequency in a given library. +#[derive(Debug, Clone)] +pub struct OligoFrequncy(HashMap, usize>); + +impl Deref for OligoFrequncy { + type Target = HashMap, usize>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for OligoFrequncy { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FromIterator<(Vec, usize)> for OligoFrequncy { + fn from_iter, usize)>>(iter: I) -> Self { + Self(iter.into_iter().collect()) + } +} + +impl OligoFrequncy { + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// The likelihood of a query oligo being generated by the library. + /// If the query is present in the library, the likelihood is 1.0. + /// Otherwise, the likelihood is calculated as + pub fn likelihood<'a>(&'a self, query: &'a [u8], qual: &[u8], n_mismatch: usize) -> (&'a [u8], f64) { + if n_mismatch == 0 { + if self.0.contains_key(query) { + (query, 1.0) + } else { + (query, 0.0) + } + } else if n_mismatch == 1 { + self.likelihood1(query, qual) + } else if n_mismatch == 2 { + self.likelihood2(query, qual) + } else { + todo!() + } + } + + /// The likelihood up to 2 mismatches. + fn likelihood2<'a>(&'a self, query: &'a [u8], qual: &[u8]) -> (&'a [u8], f64) { + if self.0.contains_key(query) { + return (query, 1.0); + } + + let mut best_option = None; + let mut total_likelihood = 0.0; + let mut query_bytes = query.to_vec(); + + // Single mismatch loop + for (pos1, &qv1) in qual.iter().enumerate() { + let qv1 = qv1.min(BC_MAX_QV); + let original1 = query_bytes[pos1]; + + for base1 in BASE_OPTS { + if base1 != original1 { + query_bytes[pos1] = base1; + + // Check for 1-mismatch barcode match + if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) { + let bc_count = 1 + raw_count; + let likelihood = bc_count as f64 * error_probability(qv1); + update_best_option(&mut best_option, likelihood, key); + total_likelihood += likelihood; + } + + // Loop for the second mismatch + for (pos2, &qv2) in qual.iter().enumerate().skip(pos1 + 1) { + let qv2 = qv2.min(BC_MAX_QV); + let original2 = query_bytes[pos2]; + + for val2 in BASE_OPTS { + if val2 != original2 { + query_bytes[pos2] = val2; + + // Check for 2-mismatch barcode match + if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) { + let bc_count = 1 + raw_count; + let likelihood = bc_count as f64 * error_probability(qv1) * error_probability(qv2); + update_best_option(&mut best_option, likelihood, key); + total_likelihood += likelihood; + } + } + } + // Restore original value for second position + query_bytes[pos2] = original2; + } + } + } + // Restore original value for first position + query_bytes[pos1] = original1; + } + + if let Some((best_like, best_bc)) = best_option { + (best_bc, best_like / total_likelihood) + } else { + (query, 0.0) + } + } + + + /// The likehood up to 1 mismatch. + fn likelihood1<'a>(&'a self, query: &'a [u8], qual: &[u8]) -> (&'a [u8], f64) { + if self.0.contains_key(query) { + return (query, 1.0) + } + + let mut best_option = None; + let mut total_likelihood = 0.0; + let mut query_bytes = query.to_vec(); + for (pos, &qv) in qual.iter().enumerate() { + let qv = qv.min(BC_MAX_QV); + let existing = query_bytes[pos]; + for val in BASE_OPTS { + if val != existing { + query_bytes[pos] = val; + if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) { + let bc_count = 1 + raw_count; + let likelihood = bc_count as f64 * error_probability(qv); + update_best_option(&mut best_option, likelihood, key); + total_likelihood += likelihood; + } + } + } + query_bytes[pos] = existing; + } + + if let Some((best_like, best_bc)) = best_option { + (best_bc, best_like / total_likelihood) + } else { + (query, 0.0) + } + } +} + +// Helper function to update the best option +fn update_best_option<'a>( + best_option: &mut Option<(f64, &'a [u8])>, + likelihood: f64, + key: &'a [u8], +) { + match best_option { + None => *best_option = Some((likelihood, key)), + Some(ref old_best) => { + if old_best.0 < likelihood { + *best_option = Some((likelihood, key)); + } + }, + } +} + #[derive(Debug)] pub struct Whitelist { whitelist_exists: bool, - barcode_counts: HashMap, + barcode_counts: OligoFrequncy, mismatch_count: usize, pub(crate) total_count: usize, pub(crate) total_base_count: u64, @@ -19,7 +179,7 @@ impl Whitelist { pub fn empty() -> Self { Self { whitelist_exists: false, - barcode_counts: HashMap::new(), + barcode_counts: OligoFrequncy::new(), mismatch_count: 0, total_count: 0, total_base_count: 0, @@ -29,15 +189,15 @@ impl Whitelist { } /// Create a new whitelist from an iterator of strings. - pub fn new, S: ToString>(iter: I) -> Self { + pub fn new, S: Into>>(iter: I) -> Self { let mut whitelist = Self::empty(); whitelist.whitelist_exists = true; - whitelist.barcode_counts = iter.into_iter().map(|x| (x.to_string(), 0)).collect(); + whitelist.barcode_counts = iter.into_iter().map(|x| (x.into(), 0)).collect(); whitelist } /// Update the barcode counter with a barcode and its quality scores. - pub fn count_barcode(&mut self, barcode: &str, barcode_qual: &[u8]) { + pub fn count_barcode(&mut self, barcode: &[u8], barcode_qual: &[u8]) { if self.whitelist_exists { if let Some(count) = self.barcode_counts.get_mut(barcode) { *count += 1; @@ -45,7 +205,7 @@ impl Whitelist { self.mismatch_count += 1; } } else if barcode.len() > 1 { - *self.barcode_counts.entry(barcode.to_string()).or_insert(0) += 1; + *self.barcode_counts.entry(barcode.to_vec()).or_insert(0) += 1; } else { self.mismatch_count += 1; } @@ -62,7 +222,7 @@ impl Whitelist { } } - pub fn get_barcode_counts(&self) -> &HashMap { + pub fn get_barcode_counts(&self) -> &OligoFrequncy { &self.barcode_counts } @@ -131,60 +291,30 @@ impl BarcodeCorrector { /// 3) If the whitelist does not exist, the barcode is always valid. /// /// Return the corrected barcode - pub fn correct(&self, barcode_counts: &HashMap, barcode: &str, qual: &[u8]) -> Result { - let expected_errors: f64 = qual.iter().map(|&q| probability(q)).sum(); + pub fn correct<'a>(&'a self, barcode_counts: &'a OligoFrequncy, barcode: &'a [u8], qual: &[u8], n_mismatch: usize) -> Result<&[u8], BarcodeError> { + let expected_errors: f64 = qual.iter().map(|&q| error_probability(q)).sum(); if expected_errors >= self.max_expected_errors { return Err(BarcodeError::ExceedExpectedError(expected_errors)); } - if barcode_counts.contains_key(barcode) { - return Ok(barcode.to_string()); - } - - let mut best_option = None; - let mut total_likelihood = 0.0; - let mut barcode_bytes = barcode.as_bytes().to_vec(); - for (pos, &qv) in qual.iter().enumerate() { - let qv = qv.min(BC_MAX_QV); - let existing = barcode_bytes[pos]; - for val in BASE_OPTS { - if val != existing { - barcode_bytes[pos] = val; - let bc = std::str::from_utf8(&barcode_bytes).unwrap(); - if let Some(raw_count) = barcode_counts.get(bc) { - let bc_count = 1 + raw_count; - let prob_edit = probability(qv); - let likelihood = bc_count as f64 * prob_edit; - match best_option { - None => best_option = Some((likelihood, bc.to_string())), - Some(ref old_best) => { - if old_best.0 < likelihood { - best_option = Some((likelihood, bc.to_string())); - } - }, - } - total_likelihood += likelihood; - } - } - } - barcode_bytes[pos] = existing; - } - if let Some((best_like, best_bc)) = best_option { - if best_like / total_likelihood >= self.bc_confidence_threshold { - return Ok(best_bc) - } + let (bc, prob) = barcode_counts.likelihood(barcode, qual, n_mismatch); + if prob <= 0.0 { + Err(BarcodeError::NoMatch) + } else if prob >= self.bc_confidence_threshold { + Ok(bc) + } else { + Err(BarcodeError::LowConfidence(prob)) } - Err(BarcodeError::NoMatch) } } /// Barcode correction problem: Given a whitelist of barcodes, and a sequenced barcode with quality scores, /// decide which barcode in the whitelist generated the sequenced barcode. -/// Convert quality scores to base-calling error probabilities. -/// This is interpreted as the likelihood of being a valid barcode base. +/// Convert Illumina quality scores to base-calling error probabilities, i.e., +/// the probability of an incorrect base call. #[inline(always)] -fn probability(qual: u8) -> f64 { +fn error_probability(qual: u8) -> f64 { let offset = 33.0; // Illumina quality score offset 10f64.powf(-((qual as f64 - offset) / 10.0)) -} \ No newline at end of file +}