Skip to content

Commit

Permalink
generalize barcode corrector to >1 mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Oct 17, 2024
1 parent d6206c9 commit dde1547
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 64 deletions.
30 changes: 17 additions & 13 deletions precellar/src/align.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub struct FastqProcessor<A> {
barcode_correct_prob: f64, // if the posterior probability of a correction
// exceeds this threshold, the barcode will be corrected.
// cellrange uses 0.975 for ATAC and 0.9 for multiome.
mismatch_in_barcode: usize, // The number of mismatches allowed in barcode
}

impl<A: Alinger> FastqProcessor<A> {
Expand All @@ -68,6 +69,7 @@ impl<A: Alinger> FastqProcessor<A> {
assay, aligner, current_modality: None, metrics: HashMap::new(),
align_qc: HashMap::new(), mito_dna: HashSet::new(),
barcode_correct_prob: 0.975,
mismatch_in_barcode: 1,
}
}

Expand Down Expand Up @@ -127,22 +129,23 @@ impl<A: Alinger> FastqProcessor<A> {
let results = barcodes.into_par_iter().zip(alignments).map(|(barcode, (ali1, ali2))| {
let corrected_barcode = corrector.correct(
whitelist.get_barcode_counts(),
std::str::from_utf8(barcode.sequence()).unwrap(),
barcode.quality_scores()
barcode.sequence(),
barcode.quality_scores(),
self.mismatch_in_barcode,
).ok();
let ali1_ = add_cell_barcode(
&header,
&ali1,
std::str::from_utf8(barcode.sequence()).unwrap(),
barcode.sequence(),
barcode.quality_scores(),
corrected_barcode.as_ref().map(|x| x.as_str())
corrected_barcode,
).unwrap();
let ali2_ = add_cell_barcode(
&header,
&ali2,
std::str::from_utf8(barcode.sequence()).unwrap(),
barcode.sequence(),
barcode.quality_scores(),
corrected_barcode.as_ref().map(|x| x.as_str())
corrected_barcode,
).unwrap();
{
let mut align_qc_lock = align_qc.lock().unwrap();
Expand All @@ -160,16 +163,17 @@ impl<A: Alinger> FastqProcessor<A> {
let results = barcodes.into_par_iter().zip(alignments).map(|(barcode, alignment)| {
let corrected_barcode = corrector.correct(
whitelist.get_barcode_counts(),
std::str::from_utf8(barcode.sequence()).unwrap(),
barcode.quality_scores()
barcode.sequence(),
barcode.quality_scores(),
self.mismatch_in_barcode,
).ok();
let ali = add_cell_barcode(
&header,
&alignment,
std::str::from_utf8(barcode.sequence()).unwrap(),
barcode.sequence(),

barcode.quality_scores(),
corrected_barcode.as_ref().map(|x| x.as_str())
corrected_barcode,
).unwrap();
{ align_qc.lock().unwrap().update(&ali, &header); }
ali
Expand Down Expand Up @@ -226,7 +230,7 @@ impl<A: Alinger> FastqProcessor<A> {
if read.is_reverse() {
record = rev_compl_fastq_record(record);
}
whitelist.count_barcode(std::str::from_utf8(record.sequence()).unwrap(), record.quality_scores());
whitelist.count_barcode(record.sequence(), record.quality_scores());
});

self.metrics.entry(modality).or_insert_with(Metrics::default)
Expand Down Expand Up @@ -387,9 +391,9 @@ impl<R: BufRead> Iterator for FastqRecordChunk<R> {
fn add_cell_barcode<R: Record>(
header: &sam::Header,
record: &R,
ori_barcode: &str,
ori_barcode: &[u8],
ori_qual: &[u8],
correct_barcode: Option<&str>,
correct_barcode: Option<&[u8]>,
) -> std::io::Result<RecordBuf> {
let mut record_buf = RecordBuf::try_from_alignment_record(header, record)?;
let data = record_buf.data_mut();
Expand Down
232 changes: 181 additions & 51 deletions precellar/src/barcode.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,173 @@
use core::f64;
use std::collections::HashMap;
use std::{collections::HashMap, ops::{Deref, DerefMut}};

const BC_MAX_QV: u8 = 66; // This is the illumina quality value
pub(crate) const BASE_OPTS: [u8; 4] = [b'A', b'C', b'G', b'T'];

/// A map of oligo species to their frequency in a given library.
#[derive(Debug, Clone)]
pub struct OligoFrequncy(HashMap<Vec<u8>, usize>);

impl Deref for OligoFrequncy {
type Target = HashMap<Vec<u8>, usize>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl DerefMut for OligoFrequncy {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

impl FromIterator<(Vec<u8>, usize)> for OligoFrequncy {
fn from_iter<I: IntoIterator<Item = (Vec<u8>, usize)>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}

impl OligoFrequncy {
pub fn new() -> Self {
Self(HashMap::new())
}

/// The likelihood of a query oligo being generated by the library.
/// If the query is present in the library, the likelihood is 1.0.
/// Otherwise, the likelihood is calculated as
pub fn likelihood<'a>(&'a self, query: &'a [u8], qual: &[u8], n_mismatch: usize) -> (&'a [u8], f64) {
if n_mismatch == 0 {
if self.0.contains_key(query) {
(query, 1.0)
} else {
(query, 0.0)
}
} else if n_mismatch == 1 {
self.likelihood1(query, qual)
} else if n_mismatch == 2 {
self.likelihood2(query, qual)
} else {
todo!()
}
}

/// The likelihood up to 2 mismatches.
fn likelihood2<'a>(&'a self, query: &'a [u8], qual: &[u8]) -> (&'a [u8], f64) {
if self.0.contains_key(query) {
return (query, 1.0);
}

let mut best_option = None;
let mut total_likelihood = 0.0;
let mut query_bytes = query.to_vec();

// Single mismatch loop
for (pos1, &qv1) in qual.iter().enumerate() {
let qv1 = qv1.min(BC_MAX_QV);
let original1 = query_bytes[pos1];

for base1 in BASE_OPTS {
if base1 != original1 {
query_bytes[pos1] = base1;

// Check for 1-mismatch barcode match
if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) {
let bc_count = 1 + raw_count;
let likelihood = bc_count as f64 * error_probability(qv1);
update_best_option(&mut best_option, likelihood, key);
total_likelihood += likelihood;
}

// Loop for the second mismatch
for (pos2, &qv2) in qual.iter().enumerate().skip(pos1 + 1) {
let qv2 = qv2.min(BC_MAX_QV);
let original2 = query_bytes[pos2];

for val2 in BASE_OPTS {
if val2 != original2 {
query_bytes[pos2] = val2;

// Check for 2-mismatch barcode match
if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) {
let bc_count = 1 + raw_count;
let likelihood = bc_count as f64 * error_probability(qv1) * error_probability(qv2);
update_best_option(&mut best_option, likelihood, key);
total_likelihood += likelihood;
}
}
}
// Restore original value for second position
query_bytes[pos2] = original2;
}
}
}
// Restore original value for first position
query_bytes[pos1] = original1;
}

if let Some((best_like, best_bc)) = best_option {
(best_bc, best_like / total_likelihood)
} else {
(query, 0.0)
}
}


/// The likehood up to 1 mismatch.
fn likelihood1<'a>(&'a self, query: &'a [u8], qual: &[u8]) -> (&'a [u8], f64) {
if self.0.contains_key(query) {
return (query, 1.0)
}

let mut best_option = None;
let mut total_likelihood = 0.0;
let mut query_bytes = query.to_vec();
for (pos, &qv) in qual.iter().enumerate() {
let qv = qv.min(BC_MAX_QV);
let existing = query_bytes[pos];
for val in BASE_OPTS {
if val != existing {
query_bytes[pos] = val;
if let Some((key, raw_count)) = self.0.get_key_value(&query_bytes) {
let bc_count = 1 + raw_count;
let likelihood = bc_count as f64 * error_probability(qv);
update_best_option(&mut best_option, likelihood, key);
total_likelihood += likelihood;
}
}
}
query_bytes[pos] = existing;
}

if let Some((best_like, best_bc)) = best_option {
(best_bc, best_like / total_likelihood)
} else {
(query, 0.0)
}
}
}

// Helper function to update the best option
fn update_best_option<'a>(
best_option: &mut Option<(f64, &'a [u8])>,
likelihood: f64,
key: &'a [u8],
) {
match best_option {
None => *best_option = Some((likelihood, key)),
Some(ref old_best) => {
if old_best.0 < likelihood {
*best_option = Some((likelihood, key));
}
},
}
}

#[derive(Debug)]
pub struct Whitelist {
whitelist_exists: bool,
barcode_counts: HashMap<String, usize>,
barcode_counts: OligoFrequncy,
mismatch_count: usize,
pub(crate) total_count: usize,
pub(crate) total_base_count: u64,
Expand All @@ -19,7 +179,7 @@ impl Whitelist {
pub fn empty() -> Self {
Self {
whitelist_exists: false,
barcode_counts: HashMap::new(),
barcode_counts: OligoFrequncy::new(),
mismatch_count: 0,
total_count: 0,
total_base_count: 0,
Expand All @@ -29,23 +189,23 @@ impl Whitelist {
}

/// Create a new whitelist from an iterator of strings.
pub fn new<I: IntoIterator<Item = S>, S: ToString>(iter: I) -> Self {
pub fn new<I: IntoIterator<Item = S>, S: Into<Vec<u8>>>(iter: I) -> Self {
let mut whitelist = Self::empty();
whitelist.whitelist_exists = true;
whitelist.barcode_counts = iter.into_iter().map(|x| (x.to_string(), 0)).collect();
whitelist.barcode_counts = iter.into_iter().map(|x| (x.into(), 0)).collect();
whitelist
}

/// Update the barcode counter with a barcode and its quality scores.
pub fn count_barcode(&mut self, barcode: &str, barcode_qual: &[u8]) {
pub fn count_barcode(&mut self, barcode: &[u8], barcode_qual: &[u8]) {
if self.whitelist_exists {
if let Some(count) = self.barcode_counts.get_mut(barcode) {
*count += 1;
} else {
self.mismatch_count += 1;
}
} else if barcode.len() > 1 {
*self.barcode_counts.entry(barcode.to_string()).or_insert(0) += 1;
*self.barcode_counts.entry(barcode.to_vec()).or_insert(0) += 1;
} else {
self.mismatch_count += 1;
}
Expand All @@ -62,7 +222,7 @@ impl Whitelist {
}
}

pub fn get_barcode_counts(&self) -> &HashMap<String, usize> {
pub fn get_barcode_counts(&self) -> &OligoFrequncy {
&self.barcode_counts
}

Expand Down Expand Up @@ -131,60 +291,30 @@ impl BarcodeCorrector {
/// 3) If the whitelist does not exist, the barcode is always valid.
///
/// Return the corrected barcode
pub fn correct(&self, barcode_counts: &HashMap<String, usize>, barcode: &str, qual: &[u8]) -> Result<String, BarcodeError> {
let expected_errors: f64 = qual.iter().map(|&q| probability(q)).sum();
pub fn correct<'a>(&'a self, barcode_counts: &'a OligoFrequncy, barcode: &'a [u8], qual: &[u8], n_mismatch: usize) -> Result<&[u8], BarcodeError> {
let expected_errors: f64 = qual.iter().map(|&q| error_probability(q)).sum();
if expected_errors >= self.max_expected_errors {
return Err(BarcodeError::ExceedExpectedError(expected_errors));
}
if barcode_counts.contains_key(barcode) {
return Ok(barcode.to_string());
}

let mut best_option = None;
let mut total_likelihood = 0.0;
let mut barcode_bytes = barcode.as_bytes().to_vec();
for (pos, &qv) in qual.iter().enumerate() {
let qv = qv.min(BC_MAX_QV);
let existing = barcode_bytes[pos];
for val in BASE_OPTS {
if val != existing {
barcode_bytes[pos] = val;
let bc = std::str::from_utf8(&barcode_bytes).unwrap();
if let Some(raw_count) = barcode_counts.get(bc) {
let bc_count = 1 + raw_count;
let prob_edit = probability(qv);
let likelihood = bc_count as f64 * prob_edit;
match best_option {
None => best_option = Some((likelihood, bc.to_string())),
Some(ref old_best) => {
if old_best.0 < likelihood {
best_option = Some((likelihood, bc.to_string()));
}
},
}
total_likelihood += likelihood;
}
}
}
barcode_bytes[pos] = existing;
}

if let Some((best_like, best_bc)) = best_option {
if best_like / total_likelihood >= self.bc_confidence_threshold {
return Ok(best_bc)
}
let (bc, prob) = barcode_counts.likelihood(barcode, qual, n_mismatch);
if prob <= 0.0 {
Err(BarcodeError::NoMatch)
} else if prob >= self.bc_confidence_threshold {
Ok(bc)
} else {
Err(BarcodeError::LowConfidence(prob))
}
Err(BarcodeError::NoMatch)
}
}

/// Barcode correction problem: Given a whitelist of barcodes, and a sequenced barcode with quality scores,
/// decide which barcode in the whitelist generated the sequenced barcode.

/// Convert quality scores to base-calling error probabilities.
/// This is interpreted as the likelihood of being a valid barcode base.
/// Convert Illumina quality scores to base-calling error probabilities, i.e.,
/// the probability of an incorrect base call.
#[inline(always)]
fn probability(qual: u8) -> f64 {
fn error_probability(qual: u8) -> f64 {
let offset = 33.0; // Illumina quality score offset
10f64.powf(-((qual as f64 - offset) / 10.0))
}
}

0 comments on commit dde1547

Please sign in to comment.