88from pathlib import Path
99import argparse
1010import subprocess
11+ import shutil
1112from ete3 import Tree
1213import pandas as pd
1314from collections import defaultdict
1415from matplotlib import pyplot as plt
1516from matplotlib .colors import LogNorm
1617import seaborn as sns
1718import tempfile
19+ import logging
1820
1921
2022#####################################################################
@@ -98,8 +100,9 @@ def read_tree(input_path):
98100 tree_string = f .read ()
99101 formatted = re .sub (r";[^:]+:" , ":" , tree_string )
100102 is_duplicated = check_formatted_tree (formatted )
103+ is_small = formatted .count ("," ) < 3
101104
102- return Tree (formatted ), is_duplicated
105+ return Tree (formatted ), is_duplicated , is_small
103106
104107
105108#####################################################################
@@ -111,33 +114,38 @@ def read_tree(input_path):
111114#####################################################################
112115
113116
114- def root_tree (input_path , basename , output_path ):
115- tre ,is_duplicated = read_tree (input_path )
117+ def root_one_tree (input_path , basename , output_path ):
118+ tre ,is_duplicated , is_small = read_tree (input_path )
116119 midpoint = tre .get_midpoint_outgroup ()
117120 tre .set_outgroup (midpoint )
118121 if is_duplicated :
119122 outdir = Path (output_path ) / "multiple"
120123 Path (outdir ).mkdir (exist_ok = True , parents = True )
121124 output_path = outdir / basename
122125 output_path = str (output_path ).replace (".tre" , ".tre.multiple" )
126+ elif is_small :
127+ outdir = Path (output_path ) / "small"
128+ Path (outdir ).mkdir (exist_ok = True , parents = True )
129+ output_path = outdir / basename
130+ output_path = str (output_path ).replace (".tre" , ".tre.small" )
123131 else :
124132 outdir = Path (output_path ) / "unique"
125133 Path (outdir ).mkdir (exist_ok = True , parents = True )
126134 output_path = outdir / basename
127135
128136 tre .write (outfile = output_path )
129- return tre .write (), len (tre .get_leaves ()), output_path , is_duplicated
137+ return tre .write (), len (tre .get_leaves ()), output_path , is_duplicated , is_small
130138
131139def root_reference_tree (input_path , output_path ):
132- tre , _ = read_tree (input_path )
140+ tre , _ , _ = read_tree (input_path )
133141 midpoint = tre .get_midpoint_outgroup ()
134142 tre .set_outgroup (midpoint )
135143 tre .write (outfile = output_path )
136144 return tre .write (), len (tre .get_leaves ())
137145
138146
139147#####################################################################
140- ### FUNCTION ROOT_TREE
148+ ### FUNCTION ROOT_ALL_TREES
141149### Root all the unrooted input trees in directory
142150### core_tree: path of the core tree
143151### gene_trees: path of the csv file containing all the gene tree paths
@@ -148,8 +156,7 @@ def root_reference_tree(input_path, output_path):
148156#####################################################################
149157
150158
151- def root_trees (core_tree , gene_trees_path , output_dir , results , merge_pair = False ):
152- print ("Rooting trees" )
159+ def root_all_trees (core_tree , gene_trees_path , output_dir , results , merge_pair = False ):
153160 #'''
154161 reference_tree = core_tree
155162
@@ -165,11 +172,11 @@ def root_trees(core_tree, gene_trees_path, output_dir, results, merge_pair=False
165172 rooted_gene_trees_path = os .path .join (output_dir , "rooted_gene_trees" )
166173 for filename in df_gene_trees ["path" ]:
167174 basename = Path (filename ).name
168- gene_content , gene_tree_size , gene_tree_path , is_duplicated = root_tree (
175+ gene_content , gene_tree_size , gene_tree_path , is_duplicated , is_small = root_one_tree (
169176 filename ,
170177 basename ,
171178 rooted_gene_trees_path )
172- if not is_duplicated :
179+ if not ( is_duplicated or is_small ) :
173180 results .loc [basename , "tree_size" ] = gene_tree_size
174181 if merge_pair :
175182 with open (gene_tree_path , "w" ) as f2 :
@@ -205,6 +212,9 @@ def extract_approx_distance(text):
205212
206213def run_approx_rspr (results , input_file , lst_filename , rspr_path ):
207214 input_file .seek (0 )
215+
216+ command_exists = shutil .which (rspr_path [0 ])
217+
208218 result = subprocess .run (
209219 rspr_path , stdin = input_file , capture_output = True , text = True
210220 )
@@ -231,7 +241,6 @@ def run_approx_rspr(results, input_file, lst_filename, rspr_path):
231241def approx_rspr (
232242 rooted_gene_trees_path , results , min_branch_len = 0 , max_support_threshold = 0.7
233243):
234- print ("Calculating approx distance" )
235244 rspr_path = [
236245 "rspr" ,
237246 "-approx" ,
@@ -245,20 +254,73 @@ def approx_rspr(
245254 lst_filename = []
246255 with tempfile .TemporaryFile (mode = 'w+' ) as temp_file :
247256 for filename in os .listdir (rooted_gene_trees_path ):
248- if cur_count == group_size :
249- run_approx_rspr (results , temp_file , lst_filename , rspr_path )
250- temp_file .seek (0 )
251- temp_file .truncate ()
252- lst_filename .clear ()
253- cur_count = 0
254-
255- gene_tree_path = os .path .join (rooted_gene_trees_path , filename )
256- with open (gene_tree_path , "r" ) as infile :
257- temp_file .write (infile .read () + "\n " )
258- lst_filename .append (filename )
259- cur_count += 1
260- if cur_count > 0 :
261- run_approx_rspr (results , temp_file , lst_filename , rspr_path )
257+ if str (filename ) in results .index :
258+ print ("Found " + str (filename ))
259+ if cur_count == group_size :
260+ run_approx_rspr (results , temp_file , lst_filename , rspr_path )
261+ temp_file .seek (0 )
262+ temp_file .truncate ()
263+ lst_filename .clear ()
264+ cur_count = 0
265+
266+ gene_tree_path = os .path .join (rooted_gene_trees_path , filename )
267+ with open (gene_tree_path , "r" ) as infile :
268+ lines = infile .readlines ()
269+ if len (lines ) < 2 :
270+ print (f"File { filename } does not have enough lines." )
271+ continue
272+ tree = Tree (lines [1 ].strip ())
273+ # Calculate N: number of nodes at or above the support threshold
274+ # num_resolved = sum(1 for node in tree.traverse() if node.support >= max_support_threshold and not node.is_leaf())
275+ num_resolved = - 1
276+ for node in tree .traverse ():
277+ if node .support is not None and node .support >= max_support_threshold and not node .is_leaf ():
278+ num_resolved += 1
279+
280+ tree_size = len (tree .get_leaves ())
281+ results .loc [filename , "Num resolved" ] = num_resolved
282+ results .loc [filename , "N/tree_size" ] = num_resolved / tree_size if tree_size > 0 else 0
283+ lst_filename .append (filename )
284+ temp_file .write (lines [0 ].strip () + "\n " + lines [1 ].strip () + "\n " )
285+ cur_count += 1
286+ if cur_count > 0 :
287+ run_approx_rspr (results , temp_file , lst_filename , rspr_path )
288+
289+ # Add the approx_drSPR/N column
290+ results ["approx_drSPR/N" ] = results .apply (lambda row : float (row ["approx_drSPR" ]) / row ["Num resolved" ] if row ["Num resolved" ] > 0 else 0 , axis = 1 )
291+ print ("CBA " + str (results ))
292+
293+ #def approx_rspr_old(
294+ # rooted_gene_trees_path, results, min_branch_len=0, max_support_threshold=0.7
295+ #):
296+ # print("Calculating approx distance")
297+ # rspr_path = [
298+ # "rspr",
299+ # "-approx",
300+ # "-multifurcating",
301+ # "-length " + str(min_branch_len),
302+ # "-support " + str(max_support_threshold),
303+ # ]
304+ #
305+ # group_size = 10000
306+ # cur_count = 0
307+ # lst_filename = []
308+ # with tempfile.TemporaryFile(mode='w+') as temp_file:
309+ # for filename in os.listdir(rooted_gene_trees_path):
310+ # if cur_count == group_size:
311+ # run_approx_rspr(results, temp_file, lst_filename, rspr_path)
312+ # temp_file.seek(0)
313+ # temp_file.truncate()
314+ # lst_filename.clear()
315+ # cur_count = 0
316+ #
317+ # gene_tree_path = os.path.join(rooted_gene_trees_path, filename)
318+ # with open(gene_tree_path, "r") as infile:
319+ # temp_file.write(infile.read() + "\n")
320+ # lst_filename.append(filename)
321+ # cur_count += 1
322+ # if cur_count > 0:
323+ # run_approx_rspr(results, temp_file, lst_filename, rspr_path)
262324
263325
264326#####################################################################
@@ -289,7 +351,6 @@ def generate_heatmap(freq_table, output_path, log_scale=False):
289351#####################################################################
290352
291353def make_heatmap (results , output_path , min_distance , max_distance ):
292- print ("Generating heatmap" )
293354
294355 # create sub dataframe
295356 sub_results = results [(results ["approx_drSPR" ] >= min_distance )]
@@ -306,7 +367,6 @@ def make_heatmap(results, output_path, min_distance, max_distance):
306367
307368
308369def make_heatmap_from_tsv (input_path , output_path , min_distance , max_distance ):
309- print ("Generating heatmap from CSV" )
310370 results = pd .read_table (input_path )
311371 make_heatmap (results , output_path , min_distance , max_distance )
312372
@@ -339,7 +399,6 @@ def get_heatmap_group_size(all_values, max_groups=15):
339399#####################################################################
340400
341401def make_group_heatmap (results , output_path , min_distance , max_distance ):
342- print ("Generating group heatmap" )
343402
344403 # create sub dataframe
345404 sub_results = results [(results ["approx_drSPR" ] >= min_distance )]
@@ -383,7 +442,7 @@ def make_group_heatmap(results, output_path, min_distance, max_distance):
383442### RETURN groups of trees
384443#####################################################################
385444
386- def generate_group_sizes (target_sum , max_groups = 500 ):
445+ def generate_group_sizes (target_sum , max_groups = 1000 ):
387446 degree = 1
388447 current_sum = 0
389448 group_sizes = []
@@ -410,7 +469,6 @@ def generate_group_sizes(target_sum, max_groups=500):
410469#####################################################################
411470
412471def make_groups_v1 (results , min_limit = 10 ):
413- print ("Generating groups" )
414472 min_group = results [results ["approx_drSPR" ] <= min_limit ]["file_name" ].tolist ()
415473 groups = defaultdict ()
416474 first_group = "group_0"
@@ -438,7 +496,6 @@ def make_groups_v1(results, min_limit=10):
438496#####################################################################
439497
440498def make_groups (results , min_limit = 10 ):
441- print ("Generating groups" )
442499 min_group = results [results ["approx_drSPR" ] <= min_limit ]["file_name" ].tolist ()
443500 groups = defaultdict ()
444501 first_group = "group_0"
@@ -463,7 +520,6 @@ def make_groups(results, min_limit=10):
463520
464521
465522def make_groups_from_csv (input_df , min_limit ):
466- print ("Generating groups from CSV" )
467523 groups = make_groups_v1 (input_df , min_limit )
468524 tidy_data = [
469525 (key , val )
@@ -476,6 +532,24 @@ def make_groups_from_csv(input_df, min_limit):
476532 return merged
477533
478534
535+ # def join_annotation_data(df, annotation_data):
536+ # ann_df = pd.read_table(annotation_data, dtype={"genome_id": "str"})
537+ # ann_df.columns = map(str.lower, ann_df.columns)
538+ # ann_df.columns = ann_df.columns.str.replace(" ", "_")
539+ # ann_subset = ann_df[["gene", "product"]]
540+ #
541+ # df["tree_name"] = [f.split(".")[0] for f in df["file_name"]]
542+ #
543+ # merged = df.merge(ann_subset, how="left", left_on="tree_name", right_on="gene")
544+ #
545+ # if merged["gene"].isnull().all():
546+ # ann_subset = ann_df[["locus_tag", "gene", "product"]]
547+ # merged = df.merge(
548+ # ann_subset, how="left", left_on="tree_name", right_on="locus_tag"
549+ # )
550+ #
551+ # return merged.fillna(value="NULL").drop("tree_name", axis=1).drop_duplicates()
552+
479553def join_annotation_data (df , annotation_data ):
480554 ann_df = pd .read_table (annotation_data , dtype = {"genome_id" : "str" })
481555 ann_df .columns = map (str .lower , ann_df .columns )
@@ -492,8 +566,23 @@ def join_annotation_data(df, annotation_data):
492566 ann_subset , how = "left" , left_on = "tree_name" , right_on = "locus_tag"
493567 )
494568
495- return merged .fillna (value = "NULL" ).drop ("tree_name" , axis = 1 ). drop_duplicates ( )
569+ merged = merged .fillna ("NULL" ).drop ("tree_name" , axis = 1 )
496570
571+ # Group by all columns except 'product' and aggregate 'product'
572+ grouped = (
573+ merged .groupby (list (merged .columns .difference (['product' ])))
574+ .agg ({'product' : lambda x : '||' .join (sorted (set (x )))})
575+ .reset_index ()
576+ )
577+
578+ # Reorder columns
579+ desired_order = [
580+ "file_name" , "gene" , "tree_size" , "product" , "N/tree_size" ,
581+ "Num resolved" , "approx_drSPR" , "approx_drSPR/N"
582+ ]
583+ grouped = grouped [desired_order ]
584+
585+ return grouped .drop_duplicates ()
497586
498587def main (args = None ):
499588 args = parse_args (args )
@@ -502,7 +591,7 @@ def main(args=None):
502591 #'''
503592 results = pd .DataFrame (columns = ["file_name" , "tree_size" , "approx_drSPR" ])
504593 results .set_index ("file_name" , inplace = True )
505- rooted_paths = root_trees (
594+ rooted_paths = root_all_trees (
506595 args .CORE_TREE , args .GENE_TREES , args .OUTPUT_DIR , results , True
507596 )
508597 approx_rspr (
@@ -512,7 +601,10 @@ def main(args=None):
512601 args .MAX_SUPPORT_THRESHOLD ,
513602 )
514603
604+ #exit(11)
605+
515606 # Generate standard heatmap
607+ # results["approx_drSPR"] = pd.to_numeric(results["approx_drSPR"]).fillna(1000000)
516608 results ["approx_drSPR" ] = pd .to_numeric (results ["approx_drSPR" ])
517609 fig_path = os .path .join (args .OUTPUT_DIR , "output.png" )
518610 make_heatmap (
0 commit comments