diff --git a/fhe-cmplr/include/fhe/ckks/ir2c_handler.h b/fhe-cmplr/include/fhe/ckks/ir2c_handler.h index 5f7b857e..8f9f8b88 100644 --- a/fhe-cmplr/include/fhe/ckks/ir2c_handler.h +++ b/fhe-cmplr/include/fhe/ckks/ir2c_handler.h @@ -231,7 +231,10 @@ class IR2C_HANDLER : public INVALID_HANDLER { air::base::NODE_PTR parent = ctx.Parent(1); AIR_ASSERT(parent != air::base::Null_ptr && parent->Is_st()); - if (ctx.Provider() == core::PROVIDER::ANT) { + bool emit_data_file = ctx.Provider() == core::PROVIDER::ANT || + (ctx.Provider() == core::PROVIDER::PHANTOM && + ctx.Emit_data_file()); + if (emit_data_file) { ctx.template Emit_encode(visitor, parent, node); } else { ctx.template Emit_runtime_encode(visitor, parent, node); @@ -276,4 +279,4 @@ class IR2C_HANDLER : public INVALID_HANDLER { } // namespace ckks } // namespace fhe -#endif \ No newline at end of file +#endif diff --git a/fhe-cmplr/rtlib/cmake/modules/phantom.cmake b/fhe-cmplr/rtlib/cmake/modules/phantom.cmake index 113fe15a..d28988b9 100644 --- a/fhe-cmplr/rtlib/cmake/modules/phantom.cmake +++ b/fhe-cmplr/rtlib/cmake/modules/phantom.cmake @@ -8,8 +8,8 @@ # Build external Phantom project dependent function function(build_external_phantom) - set(PHANTOM_URL "https://git:$ENV{CI_TOKEN}@code.alipay.com/zhanggongliang.zgl/phantom-fhe.git") - set(PHANTOM_URL_SSH "git@code.alipay.com:zhanggongliang.zgl/phantom-fhe.git") + set(PHANTOM_URL "https://github.com/zggl404/phantom-fhe.git") + set(PHANTOM_URL_SSH "git@github.com:zggl404/phantom-fhe.git") if(EXTERNAL_URL_SSH) set(REPO_PHANTOM_URL ${PHANTOM_URL_SSH}) else() diff --git a/fhe-cmplr/rtlib/include/common/rt_env.h b/fhe-cmplr/rtlib/include/common/rt_env.h index 8d117a3d..dae36fd0 100644 --- a/fhe-cmplr/rtlib/include/common/rt_env.h +++ b/fhe-cmplr/rtlib/include/common/rt_env.h @@ -25,6 +25,8 @@ #define ENV_PT_ENTRY_COUNT "PT_ENTRY_COUNT" //! PT_PREFETCH_COUNT: number of pt for prefetching. default: 2 #define ENV_PT_PREFETCH_COUNT "PT_PREFETCH_COUNT" +//! PT_MSG_DUMP_COUNT=int: number of message values to dump. default: 0 +#define ENV_PT_MSG_DUMP_COUNT "PT_MSG_DUMP_COUNT" //! environment variable to control rt data file reader (RT_DATA_FILE) //! RT_DATA_ASYNC_READ=0|1: use asynchronous read. default: 0 diff --git a/fhe-cmplr/rtlib/phantom/src/phantom_lib.cu b/fhe-cmplr/rtlib/phantom/src/phantom_lib.cu index 5a8b3204..de091ef4 100755 --- a/fhe-cmplr/rtlib/phantom/src/phantom_lib.cu +++ b/fhe-cmplr/rtlib/phantom/src/phantom_lib.cu @@ -4,10 +4,17 @@ #include "common/common.h" #include "common/error.h" #include "common/io_api.h" +#include "common/pt_mgr.h" #include "common/rt_api.h" #include "common/rtlib_timing.h" #include "boot/Bootstrapper.cuh" +#include +#include +#include +#include +#include + using namespace phantom; using namespace phantom::arith; using namespace phantom::util; @@ -57,7 +64,9 @@ public: std::vector vec(input->_vals, input->_vals + len); CKKS_PARAMS *prog_param = Get_context_params(); Plaintext pt; - _evaluator->encoder.encode(vec, std::pow(2.0, _scaling_mod_size), pt); + int chain_index = _num_prime_parts - prog_param->_input_level; + double encode_scale = std::pow(2.0, _scaling_mod_size); + _evaluator->encoder.encode(vec, chain_index, encode_scale, pt); Ciphertext *ct = new Ciphertext; _evaluator->encryptor.encrypt(pt, *ct); Io_set_input(name, 0, ct); @@ -91,8 +100,9 @@ public: LEVEL_T level) { std::vector vec(input, input + len); - std::vector vec_after; - _evaluator->encoder.encode(vec, _num_prime_parts - level, std::pow(2.0, _scaling_mod_size * scale), *pt); + double encode_scale = std::pow(2.0, _scaling_mod_size * scale); + int chain_index = _num_prime_parts - level; + _evaluator->encoder.encode(vec, chain_index, encode_scale, *pt); } void Encode_float_cst_lvl(Plaintext *pt, float *input, size_t len, @@ -109,7 +119,9 @@ public: LEVEL_T level) { std::vector vec(len, input); - _evaluator->encoder.encode(vec, _num_prime_parts - level, std::pow(2.0, _scaling_mod_size * scale), *pt); + double encode_scale = std::pow(2.0, _scaling_mod_size * scale); + int chain_index = _num_prime_parts - level; + _evaluator->encoder.encode(vec, chain_index, encode_scale, *pt); } void Encode_float_mask_cst_lvl(Plaintext *pt, float input, size_t len, SCALE_T scale, int level) @@ -276,26 +288,73 @@ public: _evaluator->evaluator.relinearize(*op1, *_rlk, *res); } } + void Bootstrap(Ciphertext *op1, Ciphertext *res, int level, int slot) { _evaluator->evaluator.mod_switch_to_inplace(*op1, _num_prime_parts - 1); - switch (slot) + int effective_slot = slot; + if (effective_slot == 0) + { + effective_slot = 32768; + } + + Ciphertext *bootstrap_input = op1; + + Ciphertext cyclic_input; + bool cyclic_input_used = false; + if (effective_slot > 0 && static_cast(effective_slot) < _slot_count) + { + cyclic_input = *bootstrap_input; + Ciphertext rotated_input = *bootstrap_input; + for (uint64_t offset = effective_slot; offset < _slot_count; offset += effective_slot) + { + Rotate(&rotated_input, effective_slot, &rotated_input); + Add(&cyclic_input, &rotated_input, &cyclic_input); + } + rotated_input.release(); + bootstrap_input = &cyclic_input; + cyclic_input_used = true; + } + + Ciphertext in_place_input; + bool in_place_bootstrap = res == op1; + if (in_place_bootstrap) + { + in_place_input = *bootstrap_input; + bootstrap_input = &in_place_input; + } + + res->release(); + switch (effective_slot) { case 16384: - _bootstrapper_16384->bootstrap_3(*res, *op1); + _bootstrapper_16384->bootstrap_real_3(*res, *bootstrap_input); break; case 8192: - _bootstrapper_8192->bootstrap_3(*res, *op1); + _bootstrapper_8192->bootstrap_real_3(*res, *bootstrap_input); break; case 4096: - _bootstrapper_4096->bootstrap_3(*res, *op1); + _bootstrapper_4096->bootstrap_real_3(*res, *bootstrap_input); + break; + case 32768: + _bootstrapper_32768->bootstrap_real_3(*res, *bootstrap_input); break; default: - std::cout<<"Unsupported slot size for bootstrap: (must 16384,8192,4096)" << slot << std::endl; + IS_TRUE(false, "Unsupported slot size for bootstrap"); break; } + if (in_place_bootstrap) + { + in_place_input.release(); + } + + if (cyclic_input_used) + { + cyclic_input.release(); + } + int target_level = _num_prime_parts - level; if (level != 0 && target_level > res->chain_index()) { @@ -327,17 +386,28 @@ public: EncryptionParameters parms(scheme_type::ckks); uint32_t degree = prog_param->_poly_degree; parms.set_poly_modulus_degree(degree); + + uint32_t _bts_required_level = 14; + uint32_t _bts_remaining_level = prog_param->_mul_depth - 14; std::vector bits; - bits.push_back(prog_param->_scaling_mod_size); - for (uint32_t i = 0; i < prog_param->_mul_depth; ++i) + bits.push_back(prog_param->_first_mod_size); + for (uint32_t i = 0; i < _bts_remaining_level; ++i) { bits.push_back(prog_param->_scaling_mod_size); } - - bits.push_back(prog_param->_first_mod_size); + for (uint32_t i = 0; i < _bts_required_level; ++i) + { + bits.push_back(prog_param->_first_mod_size); + } + constexpr size_t special_modulus_size = 4; + for (size_t i = 0; i < special_modulus_size; i++) + { + bits.push_back(prog_param->_first_mod_size); + } parms.set_coeff_modulus(phantom::arith::CoeffModulus::Create(degree, bits)); parms.set_secret_key_hamming_weight(192); - _num_prime_parts = bits.size(); + parms.set_special_modulus_size(special_modulus_size); + _num_prime_parts = bits.size() - special_modulus_size + 1; phantom::arith::sec_level_type sec = phantom::arith::sec_level_type::tc128; switch (prog_param->_sec_level) { @@ -379,6 +449,9 @@ public: long loge = 10; int log_slot_count = 15; + _bootstrapper_32768 = std::make_unique( + loge, 15, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size), + boundary_K, deg, scale_factor, inverse_deg, _evaluator.get()); _bootstrapper_16384 = std::make_unique( loge, 14, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size), boundary_K, deg, scale_factor, inverse_deg, _evaluator.get()); @@ -389,6 +462,7 @@ public: loge, 12, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size), boundary_K, deg, scale_factor, inverse_deg, _evaluator.get()); + _bootstrapper_32768->prepare_mod_polynomial(); _bootstrapper_16384->prepare_mod_polynomial(); _bootstrapper_8192->prepare_mod_polynomial(); _bootstrapper_4096->prepare_mod_polynomial(); @@ -400,22 +474,22 @@ public: gal_steps_vector.push_back((1 << i)); } + _bootstrapper_32768->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector); _bootstrapper_16384->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector); _bootstrapper_8192->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector); _bootstrapper_4096->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector); - std::cout << "the size of gal_steps_vector is " << gal_steps_vector.size() << std::endl; _evaluator->decryptor.create_galois_keys_from_steps(gal_steps_vector, *(_evaluator.get()->galois_keys)); - std::cout << "gen rot key done " << gal_steps_vector.size() << std::endl; // log2(32768) = 15, log2(16384) = 14, log2(8192) = 13, log2(4096) = 12 - _bootstrapper_16384->slot_vec.push_back(14); + _bootstrapper_32768->slot_vec.push_back(15); + _bootstrapper_16384->slot_vec.push_back(14); _bootstrapper_8192->slot_vec.push_back(13); _bootstrapper_4096->slot_vec.push_back(12); + _bootstrapper_32768->generate_LT_coefficient_3(); _bootstrapper_16384->generate_LT_coefficient_3(); _bootstrapper_8192->generate_LT_coefficient_3(); _bootstrapper_4096->generate_LT_coefficient_3(); - printf( "ckks_param: _provider = %d, _poly_degree = %d, _sec_level = %ld, " "mul_depth = %ld, _first_mod_size = %ld, _scaling_mod_size = %ld, " @@ -488,14 +562,6 @@ public: } _evaluator->decryptor.create_galois_keys_from_steps(gal_steps_vector, *(_evaluator.get()->galois_keys)); - printf( - "ckks_param: _provider = %d, _poly_degree = %d, _sec_level = %ld, " - "mul_depth = %ld, _first_mod_size = %ld, _scaling_mod_size = %ld, " - "_num_q_parts = %ld, _num_rot_idx = %ld,_num_prime_parts = %ld\n", - prog_param->_provider, prog_param->_poly_degree, prog_param->_sec_level, - prog_param->_mul_depth, prog_param->_first_mod_size, - prog_param->_scaling_mod_size, prog_param->_num_q_parts, - prog_param->_num_rot_idx, _num_prime_parts); } SCALE_T Scale(const Ciphertext *op) { @@ -556,10 +622,19 @@ void Prepare_context() Init_rtlib_timing(); Io_init(); PHANTOM_CONTEXT::Init_context(); + RT_DATA_INFO *data_info = Get_rt_data_info(); + if (data_info != nullptr) + { + Pt_mgr_init(data_info->_file_name); + } } void Finalize_context() { + if (Get_rt_data_info() != nullptr) + { + Pt_mgr_fini(); + } PHANTOM_CONTEXT::Fini_context(); Io_fini(); } diff --git a/fhe-cmplr/rtlib/phantom/src/pt_mgr.cu b/fhe-cmplr/rtlib/phantom/src/pt_mgr.cu new file mode 100644 index 00000000..473f223d --- /dev/null +++ b/fhe-cmplr/rtlib/phantom/src/pt_mgr.cu @@ -0,0 +1,235 @@ +//-*-c++-*- +//============================================================================= +// +// Copyright (c) Ant Group Co., Ltd +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//============================================================================= + +#include "common/pt_mgr.h" + +#include +#include +#include + +#include "common/error.h" +#include "common/rt_data_file.h" +#include "common/rt_env.h" +#include "rt_phantom/rt_phantom.h" + +struct PT_MGR { + struct RT_DATA_FILE* _file; + char* _msg_buf; + uint64_t _msg_size; + bool _sync_read; +}; + +static PT_MGR Pt_mgr = {0}; + +static size_t Pt_msg_dump_count(); +static void Dump_msg_preview(const char* tag, uint32_t index, size_t ofst, + const float* data, size_t len, uint32_t scale, + uint32_t level); + +bool Pt_mgr_init(const char* fname) { + IS_TRUE(fname != NULL, "missing rt data file name"); + IS_TRUE(Pt_mgr._file == NULL, "pt mgr already initialized"); + + bool sync_read = true; + const char* sr_env = getenv(ENV_RT_DATA_ASYNC_READ); + if (sr_env != NULL && atoi(sr_env) == 1) { + sync_read = false; + } + + if (Block_io_init(sync_read) == false) { + return false; + } + + Pt_mgr._sync_read = sync_read; + Pt_mgr._file = Rt_data_open(fname, sync_read); + if (Pt_mgr._file == NULL) { + Block_io_fini(sync_read); + Pt_mgr._sync_read = false; + return false; + } + + IS_TRUE(!Rt_data_is_plaintext(Pt_mgr._file), + "phantom pt mgr does not support plaintext rt data yet"); + + Pt_mgr._msg_size = Rt_data_size(Pt_mgr._file); + Pt_mgr._msg_buf = (char*)malloc(Pt_mgr._msg_size); + IS_TRUE(Pt_mgr._msg_buf != NULL || Pt_mgr._msg_size == 0, + "failed to malloc rt data buffer"); + bool fill_ok = Rt_data_fill(Pt_mgr._file, Pt_mgr._msg_buf, Pt_mgr._msg_size); + FMT_ASSERT(fill_ok, "failed to fill rt data from file"); + return true; +} + +void Pt_mgr_fini() { + if (Pt_mgr._file != NULL) { + Rt_data_close(Pt_mgr._file); + Pt_mgr._file = NULL; + } + free(Pt_mgr._msg_buf); + Pt_mgr._msg_buf = NULL; + Pt_mgr._msg_size = 0; + Block_io_fini(Pt_mgr._sync_read); + Pt_mgr._sync_read = false; +} + +bool Pt_pre_encode() { return false; } + +void Pt_prefetch(uint32_t index) { (void)index; } + +void* Pt_get(uint32_t index, size_t len, uint32_t scale, uint32_t level) { + (void)index; + (void)len; + (void)scale; + (void)level; + IS_TRUE(false, "phantom pt mgr does not support plaintext rt data"); + return NULL; +} + +void* Pt_get_validate(float* buf, uint32_t index, size_t len, uint32_t scale, + uint32_t level) { + (void)buf; + (void)index; + (void)len; + (void)scale; + (void)level; + IS_TRUE(false, "phantom pt mgr does not support plaintext rt data"); + return NULL; +} + +void Pt_free(uint32_t index) { (void)index; } + +void Free_data(void* poly) { (void)poly; } + +static float* Msg_ptr(uint32_t index, uint64_t ofst, size_t len) { + IS_TRUE(Pt_mgr._file != NULL, "pt mgr is not initialized"); + uint64_t file_ofst = Rt_data_entry_offset(Pt_mgr._file, index, + (ofst + len) * sizeof(float)); + IS_TRUE(file_ofst + (ofst + len) * sizeof(float) <= Pt_mgr._msg_size, + "entry offset too large"); + return (float*)&Pt_mgr._msg_buf[file_ofst + ofst * sizeof(float)]; +} + +static size_t Pt_msg_dump_count() { + static size_t dump_count = (size_t)-1; + if (dump_count == (size_t)-1) { + const char* env = getenv(ENV_PT_MSG_DUMP_COUNT); + dump_count = (env == NULL) ? 0 : strtoull(env, NULL, 10); + } + return dump_count; +} + +static void Dump_msg_preview(const char* tag, uint32_t index, size_t ofst, + const float* data, size_t len, uint32_t scale, + uint32_t level) { + size_t dump_count = Pt_msg_dump_count(); + if (dump_count == 0 || data == NULL) { + return; + } + + size_t preview = (len < dump_count) ? len : dump_count; + size_t first_nonzero = len; + size_t last_nonzero = len; + size_t nonzero_count = 0; + float min_val = 0.0f; + float max_val = 0.0f; + + if (len > 0) { + min_val = data[0]; + max_val = data[0]; + } + for (size_t i = 0; i < len; ++i) { + float val = data[i]; + if (val < min_val) { + min_val = val; + } + if (val > max_val) { + max_val = val; + } + if (fabs(val) > 0.000001f) { + if (first_nonzero == len) { + first_nonzero = i; + } + last_nonzero = i; + ++nonzero_count; + } + } + + fprintf(stderr, + "[pt_mgr] %s index=%u ofst=%zu len=%zu scale=%u level=%u head:", + tag, index, ofst, len, scale, level); + for (size_t i = 0; i < preview; ++i) { + fprintf(stderr, " %g", data[i]); + } + if (preview < len) { + fprintf(stderr, " ..."); + } + + fprintf(stderr, " | nz=%zu", nonzero_count); + if (len > 0) { + fprintf(stderr, " min=%g max=%g", min_val, max_val); + } + if (first_nonzero < len) { + size_t nz_preview = ((len - first_nonzero) < dump_count) + ? (len - first_nonzero) + : dump_count; + fprintf(stderr, " first_nz=%zu:%g last_nz=%zu:%g nz_head:", first_nonzero, + data[first_nonzero], last_nonzero, data[last_nonzero]); + for (size_t i = 0; i < nz_preview; ++i) { + fprintf(stderr, " %g", data[first_nonzero + i]); + } + if (first_nonzero + nz_preview < len) { + fprintf(stderr, " ..."); + } + } else { + fprintf(stderr, " all_zero=yes"); + } + fprintf(stderr, "\n"); + fflush(stderr); +} + +void* Pt_from_msg(void* pt, uint32_t index, size_t len, uint32_t scale, + uint32_t level) { + float* data = Msg_ptr(index, 0, len); + Dump_msg_preview("Pt_from_msg", index, 0, data, len, scale, level); + Encode_float((PLAIN)pt, data, len, scale, level); + return pt; +} + +void* Pt_from_msg_ofst(void* pt, uint32_t index, size_t ofst, size_t len, + uint32_t scale, uint32_t level) { + float* data = Msg_ptr(index, ofst, len); + Dump_msg_preview("Pt_from_msg_ofst", index, ofst, data, len, scale, level); + Encode_float((PLAIN)pt, data, len, scale, level); + return pt; +} + +void Pt_from_msg_validate(void* pt, float* buf, uint32_t index, size_t len, + uint32_t scale, uint32_t level) { + float* data = Msg_ptr(index, 0, len); + Dump_msg_preview("Pt_from_msg_validate", index, 0, data, len, scale, level); + for (uint32_t i = 0; i < len; ++i) { + FMT_ASSERT(fabs(buf[i] - data[i]) < 0.000001, + "Pt_from_msg_validate failed. index=%d, i=%d: %f != %f.", + index, i, buf[i], data[i]); + } + Encode_float((PLAIN)pt, data, len, scale, level); +} + +void Pt_from_msg_ofst_validate(void* pt, float* buf, uint32_t index, + size_t ofst, size_t len, uint32_t scale, + uint32_t level) { + float* data = Msg_ptr(index, ofst, len); + Dump_msg_preview("Pt_from_msg_ofst_validate", index, ofst, data, len, scale, + level); + for (uint32_t i = 0; i < len; ++i) { + FMT_ASSERT(fabs(buf[i] - data[i]) < 0.000001, + "Pt_from_msg_validate failed. index=%d, i=%d: %f != %f.", + index, i, buf[i], data[i]); + } + Encode_float((PLAIN)pt, data, len, scale, level); +}