Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions fhe-cmplr/include/fhe/ckks/ir2c_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,10 @@ class IR2C_HANDLER : public INVALID_HANDLER {
air::base::NODE_PTR parent = ctx.Parent(1);

AIR_ASSERT(parent != air::base::Null_ptr && parent->Is_st());
if (ctx.Provider() == core::PROVIDER::ANT) {
bool emit_data_file = ctx.Provider() == core::PROVIDER::ANT ||
(ctx.Provider() == core::PROVIDER::PHANTOM &&
ctx.Emit_data_file());
if (emit_data_file) {
ctx.template Emit_encode<RETV, VISITOR>(visitor, parent, node);
} else {
ctx.template Emit_runtime_encode<RETV, VISITOR>(visitor, parent, node);
Expand Down Expand Up @@ -276,4 +279,4 @@ class IR2C_HANDLER : public INVALID_HANDLER {
} // namespace ckks
} // namespace fhe

#endif
#endif
4 changes: 2 additions & 2 deletions fhe-cmplr/rtlib/cmake/modules/phantom.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
# Build external Phantom project dependent function
function(build_external_phantom)

set(PHANTOM_URL "https://git:$ENV{CI_TOKEN}@code.alipay.com/zhanggongliang.zgl/phantom-fhe.git")
set(PHANTOM_URL_SSH "git@code.alipay.com:zhanggongliang.zgl/phantom-fhe.git")
set(PHANTOM_URL "https://github.com/zggl404/phantom-fhe.git")
set(PHANTOM_URL_SSH "git@github.com:zggl404/phantom-fhe.git")
if(EXTERNAL_URL_SSH)
set(REPO_PHANTOM_URL ${PHANTOM_URL_SSH})
else()
Expand Down
2 changes: 2 additions & 0 deletions fhe-cmplr/rtlib/include/common/rt_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
#define ENV_PT_ENTRY_COUNT "PT_ENTRY_COUNT"
//! PT_PREFETCH_COUNT: number of pt for prefetching. default: 2
#define ENV_PT_PREFETCH_COUNT "PT_PREFETCH_COUNT"
//! PT_MSG_DUMP_COUNT=int: number of message values to dump. default: 0
#define ENV_PT_MSG_DUMP_COUNT "PT_MSG_DUMP_COUNT"

//! environment variable to control rt data file reader (RT_DATA_FILE)
//! RT_DATA_ASYNC_READ=0|1: use asynchronous read. default: 0
Expand Down
127 changes: 101 additions & 26 deletions fhe-cmplr/rtlib/phantom/src/phantom_lib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
#include "common/common.h"
#include "common/error.h"
#include "common/io_api.h"
#include "common/pt_mgr.h"
#include "common/rt_api.h"
#include "common/rtlib_timing.h"
#include "boot/Bootstrapper.cuh"

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <cstdint>
#include <string>

using namespace phantom;
using namespace phantom::arith;
using namespace phantom::util;
Expand Down Expand Up @@ -57,7 +64,9 @@ public:
std::vector<double> vec(input->_vals, input->_vals + len);
CKKS_PARAMS *prog_param = Get_context_params();
Plaintext pt;
_evaluator->encoder.encode(vec, std::pow(2.0, _scaling_mod_size), pt);
int chain_index = _num_prime_parts - prog_param->_input_level;
double encode_scale = std::pow(2.0, _scaling_mod_size);
_evaluator->encoder.encode(vec, chain_index, encode_scale, pt);
Ciphertext *ct = new Ciphertext;
_evaluator->encryptor.encrypt(pt, *ct);
Io_set_input(name, 0, ct);
Expand Down Expand Up @@ -91,8 +100,9 @@ public:
LEVEL_T level)
{
std::vector<double> vec(input, input + len);
std::vector<double> vec_after;
_evaluator->encoder.encode(vec, _num_prime_parts - level, std::pow(2.0, _scaling_mod_size * scale), *pt);
double encode_scale = std::pow(2.0, _scaling_mod_size * scale);
int chain_index = _num_prime_parts - level;
_evaluator->encoder.encode(vec, chain_index, encode_scale, *pt);
}

void Encode_float_cst_lvl(Plaintext *pt, float *input, size_t len,
Expand All @@ -109,7 +119,9 @@ public:
LEVEL_T level)
{
std::vector<double> vec(len, input);
_evaluator->encoder.encode(vec, _num_prime_parts - level, std::pow(2.0, _scaling_mod_size * scale), *pt);
double encode_scale = std::pow(2.0, _scaling_mod_size * scale);
int chain_index = _num_prime_parts - level;
_evaluator->encoder.encode(vec, chain_index, encode_scale, *pt);
}
void Encode_float_mask_cst_lvl(Plaintext *pt, float input, size_t len,
SCALE_T scale, int level)
Expand Down Expand Up @@ -276,26 +288,73 @@ public:
_evaluator->evaluator.relinearize(*op1, *_rlk, *res);
}
}

void Bootstrap(Ciphertext *op1, Ciphertext *res, int level, int slot)
{
_evaluator->evaluator.mod_switch_to_inplace(*op1, _num_prime_parts - 1);

switch (slot)
int effective_slot = slot;
if (effective_slot == 0)
{
effective_slot = 32768;
}

Ciphertext *bootstrap_input = op1;

Ciphertext cyclic_input;
bool cyclic_input_used = false;
if (effective_slot > 0 && static_cast<uint64_t>(effective_slot) < _slot_count)
{
cyclic_input = *bootstrap_input;
Ciphertext rotated_input = *bootstrap_input;
for (uint64_t offset = effective_slot; offset < _slot_count; offset += effective_slot)
{
Rotate(&rotated_input, effective_slot, &rotated_input);
Add(&cyclic_input, &rotated_input, &cyclic_input);
}
rotated_input.release();
bootstrap_input = &cyclic_input;
cyclic_input_used = true;
}

Ciphertext in_place_input;
bool in_place_bootstrap = res == op1;
if (in_place_bootstrap)
{
in_place_input = *bootstrap_input;
bootstrap_input = &in_place_input;
}

res->release();
switch (effective_slot)
{
case 16384:
_bootstrapper_16384->bootstrap_3(*res, *op1);
_bootstrapper_16384->bootstrap_real_3(*res, *bootstrap_input);
break;
case 8192:
_bootstrapper_8192->bootstrap_3(*res, *op1);
_bootstrapper_8192->bootstrap_real_3(*res, *bootstrap_input);
break;
case 4096:
_bootstrapper_4096->bootstrap_3(*res, *op1);
_bootstrapper_4096->bootstrap_real_3(*res, *bootstrap_input);
break;
case 32768:
_bootstrapper_32768->bootstrap_real_3(*res, *bootstrap_input);
break;
default:
std::cout<<"Unsupported slot size for bootstrap: (must 16384,8192,4096)" << slot << std::endl;
IS_TRUE(false, "Unsupported slot size for bootstrap");
break;
}

if (in_place_bootstrap)
{
in_place_input.release();
}

if (cyclic_input_used)
{
cyclic_input.release();
}

int target_level = _num_prime_parts - level;
if (level != 0 && target_level > res->chain_index())
{
Expand Down Expand Up @@ -327,17 +386,28 @@ public:
EncryptionParameters parms(scheme_type::ckks);
uint32_t degree = prog_param->_poly_degree;
parms.set_poly_modulus_degree(degree);

uint32_t _bts_required_level = 14;
uint32_t _bts_remaining_level = prog_param->_mul_depth - 14;
std::vector<int> bits;
bits.push_back(prog_param->_scaling_mod_size);
for (uint32_t i = 0; i < prog_param->_mul_depth; ++i)
bits.push_back(prog_param->_first_mod_size);
for (uint32_t i = 0; i < _bts_remaining_level; ++i)
{
bits.push_back(prog_param->_scaling_mod_size);
}

bits.push_back(prog_param->_first_mod_size);
for (uint32_t i = 0; i < _bts_required_level; ++i)
{
bits.push_back(prog_param->_first_mod_size);
}
constexpr size_t special_modulus_size = 4;
for (size_t i = 0; i < special_modulus_size; i++)
{
bits.push_back(prog_param->_first_mod_size);
}
parms.set_coeff_modulus(phantom::arith::CoeffModulus::Create(degree, bits));
parms.set_secret_key_hamming_weight(192);
_num_prime_parts = bits.size();
parms.set_special_modulus_size(special_modulus_size);
_num_prime_parts = bits.size() - special_modulus_size + 1;
phantom::arith::sec_level_type sec = phantom::arith::sec_level_type::tc128;
switch (prog_param->_sec_level)
{
Expand Down Expand Up @@ -379,6 +449,9 @@ public:
long loge = 10;

int log_slot_count = 15;
_bootstrapper_32768 = std::make_unique<Bootstrapper>(
loge, 15, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size),
boundary_K, deg, scale_factor, inverse_deg, _evaluator.get());
_bootstrapper_16384 = std::make_unique<Bootstrapper>(
loge, 14, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size),
boundary_K, deg, scale_factor, inverse_deg, _evaluator.get());
Expand All @@ -389,6 +462,7 @@ public:
loge, 12, log_slot_count, prog_param->_mul_depth, std::pow(2.0, _scaling_mod_size),
boundary_K, deg, scale_factor, inverse_deg, _evaluator.get());

_bootstrapper_32768->prepare_mod_polynomial();
_bootstrapper_16384->prepare_mod_polynomial();
_bootstrapper_8192->prepare_mod_polynomial();
_bootstrapper_4096->prepare_mod_polynomial();
Expand All @@ -400,22 +474,22 @@ public:
gal_steps_vector.push_back((1 << i));
}

_bootstrapper_32768->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector);
_bootstrapper_16384->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector);
_bootstrapper_8192->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector);
_bootstrapper_4096->addLeftRotKeys_Linear_to_vector_3(gal_steps_vector);

std::cout << "the size of gal_steps_vector is " << gal_steps_vector.size() << std::endl;
_evaluator->decryptor.create_galois_keys_from_steps(gal_steps_vector, *(_evaluator.get()->galois_keys));
std::cout << "gen rot key done " << gal_steps_vector.size() << std::endl;
// log2(32768) = 15, log2(16384) = 14, log2(8192) = 13, log2(4096) = 12
_bootstrapper_16384->slot_vec.push_back(14);
_bootstrapper_32768->slot_vec.push_back(15);
_bootstrapper_16384->slot_vec.push_back(14);
_bootstrapper_8192->slot_vec.push_back(13);
_bootstrapper_4096->slot_vec.push_back(12);

_bootstrapper_32768->generate_LT_coefficient_3();
_bootstrapper_16384->generate_LT_coefficient_3();
_bootstrapper_8192->generate_LT_coefficient_3();
_bootstrapper_4096->generate_LT_coefficient_3();

printf(
"ckks_param: _provider = %d, _poly_degree = %d, _sec_level = %ld, "
"mul_depth = %ld, _first_mod_size = %ld, _scaling_mod_size = %ld, "
Expand Down Expand Up @@ -488,14 +562,6 @@ public:
}

_evaluator->decryptor.create_galois_keys_from_steps(gal_steps_vector, *(_evaluator.get()->galois_keys));
printf(
"ckks_param: _provider = %d, _poly_degree = %d, _sec_level = %ld, "
"mul_depth = %ld, _first_mod_size = %ld, _scaling_mod_size = %ld, "
"_num_q_parts = %ld, _num_rot_idx = %ld,_num_prime_parts = %ld\n",
prog_param->_provider, prog_param->_poly_degree, prog_param->_sec_level,
prog_param->_mul_depth, prog_param->_first_mod_size,
prog_param->_scaling_mod_size, prog_param->_num_q_parts,
prog_param->_num_rot_idx, _num_prime_parts);
}
SCALE_T Scale(const Ciphertext *op)
{
Expand Down Expand Up @@ -556,10 +622,19 @@ void Prepare_context()
Init_rtlib_timing();
Io_init();
PHANTOM_CONTEXT::Init_context();
RT_DATA_INFO *data_info = Get_rt_data_info();
if (data_info != nullptr)
{
Pt_mgr_init(data_info->_file_name);
}
}

void Finalize_context()
{
if (Get_rt_data_info() != nullptr)
{
Pt_mgr_fini();
}
PHANTOM_CONTEXT::Fini_context();
Io_fini();
}
Expand Down
Loading