Skip to content

Commit

Permalink
ref
Browse files Browse the repository at this point in the history
  • Loading branch information
katerinakazantseva committed Jun 19, 2024
1 parent 0b9fc06 commit 508730d
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 52 deletions.
10 changes: 5 additions & 5 deletions strainy/clustering/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def cluster(i, flye_consensus):
Rcl=StRainyArgs().Rcl
R=Rcl/2
logger.info("### Reading SNPs...")
SNP_pos = build_data.read_snp(StRainyArgs().snp, edge, StRainyArgs().bam, StRainyArgs().AF)
snp_pos = build_data.read_snp(StRainyArgs().snp, edge, StRainyArgs().bam, StRainyArgs().AF)
logger.info("### Reading Reads...")
data = build_data.read_bam(StRainyArgs().bam, edge, SNP_pos, min_mapping_quality,min_base_quality, min_al_len, de_max[StRainyArgs().mode])
data = build_data.read_bam(StRainyArgs().bam, edge, snp_pos, min_mapping_quality,min_base_quality, min_al_len, de_max[StRainyArgs().mode])
cl = pd.DataFrame(columns=['ReadName', 'Cluster', 'Start'])

for key, value in data.items():
Expand All @@ -82,7 +82,7 @@ def cluster(i, flye_consensus):

if num_reads == 0:
return
if len(SNP_pos) == 0:
if len(snp_pos) == 0:
cl = pd.DataFrame(columns=['ReadName', 'Cluster', 'Start'])
for key, value in data.items():
row = pd.DataFrame({'ReadName':[key], 'Cluster':['NA'], 'Start':[value['Start']]})
Expand All @@ -94,7 +94,7 @@ def cluster(i, flye_consensus):

#CALCULATE DISTANCE and ADJ MATRIX
logger.info("### Calculatind distances/Building adj matrix...")
m = matrix.build_adj_matrix(cl, data, SNP_pos, I, StRainyArgs().bam, edge, R)
m = matrix.build_adj_matrix(cl, data, snp_pos, I, StRainyArgs().bam, edge, R)
if StRainyArgs().debug:
m.to_csv("%s/adj_M/adj_M_%s_%s_%s.csv" % (StRainyArgs().output_intermediate, edge, I, StRainyArgs().AF))
logger.info("### Removing overweighed egdes...")
Expand Down Expand Up @@ -127,7 +127,7 @@ def cluster(i, flye_consensus):
cl.loc[cl['Cluster'] == 'NA', 'Cluster'] = UNCLUSTERED_GROUP_N
if clN != 0:
logger.info("### Cluster post-processing...")
cl = postprocess(StRainyArgs().bam, cl, SNP_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge_cov)
cl = postprocess(StRainyArgs().bam, cl, snp_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge_cov)
else:
counts = cl['Cluster'].value_counts(dropna=False)
cl = cl[~cl['Cluster'].isin(counts[counts < 6].index)]
Expand Down
97 changes: 50 additions & 47 deletions strainy/clustering/cluster_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@
logger = logging.getLogger()


def split_cluster(cl,cluster, data,cons,clSNP, bam, edge, R, I,only_with_common_snip=True):
def split_cluster(cl,cluster, data,cons,clust_snp, bam, edge, R, I,only_with_common_snip=True):
"""Split cluster based on snp set (clust_snp)"""
#logging.debug("Split cluster: " + str(cluster)+ " "+ str(only_with_common_snip))
child_clusters = []
reads = sorted(set(cl.loc[cl["Cluster"] == cluster,"ReadName"].values))
if cluster == UNCLUSTERED_GROUP_N or cluster==UNCLUSTERED_GROUP_N2 \
or only_with_common_snip is False: #NA cluster
m = matrix.build_adj_matrix(cl[cl["Cluster"] == cluster],
data, clSNP, I, bam, edge, R, only_with_common_snip=False)
data, clust_snp, I, bam, edge, R, only_with_common_snip=False)
else:
m = matrix.build_adj_matrix(cl[cl["Cluster"] == cluster], data, clSNP, I, bam,edge,R)
m = matrix.build_adj_matrix(cl[cl["Cluster"] == cluster], data, clust_snp, I, bam,edge,R)
m = matrix.remove_edges(m, 1)
m.columns = range(0,len(cl[cl["Cluster"] == cluster]["ReadName"]))
m.index = range(0,len(cl[cl["Cluster"] == cluster]["ReadName"]))
Expand Down Expand Up @@ -93,13 +94,12 @@ def build_adj_matrix_clusters(edge,cons,cl,flye_consensus,
if m[second_cl][first_cl] == -1:
m.loc[first_cl, second_cl] = matrix.distance_clusters\
(edge, first_cl, second_cl, cons, cl,flye_consensus, only_with_common_snip)
#m[second_cl][first_cl] = matrix.distance_clusters(edge, first_cl, second_cl, cons, cl,flye_consensus, only_with_common_snip)
return m



def join_clusters(cons, cl, Rcl, edge, consensus,
only_with_common_snip=True,set_clusters=None, only_nested=False, transitive=False):
def join_clusters(cons, cl, Rcl, edge, consensus,only_with_common_snip=True,
set_clusters=None, only_nested=False, transitive=False):
MAX_VIS_SIZE = 500
CUT_OFF=3

Expand Down Expand Up @@ -141,7 +141,7 @@ def join_clusters(cons, cl, Rcl, edge, consensus,
for n_path in path_remove:
try:
G_vis.remove_edge(n_path[0], n_path[len(n_path)-1])
except:
except nx.exception.NetworkXError:
continue

lis = list(nx.topological_sort(nx.line_graph(G_vis)))
Expand All @@ -167,7 +167,7 @@ def join_clusters(cons, cl, Rcl, edge, consensus,


if transitive is True:
graph_str = str(nx.nx_agraph.to_agraph(G))
#graph_str = str(nx.nx_agraph.to_agraph(G))
not_visited=list(G.nodes).copy()
while not_visited:
node=not_visited[0]
Expand Down Expand Up @@ -205,7 +205,7 @@ def join_clusters(cons, cl, Rcl, edge, consensus,
for n_path in path_remove:
try:
G.remove_edge(n_path[0], n_path[len(n_path)-1])
except :
except nx.exception.NetworkXError:
continue
G.remove_edges_from(ebunch = to_remove)

Expand All @@ -232,7 +232,7 @@ def join_clusters(cons, cl, Rcl, edge, consensus,
nested[neighbor] = nodes
except KeyError:
continue
except:
except nx.exception.NetworkXError:
continue

groups = list(nx.connected_components(G))
Expand All @@ -253,53 +253,52 @@ def join_clusters(cons, cl, Rcl, edge, consensus,
return cl


def split_all(cl, cluster, data, cons,bam, edge, R, I, SNP_pos,reference_seq,type):
def split_all(cl, cluster, data, cons,bam, edge, R, I, snp_pos,reference_seq,type):
if type=="unclustered":
factor="Strange"
snp_set="clSNP"
snp_set="clust_snp"
if type=="lowheterozygosity":
factor="Strange2"
snp_set="clSNP2"
snp_set="clust_snp2"


if cons[cluster][factor] == 1:
clSNP = cons[cluster][snp_set]
res = split_cluster(cl, cluster, data,cons, clSNP, bam, edge, R, I)
clust_snp = cons[cluster][snp_set]
res = split_cluster(cl, cluster, data,cons, clust_snp, bam, edge, R, I)
new_cl_id_na=res[0]
clN =res[1]
build_data.cluster_consensuns(cl, new_cl_id_na, SNP_pos, data, cons, edge, reference_seq)
build_data.cluster_consensuns(cl, new_cl_id_na, snp_pos, data, cons, edge, reference_seq)

if clN != 0: #if clN==0 we dont need split NA cluster
split_cluster(cl, new_cl_id_na, data, cons,
cons[new_cl_id_na][snp_set], bam, edge, R, I, False)
clusters = sorted(set(cl.loc[cl["Cluster"] != "NA", "Cluster"].values))

if clN == 1: #STOP LOOP IF EXIST
build_data.cluster_consensuns(cl, new_cl_id_na + clN, SNP_pos, data, cons, edge, reference_seq)

build_data.cluster_consensuns(cl,
new_cl_id_na + clN, snp_pos, data, cons, edge, reference_seq)
for clstr in clusters:
if clstr not in cons:
build_data.cluster_consensuns(cl, clstr, SNP_pos, data, cons, edge, reference_seq)
split_all(cl, clstr, data, cons,bam, edge, R, I, SNP_pos,reference_seq,"unclustered")


build_data.cluster_consensuns(cl,
clstr, snp_pos, data, cons, edge, reference_seq)
split_all(cl, clstr, data, cons,bam, edge,
R, I, snp_pos,reference_seq,"unclustered")




def postprocess(bam, cl, SNP_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge_cov):
def postprocess(bam, cl, snp_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge_cov):
reference_seq = build_data.read_fasta_seq(StRainyArgs().fa, edge)
cons = build_data.build_data_cons(cl, SNP_pos, data, edge, reference_seq)
cons = build_data.build_data_cons(cl, snp_pos, data, edge, reference_seq)
if StRainyArgs().debug:
cl.to_csv(f"{StRainyArgs().output_intermediate}/clusters/{edge}_1.csv")
clusters = sorted(set(cl.loc[cl["Cluster"] != "NA","Cluster"].values))



cl.loc[cl["Cluster"] == "NA", "Cluster"] = UNCLUSTERED_GROUP_N
build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, SNP_pos, data, cons, edge, reference_seq)
clSNP = cons[UNCLUSTERED_GROUP_N]["clSNP2"]
splitna = split_cluster(cl, UNCLUSTERED_GROUP_N, data, cons, clSNP, bam, edge, R, I,False)
build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, snp_pos, data, cons, edge, reference_seq)
clust_snp = cons[UNCLUSTERED_GROUP_N]["clust_snp2"]
splitna = split_cluster(cl, UNCLUSTERED_GROUP_N, data, cons, clust_snp, bam, edge, R, I,False)

#Remove unclustered reads after splitting NA cluster
splitna[0]
Expand All @@ -309,20 +308,20 @@ def postprocess(bam, cl, SNP_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge
clusters = sorted(set(cl.loc[cl["Cluster"] != splitna[0], "Cluster"].values))
clusters = sorted(set(cl.loc[cl["Cluster"] != UNCLUSTERED_GROUP_N, "Cluster"].values))

build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, SNP_pos, data, cons, edge, reference_seq)
build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, snp_pos, data, cons, edge, reference_seq)
counts = cl["Cluster"].value_counts(dropna = False)

for cluster in clusters:
if cluster not in cons:
build_data.cluster_consensuns(cl, cluster, SNP_pos, data, cons, edge, reference_seq)
build_data.cluster_consensuns(cl, cluster, snp_pos, data, cons, edge, reference_seq)

cl = join_clusters(cons, cl, Rcl, edge, flye_consensus)
cons = build_data.build_data_cons(cl, SNP_pos, data, edge, reference_seq)
cons = build_data.build_data_cons(cl, snp_pos, data, edge, reference_seq)

clusters = sorted(set(cl.loc[cl["Cluster"] != "NA","Cluster"].values))
prev_clusters = clusters
for cluster in clusters:
split_all(cl, cluster, data, cons,bam, edge, R, I, SNP_pos,reference_seq,"unclustered")
split_all(cl, cluster, data, cons,bam, edge, R, I, snp_pos,reference_seq,"unclustered")
clusters = sorted(set(cl.loc[cl["Cluster"] != "NA", "Cluster"].values))
new_clusters = list(set(clusters) - set(prev_clusters))
prev_clusters = clusters
Expand All @@ -331,20 +330,21 @@ def postprocess(bam, cl, SNP_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge

for cluster in clusters:
if cluster not in cons:
build_data.cluster_consensuns(cl, cluster, SNP_pos, data, cons, edge, reference_seq)
build_data.cluster_consensuns(cl, cluster, snp_pos, data, cons, edge, reference_seq)
clusters = sorted(set(cl.loc[cl["Cluster"] != "NA", "Cluster"].values))

logging.info("Split stage2: Break regions of low heterozygosity")
for cluster in clusters:
split_all(cl, cluster, data, cons,bam, edge, R, I, SNP_pos,reference_seq,"lowheterozygosity")
split_all(cl, cluster, data, cons,bam, edge,
R, I, snp_pos,reference_seq,"lowheterozygosity")




cl.loc[cl["Cluster"] == "NA", "Cluster"] = UNCLUSTERED_GROUP_N
build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, SNP_pos, data, cons, edge, reference_seq)
clSNP = cons[UNCLUSTERED_GROUP_N]["clSNP2"]
splitna = split_cluster(cl, UNCLUSTERED_GROUP_N, data, cons, clSNP, bam, edge, R, I,False)
build_data.cluster_consensuns(cl, UNCLUSTERED_GROUP_N, snp_pos, data, cons, edge, reference_seq)
clust_snp = cons[UNCLUSTERED_GROUP_N]["clust_snp2"]
splitna = split_cluster(cl, UNCLUSTERED_GROUP_N, data, cons, clust_snp, bam, edge, R, I,False)
#Remove unclustered reads after splitting NA cluster
splitna[0]
cl = cl[cl["Cluster"] != splitna[0]]
Expand All @@ -353,30 +353,33 @@ def postprocess(bam, cl, SNP_pos, data, edge, R,Rcl, I, flye_consensus,mean_edge
clusters = sorted(set(cl.loc[cl["Cluster"] != splitna[0], "Cluster"].values))
clusters = sorted(set(cl.loc[cl["Cluster"] != UNCLUSTERED_GROUP_N, "Cluster"].values))

cl=update_cluster_set(cl, cluster, SNP_pos, data, cons, edge, reference_seq,mean_edge_cov)
cl=update_cluster_set(cl, cluster, snp_pos, data, cons, edge, reference_seq,mean_edge_cov)

cl = join_clusters(cons, cl, Rcl, edge, flye_consensus)
cl=update_cluster_set(cl, cluster, SNP_pos, data, cons, edge, reference_seq,mean_edge_cov)
cl=update_cluster_set(cl, cluster, snp_pos, data, cons, edge, reference_seq,mean_edge_cov)
cl = join_clusters(cons, cl, Rcl, edge, flye_consensus, transitive=True)
cl=update_cluster_set(cl, cluster, SNP_pos, data, cons, edge, reference_seq,mean_edge_cov)
cl=update_cluster_set(cl, cluster, snp_pos, data, cons, edge, reference_seq,mean_edge_cov)
cl = join_clusters(cons, cl, Rcl, edge, flye_consensus)
cl=update_cluster_set(cl, cluster, SNP_pos, data, cons, edge,
reference_seq,mean_edge_cov,fraction=0.05)
cl = join_clusters(cons, cl, Rcl, edge, flye_consensus, only_with_common_snip=False,only_nested=True)
cl=update_cluster_set(cl, cluster, snp_pos, data, cons,
edge,reference_seq,mean_edge_cov,0.05)
cl = join_clusters(cons, cl, Rcl, edge, flye_consensus,
only_with_common_snip=False,only_nested=True)
counts = cl["Cluster"].value_counts(dropna = False)
cl = cl[~cl["Cluster"].isin(counts[counts < 6].index)] #TODO change for cov*01.
cl=update_cluster_set(cl, cluster, SNP_pos, data, cons, edge,
reference_seq,mean_edge_cov,fraction=0.05)
cl=update_cluster_set(cl, cluster, snp_pos, data, cons, edge,
reference_seq,mean_edge_cov,0.05)
return cl



def update_cluster_set(cl, cluster, SNP_pos, data, cons, edge, reference_seq,mean_edge_cov,fraction=0.01):

def update_cluster_set(cl, snp_pos, data, cons, edge,
reference_seq,mean_edge_cov,fraction=0.01):
"""Update consensus and remove small clusters (less 1% of unitig coverage)"""
clusters = sorted(set(cl.loc[cl["Cluster"] != "NA", "Cluster"].values))
for clstr in clusters:
if clstr not in cons:
build_data.cluster_consensuns(cl, clstr, SNP_pos, data, cons, edge, reference_seq)
build_data.cluster_consensuns(cl, clstr, snp_pos, data, cons, edge, reference_seq)

for clstr in clusters:
if cons[clstr]['Cov']<mean_edge_cov*fraction:
Expand Down

0 comments on commit 508730d

Please sign in to comment.