diff --git a/CMakeLists.txt b/CMakeLists.txt index c512ac03..0a069828 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,6 +45,7 @@ option(USE_HTSLIB "Use htslib" OFF) option(USE_NETAM "Use netam/thrifty libtorch integration" OFF) option(DISABLE_PARALLELISM "Disable all parallelism for debugging" OFF) option(USE_SYSTEM_TBB "Use system-installed TBB instead of fetching from source" OFF) +option(USE_TEST_LOGS "Write debug artifacts (before_optimize.pb, after_optimize.pb)" OFF) set(TBB_VERSION "v2022.1.0") @@ -131,7 +132,7 @@ if(MPI_C_COMPILER) # Get the first path (space-separated list) string(REPLACE " " ";" MPI_INCDIRS_LIST "${MPI_INCDIRS}") list(GET MPI_INCDIRS_LIST 0 MPI_PRIMARY_INCDIR) - + set(MPI_C_HEADER_DIR "${MPI_PRIMARY_INCDIR}" CACHE PATH "") set(MPI_CXX_HEADER_DIR "${MPI_PRIMARY_INCDIR}" CACHE PATH "") endif() @@ -339,6 +340,10 @@ function(larch_compile_opts PRODUCT) target_compile_options(${PRODUCT} PUBLIC -DDISABLE_PARALLELISM) endif() + if(USE_TEST_LOGS) + target_compile_options(${PRODUCT} PUBLIC -DUSE_TEST_LOGS) + endif() + if(USE_ASAN) target_compile_options(${PRODUCT} PUBLIC -O0 -g3 -fsanitize=address -fno-sanitize-recover) elseif(USE_TSAN) @@ -350,14 +355,14 @@ function(larch_compile_opts PRODUCT) # (see torch target modification after find_package(Torch)) target_link_libraries(${PRODUCT} PUBLIC protobuf::libprotobuf) if(absl_FOUND) - target_link_libraries(${PRODUCT} PUBLIC + target_link_libraries(${PRODUCT} PUBLIC absl::log absl::log_internal_check_impl ) endif() target_include_directories(${PRODUCT} PUBLIC ${PROTO_OUT_DIR}) target_include_directories(${PRODUCT} PUBLIC ${PROTO_OUT_DIR}/deps/usher) - + target_include_directories(${PRODUCT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/deps/range-v3/install/include) target_compile_options(${PRODUCT} PUBLIC -DRANGES_DISABLE_DEPRECATED_WARNINGS) add_dependencies(${PRODUCT} range-v3) @@ -596,11 +601,11 @@ larch_executable(larch-dag2dot ) larch_install(larch-dag2dot) -# # bcr-larch -larch_executable(bcr-larch +# # larch-bcr +larch_executable(larch-bcr tools/bcr-larch.cpp ) -larch_install(bcr-larch) +larch_install(larch-bcr) # # larch-usher if(USE_USHER) @@ -610,6 +615,10 @@ if(USE_USHER) target_compile_options(larch-usher PRIVATE ${STRICT_WARNINGS}) target_link_libraries(larch-usher PRIVATE larch-usher-glue) add_dependencies(larch-usher larch-usher-glue) + if(USE_NETAM) + target_compile_definitions(larch-usher PRIVATE USE_NETAM) + target_link_libraries(larch-usher PRIVATE netam) + endif() larch_install(larch-usher) endif() diff --git a/include/larch/impl/produce_mat_impl.hpp b/include/larch/impl/produce_mat_impl.hpp index ca2d086c..f4f6c9ce 100644 --- a/include/larch/impl/produce_mat_impl.hpp +++ b/include/larch/impl/produce_mat_impl.hpp @@ -47,8 +47,10 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback, static_assert(DAG::template contains_element_feature); auto& tree = dag.GetMutableMAT(); -#ifdef KEEP_ASSERTS +#ifdef USE_TEST_LOGS Mutation_Annotated_Tree::save_mutation_annotated_tree(tree, "before_optimize.pb"); +#endif +#ifdef KEEP_ASSERTS check_MAT_MADAG_Eq(tree, dag); #endif @@ -92,6 +94,12 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback, "intermediate_newick", // intermediate newick name callback // callback ); +#ifndef USE_TEST_LOGS + // matOptimize unconditionally writes intermediate_newick*.pb.gz when + // allow_drift is true. Clean up since we can't suppress it without + // changing search behavior. + std::remove("intermediate_newick1.pb.gz"); +#endif tree.uncondense_leaves(); tree.condense_leaves(condense_arg); tree.fix_node_idx(); @@ -100,7 +108,9 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback, tree.uncondense_leaves(); tree.fix_node_idx(); +#ifdef USE_TEST_LOGS Mutation_Annotated_Tree::save_mutation_annotated_tree(tree, "after_optimize.pb"); +#endif auto result = std::make_pair(AddMATConversion(MADAGStorage<>::EmptyDefault()), std::move(tree)); result.first.View().BuildFromMAT(result.second, dag.GetReferenceSequence()); diff --git a/pixi.toml b/pixi.toml index 1514335a..d14b6532 100644 --- a/pixi.toml +++ b/pixi.toml @@ -33,6 +33,10 @@ description = "Build larch using settings from larch-build.env." cmd = "bash scripts/run-tests.sh -tag slow" description = "Run tests (excluding slow tests). Build first with 'pixi run build'." +[tasks.test-netam] +cmd = "bash scripts/run-tests.sh +tag netam" +description = "Run all netam tests (ML SPR, likelihood, model, S5F). Requires USE_NETAM=yes build." + [tasks.test-all] cmd = "bash scripts/run-tests.sh" description = "Run all tests including slow tests. WARNING: Very slow." diff --git a/scripts/larch-build.env.template b/scripts/larch-build.env.template index f2e73838..613077d4 100644 --- a/scripts/larch-build.env.template +++ b/scripts/larch-build.env.template @@ -32,6 +32,7 @@ LARCH_PROTOBUF_PATH=auto # -DDISABLE_PARALLELISM=yes Disable all parallelism (debugging) # -DKEEP_ASSERTS=ON Keep asserts in release builds # -DUSE_CPPTRACE=ON Readable backtraces on exceptions +# -DUSE_TEST_LOGS=ON Write debug artifacts (before/after_optimize.pb) # -DUSE_SYSTEM_TBB=yes Use system TBB instead of fetching from source # (auto-enabled on macOS) # Example: LARCH_CMAKE_EXTRA=-DUSE_ASAN=yes -DKEEP_ASSERTS=ON diff --git a/src/netam/kmer_sequence_encoder.cpp b/src/netam/kmer_sequence_encoder.cpp index de944a6e..efd2e3f3 100644 --- a/src/netam/kmer_sequence_encoder.cpp +++ b/src/netam/kmer_sequence_encoder.cpp @@ -30,8 +30,8 @@ std::pair kmer_sequence_encoder::encode_sequence( std::string padded_sequence = std::string(overhang_length_, 'N') + upper_seq + std::string(overhang_length_, 'N'); - // Index for N-containing kmers (last index, 4^k) - auto n_index = signed_cast(narrowing_cast(all_kmers_.size() - 1)); + // Index for N-containing kmers (index 0, matching Python netam convention) + constexpr std::int32_t n_index = 0; // Encode kmers std::vector kmer_indices; @@ -75,8 +75,8 @@ torch::Tensor kmer_sequence_encoder::encode_bases(const std::string& sequence) { std::vector kmer_sequence_encoder::generate_kmers(std::size_t length) { std::vector kmers; - // Generate all possible kmers of given length (4^k kmers) - // Uses lexicographic order: AAAAA=0, AAAAC=1, AAAAG=2, AAAAT=3, AAACA=4, ... + // Generate all possible kmers of given length (4^k kmers) in lexicographic order. + // Final indices: N=0, AAA=1, AAC=2, ..., TTT=4^k (N-kmer prepended below). std::function generate = [&](std::string current, std::size_t pos) { if (pos == length) { @@ -90,9 +90,12 @@ std::vector kmer_sequence_encoder::generate_kmers(std::size_t lengt generate("", 0); - // Add placeholder for N-containing kmers at the end (index 4^k) - kmers.push_back("N"); - return kmers; + // N-kmer placeholder at index 0 (matches Python netam convention) + std::vector result; + result.reserve(1 + kmers.size()); + result.push_back("N"); + result.insert(result.end(), kmers.begin(), kmers.end()); + return result; } torch::Tensor kmer_sequence_encoder::compute_wt_base_modifier( diff --git a/test/main.cpp b/test/main.cpp index 7a38e677..389b251e 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -35,6 +35,7 @@ static void get_usage() { {"--list", "Prints information about all selected tests (IDs, tags), but does not run " "them"}, + {"--list-tags", "List all available test tags and exit"}, {"nocatch", "Allow exceptions to escape for debugging"}}; std::cout << FormatUsage(program_desc, usage_examples, flag_desc_pairs); @@ -91,6 +92,7 @@ int main(int argc, char* argv[]) { #endif bool no_catch = false; bool opt_list_names = false; + bool opt_list_tags = false; bool opt_test_range = false; assert(std::filesystem::exists("./data/") && "Test data folder not found."); @@ -108,6 +110,8 @@ int main(int argc, char* argv[]) { no_catch = true; } else if (std::string("--list") == argv[i]) { opt_list_names = true; + } else if (std::string("--list-tags") == argv[i]) { + opt_list_tags = true; } else if (std::string("--range") == argv[i]) { opt_test_range = true; range = parse_range(argv[++i]); @@ -160,6 +164,20 @@ int main(int argc, char* argv[]) { std::cout << std::endl; } + if (opt_list_tags) { + std::set all_tags; + for (const auto& test : get_all_tests()) { + for (const auto& tag : test.tags) { + all_tags.insert(tag); + } + } + std::cout << "AVAILABLE TAGS:" << std::endl; + for (const auto& tag : all_tags) { + std::cout << " " << tag << std::endl; + } + return EXIT_SUCCESS; + } + if (opt_list_names) { return EXIT_SUCCESS; } diff --git a/test/test_common.hpp b/test/test_common.hpp index 575239ae..0a21bd48 100644 --- a/test/test_common.hpp +++ b/test/test_common.hpp @@ -67,6 +67,7 @@ inline void print_peak_mem() { } inline const std::string test_output_folder = "data/_ignore/"; +inline const std::string test_log_folder = test_output_folder + "/optimization_log"; inline std::pair run_larch_usher( std::string_view input_dag_path, std::string_view output_dag_path, @@ -75,7 +76,7 @@ inline std::pair run_larch_usher( std::optional other_options = std::nullopt, bool do_print_stdout = true, bool do_print_stderr = false, bool do_print_summary = true) { - std::string log_path = test_output_folder + "/optimization_log"; + std::string log_path = test_log_folder; std::stringstream ss; // ss << "/usr/bin/time "; diff --git a/test/test_fileio_dagbin.cpp b/test/test_fileio_dagbin.cpp index 610e9fde..7acb653c 100644 --- a/test/test_fileio_dagbin.cpp +++ b/test/test_fileio_dagbin.cpp @@ -104,23 +104,20 @@ std::string output_dag_path_no_ext = test_output_folder + "/temp"; std::string output_dag_path_protobuf = test_output_folder + "/temp.pb"; std::string output_dag_path_dagbin = test_output_folder + "/temp.dagbin"; - std::string inter_dag_path_protobuf = output_dag_path_protobuf + ".intermediate"; - std::string inter_dag_path_dagbin = output_dag_path_dagbin + ".intermediate"; + // Intermediate DAGs are written inside the log directory by larch-usher + std::string inter_dag_path_protobuf = test_log_folder + "/intermediate.pb"; + std::string inter_dag_path_dagbin = test_log_folder + "/intermediate.dagbin"; std::string other_options = "--thread 1 "; if (use_seed) { other_options += "--seed 42 "; } if (save_both) { - inter_dag_path_protobuf = output_dag_path_no_ext + ".intermediate.pb"; - inter_dag_path_dagbin = output_dag_path_no_ext + ".intermediate.dagbin"; other_options += " --output-format debug-all"; auto [command, result] = run_larch_usher(input_dag_path, output_dag_path_no_ext, std::nullopt, iter, other_options); TestAssert((result == 0) && "larch-usher debug-all run failed."); } else { - inter_dag_path_protobuf = output_dag_path_protobuf + ".intermediate"; - inter_dag_path_dagbin = output_dag_path_dagbin + ".intermediate"; auto [command1, result1] = run_larch_usher(input_dag_path, output_dag_path_protobuf, std::nullopt, iter, other_options); TestAssert((result1 == 0) && "larch-usher protobuf run failed."); diff --git a/test/test_larch_usher.cpp b/test/test_larch_usher.cpp index cd2e003a..2128f262 100644 --- a/test/test_larch_usher.cpp +++ b/test/test_larch_usher.cpp @@ -45,7 +45,7 @@ static void test_larch_usher_merged_dag() { std::string output_dag_path = test_output_folder + "/opt_dag.pb"; std::string command = "./bin/larch-usher -i data/larch_merged_dag.pb -o " + output_dag_path + - " -c 1 -s 0 --max-subtree-clade-size 2000 --trim --quiet"; + " -c 1 -s 0 --max-subtree-clade-size 2000 --trim"; std::cout << ">COMMAND_EXECUTE: \"" << command << "\"" << std::endl; int result = std::system(command.c_str()); TestAssert(0 == result); diff --git a/test/test_netam_kmer_encoder.cpp b/test/test_netam_kmer_encoder.cpp index 34938cc0..af6362f3 100644 --- a/test/test_netam_kmer_encoder.cpp +++ b/test/test_netam_kmer_encoder.cpp @@ -226,20 +226,19 @@ void test_encode_sequence_kmer_indices_valid() { } void test_encode_sequence_padding_creates_zero_index() { - // K-mers containing N should map to the last index (kmer_count - 1) + // K-mers containing N should map to index 0 (matching Python netam convention) auto yaml = make_config(3, 5); kmer_sequence_encoder encoder{yaml}; auto [encoded, wt_modifier] = encoder.encode_sequence("ACG"); auto encoded_accessor = encoded.accessor(); - auto n_index = - static_cast(encoder.kmer_count() - 1); + constexpr int32_t n_index = 0; - // Position 0: k-mer is "NAC" (contains N from padding) -> last index + // Position 0: k-mer is "NAC" (contains N from padding) -> index 0 TestAssert(encoded_accessor[0] == n_index); - // Position 2: k-mer is "CGN" (contains N from padding) -> last index + // Position 2: k-mer is "CGN" (contains N from padding) -> index 0 TestAssert(encoded_accessor[2] == n_index); } diff --git a/tools/larch-usher.cpp b/tools/larch-usher.cpp index c13d2e1d..d0fc2898 100644 --- a/tools/larch-usher.cpp +++ b/tools/larch-usher.cpp @@ -25,6 +25,12 @@ #include "larch/usher_glue.hpp" +#ifdef USE_NETAM +#include +#include +#include +#endif + #include [[noreturn]] static void Usage() { @@ -43,9 +49,11 @@ {"-v,--VCF-input-file FILE", "Path to VCF file, containing ambiguous leaf sequence data"}, {"-c,--count INT", "Number of iterations (default: 1)"}, + {"-l,--logpath DIR", + "Enable optimization logging to the given directory. Writes logfile.csv\n" + "and intermediate DAG files to this directory."}, {"--inter-save INT", - "Saves a new intermediate DAG file once every given number of iterations \n" - "(default: no intermediate DAG files saved)"}, + "Save a numbered DAG snapshot every N iterations (requires -l)"}, {"-s,--switch-subtrees INT", "Switch to optimizing subtrees after the specified number of iterations \n" "(default: never)"}, @@ -85,7 +93,20 @@ {"--ignore-root_edge_mutations", "Ignore root edge mutations when computing parsimony\n"}, {"-S,--autodetect-stoptime", - "Set program to exit after parsimony improvement plateaus\n"}}; + "Set program to exit after parsimony improvement plateaus\n"}, + {"--scoring-backend ENUM", + "Scoring backend for evaluating SPR moves (default: parsimony)\n" + "[parsimony, ml] (ml requires building with -DUSE_NETAM=yes)"}, + {"--model-config FILE", + "Path to YAML model config (REQUIRED when --scoring-backend ml)\n" + "e.g., data/linearham/ThriftyHumV0.2-45.yml"}, + {"--model-weights FILE", + "Path to libtorch .pth weights (REQUIRED when --scoring-backend ml)\n" + "e.g., data/linearham/ThriftyHumV0.2-45-libtorch.pth"}, + {"--move-coeff-ml FLOAT", + "ML log-likelihood coefficient for scoring moves (default: 1.0 when\n" + "--scoring-backend ml, 0.0 otherwise). Higher values weight ML more\n" + "relative to parsimony."}}; std::cout << FormatUsage(program_desc, usage_examples, flag_desc_pairs); @@ -129,6 +150,315 @@ std::vector> clades_difference( return result; } +#ifdef USE_NETAM +// Expand a CompactGenome to a full sequence string using the reference sequence. +template +std::string ExpandCompactGenome(const RefSeq& ref_seq, const CG& compact_genome) { + std::string result{ref_seq.begin(), ref_seq.end()}; + for (const auto& [pos, base] : compact_genome) { + result[pos.value - 1] = base.ToChar(); // MutationPosition is 1-indexed + } + return result; +} + +// Compute edge log-likelihood from two compact genomes. +template +double ComputeEdgeLL(const RefSeq& ref_seq, const CG& parent_cg, const CG& child_cg, + netam::crepe& model) { + std::string parent_seq = ExpandCompactGenome(ref_seq, parent_cg); + std::string child_seq = ExpandCompactGenome(ref_seq, child_cg); + + auto [encoded_1d, wt_mod_2d] = model.encoder().encode_sequence(parent_seq); + auto encoded = encoded_1d.unsqueeze(0); + auto wt_mod = wt_mod_2d.unsqueeze(0); + auto mask = torch::ones({1, encoded.size(1)}, torch::kBool); + + auto [rates, csp_logits] = model->forward(encoded, mask, wt_mod); + auto csp = torch::softmax(csp_logits, -1); + + auto parent_indices = netam::kmer_sequence_encoder::encode_bases(parent_seq); + auto child_indices = netam::kmer_sequence_encoder::encode_bases(child_seq); + + auto ll = + netam::poisson_context_log_likelihood(rates, csp, parent_indices, child_indices); + return ll.item(); +} + +// Compute the ML log-likelihood CHANGE for an SPR move. +// +// Returns (new_LL - old_LL) for the edges affected by the move: +// - new_LL: sum of edge LLs for fragment edges using post-move compact genomes +// - old_LL: sum of edge LLs for the same child nodes' pre-move parent edges +// +// Only edges that changed are scored. The MoveNew node (created by the SPR) +// has no old edge, so it contributes only to new_LL. Its children (source and +// destination) have old parent edges that contribute to old_LL. +// +// Positive return = move improved likelihood. Negative = move worsened it. +template +double ComputeFragmentMLScore(const SPRView& spr, const FragmentView& fragment, + netam::crepe& model) { + torch::NoGradGuard no_grad; + const auto& ref_seq = spr.Const().GetReferenceSequence(); + double new_ll = 0.0; + double old_ll = 0.0; + + for (auto edge : fragment.GetEdges()) { + auto parent_id = edge.GetParent().GetId(); + auto child_id = edge.GetChild().GetId(); + auto parent_node = spr.Const().Get(parent_id); + auto child_node = spr.Const().Get(child_id); + + if (parent_node.IsUA()) { + continue; + } + + // New edge LL: post-move compact genomes from the SPR view + new_ll += ComputeEdgeLL(ref_seq, parent_node.GetCompactGenome(), + child_node.GetCompactGenome(), model); + + // Old edge LL: pre-move parent→child relationship + // MoveNew nodes didn't exist before the move, so they have no old edge. + if (not child_node.IsMoveNew()) { + auto old_child = child_node.GetOld(); + auto old_parent = old_child.GetSingleParent().GetParent(); + if (not old_parent.IsUA()) { + old_ll += ComputeEdgeLL(ref_seq, old_parent.GetCompactGenome(), + old_child.GetCompactGenome(), model); + } + } + } + + return new_ll - old_ll; +} +#endif // USE_NETAM + +struct MLScoringConfig { +#ifdef USE_NETAM + netam::crepe* model = nullptr; +#endif + double coeff = 0.0; + + template + double AdjustScore(double base_score, + [[maybe_unused]] const SPRView& spr, + [[maybe_unused]] const FragmentType& fragment) const { +#ifdef USE_NETAM + if (model != nullptr and coeff != 0.0) { + double fragment_ll = ComputeFragmentMLScore(spr, fragment, *model); + return base_score - coeff * fragment_ll; + } +#endif + return base_score; + } +}; + +class OptimizationLogger { + public: + OptimizationLogger(Merge& merge, bool use_ua_free_parsimony, FileFormat output_format) + : merge_{merge}, + use_ua_free_parsimony_{use_ua_free_parsimony}, + output_format_{output_format} {} + + void OpenLogfile(const std::string& log_dir) { + log_dir_ = log_dir; + std::filesystem::create_directory(log_dir); + logfile_.open(log_dir + "/logfile.csv"); + WriteHeader(); + } + + void SetIntermediateSaveInterval(uint interval) { + intermediate_save_interval_ = interval; + } + +#ifdef USE_NETAM + void SetMLModel(netam::crepe* model) { ml_model_ = model; } +#endif + + // Compute and log all DAG metrics for this iteration. + // Returns the min parsimony score. + size_t Log(size_t iteration) { + std::cout << "############ Logging for iteration " << iteration << " #######\n"; + merge_.ComputeResultEdgeMutations(); + + auto metrics = ComputeMetrics(); + PrintMetrics(metrics); + WriteRow(iteration, metrics); + SaveIntermediateDAG(iteration); + + return metrics.min_parsimony; + } + + private: + struct Metrics { + ArbitraryInt tree_count; + size_t min_parsimony; + ArbitraryInt min_parsimony_count; + size_t max_parsimony; + ArbitraryInt min_sum_rf; + ArbitraryInt max_sum_rf; + ArbitraryInt min_sum_rf_count; + ArbitraryInt max_sum_rf_count; + double best_tree_ll = 0.0; + bool has_ml_score = false; + }; + + Metrics ComputeMetrics() { + Metrics m{}; + const auto root = merge_.GetResult().GetRoot(); + + SubtreeWeight tree_counter{merge_.GetResult()}; + m.tree_count = tree_counter.ComputeWeightBelow(root, {}); + + if (use_ua_free_parsimony_) { + SubtreeWeight scorer{merge_.GetResult()}; + m.min_parsimony = scorer.ComputeWeightBelow(root, {}); + m.min_parsimony_count = scorer.MinWeightCount(root, {}); + } else { + SubtreeWeight scorer{merge_.GetResult()}; + m.min_parsimony = scorer.ComputeWeightBelow(root, {}); + m.min_parsimony_count = scorer.MinWeightCount(root, {}); + } + + // Max parsimony always uses BinaryParsimonyScore (no UA-free variant exists) + SubtreeWeight max_pars{merge_.GetResult()}; + m.max_parsimony = max_pars.ComputeWeightBelow(root, {}); + + SumRFDistance min_rf_ops{merge_, merge_}; + SubtreeWeight min_rf{merge_.GetResult()}; + m.min_sum_rf = + min_rf.ComputeWeightBelow(root, min_rf_ops) + min_rf_ops.GetOps().GetShiftSum(); + m.min_sum_rf_count = min_rf.MinWeightCount(root, min_rf_ops); + + MaxSumRFDistance max_rf_ops{merge_, merge_}; + SubtreeWeight max_rf{merge_.GetResult()}; + m.max_sum_rf = + max_rf.ComputeWeightBelow(root, max_rf_ops) + max_rf_ops.GetOps().GetShiftSum(); + m.max_sum_rf_count = max_rf.MinWeightCount(root, max_rf_ops); + +#ifdef USE_NETAM + if (ml_model_ != nullptr) { + m.best_tree_ll = ComputeBestParsimonyTreeLL(); + m.has_ml_score = true; + } +#endif + + timer_.stop(); + return m; + } + + void PrintMetrics(const Metrics& m) { + std::cout << "Min parsimony score in DAG: " << m.min_parsimony << "\n"; + std::cout << "Max parsimony score in DAG: " << m.max_parsimony << "\n"; + std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Total trees in DAG: " << m.tree_count + << "\n"; + std::cout << "Optimal trees in DAG: " << m.min_parsimony_count << "\n"; + std::cout << "Min summed RF distance over trees: " << m.min_sum_rf << "\n"; + std::cout << "Max summed RF distance over trees: " << m.max_sum_rf << "\n"; + if (m.has_ml_score) { + std::cout << "Best tree log-likelihood: " << m.best_tree_ll << "\n"; + } + } + + void WriteHeader() { + logfile_ << "Iteration\tNTrees\tNNodes\tNEdges\tMaxParsimony\tNTreesMaxParsimony\t" + "WorstParsimony\tMinSumRFDistance\tMaxSumRFDistance\tMinSumRFCount\t" + "MaxSumRFCount\tSecondsElapsed"; +#ifdef USE_NETAM + if (ml_model_ != nullptr) { + logfile_ << "\tBestTreeLL"; + } +#endif + logfile_ << '\n'; + } + + void WriteRow(size_t iteration, const Metrics& m) { + if (not logfile_.is_open()) { + return; + } + logfile_ << iteration << '\t' << m.tree_count << '\t' + << merge_.GetResult().GetNodesCount() << '\t' + << merge_.GetResult().GetEdgesCount() << '\t' << m.min_parsimony << '\t' + << m.min_parsimony_count << '\t' << m.max_parsimony << '\t' << m.min_sum_rf + << '\t' << m.max_sum_rf << '\t' << m.min_sum_rf_count << '\t' + << m.max_sum_rf_count << '\t' << timer_.durationS(); + if (m.has_ml_score) { + logfile_ << '\t' << m.best_tree_ll; + } + logfile_ << '\n' << std::flush; + } + + void SaveIntermediateDAG(size_t iteration) { + if (not logfile_.is_open()) { + return; + } + std::string path = IntermediatePath(); + bool append_changes = (iteration > 0); + StoreDAG(merge_.GetResult(), path, output_format_, append_changes); + if (intermediate_save_interval_.has_value() and + (iteration % intermediate_save_interval_.value() == 0)) { + std::string snapshot = + log_dir_ + "/snapshot." + std::to_string(iteration) + FormatExt(); + std::cout << "############ Saving snapshot DAG to: " << snapshot << std::endl; + StoreDAG(merge_.GetResult(), snapshot, output_format_); + } + } + + // DebugAll appends its own extensions (.pb, .dagbin), so no ext for that format. + std::string FormatExt() const { + switch (output_format_) { + case FileFormat::Dagbin: + return ".dagbin"; + case FileFormat::DebugAll: + return ""; + case FileFormat::ProtobufDAG: + return ".pb"; + default: + Fail("FormatExt: unsupported output format for intermediate DAG"); + } + } + + std::string IntermediatePath() const { + return log_dir_ + "/intermediate" + FormatExt(); + } + +#ifdef USE_NETAM + // Compute the LL of the best-parsimony tree sampled from the DAG. + double ComputeBestParsimonyTreeLL() { + torch::NoGradGuard no_grad; + auto dag = merge_.GetResult(); + const auto& ref_seq = dag.GetReferenceSequence(); + + SubtreeWeight scorer{dag}; + auto tree_storage = scorer.MinWeightSampleTree({}, std::nullopt); + auto tree = tree_storage.View(); + tree.GetRoot().Validate(true); + tree.RecomputeCompactGenomes(true); + + double total_ll = 0.0; + for (auto edge : tree.GetEdges()) { + if (edge.GetParent().IsUA()) { + continue; + } + total_ll += ComputeEdgeLL(ref_seq, edge.GetParent().GetCompactGenome(), + edge.GetChild().GetCompactGenome(), *ml_model_); + } + return total_ll; + } +#endif + + Merge& merge_; + bool use_ua_free_parsimony_; + FileFormat output_format_; + std::string log_dir_; + std::optional intermediate_save_interval_; + std::ofstream logfile_; + Benchmark timer_; +#ifdef USE_NETAM + netam::crepe* ml_model_ = nullptr; +#endif +}; + using Storage = MergeDAGStorage<>; using ReassignedStatesStorage = decltype(AddMappedNodes(AddMATConversion(Storage::EmptyDefault()))); @@ -233,14 +563,18 @@ struct Treebased_Move_Found_Callback } move.score_change = move_score_coeffs_.second * move.score_change - move_score_coeffs_.first * node_id_map_count; - return {false, move.score_change <= 0}; + double effective_score = ml_config_.AdjustScore(move.score_change, spr, fragment); + return {false, effective_score <= 0}; } void OnRadius() {}; std::pair move_score_coeffs_; + MLScoringConfig ml_config_; }; +// Accepts all moves unconditionally (--callback-option all-moves). +// No ML scoring — this callback is for maximum DAG exploration regardless of score. struct Merge_All_Moves_Found_Callback : public BatchingCallback { MOVE_ONLY_VIRT_DTOR(Merge_All_Moves_Found_Callback); @@ -373,7 +707,8 @@ struct Merge_All_Profitable_Moves_Found_Callback } move.score_change = move_score_coeffs_.second * move.score_change - move_score_coeffs_.first * node_id_map_count; - return {move.score_change <= 0, move.score_change <= 0}; + double effective_score = ml_config_.AdjustScore(move.score_change, spr, fragment); + return {effective_score <= 0, effective_score <= 0}; } void OnRadius() {}; @@ -384,6 +719,7 @@ struct Merge_All_Profitable_Moves_Found_Callback std::atomic sample_mat_ = nullptr; std::mutex merge_mtx_; std::pair move_score_coeffs_; + MLScoringConfig ml_config_; }; struct Merge_All_Profitable_Moves_Found_Fixed_Tree_Callback @@ -495,7 +831,8 @@ struct Merge_All_Profitable_Moves_Found_Fixed_Tree_Callback } move.score_change = move_score_coeffs_.second * move.score_change - move_score_coeffs_.first * node_id_map_count; - return {move.score_change <= 0, false}; + double effective_score = ml_config_.AdjustScore(move.score_change, spr, fragment); + return {effective_score <= 0, false}; } void OnRadius() {}; @@ -506,6 +843,7 @@ struct Merge_All_Profitable_Moves_Found_Fixed_Tree_Callback std::atomic sample_mat_ = nullptr; std::mutex merge_mtx_; std::pair move_score_coeffs_; + MLScoringConfig ml_config_; }; enum class SampleMethod { @@ -524,17 +862,17 @@ enum class CallbackMethod { AllMoves }; +enum class ScoringBackend { Parsimony, ML }; + int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) Arguments args = GetArguments(argc, argv); int ignored{}; std::string input_dag_path; std::string output_dag_path; - std::string intermediate_dag_path; - std::string logfile_path = "optimization_log"; + std::string logfile_path; std::string refseq_path; std::string vcf_path; CallbackMethod callback_config = CallbackMethod::BestMoves; - bool write_intermediate_dag = true; std::optional write_intermediate_every_x_iters = std::nullopt; FileFormat input_format = FileFormat::Infer; FileFormat output_format = FileFormat::Infer; @@ -542,8 +880,8 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) bool sample_uniformly = false; size_t iter_count = 1; unsigned int thread_count = 0; - int move_coeff_nodes = 1; - int move_coeff_pscore = 1; + std::optional user_move_coeff_nodes = std::nullopt; + std::optional user_move_coeff_pscore = std::nullopt; size_t switch_subtrees = std::numeric_limits::max(); size_t min_subtree_clade_size = 100; // NOLINT size_t max_subtree_clade_size = 1000; // NOLINT @@ -557,6 +895,10 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) size_t current_best_parsimony = NoId; size_t time_limit = NoId; std::optional user_seed = std::nullopt; + ScoringBackend scoring_backend = ScoringBackend::Parsimony; + std::string model_config_path; + std::string model_weights_path; + std::optional user_move_coeff_ml = std::nullopt; Benchmark total_timer; @@ -576,13 +918,17 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) } else if (name == "-l" or name == "--logpath") { ParseOption(name, params, logfile_path, 1); } else if (name == "--move-coeff-pscore") { - ParseOption(name, params, move_coeff_pscore, 1); + int temp; + ParseOption(name, params, temp, 1); + user_move_coeff_pscore = temp; } else if (name == "--min-subtree-clade-size") { ParseOption(name, params, min_subtree_clade_size, 1); } else if (name == "--max-subtree-clade-size") { ParseOption(name, params, max_subtree_clade_size, 1); } else if (name == "--move-coeff-nodes") { - ParseOption(name, params, move_coeff_nodes, 1); + int temp; + ParseOption(name, params, temp, 1); + user_move_coeff_nodes = temp; } else if (name == "--sample-method") { std::string temp; ParseOption(name, params, temp, 1); @@ -631,9 +977,6 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) } else if (name == "--keep-fragment-uncollapsed") { ParseOption(name, params, collapse_empty_fragment_edges, 0); collapse_empty_fragment_edges = false; - } else if (name == "--quiet") { - ParseOption(name, params, write_intermediate_dag, 0); - write_intermediate_dag = false; } else if (name == "--inter-save") { uint temp; ParseOption(name, params, temp, 1); @@ -663,6 +1006,26 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) time_limit = ParseNumber(*params.begin()); } else if (name == "-S" or name == "--autodetect-stoptime") { plateau_stopping_condition = true; + } else if (name == "--scoring-backend") { + std::string temp; + ParseOption(name, params, temp, 1); + if (temp == "parsimony") { + scoring_backend = ScoringBackend::Parsimony; + } else if (temp == "ml") { + scoring_backend = ScoringBackend::ML; + } else { + std::cerr << "ERROR: Unknown `--scoring-backend` option '" << temp << "'." + << std::endl; + Fail(); + } + } else if (name == "--model-config") { + ParseOption(name, params, model_config_path, 1); + } else if (name == "--model-weights") { + ParseOption(name, params, model_weights_path, 1); + } else if (name == "--move-coeff-ml") { + double temp; + ParseOption(name, params, temp, 1); + user_move_coeff_ml = temp; } else { std::cerr << "Unknown argument '" << name << "'.\n"; Fail(); @@ -686,6 +1049,51 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) Fail(); } + // Validate logging arguments + if (write_intermediate_every_x_iters.has_value() and logfile_path.empty()) { + std::cerr << "ERROR: --inter-save requires -l,--logpath.\n"; + Fail(); + } + + // Validate ML scoring backend arguments + if (user_move_coeff_ml.has_value() and scoring_backend != ScoringBackend::ML) { + std::cerr << "ERROR: --move-coeff-ml requires --scoring-backend ml.\n"; + Fail(); + } + if (scoring_backend == ScoringBackend::ML) { + if (model_config_path.empty() or model_weights_path.empty()) { + std::cerr << "ERROR: --model-config and --model-weights are required when " + "--scoring-backend ml is selected.\n"; + Fail(); + } +#ifndef USE_NETAM + std::cerr << "ERROR: ML scoring backend requires building with -DUSE_NETAM=yes.\n"; + Fail(); +#endif + } + + // Load ML model if requested +#ifdef USE_NETAM + std::optional ml_model; + if (scoring_backend == ScoringBackend::ML) { + std::cout << "Loading ML model..." << std::flush; + ml_model.emplace(model_weights_path, model_config_path); + (*ml_model)->eval(); + std::cout << " done.\n"; + std::cout << " Model config: " << model_config_path << "\n"; + std::cout << " Model weights: " << model_weights_path << "\n"; + } +#endif + + // Resolve scoring coefficients based on backend. + // ML backend: default to ML-only (pscore=0, nodes=0, ml=1) + // Parsimony backend: default to parsimony-only (pscore=1, nodes=1, ml=0) + // User-specified values always override these defaults. + bool is_ml = (scoring_backend == ScoringBackend::ML); + int move_coeff_pscore = user_move_coeff_pscore.value_or(is_ml ? 0 : 1); + int move_coeff_nodes = user_move_coeff_nodes.value_or(is_ml ? 0 : 1); + double move_coeff_ml = user_move_coeff_ml.value_or(is_ml ? 1.0 : 0.0); + bool is_input_dag = refseq_path.empty(); if (input_format == FileFormat::Infer) { input_format = InferFileFormat(input_dag_path); @@ -699,16 +1107,9 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) output_format = FileFormat::ProtobufDAG; } } - if (intermediate_dag_path.empty()) { - intermediate_dag_path = output_dag_path + ".intermediate"; - } - RandomNumberGenerator main_rng{user_seed}; std::cout << "Random seed: " << main_rng.seed_ << "\n"; - std::filesystem::create_directory(logfile_path); - std::string logfile_name = logfile_path + "/logfile.csv"; - MPI_Init_thread(&argc, &argv, MPI_THREAD_MULTIPLE, &ignored); #ifdef DISABLE_PARALLELISM tbb::global_control c(tbb::global_control::max_allowed_parallelism, 1); @@ -721,12 +1122,6 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) std::cout << "Thread count: " << c.active_value(tbb::global_control::max_allowed_parallelism) << "\n"; - std::ofstream logfile; - logfile.open(logfile_name); - logfile << "Iteration\tNTrees\tNNodes\tNEdges\tMaxParsimony\tNTreesMaxParsimony\tWors" - "tParsimony\tMinSumRFDistance\tMaxSumRFDistance\tMinSumRFCount\tMaxSumRFCo" - "unt\tSecondsElapsed"; - Benchmark load_timer; std::cout << "Loading input DAG..." << std::flush; MADAGStorage<> input_dag_storage = LoadDAG(input_dag_path, input_format, refseq_path); @@ -748,91 +1143,27 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) optimized_dags; merge.ComputeResultEdgeMutations(); - Benchmark log_timer; - auto logger = [&merge, &logfile, &log_timer, &intermediate_dag_path, - &write_intermediate_dag, &write_intermediate_every_x_iters, - &output_format, &use_ua_free_parsimony](size_t iteration) { - std::cout << "############ Logging for iteration " << iteration << " #######\n"; - merge.ComputeResultEdgeMutations(); - - const auto root_node = merge.GetResult().GetRoot(); - // Tree count - SubtreeWeight tree_counter{merge.GetResult()}; - auto tree_count = tree_counter.ComputeWeightBelow(root_node, {}); - // Min Parsimony score - SubtreeWeight min_parsimony_scorer{ - merge.GetResult()}; - auto min_parsimony_score = min_parsimony_scorer.ComputeWeightBelow(root_node, {}); - auto min_parsimony_count = min_parsimony_scorer.MinWeightCount(root_node, {}); - // Max Parsimony score - SubtreeWeight max_parsimony_scorer{ - merge.GetResult()}; - auto max_parsimony_score = max_parsimony_scorer.ComputeWeightBelow(root_node, {}); - // Min UA-FREE Parsimony score - SubtreeWeight min_ua_free_parsimony_scorer{ - merge.GetResult()}; - auto min_ua_free_parsimony_score = - min_ua_free_parsimony_scorer.ComputeWeightBelow(root_node, {}); - auto min_ua_free_parsimony_count = - min_ua_free_parsimony_scorer.MinWeightCount(root_node, {}); - // Min Sum RF Distance - SumRFDistance min_sum_rf_dist_weight_ops{merge, merge}; - auto min_shift_sum = min_sum_rf_dist_weight_ops.GetOps().GetShiftSum(); - SubtreeWeight min_sum_rf_dist_scorer{merge.GetResult()}; - auto min_sum_rf_distance = min_sum_rf_dist_scorer.ComputeWeightBelow( - root_node, min_sum_rf_dist_weight_ops); - min_sum_rf_distance += min_shift_sum; - auto min_sum_rf_count = - min_sum_rf_dist_scorer.MinWeightCount(root_node, min_sum_rf_dist_weight_ops); - // Max Sum RF Distance - MaxSumRFDistance max_sum_rf_dist_weight_ops{merge, merge}; - auto max_shift_sum = max_sum_rf_dist_weight_ops.GetOps().GetShiftSum(); - SubtreeWeight max_sum_rf_dist_scorer{merge.GetResult()}; - auto max_sum_rf_distance = max_sum_rf_dist_scorer.ComputeWeightBelow( - root_node, max_sum_rf_dist_weight_ops); - max_sum_rf_distance += max_shift_sum; - auto max_sum_rf_count = - max_sum_rf_dist_scorer.MinWeightCount(root_node, max_sum_rf_dist_weight_ops); - - log_timer.stop(); - - if (use_ua_free_parsimony) { - min_parsimony_score = min_ua_free_parsimony_score; - min_parsimony_count = min_ua_free_parsimony_count; - } - std::cout << "Min parsimony score in DAG: " << min_parsimony_score << "\n"; - std::cout << "Max parsimony score in DAG: " << max_parsimony_score << "\n"; - std::cout << ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Total trees in DAG: " << tree_count - << "\n"; - std::cout << "Optimal trees in DAG: " << min_parsimony_count << "\n"; - std::cout << "Min summed RF distance over trees: " << min_sum_rf_distance << "\n"; - std::cout << "Max summed RF distance over trees: " << max_sum_rf_distance << "\n"; - - logfile << '\n' - << iteration << '\t' << tree_count << '\t' - << merge.GetResult().GetNodesCount() << '\t' - << merge.GetResult().GetEdgesCount() << '\t' << min_parsimony_score << '\t' - << min_parsimony_count << '\t' << max_parsimony_score << '\t' - << min_sum_rf_distance << '\t' << max_sum_rf_distance << '\t' - << min_sum_rf_count << '\t' << max_sum_rf_count << '\t' - << log_timer.durationS() << std::flush; - - if (write_intermediate_dag) { - bool append_changes = (iteration > 0); - StoreDAG(merge.GetResult(), intermediate_dag_path, output_format, append_changes); - if (write_intermediate_every_x_iters.has_value() and - (iteration % write_intermediate_every_x_iters.value() == 0)) { - std::string intermediate_dag_path_final = - intermediate_dag_path + "." + std::to_string(iteration); - std::cout << "############ Saving intermediate DAG file to: " - << intermediate_dag_path_final << std::endl; - std::filesystem::copy_file(intermediate_dag_path, intermediate_dag_path_final, - std::filesystem::copy_options::overwrite_existing); - } + OptimizationLogger logger{merge, use_ua_free_parsimony, output_format}; +#ifdef USE_NETAM + if (ml_model.has_value()) { + logger.SetMLModel(&(*ml_model)); + } +#endif + if (not logfile_path.empty()) { + logger.OpenLogfile(logfile_path); + if (write_intermediate_every_x_iters.has_value()) { + logger.SetIntermediateSaveInterval(write_intermediate_every_x_iters.value()); } - return min_parsimony_score; - }; - current_best_parsimony = logger(0); + } + current_best_parsimony = logger.Log(0); + + MLScoringConfig ml_config; +#ifdef USE_NETAM + if (ml_model.has_value()) { + ml_config.model = &(*ml_model); + } +#endif + ml_config.coeff = move_coeff_ml; bool subtrees = false; for (size_t i = 0; i < iter_count; ++i) { @@ -980,16 +1311,19 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) } else if (callback_config == CallbackMethod::BestMovesFixedTree) { Merge_All_Profitable_Moves_Found_Fixed_Tree_Callback callback{ merge, {move_coeff_nodes, move_coeff_pscore}, collapse_empty_fragment_edges}; + callback.ml_config_ = ml_config; optimized_dags.push_back( optimize_dag_direct(sample_tree.View(), callback, callback, callback)); } else if (callback_config == CallbackMethod::BestMovesTreeBased) { Treebased_Move_Found_Callback callback{ merge, {move_coeff_nodes, move_coeff_pscore}, collapse_empty_fragment_edges}; + callback.ml_config_ = ml_config; optimized_dags.push_back( optimize_dag_direct(sample_tree.View(), callback, callback, callback)); } else if (callback_config == CallbackMethod::BestMoves) { Merge_All_Profitable_Moves_Found_Callback callback{ merge, {move_coeff_nodes, move_coeff_pscore}, collapse_empty_fragment_edges}; + callback.ml_config_ = ml_config; optimized_dags.push_back( optimize_dag_direct(sample_tree.View(), callback, callback, callback)); } else { @@ -1000,7 +1334,7 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) auto optimized_view = optimized_dags.back().first.View(); optimized_view.RecomputeCompactGenomes(false); merge.AddDAG(optimized_view); - auto this_parsimony = logger(i + 1); + auto this_parsimony = logger.Log(i + 1); if (time_limit > 0) { total_timer.stop(); if (size_t(total_timer.durationFloorMinutes()) >= time_limit) { @@ -1022,8 +1356,11 @@ int main(int argc, char** argv) { // NOLINT(bugprone-exception-escape) std::cout << "new node coefficient: " << move_coeff_nodes << "\n"; std::cout << "parsimony score coefficient: " << move_coeff_pscore << "\n"; - - logfile.close(); + std::cout << "scoring backend: " + << (scoring_backend == ScoringBackend::ML ? "ml" : "parsimony") << "\n"; + if (move_coeff_ml != 0.0) { + std::cout << "ML score coefficient: " << move_coeff_ml << "\n"; + } Benchmark save_timer; std::cout << "Saving final DAG..." << std::flush;