Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fhe-cmplr/ckks/include/min_cut_region.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "dfg_region_container.h"
#include "fhe/ckks/config.h"
#include "fhe/core/lower_ctx.h"
#include <map>
namespace fhe {
namespace ckks {

Expand Down Expand Up @@ -156,6 +157,12 @@ class MIN_CUT_REGION {
//! @brief Return the cost of the current cut, which is the sum of the
//! cost of all bootstrap/rescale operations required for the current cut.
double Cut_cost(const CUT_TYPE& cut);
//! @brief Ensure no ancester-descendant relationship in cut to avoid repeat cut (O(n))
bool Verify_scale(const CUT_TYPE& cur_cut);
//! @brief Helper function using DFS with memoization to check elem_id's ancestors
bool HasValidAncestors( REGION_ELEM_ID elem_id,
const std::set<REGION_ELEM_ID>& cut_set,
std::map<REGION_ELEM_ID, int>& visited_status);
//! @brief Merge node, and update cur_cut.
void Update_cut(const SCC_NODE_PTR& node, CUT_TYPE& cur_cut);
//! @brief Collect all the nodes originating from upper region to use as
Expand Down
72 changes: 72 additions & 0 deletions fhe-cmplr/ckks/src/min_cut_region.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,76 @@ double MIN_CUT_REGION::Init_src_node(void) {
return cost;
}

bool MIN_CUT_REGION::Verify_scale(const CUT_TYPE& cut) {
// Get all elements in the cut
std::set<REGION_ELEM_ID> cut_set = cut.Cut_elem();
if (cut_set.empty()) return true;

// Map to store visited status for each element:
// -1: not visited, 0: no ancestor in cut, 1: has ancestor in cut
std::map<REGION_ELEM_ID, int> visited_status;

// Check each element in the cut
for (REGION_ELEM_ID elem_id : cut_set) {
if (!HasValidAncestors(elem_id, cut_set, visited_status)) {
return false;
}
}

return true;
}

bool MIN_CUT_REGION::HasValidAncestors(
REGION_ELEM_ID elem_id,
const std::set<REGION_ELEM_ID>& cut_set,
std::map<REGION_ELEM_ID, int>& visited_status) {

// Check memoization
auto it = visited_status.find(elem_id);
if (it != visited_status.end()) {
if (it->second != -1) {
return it->second == 0; // Return true if no ancestor in cut
}else return true; // checked circle
} else {
visited_status[elem_id] = -1; // Initialize if not in map
}

// Get element node
REGION_ELEM_PTR elem_node = Region_cntr()->Node(elem_id);

// Check all predecessors
for (uint32_t i = 0; i < elem_node->Pred_cnt(); ++i) {
REGION_ELEM_ID pred_id = elem_node->Pred_id(i);
if (pred_id == air::base::Null_id) continue;
REGION_ELEM_PTR pred = Region_cntr()->Node(pred_id);
if(pred->Region_id() != elem_node->Region_id()) continue;

// If predecessor is in the cut, invalid
if (cut_set.find(pred_id) != cut_set.end()) {
visited_status[elem_id] = 1; // Has ancestor in cut
return false;
}

// Recursively check the predecessor
if (!HasValidAncestors(pred_id, cut_set, visited_status)) {
visited_status[elem_id] = 1; // Has ancestor in cut
//all region nodes in same scc be false
SCC_NODE_ID scc_id = Scc_cntr()->Scc_node(elem_id);
SCC_NODE_PTR scc_node = Scc_cntr()->Node(scc_id);
for(SCC_NODE::ELEM_ITER elt = scc_node->Begin_elem(); elt != scc_node->End_elem(); elt++){
REGION_ELEM_ID elem_id_t(*elt);
visited_status[elem_id_t] = 1;
}
return false;
}
}

// All ancestors are valid
visited_status[elem_id] = 0; // No ancestor in cut
return true;
}


void MIN_CUT_REGION::Update_cut(const SCC_NODE_PTR& scc_node, CUT_TYPE& cut) {
// 1. add element node in SCC_NODE into cut
for (SCC_NODE::ELEM_ITER elem_iter = scc_node->Begin_elem();
Expand Down Expand Up @@ -350,6 +420,8 @@ void MIN_CUT_REGION::Min_cut_phase() {
op_cost_incr += Scc_node_op_cost(next_src);
Update_cut(next_src, cur_cut);
if (cur_cut.Empty()) continue;
//avoid rescale duplicates: ancestor-descendant nodes exist in cut set
if (!Verify_scale(cur_cut)) continue;
// cost increase of rescale/bootstrap and FHE operations.
double tot_cost_incr = op_cost_incr + Cut_cost(cur_cut);
cur_cut.Set_cost_incr(tot_cost_incr);
Expand Down
Binary file added model/labelPsi.onnx
Binary file not shown.