From 8665d637c452ace1d0bf7248838237ee000a47e4 Mon Sep 17 00:00:00 2001 From: Kai Zhang Date: Fri, 18 Oct 2024 22:12:26 +0800 Subject: [PATCH] use RwLock --- precellar/src/align.rs | 8 ++-- python/src/pyseqspec.rs | 4 +- seqspec/src/lib.rs | 53 +++++++++++++++----------- seqspec/src/read.rs | 9 ++++- seqspec/src/region.rs | 84 +++++++++++++++++++++++------------------ 5 files changed, 91 insertions(+), 67 deletions(-) diff --git a/precellar/src/align.rs b/precellar/src/align.rs index 214f497..07d6891 100644 --- a/precellar/src/align.rs +++ b/precellar/src/align.rs @@ -240,14 +240,14 @@ impl FastqProcessor { } fn get_whitelist(&self) -> Result { - let regions: Vec<_> = self.assay.library_spec.get_modality(&self.modality()).unwrap() - .subregions.iter().filter(|r| r.region_type.is_barcode()).collect(); + let regions = self.assay.library_spec.get_modality(&self.modality()).unwrap().read().unwrap(); + let regions: Vec<_> = regions.subregions.iter().filter(|r| r.read().unwrap().region_type.is_barcode()).collect(); if regions.len() != 1 { bail!("Expecting exactly one barcode region, found {}", regions.len()); } let region = regions[0]; - if region.sequence_type == SequenceType::Onlist { - Ok(Whitelist::new(region.onlist.as_ref().unwrap().read()?)) + if region.read().unwrap().sequence_type == SequenceType::Onlist { + Ok(Whitelist::new(region.read().unwrap().onlist.as_ref().unwrap().read()?)) } else { Ok(Whitelist::empty()) } diff --git a/python/src/pyseqspec.rs b/python/src/pyseqspec.rs index 4a46363..6225b37 100644 --- a/python/src/pyseqspec.rs +++ b/python/src/pyseqspec.rs @@ -188,7 +188,7 @@ impl Assay { } let tree = Tree::new("".to_string()).with_leaves( - assay.library_spec.modalities().map(|region| build_tree(region, &read_list)) + assay.library_spec.modalities().map(|region| build_tree(®ion.read().unwrap(), &read_list)) ); format!("{}", tree) } @@ -212,7 +212,7 @@ fn build_tree(region: &Region, read_list: &HashMap>) -> Tree< format!("{}({})", id, len) }; Tree::new(label) - .with_leaves(region.subregions.iter().map(|child| build_tree(child, read_list))) + .with_leaves(region.subregions.iter().map(|child| build_tree(&child.read().unwrap(), read_list))) } fn format_read(read: &Read) -> String { diff --git a/seqspec/src/lib.rs b/seqspec/src/lib.rs index 5f98ec9..2dc0bd1 100644 --- a/seqspec/src/lib.rs +++ b/seqspec/src/lib.rs @@ -13,7 +13,7 @@ use noodles::fastq; use serde::{Deserialize, Deserializer, Serialize}; use serde_yaml::{self, Value}; use utils::open_file_for_write; -use std::{fs, path::PathBuf, str::FromStr, sync::Arc}; +use std::{fs, path::PathBuf, str::FromStr, sync::{Arc, RwLock}}; use anyhow::{bail, anyhow, Result}; use std::path::Path; @@ -65,8 +65,8 @@ impl Assay { }); } }); - self.library_spec.regions_mut().for_each(|region| { - if let Some(onlist) = &mut Arc::::get_mut(region).unwrap().onlist { + self.library_spec.regions().for_each(|region| { + if let Some(onlist) = &mut region.write().unwrap().onlist { if let Err(e) = onlist.normalize_path(base_dir) { warn!("{}", e); } @@ -85,23 +85,22 @@ impl Assay { }); } }); - /* - self.library_spec.regions_mut().for_each(|region| { - if let Some(onlist) = &mut Arc::::get_mut(region).unwrap().onlist { + self.library_spec.regions().for_each(|region| { + if let Some(onlist) = &mut region.write().unwrap().onlist { if let Err(e) = onlist.unnormalize_path(base_dir.as_ref()) { warn!("Failed to unnormalize path: {}", e); } } }); - */ } /// Add default Illumina reads to the sequence spec. pub fn add_illumina_reads(&mut self, modality: Modality, read_len: usize, forward_strand_workflow: bool) -> Result<()> { - fn advance_until(iterator: &mut std::slice::Iter<'_, Arc>, f: fn(&Region) -> bool) -> Option<(Arc, Vec>)> { + fn advance_until(iterator: &mut std::slice::Iter<'_, Arc>>, f: fn(&Region) -> bool) -> Option<(Arc>, Vec>>)> { let mut regions = Vec::new(); while let Some(next_region) = iterator.next() { - if f(next_region) { + let r = next_region.read().unwrap(); + if f(&r) { return Some((next_region.clone(), regions)) } else { regions.push(next_region.clone()); @@ -110,15 +109,15 @@ impl Assay { None } - fn get_length(regions: &[Arc], reverse: bool) -> usize { + fn get_length(regions: &[Arc>], reverse: bool) -> usize { if reverse { regions.iter() - .skip_while(|region| region.sequence_type == SequenceType::Fixed) - .map(|region| region.len().unwrap() as usize).sum() + .skip_while(|region| region.read().unwrap().sequence_type == SequenceType::Fixed) + .map(|region| region.read().unwrap().len().unwrap() as usize).sum() } else { regions.iter().rev() - .skip_while(|region| region.sequence_type == SequenceType::Fixed) - .map(|region| region.len().unwrap() as usize).sum() + .skip_while(|region| region.read().unwrap().sequence_type == SequenceType::Fixed) + .map(|region| region.read().unwrap().len().unwrap() as usize).sum() } } @@ -140,10 +139,13 @@ impl Assay { self.delete_all_reads(modality); let regions = self.library_spec.get_modality(&modality).ok_or_else(|| anyhow!("Modality not found: {:?}", modality))?.clone(); + let regions = regions.read().unwrap(); let mut regions = regions.subregions.iter(); while let Some(current_region) = regions.next() { + let current_region = current_region.read().unwrap(); if is_p5(¤t_region) { if let Some((next_region, acc)) = advance_until(&mut regions, is_read1) { + let next_region = next_region.read().unwrap(); self.update_read::( &format!("{}-R1", modality.to_string()), Some(modality), @@ -155,7 +157,7 @@ impl Assay { let acc_len = get_length(acc.as_slice(), false); if acc_len > 0 { self.update_read::( - &format!("{}-I1", modality.to_string()), + &format!("{}-I2", modality.to_string()), Some(modality), Some(¤t_region.region_id), Some(false), @@ -166,7 +168,7 @@ impl Assay { let acc_len = get_length(acc.as_slice(), true); if acc_len > 0 { self.update_read::( - &format!("{}-I1", modality.to_string()), + &format!("{}-I2", modality.to_string()), Some(modality), Some(&next_region.region_id), Some(true), @@ -187,7 +189,7 @@ impl Assay { )?; if acc_len > 0 { self.update_read::( - &format!("{}-I2", modality.to_string()), + &format!("{}-I1", modality.to_string()), Some(modality), Some(¤t_region.region_id), Some(false), @@ -276,7 +278,7 @@ impl Assay { pub fn get_index(&self, read_id: &str) -> Option { let read = self.sequence_spec.get(read_id)?; let region = self.library_spec.get_parent(&read.primer_id)?; - read.get_index(region) + read.get_index(®ion.read().unwrap()) } pub fn iter_reads(&self, modality: Modality) -> impl Iterator { @@ -287,7 +289,8 @@ impl Assay { pub fn validate>(&self, read: &Read, dir: P) -> Result<()> { let region = self.library_spec.get_parent(&read.primer_id) .ok_or_else(|| anyhow!("Primer not found: {}", read.primer_id))?; - if let Some(index) = read.get_index(region) { + if let Some(index) = read.get_index(®ion.read().unwrap()) { + fs::create_dir_all(&dir)?; let output_valid = dir.as_ref().join(format!("{}.fq.zst", read.read_id)); let output_valid = open_file_for_write(output_valid, None, None, 8)?; let mut output_valid = fastq::io::Writer::new(output_valid); @@ -295,8 +298,11 @@ impl Assay { let output_other = open_file_for_write(output_other, None, None, 8)?; let mut output_other = fastq::io::Writer::new(output_other); if let Some(mut reader) = read.open() { - let mut validators: Vec<_> = index.index.iter().map(|(region_id, _, range)| { + let regions: Vec<_> = index.index.iter().map(|(region_id, _, range)| { let region = self.library_spec.get(region_id).unwrap(); + (region.read().unwrap(), range) + }).collect(); + let mut validators: Vec<_> = regions.iter().map(|(region, range)| { ReadValidator::new(region) .with_range(range.start as usize ..range.end as usize) .with_strand(read.strand) @@ -326,7 +332,7 @@ impl Assay { let region = self.library_spec.get_parent(&read.primer_id) .ok_or_else(|| anyhow!("Primer not found: {}", read.primer_id))?; // Check if the primer exists - if let Some(index) = read.get_index(region) { + if let Some(index) = read.get_index(®ion.read().unwrap()) { match index.readlen_info { ReadSpan::Covered | ReadSpan::Span(_) => {}, ReadSpan::NotEnough => { @@ -344,8 +350,11 @@ impl Assay { } if let Some(mut reader) = read.open() { - let mut validators: Vec<_> = index.index.iter().map(|(region_id, _, range)| { + let regions = index.index.iter().map(|(region_id, _, range)| { let region = self.library_spec.get(region_id).unwrap(); + (region.read().unwrap(), range) + }).collect::>(); + let mut validators: Vec<_> = regions.iter().map(|(region, range)| { ReadValidator::new(region) .with_range(range.start as usize ..range.end as usize) .with_strand(read.strand) diff --git a/seqspec/src/read.rs b/seqspec/src/read.rs index bc5e463..a8ea143 100644 --- a/seqspec/src/read.rs +++ b/seqspec/src/read.rs @@ -7,7 +7,7 @@ use indexmap::IndexMap; use serde::{Deserialize, Serialize, Serializer}; use std::collections::HashSet; use std::ops::{Deref, DerefMut}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::{io::{BufRead, BufReader}, ops::Range}; use anyhow::Result; use std::path::Path; @@ -107,6 +107,7 @@ impl Read { self.get_read_span( region.subregions.iter().rev() .skip_while(|region| { + let region = region.read().unwrap(); let found = region.region_type.is_sequencing_primer() && region.region_id == self.primer_id; if found { found_primer = true; @@ -118,6 +119,7 @@ impl Read { self.get_read_span( region.subregions.iter() .skip_while(|region| { + let region = region.read().unwrap(); let found = region.region_type.is_sequencing_primer() && region.region_id == self.primer_id; if found { found_primer = true; @@ -137,13 +139,14 @@ impl Read { /// Get the regions of the read. fn get_read_span<'a, I>(&self, mut regions: I) -> RegionIndex where - I: Iterator>, + I: Iterator>>, { let mut index = Vec::new(); let read_len = self.max_len; let mut cur_pos = 0; let mut readlen_info = ReadSpan::Covered; while let Some(region) = regions.next() { + let region = region.read().unwrap(); let region_id = region.region_id.clone(); let region_type = region.region_type; if region.is_fixed_length() { // Fixed-length region @@ -165,12 +168,14 @@ impl Read { } else if cur_pos + region.max_len < read_len { // Variable-length region and read is longer than max length index.push((region_id, region_type, cur_pos..cur_pos + region.max_len)); if let Some(next_region) = regions.next() { + let next_region = next_region.read().unwrap(); readlen_info = ReadSpan::ReadThrough(next_region.region_id.clone()); } break; } else { // Variable-length region and read is within the length range index.push((region_id, region_type, cur_pos..read_len)); if let Some(next_region) = regions.next() { + let next_region = next_region.read().unwrap(); readlen_info = ReadSpan::MayReadThrough(next_region.region_id.clone()); } break; diff --git a/seqspec/src/region.rs b/seqspec/src/region.rs index 56b7c6f..c803c4b 100644 --- a/seqspec/src/region.rs +++ b/seqspec/src/region.rs @@ -4,14 +4,23 @@ use crate::read::UrlType; use cached_path::Cache; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use std::{collections::{HashMap, HashSet}, io::BufRead, path::Path, sync::Arc}; +use std::{collections::{HashMap, HashSet}, io::BufRead, ops::Deref, path::Path, sync::{Arc, RwLock}}; use anyhow::Result; -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct LibSpec { - modalities: IndexMap>, - parent_map: HashMap>, - region_map: HashMap>, + modalities: IndexMap>>, + parent_map: HashMap>>, + region_map: HashMap>>, +} + +impl PartialEq for LibSpec { + fn eq(&self, other: &Self) -> bool { + self.modalities.keys().all(|k| { + self.modalities.get(k).unwrap().read().unwrap().deref() == + other.modalities.get(k).unwrap().read().unwrap().deref() + }) + } } impl Serialize for LibSpec { @@ -36,18 +45,20 @@ impl LibSpec { let mut parent_map = HashMap::new(); for region in regions { if let RegionType::Modality(modality) = region.region_type { - let region = Arc::new(region); + let region = Arc::new(RwLock::new(region)); + let region_id = region.read().unwrap().region_id.clone(); if modalities.insert(modality, region.clone()).is_some() { return Err(anyhow::anyhow!("Duplicate modality: {:?}", modality)); } - if region_map.insert(region.region_id.clone(), region.clone()).is_some() { - return Err(anyhow::anyhow!("Duplicate region id: {}", region.region_id)); + if region_map.insert(region_id.clone(), region.clone()).is_some() { + return Err(anyhow::anyhow!("Duplicate region id: {}", region_id)); } - for subregion in region.subregions.iter() { - if region_map.insert(subregion.region_id.clone(), subregion.clone()).is_some() { - return Err(anyhow::anyhow!("Duplicate region id: {}", subregion.region_id)); + for subregion in region.read().unwrap().subregions.iter() { + let id = subregion.read().unwrap().region_id.clone(); + if region_map.insert(id.clone(), subregion.clone()).is_some() { + return Err(anyhow::anyhow!("Duplicate region id: {}", id)); } - parent_map.insert(subregion.region_id.clone(), region.clone()); + parent_map.insert(id, region.clone()); } } else { return Err(anyhow::anyhow!("Top-level regions must be modalities")); @@ -57,45 +68,29 @@ impl LibSpec { } /// Iterate over all regions with modality type in the library. - pub fn modalities(&self) -> impl Iterator> { + pub fn modalities(&self) -> impl Iterator>> { self.modalities.values() } /// Iterate over all regions in the library. - pub fn regions(&self) -> impl Iterator> { + pub fn regions(&self) -> impl Iterator>> { self.region_map.values() } - pub fn regions_mut(&mut self) -> impl Iterator> { - self.region_map.values_mut() - } - - pub fn get_modality(&self, modality: &Modality) -> Option<&Arc> { + pub fn get_modality(&self, modality: &Modality) -> Option<&Arc>> { self.modalities.get(modality) } - pub fn get_modality_mut(&mut self, modality: &Modality) -> Option<&mut Arc> { - self.modalities.get_mut(modality) - } - - pub fn get(&self, region_id: &str) -> Option<&Arc> { + pub fn get(&self, region_id: &str) -> Option<&Arc>> { self.region_map.get(region_id) } - pub fn get_mut(&mut self, region_id: &str) -> Option<&mut Arc> { - self.region_map.get_mut(region_id) - } - - pub fn get_parent(&self, region_id: &str) -> Option<&Arc> { + pub fn get_parent(&self, region_id: &str) -> Option<&Arc>> { self.parent_map.get(region_id) } - - pub fn get_parent_mut(&mut self, region_id: &str) -> Option<&mut Arc> { - self.parent_map.get_mut(region_id) - } } -#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub struct Region { pub region_id: String, pub region_type: RegionType, @@ -106,7 +101,22 @@ pub struct Region { pub max_len: u32, pub onlist: Option, #[serde(rename = "regions", deserialize_with = "deserialize_regions")] - pub subregions: Vec>, + pub subregions: Vec>>, +} + +impl PartialEq for Region { + fn eq(&self, other: &Self) -> bool { + self.region_id == other.region_id && + self.region_type == other.region_type && + self.name == other.name && + self.sequence_type == other.sequence_type && + self.sequence == other.sequence && + self.min_len == other.min_len && + self.max_len == other.max_len && + self.onlist == other.onlist && + self.subregions.iter().zip(other.subregions.iter()) + .all(|(a, b)| a.read().unwrap().deref() == b.read().unwrap().deref()) + } } impl Region { @@ -127,12 +137,12 @@ impl Region { } } -fn deserialize_regions<'de, D>(deserializer: D) -> Result>, D::Error> +fn deserialize_regions<'de, D>(deserializer: D) -> Result>>, D::Error> where D: serde::Deserializer<'de>, { if let Some(regions) = Option::>::deserialize(deserializer)? { - Ok(regions.into_iter().map(Arc::new).collect()) + Ok(regions.into_iter().map(|x| Arc::new(RwLock::new(x))).collect()) } else { Ok(Vec::new()) }