diff --git a/fhe-cmplr/ckks/include/min_cut_region.h b/fhe-cmplr/ckks/include/min_cut_region.h index d688bc93..554d480b 100644 --- a/fhe-cmplr/ckks/include/min_cut_region.h +++ b/fhe-cmplr/ckks/include/min_cut_region.h @@ -17,6 +17,7 @@ #include "dfg_region_container.h" #include "fhe/ckks/config.h" #include "fhe/core/lower_ctx.h" +#include namespace fhe { namespace ckks { @@ -156,6 +157,12 @@ class MIN_CUT_REGION { //! @brief Return the cost of the current cut, which is the sum of the //! cost of all bootstrap/rescale operations required for the current cut. double Cut_cost(const CUT_TYPE& cut); + //! @brief Ensure no ancester-descendant relationship in cut to avoid repeat cut (O(n)) + bool Verify_scale(const CUT_TYPE& cur_cut); + //! @brief Helper function using DFS with memoization to check elem_id's ancestors + bool HasValidAncestors( REGION_ELEM_ID elem_id, + const std::set& cut_set, + std::map& visited_status); //! @brief Merge node, and update cur_cut. void Update_cut(const SCC_NODE_PTR& node, CUT_TYPE& cur_cut); //! @brief Collect all the nodes originating from upper region to use as diff --git a/fhe-cmplr/ckks/src/min_cut_region.cxx b/fhe-cmplr/ckks/src/min_cut_region.cxx index d512117f..1347aae8 100644 --- a/fhe-cmplr/ckks/src/min_cut_region.cxx +++ b/fhe-cmplr/ckks/src/min_cut_region.cxx @@ -281,6 +281,76 @@ double MIN_CUT_REGION::Init_src_node(void) { return cost; } +bool MIN_CUT_REGION::Verify_scale(const CUT_TYPE& cut) { + // Get all elements in the cut + std::set cut_set = cut.Cut_elem(); + if (cut_set.empty()) return true; + + // Map to store visited status for each element: + // -1: not visited, 0: no ancestor in cut, 1: has ancestor in cut + std::map visited_status; + + // Check each element in the cut + for (REGION_ELEM_ID elem_id : cut_set) { + if (!HasValidAncestors(elem_id, cut_set, visited_status)) { + return false; + } + } + + return true; +} + +bool MIN_CUT_REGION::HasValidAncestors( + REGION_ELEM_ID elem_id, + const std::set& cut_set, + std::map& visited_status) { + + // Check memoization + auto it = visited_status.find(elem_id); + if (it != visited_status.end()) { + if (it->second != -1) { + return it->second == 0; // Return true if no ancestor in cut + }else return true; // checked circle + } else { + visited_status[elem_id] = -1; // Initialize if not in map + } + + // Get element node + REGION_ELEM_PTR elem_node = Region_cntr()->Node(elem_id); + + // Check all predecessors + for (uint32_t i = 0; i < elem_node->Pred_cnt(); ++i) { + REGION_ELEM_ID pred_id = elem_node->Pred_id(i); + if (pred_id == air::base::Null_id) continue; + REGION_ELEM_PTR pred = Region_cntr()->Node(pred_id); + if(pred->Region_id() != elem_node->Region_id()) continue; + + // If predecessor is in the cut, invalid + if (cut_set.find(pred_id) != cut_set.end()) { + visited_status[elem_id] = 1; // Has ancestor in cut + return false; + } + + // Recursively check the predecessor + if (!HasValidAncestors(pred_id, cut_set, visited_status)) { + visited_status[elem_id] = 1; // Has ancestor in cut + //all region nodes in same scc be false + SCC_NODE_ID scc_id = Scc_cntr()->Scc_node(elem_id); + SCC_NODE_PTR scc_node = Scc_cntr()->Node(scc_id); + for(SCC_NODE::ELEM_ITER elt = scc_node->Begin_elem(); elt != scc_node->End_elem(); elt++){ + REGION_ELEM_ID elem_id_t(*elt); + visited_status[elem_id_t] = 1; + } + return false; + } + } + + // All ancestors are valid + visited_status[elem_id] = 0; // No ancestor in cut + return true; +} + + void MIN_CUT_REGION::Update_cut(const SCC_NODE_PTR& scc_node, CUT_TYPE& cut) { // 1. add element node in SCC_NODE into cut for (SCC_NODE::ELEM_ITER elem_iter = scc_node->Begin_elem(); @@ -350,6 +420,8 @@ void MIN_CUT_REGION::Min_cut_phase() { op_cost_incr += Scc_node_op_cost(next_src); Update_cut(next_src, cur_cut); if (cur_cut.Empty()) continue; + //avoid rescale duplicates: ancestor-descendant nodes exist in cut set + if (!Verify_scale(cur_cut)) continue; // cost increase of rescale/bootstrap and FHE operations. double tot_cost_incr = op_cost_incr + Cut_cost(cur_cut); cur_cut.Set_cost_incr(tot_cost_incr); diff --git a/model/labelPsi.onnx b/model/labelPsi.onnx new file mode 100644 index 00000000..932cd7e9 Binary files /dev/null and b/model/labelPsi.onnx differ