Skip to content

Commit

Permalink
merged unitig bam
Browse files Browse the repository at this point in the history
  • Loading branch information
katerinakazantseva committed May 29, 2024
1 parent f28c950 commit e408590
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 43 deletions.
22 changes: 10 additions & 12 deletions strainy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import re
import subprocess
import argparse
import gfapy
import logging
import shutil

import cProfile
import gfapy


Expand Down Expand Up @@ -52,33 +51,33 @@ def get_processor_name():

def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

requiredNamed = parser.add_argument_group('Required named arguments')
requiredNamed.add_argument("-o", "--output", help="output directory",required=True)
requiredNamed.add_argument("-g", "--gfa", help="input gfa to uncollapse",required=True)
requiredNamed.add_argument("-m", "--mode", help="type of reads", choices=["hifi", "nano"], required=True)
requiredNamed.add_argument("-q", "--fastq",
help="fastq file with reads to phase / assemble",
requiredNamed.add_argument("-m", "--mode", help="type of reads", choices=["hifi", "nano"],
required=True)
requiredNamed.add_argument("-q", "--fastq",help="fastq file with reads to phase / assemble",
required=True)

parser.add_argument("-s", "--stage", help="stage to run: either phase, transform or e2e (phase + transform)",
choices=["phase", "transform", "e2e"], default="e2e")
parser.add_argument("--snp", help="path to vcf file with SNP calls to use", default=None)
parser.add_argument("-t", "--threads", help="number of threads to use", type=int, default=4)
parser.add_argument("-f", "--fasta", required=False, help=argparse.SUPPRESS)
parser.add_argument("-b", "--bam", help="path to indexed alignment in bam format",required=False)
parser.add_argument("--link-simplify", required=False, action="store_true", default=False, dest="link_simplify",
help="Enable agressive graph simplification")
parser.add_argument("--link-simplify", required=False, action="store_true", default=False,
dest="link_simplify",help="Enable agressive graph simplification")
parser.add_argument("--debug", required=False, action="store_true", default=False,
help="Generate extra output for debugging")
parser.add_argument("--unitig-split-length",
help="The length (in kb) which the unitigs that are longer will be split, set 0 to disable",
required=False,
type=int,
default=50)
parser.add_argument("--only-split",help="Do not run stRainy, only split long gfa unitigs", default='False', required=False)
parser.add_argument("--only-split",help="Do not run stRainy, only split long gfa unitigs", default='False',
required=False)
parser.add_argument("-d","--cluster-divergence",help="cluster divergence", type=float, default=0, required=False)
parser.add_argument("-a","--allele-frequency",help="Set allele frequency for internal caller only (pileup)", type=float, default=0.2, required=False)
parser.add_argument("-a","--allele-frequency",help="Set allele frequency for internal caller only (pileup)",
type=float, default=0.2, required=False)
parser.add_argument("--min-unitig-length",
help="The length (in kb) which the unitigs that are shorter will not be phased",
required=False,
Expand Down Expand Up @@ -129,7 +128,6 @@ def main():
elif args.stage == "transform":
sys.exit(transform_main(args))
elif args.stage == "e2e":
import cProfile
pr_phase = cProfile.Profile()
pr_phase.enable()
phase_main(args)
Expand Down
22 changes: 15 additions & 7 deletions strainy/color_bam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,21 @@
from strainy.params import *


def write_bam(edge, I, AF):
def write_bam(edge, I, AF, cl_file=None,file=None):
infile = pysam.AlignmentFile(StRainyArgs().bam, "rb")
outfile = pysam.AlignmentFile("%s/bam/coloredBAM_unitig_%s.bam" % (StRainyArgs().output_intermediate, edge), "wb", template=infile)
cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),keep_default_na=False)
if file==None:
outfile = pysam.AlignmentFile("%s/bam/coloredBAM_unitig_%s.bam" % (StRainyArgs().output_intermediate, edge), "wb", template=infile)
else:
outfile = pysam.AlignmentFile(file,"wb", template=infile)
if cl_file==None:
cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),keep_default_na=False)
else:
cl = pd.read_csv(cl_file,keep_default_na=False)
iter = infile.fetch(edge,until_eof=True)
cmap = plt.get_cmap("viridis")
cl.loc[cl["Cluster"] == "NA", "Cluster"] = 0
clusters=sorted(set(cl["Cluster"].astype(int)))
#clusters=sorted(set(cl["Cluster"].astype(int)))
clusters = set(cl["Cluster"])
cmap = cmap(np.linspace(0, 1, len(clusters)))
colors={}
i=0
Expand All @@ -35,7 +42,8 @@ def write_bam(edge, I, AF):

for read in iter:
try:
clN = int(cl_dict[str(read).split()[0]])
#clN = int(cl_dict[str(read).split()[0]])
clN = cl_dict[str(read).split()[0]]
tag = colors[clN]
read.set_tag("YC", tag, replace=False)
outfile.write(read)
Expand All @@ -44,9 +52,9 @@ def write_bam(edge, I, AF):
outfile.close()


def color(edge):
def color(edge,cl_file=None,file=None):
try:
write_bam(edge, I, StRainyArgs().AF)
write_bam(edge, I, StRainyArgs().AF,cl_file,file)
except (FileNotFoundError):
pass

58 changes: 36 additions & 22 deletions strainy/phase.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import multiprocessing
import pysam
import pickle
import os
import pickle
import sys
import subprocess
import multiprocessing
import logging
import shutil
import traceback
import time
import pysam

from strainy.clustering.cluster import cluster
from strainy.color_bam import color
Expand All @@ -24,13 +23,14 @@ def _thread_fun(i, shared_flye_consensus, args):
init_global_args_storage(args)

set_thread_logging(StRainyArgs().log_phase, "phase", multiprocessing.current_process().pid)
logger.info("\n\n\t == == Processing unitig " + str(StRainyArgs().edges_to_phase[i]) + " == == ")
logger.info("\n\n\t == == Processing unitig "
+ str(StRainyArgs().edges_to_phase[i]) + " == == ")

try:
cluster(i, shared_flye_consensus)
except Exception as e:
logger.error("Worker thread exception! " + str(e) + "\n" + traceback.format_exc())
raise e
except Exception as excpt:
logger.error("Worker thread exception! " + str(excpt) + "\n" + traceback.format_exc())
raise excpt

logger.debug("Thread worker function finished!")

Expand All @@ -40,7 +40,8 @@ def phase(edges, args):

empty_consensus_dict = {}
default_manager = multiprocessing.Manager()
shared_flye_consensus = FlyeConsensus(StRainyArgs().bam, StRainyArgs().fa, 1, empty_consensus_dict, default_manager)
shared_flye_consensus = FlyeConsensus(StRainyArgs().bam, StRainyArgs().fa, 1,
empty_consensus_dict, default_manager)
if StRainyArgs().threads == 1:
for i in range(len(edges)):
cluster(i, shared_flye_consensus)
Expand All @@ -62,38 +63,52 @@ def phase(edges, args):
return shared_flye_consensus.get_consensus_dict()


def color_bam(edges):
def color_bam(edges, transfrom_stage=False):
logger.info("Creating phased bam")
for e in edges:
color(e)
if transfrom_stage == False:
for edge in edges:
color(edge)
out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam")
final_aln = os.path.join(StRainyArgs().output, "alignment_phased.bam")

else:
for edge in edges:
color(edge, cl_file="%s/clusters/clusters_%s_%s_%s_MERGED.csv" %
(StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF),
file="%s/bam/merged/coloredBAM_unitig_%s_merged.bam" % (StRainyArgs().output_intermediate, edge))
out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam/merged")
final_aln = os.path.join(StRainyArgs().output, "alignment_phased_merged.bam")

out_bam_dir = os.path.join(StRainyArgs().output_intermediate, "bam")
final_aln = os.path.join(StRainyArgs().output, "alignment_phased.bam")

files_to_be_merged = []
for fname in subprocess.check_output(f'find {out_bam_dir} -name "*unitig*.bam"', shell = True, universal_newlines = True).split("\n"):
for fname in subprocess.check_output(f'find {out_bam_dir} -name "*unitig*.bam"',
shell = True, universal_newlines = True).split("\n"):
if len(fname):
files_to_be_merged.append(fname)

# Number of file to be merged could be > 4092, in which case samtools merge throws too many open files error
# Number of file to be merged could be > 4092,
# in which case samtools merge throws too many open files error
for i, bam_file in enumerate(files_to_be_merged):
# fetch the header and put it at the top of the file, for the first bam_file only
if i == 0:
subprocess.check_output(f'samtools view -H {bam_file} > {out_bam_dir}/coloredSAM.sam', shell = True)
subprocess.check_output(f'samtools view -H {bam_file} > '
f'{out_bam_dir}/coloredSAM.sam',shell = True)

# convert bam to sam, append to the file
subprocess.check_output(f'samtools view {bam_file} >> {out_bam_dir}/coloredSAM.sam', shell = True)
subprocess.check_output(f'samtools view {bam_file} >> {out_bam_dir}/coloredSAM.sam',
shell = True)

# convert the file to bam and sort
subprocess.check_output(f'samtools view -b {out_bam_dir}/coloredSAM.sam >> {out_bam_dir}/unsortedBAM.bam', shell = True)
subprocess.check_output(f'samtools view -b {out_bam_dir}/coloredSAM.sam >> '
f'{out_bam_dir}/unsortedBAM.bam',shell = True)
pysam.samtools.sort(f'{out_bam_dir}/unsortedBAM.bam', "-o", final_aln)
pysam.samtools.index(final_aln)

# remove unnecessary files
os.remove(f'{out_bam_dir}/unsortedBAM.bam')
os.remove(f'{out_bam_dir}/coloredSAM.sam')
for f in files_to_be_merged:
os.remove(f)
for file in files_to_be_merged:
os.remove(file)


def phase_main(args):
Expand All @@ -105,8 +120,7 @@ def phase_main(args):
"%s/bam/clusters" % StRainyArgs().output_intermediate,
"%s/flye_inputs" % StRainyArgs().output_intermediate,
"%s/flye_outputs" % StRainyArgs().output_intermediate
)

)
debug_dirs = ("%s/graphs/" % StRainyArgs().output_intermediate,
"%s/adj_M/" % StRainyArgs().output_intermediate
)
Expand Down
2 changes: 1 addition & 1 deletion strainy/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import subprocess
import os
import logging
import pysam
import gfapy
import subprocess

from strainy.params import StRainyArgs

Expand Down
28 changes: 27 additions & 1 deletion strainy/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import traceback
import csv

from strainy.color_bam import color
import strainy.clustering.build_adj_matrix as matrix
import strainy.clustering.cluster_postprocess as postprocess
import strainy.simplification.simplify_links as smpl
Expand All @@ -24,6 +24,7 @@
from strainy.reports.strainy_stats import strain_stats_report
from strainy.reports.call_variants import produce_strainy_vcf
from strainy.preprocessing import gfa_to_fasta
from strainy.phase import color_bam

logger = logging.getLogger()

Expand Down Expand Up @@ -1035,8 +1036,12 @@ def transform_main(args):
shutil.copyfile(out_clusters, strainy_utgs)

phased_graph = gfapy.Gfa.from_file(out_clusters) #parsing again because gfapy can"t copy

segs_unmerged=phased_graph.segment_names
gfapy.GraphOperations.merge_linear_paths(phased_graph)
clean_graph(phased_graph)
segs_merged = phased_graph.segment_names

out_merged = os.path.join(StRainyArgs().output_intermediate, "20_extended_haplotypes.gfa")
gfapy.Gfa.to_file(phased_graph, out_merged)

Expand Down Expand Up @@ -1066,5 +1071,26 @@ def transform_main(args):
produce_strainy_vcf(StRainyArgs().fa, strain_utgs_fasta, StRainyArgs().threads,
strain_utgs_aln, open(vcf_strain_variants, "w"))

logger.info("Update clusters and colored BAM")
merged_clusters={}
AF = StRainyArgs().AF
#I = StRainyArgs().I
for seg in [i for i in segs_unmerged if i not in segs_merged]:
seg_merged = [k for k in segs_merged if re.search(seg, k) != None][0]
merged_clusters[seg] = seg_merged

for edge in StRainyArgs().edges:
try:
cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),
keep_default_na=False)
clusters = sorted(set(cl['Cluster']))
for cluster in clusters:
seg=str(edge)+"_"+str(cluster)
if seg in merged_clusters.keys():
cl.loc[cl['Cluster'] == cluster, 'Cluster'] = merged_clusters[seg]
cl.to_csv("%s/clusters/clusters_%s_%s_%s_MERGED.csv" % (StRainyArgs().output_intermediate, edge, I, AF))
except(FileNotFoundError): pass
os.makedirs("%s/bam/merged/" % StRainyArgs().output_intermediate, exist_ok=True)
color_bam(StRainyArgs().edges, transfrom_stage=True)
flye_consensus.print_cache_statistics()
logger.info("### Done!")

0 comments on commit e408590

Please sign in to comment.