Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ option(USE_HTSLIB "Use htslib" OFF)
option(USE_NETAM "Use netam/thrifty libtorch integration" OFF)
option(DISABLE_PARALLELISM "Disable all parallelism for debugging" OFF)
option(USE_SYSTEM_TBB "Use system-installed TBB instead of fetching from source" OFF)
option(USE_TEST_LOGS "Write debug artifacts (before_optimize.pb, after_optimize.pb)" OFF)

set(TBB_VERSION "v2022.1.0")

Expand Down Expand Up @@ -131,7 +132,7 @@ if(MPI_C_COMPILER)
# Get the first path (space-separated list)
string(REPLACE " " ";" MPI_INCDIRS_LIST "${MPI_INCDIRS}")
list(GET MPI_INCDIRS_LIST 0 MPI_PRIMARY_INCDIR)

set(MPI_C_HEADER_DIR "${MPI_PRIMARY_INCDIR}" CACHE PATH "")
set(MPI_CXX_HEADER_DIR "${MPI_PRIMARY_INCDIR}" CACHE PATH "")
endif()
Expand Down Expand Up @@ -339,6 +340,10 @@ function(larch_compile_opts PRODUCT)
target_compile_options(${PRODUCT} PUBLIC -DDISABLE_PARALLELISM)
endif()

if(USE_TEST_LOGS)
target_compile_options(${PRODUCT} PUBLIC -DUSE_TEST_LOGS)
endif()

if(USE_ASAN)
target_compile_options(${PRODUCT} PUBLIC -O0 -g3 -fsanitize=address -fno-sanitize-recover)
elseif(USE_TSAN)
Expand All @@ -350,14 +355,14 @@ function(larch_compile_opts PRODUCT)
# (see torch target modification after find_package(Torch))
target_link_libraries(${PRODUCT} PUBLIC protobuf::libprotobuf)
if(absl_FOUND)
target_link_libraries(${PRODUCT} PUBLIC
target_link_libraries(${PRODUCT} PUBLIC
absl::log
absl::log_internal_check_impl
)
endif()
target_include_directories(${PRODUCT} PUBLIC ${PROTO_OUT_DIR})
target_include_directories(${PRODUCT} PUBLIC ${PROTO_OUT_DIR}/deps/usher)

target_include_directories(${PRODUCT} PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/deps/range-v3/install/include)
target_compile_options(${PRODUCT} PUBLIC -DRANGES_DISABLE_DEPRECATED_WARNINGS)
add_dependencies(${PRODUCT} range-v3)
Expand Down Expand Up @@ -596,11 +601,11 @@ larch_executable(larch-dag2dot
)
larch_install(larch-dag2dot)

# # bcr-larch
larch_executable(bcr-larch
# # larch-bcr
larch_executable(larch-bcr
tools/bcr-larch.cpp
)
larch_install(bcr-larch)
larch_install(larch-bcr)

# # larch-usher
if(USE_USHER)
Expand All @@ -610,6 +615,10 @@ if(USE_USHER)
target_compile_options(larch-usher PRIVATE ${STRICT_WARNINGS})
target_link_libraries(larch-usher PRIVATE larch-usher-glue)
add_dependencies(larch-usher larch-usher-glue)
if(USE_NETAM)
target_compile_definitions(larch-usher PRIVATE USE_NETAM)
target_link_libraries(larch-usher PRIVATE netam)
endif()
larch_install(larch-usher)
endif()

12 changes: 11 additions & 1 deletion include/larch/impl/produce_mat_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback,
static_assert(DAG::template contains_element_feature<Component::Node, MATConversion>);
auto& tree = dag.GetMutableMAT();

#ifdef KEEP_ASSERTS
#ifdef USE_TEST_LOGS
Mutation_Annotated_Tree::save_mutation_annotated_tree(tree, "before_optimize.pb");
#endif
#ifdef KEEP_ASSERTS
check_MAT_MADAG_Eq(tree, dag);
#endif

Expand Down Expand Up @@ -92,6 +94,12 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback,
"intermediate_newick", // intermediate newick name
callback // callback
);
#ifndef USE_TEST_LOGS
// matOptimize unconditionally writes intermediate_newick*.pb.gz when
// allow_drift is true. Clean up since we can't suppress it without
// changing search behavior.
std::remove("intermediate_newick1.pb.gz");
#endif
tree.uncondense_leaves();
tree.condense_leaves(condense_arg);
tree.fix_node_idx();
Expand All @@ -100,7 +108,9 @@ auto optimize_dag_direct(DAG dag, Move_Found_Callback& callback,

tree.uncondense_leaves();
tree.fix_node_idx();
#ifdef USE_TEST_LOGS
Mutation_Annotated_Tree::save_mutation_annotated_tree(tree, "after_optimize.pb");
#endif
auto result =
std::make_pair(AddMATConversion(MADAGStorage<>::EmptyDefault()), std::move(tree));
result.first.View().BuildFromMAT(result.second, dag.GetReferenceSequence());
Expand Down
4 changes: 4 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ description = "Build larch using settings from larch-build.env."
cmd = "bash scripts/run-tests.sh -tag slow"
description = "Run tests (excluding slow tests). Build first with 'pixi run build'."

[tasks.test-netam]
cmd = "bash scripts/run-tests.sh +tag netam"
description = "Run all netam tests (ML SPR, likelihood, model, S5F). Requires USE_NETAM=yes build."

[tasks.test-all]
cmd = "bash scripts/run-tests.sh"
description = "Run all tests including slow tests. WARNING: Very slow."
Expand Down
1 change: 1 addition & 0 deletions scripts/larch-build.env.template
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ LARCH_PROTOBUF_PATH=auto
# -DDISABLE_PARALLELISM=yes Disable all parallelism (debugging)
# -DKEEP_ASSERTS=ON Keep asserts in release builds
# -DUSE_CPPTRACE=ON Readable backtraces on exceptions
# -DUSE_TEST_LOGS=ON Write debug artifacts (before/after_optimize.pb)
# -DUSE_SYSTEM_TBB=yes Use system TBB instead of fetching from source
# (auto-enabled on macOS)
# Example: LARCH_CMAKE_EXTRA=-DUSE_ASAN=yes -DKEEP_ASSERTS=ON
Expand Down
17 changes: 10 additions & 7 deletions src/netam/kmer_sequence_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ std::pair<torch::Tensor, torch::Tensor> kmer_sequence_encoder::encode_sequence(
std::string padded_sequence = std::string(overhang_length_, 'N') + upper_seq +
std::string(overhang_length_, 'N');

// Index for N-containing kmers (last index, 4^k)
auto n_index = signed_cast(narrowing_cast<std::uint32_t>(all_kmers_.size() - 1));
// Index for N-containing kmers (index 0, matching Python netam convention)
constexpr std::int32_t n_index = 0;

// Encode kmers
std::vector<std::int32_t> kmer_indices;
Expand Down Expand Up @@ -75,8 +75,8 @@ torch::Tensor kmer_sequence_encoder::encode_bases(const std::string& sequence) {
std::vector<std::string> kmer_sequence_encoder::generate_kmers(std::size_t length) {
std::vector<std::string> kmers;

// Generate all possible kmers of given length (4^k kmers)
// Uses lexicographic order: AAAAA=0, AAAAC=1, AAAAG=2, AAAAT=3, AAACA=4, ...
// Generate all possible kmers of given length (4^k kmers) in lexicographic order.
// Final indices: N=0, AAA=1, AAC=2, ..., TTT=4^k (N-kmer prepended below).
std::function<void(std::string, std::size_t)> generate = [&](std::string current,
std::size_t pos) {
if (pos == length) {
Expand All @@ -90,9 +90,12 @@ std::vector<std::string> kmer_sequence_encoder::generate_kmers(std::size_t lengt

generate("", 0);

// Add placeholder for N-containing kmers at the end (index 4^k)
kmers.push_back("N");
return kmers;
// N-kmer placeholder at index 0 (matches Python netam convention)
std::vector<std::string> result;
result.reserve(1 + kmers.size());
result.push_back("N");
result.insert(result.end(), kmers.begin(), kmers.end());
return result;
}

torch::Tensor kmer_sequence_encoder::compute_wt_base_modifier(
Expand Down
18 changes: 18 additions & 0 deletions test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ static void get_usage() {
{"--list",
"Prints information about all selected tests (IDs, tags), but does not run "
"them"},
{"--list-tags", "List all available test tags and exit"},
{"nocatch", "Allow exceptions to escape for debugging"}};

std::cout << FormatUsage(program_desc, usage_examples, flag_desc_pairs);
Expand Down Expand Up @@ -91,6 +92,7 @@ int main(int argc, char* argv[]) {
#endif
bool no_catch = false;
bool opt_list_names = false;
bool opt_list_tags = false;
bool opt_test_range = false;

assert(std::filesystem::exists("./data/") && "Test data folder not found.");
Expand All @@ -108,6 +110,8 @@ int main(int argc, char* argv[]) {
no_catch = true;
} else if (std::string("--list") == argv[i]) {
opt_list_names = true;
} else if (std::string("--list-tags") == argv[i]) {
opt_list_tags = true;
} else if (std::string("--range") == argv[i]) {
opt_test_range = true;
range = parse_range(argv[++i]);
Expand Down Expand Up @@ -160,6 +164,20 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;
}

if (opt_list_tags) {
std::set<std::string> all_tags;
for (const auto& test : get_all_tests()) {
for (const auto& tag : test.tags) {
all_tags.insert(tag);
}
}
std::cout << "AVAILABLE TAGS:" << std::endl;
for (const auto& tag : all_tags) {
std::cout << " " << tag << std::endl;
}
return EXIT_SUCCESS;
}

if (opt_list_names) {
return EXIT_SUCCESS;
}
Expand Down
3 changes: 2 additions & 1 deletion test/test_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ inline void print_peak_mem() {
}

inline const std::string test_output_folder = "data/_ignore/";
inline const std::string test_log_folder = test_output_folder + "/optimization_log";

inline std::pair<std::string, int> run_larch_usher(
std::string_view input_dag_path, std::string_view output_dag_path,
Expand All @@ -75,7 +76,7 @@ inline std::pair<std::string, int> run_larch_usher(
std::optional<std::string_view> other_options = std::nullopt,
bool do_print_stdout = true, bool do_print_stderr = false,
bool do_print_summary = true) {
std::string log_path = test_output_folder + "/optimization_log";
std::string log_path = test_log_folder;

std::stringstream ss;
// ss << "/usr/bin/time ";
Expand Down
9 changes: 3 additions & 6 deletions test/test_fileio_dagbin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,20 @@
std::string output_dag_path_no_ext = test_output_folder + "/temp";
std::string output_dag_path_protobuf = test_output_folder + "/temp.pb";
std::string output_dag_path_dagbin = test_output_folder + "/temp.dagbin";
std::string inter_dag_path_protobuf = output_dag_path_protobuf + ".intermediate";
std::string inter_dag_path_dagbin = output_dag_path_dagbin + ".intermediate";
// Intermediate DAGs are written inside the log directory by larch-usher
std::string inter_dag_path_protobuf = test_log_folder + "/intermediate.pb";
std::string inter_dag_path_dagbin = test_log_folder + "/intermediate.dagbin";
std::string other_options = "--thread 1 ";
if (use_seed) {
other_options += "--seed 42 ";
}

if (save_both) {
inter_dag_path_protobuf = output_dag_path_no_ext + ".intermediate.pb";
inter_dag_path_dagbin = output_dag_path_no_ext + ".intermediate.dagbin";
other_options += " --output-format debug-all";
auto [command, result] = run_larch_usher(input_dag_path, output_dag_path_no_ext,
std::nullopt, iter, other_options);
TestAssert((result == 0) && "larch-usher debug-all run failed.");
} else {
inter_dag_path_protobuf = output_dag_path_protobuf + ".intermediate";
inter_dag_path_dagbin = output_dag_path_dagbin + ".intermediate";
auto [command1, result1] = run_larch_usher(input_dag_path, output_dag_path_protobuf,
std::nullopt, iter, other_options);
TestAssert((result1 == 0) && "larch-usher protobuf run failed.");
Expand Down
2 changes: 1 addition & 1 deletion test/test_larch_usher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static void test_larch_usher_merged_dag() {
std::string output_dag_path = test_output_folder + "/opt_dag.pb";
std::string command = "./bin/larch-usher -i data/larch_merged_dag.pb -o " +
output_dag_path +
" -c 1 -s 0 --max-subtree-clade-size 2000 --trim --quiet";
" -c 1 -s 0 --max-subtree-clade-size 2000 --trim";
std::cout << ">COMMAND_EXECUTE: \"" << command << "\"" << std::endl;
int result = std::system(command.c_str());
TestAssert(0 == result);
Expand Down
9 changes: 4 additions & 5 deletions test/test_netam_kmer_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,20 +226,19 @@ void test_encode_sequence_kmer_indices_valid() {
}

void test_encode_sequence_padding_creates_zero_index() {
// K-mers containing N should map to the last index (kmer_count - 1)
// K-mers containing N should map to index 0 (matching Python netam convention)
auto yaml = make_config(3, 5);
kmer_sequence_encoder encoder{yaml};

auto [encoded, wt_modifier] = encoder.encode_sequence("ACG");

auto encoded_accessor = encoded.accessor<int32_t, 1>();
auto n_index =
static_cast<int32_t>(encoder.kmer_count() - 1);
constexpr int32_t n_index = 0;

// Position 0: k-mer is "NAC" (contains N from padding) -> last index
// Position 0: k-mer is "NAC" (contains N from padding) -> index 0
TestAssert(encoded_accessor[0] == n_index);

// Position 2: k-mer is "CGN" (contains N from padding) -> last index
// Position 2: k-mer is "CGN" (contains N from padding) -> index 0
TestAssert(encoded_accessor[2] == n_index);
}

Expand Down
Loading