Skip to content

Commit

Permalink
ref
Browse files Browse the repository at this point in the history
  • Loading branch information
katerinakazantseva committed Jun 5, 2024
1 parent e408590 commit dbfd603
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 62 deletions.
50 changes: 18 additions & 32 deletions strainy/clustering/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,28 @@
import multiprocessing
import pandas as pd
import pysam

from strainy.clustering.community_detection import find_communities
from strainy.clustering.cluster_postprocess import postprocess
import strainy.clustering.build_adj_matrix as matrix
import strainy.clustering.build_data as build_data
from strainy.clustering import build_data as build_data
from strainy.clustering import build_adj_matrix as matrix
from strainy.params import *
import strainy.gfa_operations.gfa_ops as gfa_ops


logger = logging.getLogger()


def clusters_vis_stats(G, cl, clN, uncl, bam, edge, I, AF):
def clusters_vis_stats(G, cl, clN, uncl, edge, I):
"""Creates connection graph vis and statistics"""
cl.loc[cl['Cluster'] == 'NA', 'Cluster'] = 0
cmap = plt.get_cmap('viridis')
clusters=sorted(set(cl['Cluster'].astype(int)))
cmap = cmap(np.linspace(0, 1, len(clusters)))
colors = {}
try:
clusters.remove('0')
except:
KeyError
except KeyError:
pass
colors[0] = "#505050"
i = 0

Expand All @@ -44,14 +44,14 @@ def clusters_vis_stats(G, cl, clN, uncl, bam, edge, I, AF):
except AttributeError: #incompatability with scipy < 1.8
pass

ln = pysam.samtools.coverage("-r", edge, bam, "--no-header").split()[4]
cov = pysam.samtools.coverage("-r", edge, bam, "--no-header").split()[6]
ln = pysam.samtools.coverage("-r", edge, StRainyArgs().bam, "--no-header").split()[4]
cov = pysam.samtools.coverage("-r", edge, StRainyArgs().bam, "--no-header").split()[6]
plt.suptitle(str(edge) + " coverage:" + str(cov) + " length:" + str(ln) + " clN:" + str(clN))
plt.savefig("%s/graphs/graph_%s_%s_%s.png" % (StRainyArgs().output_intermediate, edge, I, AF), format="PNG", dpi=300)
plt.savefig("%s/graphs/graph_%s_%s_%s.png" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF), format="PNG", dpi=300)
plt.close()

# Calculate statistics
logger.debug("Summary for: " + edge)
logger.debug("Summary for: %s" + edge)
logger.debug("Clusters found: " + str(clN))
logger.debug("Reads unclassified: " + str(uncl))
logger.debug("Number of reads in each cluster: ")
Expand All @@ -61,22 +61,17 @@ def clusters_vis_stats(G, cl, clN, uncl, bam, edge, I, AF):
def cluster(i, flye_consensus):
edge = StRainyArgs().edges_to_phase[i]
Rcl=StRainyArgs().Rcl
AF=StRainyArgs().AF
R=Rcl/2
logger.info("### Reading SNPs...")
SNP_pos = build_data.read_snp(StRainyArgs().snp, edge, StRainyArgs().bam, AF)


SNP_pos = build_data.read_snp(StRainyArgs().snp, edge, StRainyArgs().bam, StRainyArgs().AF)
logger.info("### Reading Reads...")

data = build_data.read_bam(StRainyArgs().bam, edge, SNP_pos, min_mapping_quality,min_base_quality, min_al_len, de_max[StRainyArgs().mode])
cl = pd.DataFrame(columns=['ReadName', 'Cluster', 'Start'])

for key, value in data.items():
row = pd.DataFrame({'ReadName':[key], 'Cluster':['NA'], 'Start':[value['Start']]})
cl = pd.concat([cl, row])
cl = cl.reset_index(drop=True)


total_coverage = 0
edge_length = len(build_data.read_fasta_seq(StRainyArgs().fa, edge))
num_reads = len(data)
Expand All @@ -88,33 +83,26 @@ def cluster(i, flye_consensus):
if num_reads == 0:
return
if len(SNP_pos) == 0:
#data = read_bam(StRainyArgs().bam, edge, SNP_pos, min_mapping_quality, min_al_len, de_max[StRainyArgs().mode])
cl = pd.DataFrame(columns=['ReadName', 'Cluster', 'Start'])
for key, value in data.items():
row = pd.DataFrame({'ReadName':[key], 'Cluster':['NA'], 'Start':[value['Start']]})
cl = pd.concat([cl, row])
cl = cl.reset_index(drop=True)

cl['Cluster'] = 1
cl.to_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF))
cl.to_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF))
return

#CALCULATE DISTANCE and ADJ MATRIX
logger.info("### Calculatind distances/Building adj matrix...")
#try:
# m = pd.read_csv("%s/adj_M/adj_M_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF), index_col='ReadName')
#except FileNotFoundError:
m = matrix.build_adj_matrix(cl, data, SNP_pos, I, StRainyArgs().bam, edge, R)
if StRainyArgs().debug:
m.to_csv("%s/adj_M/adj_M_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF))
m.to_csv("%s/adj_M/adj_M_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF))
logger.info("### Removing overweighed egdes...")
m = matrix.remove_edges(m, R)
# BUILD graph and find clusters
logger.info("### Creating graph...")
m1 = m
m1.columns = range(0,len(cl['ReadName']))
m1.index=range(0,len(cl['ReadName']))
G = gfa_ops.from_pandas_adjacency_notinplace(matrix.change_w(m.transpose(), R))

logger.info("### Searching clusters...")
cluster_membership = find_communities(G)
clN = 0
Expand All @@ -125,13 +113,12 @@ def cluster(i, flye_consensus):
if len(group) > 3:
clN = clN + 1
cl.loc[group, 'Cluster'] = value
#cl['Cluster'][group] = value
else:
uncl = uncl + 1

logger.info(str(clN)+" clusters found")
if StRainyArgs().debug:
cl.to_csv("%s/clusters/clusters_before_splitting_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF))
cl.to_csv("%s/clusters/clusters_before_splitting_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF))

cl.loc[cl['Cluster'] == 'NA', 'Cluster'] = UNCLUSTERED_GROUP_N
if clN != 0:
Expand All @@ -140,10 +127,9 @@ def cluster(i, flye_consensus):
else:
counts = cl['Cluster'].value_counts(dropna=False)
cl = cl[~cl['Cluster'].isin(counts[counts < 6].index)]
#clN = len(set(cl.loc[cl['Cluster']!='NA']['Cluster'].values))
logger.info(str(clN) + " clusters after post-processing")
cl.to_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF))
cl.to_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF))

if StRainyArgs().debug:
logger.info("### Graph viz...")
clusters_vis_stats(G, cl, clN,uncl, StRainyArgs().bam, edge, I, AF)
clusters_vis_stats(G, cl, clN,uncl, edge, I)
61 changes: 31 additions & 30 deletions strainy/color_bam.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,61 @@
import csv
import pysam
import pandas as pd
import numpy as np
import os
import logging
import gfapy
import matplotlib.pyplot as plt
logging.getLogger("matplotlib.font_manager").disabled = True
import matplotlib as mt
import gfapy
import os

import pysam
import pandas as pd
import numpy as np
from strainy.params import *
logging.getLogger("matplotlib.font_manager").disabled = True


def write_bam(edge, I, AF, cl_file=None,file=None):
infile = pysam.AlignmentFile(StRainyArgs().bam, "rb")
if file==None:
outfile = pysam.AlignmentFile("%s/bam/coloredBAM_unitig_%s.bam" % (StRainyArgs().output_intermediate, edge), "wb", template=infile)
else:
outfile = pysam.AlignmentFile(file,"wb", template=infile)
if cl_file==None:
cl = pd.read_csv("%s/clusters/clusters_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, AF),keep_default_na=False)
else:
cl = pd.read_csv(cl_file,keep_default_na=False)
iter = infile.fetch(edge,until_eof=True)
def write_bam(edge, cl, infile,outfile):
"""Creates new bam file based on ifnfile and add YC tag to the alignment based on csv file"""
iterbam = infile.fetch(edge,until_eof=True)
cmap = plt.get_cmap("viridis")
cl.loc[cl["Cluster"] == "NA", "Cluster"] = 0
#clusters=sorted(set(cl["Cluster"].astype(int)))
clusters = set(cl["Cluster"])
cmap = cmap(np.linspace(0, 1, len(clusters)))
colors={}
i=0
colors[0] = "#505050"

try:
clusters.remove("0")
except: KeyError
except KeyError:
pass
for cluster in clusters:
colors[cluster] = mt.colors.to_hex(cmap[i])
i = i+1
cl_dict = dict(zip(cl.ReadName, cl.Cluster))

for read in iter:
for read in iterbam:
try:
#clN = int(cl_dict[str(read).split()[0]])
clN = cl_dict[str(read).split()[0]]
tag = colors[clN]
cl_n = cl_dict[str(read).split()[0]]
tag = colors[cl_n]
read.set_tag("YC", tag, replace=False)
outfile.write(read)
except (KeyError):
except KeyError:
continue
outfile.close()


def color(edge,cl_file=None,file=None):
"""Creates colored edge bam based on strainy csv file with clusters IDs by default"""
try:
write_bam(edge, I, StRainyArgs().AF,cl_file,file)
except (FileNotFoundError):
infile = pysam.AlignmentFile(StRainyArgs().bam, "rb")
if file is None:
outfile = pysam.AlignmentFile(
f"{StRainyArgs().output_intermediate}/bam/coloredBAM_unitig_{edge}.bam",
"wb", template=infile)
else:
outfile = pysam.AlignmentFile(file,"wb", template=infile)
if cl_file is None:
cl = pd.read_csv(
f"{StRainyArgs().output_intermediate}/clusters/clusters_{edge}_{I}_{StRainyArgs().AF}.csv",
keep_default_na=False)
else:
cl = pd.read_csv(cl_file,keep_default_na=False)
write_bam(edge,cl,infile,outfile)
except FileNotFoundError:
pass

0 comments on commit dbfd603

Please sign in to comment.