Skip to content

Commit

Permalink
Merge pull request #8 from dfns-labs/crt-omul
Browse files Browse the repository at this point in the history
Use CRT to speed up homomorhpic multiplication
  • Loading branch information
survived authored Sep 25, 2023
2 parents 2fcc313 + 48c5d13 commit c3ee636
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 88 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ thiserror = "1"
serde = { version = "1", optional = true }

[dev-dependencies]
rand = "0.8"
rand_dev = "0.1"
criterion = { version = "0.5", features = ["html_reports"] }
libpaillier = { version = "0.5", default-features = false, features = ["gmp"] }
Expand Down
61 changes: 47 additions & 14 deletions benches/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,17 @@ fn decryption(c: &mut criterion::Criterion) {
let p = Integer::from_str_radix(P, 16).unwrap();
let q = Integer::from_str_radix(Q, 16).unwrap();

let dk_naive =
fast_paillier::DecryptionKey::<utils::NaiveExp>::from_primes(p.clone(), q.clone()).unwrap();
let dk_crt =
fast_paillier::DecryptionKey::<utils::CrtExp>::from_primes(p.clone(), q.clone()).unwrap();
let ek = dk_naive.encryption_key();
let dk = fast_paillier::DecryptionKey::from_primes(p.clone(), q.clone()).unwrap();
let ek = dk.encryption_key();

let mut group = c.benchmark_group("Decrypt");

let mut generate_inputs = || utils::sample_in_mult_group(&mut rng, ek.nn());

group.bench_function("Naive Decrypt", |b| {
b.iter_batched(
&mut generate_inputs,
|enc_x| dk_naive.decrypt(&enc_x).unwrap(),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("Decrypt with CRT", |b| {
b.iter_batched(
&mut generate_inputs,
|enc_x| dk_crt.decrypt(&enc_x).unwrap(),
|enc_x| dk.decrypt(&enc_x).unwrap(),
criterion::BatchSize::SmallInput,
)
});
Expand All @@ -126,6 +116,42 @@ fn decryption(c: &mut criterion::Criterion) {
});
}

fn omul(c: &mut criterion::Criterion) {
let mut rng = rand_dev::DevRng::new();

let p = Integer::from_str_radix(P, 16).unwrap();
let q = Integer::from_str_radix(Q, 16).unwrap();

let dk = fast_paillier::DecryptionKey::from_primes(p.clone(), q.clone()).unwrap();
let ek = dk.encryption_key();

let mut group = c.benchmark_group("OMul");

let mut generate_inputs = || {
let scalar = ek
.nn()
.random_below_ref(&mut utils::external_rand(&mut rng))
.into();
let enc_x = utils::sample_in_mult_group(&mut rng, ek.nn());
(scalar, enc_x)
};

group.bench_function("with CRT", |b| {
b.iter_batched(
&mut generate_inputs,
|(scalar, enc_x)| dk.omul(&scalar, &enc_x).unwrap(),
criterion::BatchSize::SmallInput,
)
});
group.bench_function("without CRT", |b| {
b.iter_batched(
&mut generate_inputs,
|(scalar, enc_x)| ek.omul(&scalar, &enc_x).unwrap(),
criterion::BatchSize::SmallInput,
)
});
}

/// Old implementation of safe primes
pub fn naive_safe_prime(rng: &mut impl rand_core::RngCore, bits: u32) -> Integer {
use rug::{integer::IsPrime, Assign};
Expand Down Expand Up @@ -184,7 +210,14 @@ fn rng_covertion(c: &mut criterion::Criterion) {
});
}

criterion::criterion_group!(benches, encryption, decryption, safe_primes, rng_covertion);
criterion::criterion_group!(
benches,
encryption,
decryption,
omul,
safe_primes,
rng_covertion
);
criterion::criterion_main!(benches);

fn convert_integer_to_unknown_order(x: &Integer) -> libpaillier::unknown_order::BigNumber {
Expand Down
47 changes: 37 additions & 10 deletions src/decryption_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{utils, Ciphertext, EncryptionKey, Nonce, Plaintext};
use crate::{Error, Reason};

#[derive(Clone)]
pub struct DecryptionKey<FastExp = utils::CrtExp> {
pub struct DecryptionKey {
ek: EncryptionKey,
/// `lcm(p-1, q-1)`
lambda: Integer,
Expand All @@ -15,13 +15,14 @@ pub struct DecryptionKey<FastExp = utils::CrtExp> {
p: Integer,
q: Integer,

crt_mod_nn: utils::CrtExp,
/// Calculates `x ^ N mod N^2`. It's used for faster encryption
exp_to_n_mod_nn: FastExp,
exp_n: utils::Exponent,
/// Calculates `x ^ lambda mod N^2`. It's used for faster decryption
exp_to_lambda_mod_nn: FastExp,
exp_lambda: utils::Exponent,
}

impl<FastExp: utils::FactorizedExp> DecryptionKey<FastExp> {
impl DecryptionKey {
/// Generates a paillier key
///
/// Samples two safe 1536-bits primes that meets 128 bits security level
Expand Down Expand Up @@ -53,17 +54,19 @@ impl<FastExp: utils::FactorizedExp> DecryptionKey<FastExp> {
// u = lambda^-1 mod N
let u = lambda.invert_ref(ek.n()).ok_or(Reason::InvalidPQ)?.into();

let exp_to_n_mod_nn = FastExp::build(ek.n(), &p, &q).ok_or(Reason::BuildFastExp)?;
let exp_to_lambda_mod_nn = FastExp::build(&lambda, &p, &q).ok_or(Reason::BuildFastExp)?;
let crt_mod_nn = utils::CrtExp::build_nn(&p, &q).ok_or(Reason::BuildFastExp)?;
let exp_n = crt_mod_nn.prepare_exponent(ek.n());
let exp_lambda = crt_mod_nn.prepare_exponent(&lambda);

Ok(Self {
ek,
lambda,
mu: u,
p,
q,
exp_to_n_mod_nn,
exp_to_lambda_mod_nn,
crt_mod_nn,
exp_n,
exp_lambda,
})
}

Expand All @@ -74,7 +77,10 @@ impl<FastExp: utils::FactorizedExp> DecryptionKey<FastExp> {
}

// a = c^\lambda mod n^2
let a = self.exp_to_lambda_mod_nn.exp(c);
let a = self
.crt_mod_nn
.exp(c, &self.exp_lambda)
.ok_or(Reason::Decrypt)?;

// ell = L(a, N)
let l = self.ek.l(&a).ok_or(Reason::Decrypt)?;
Expand Down Expand Up @@ -108,7 +114,10 @@ impl<FastExp: utils::FactorizedExp> DecryptionKey<FastExp> {
// a = (1 + N)^x mod N^2 = (1 + xN) mod N^2
let a = (Integer::ONE + x * self.ek.n()) % self.ek.nn();
// b = nonce^N mod N^2
let b = self.exp_to_n_mod_nn.exp(nonce);
let b = self
.crt_mod_nn
.exp(nonce, &self.exp_n)
.ok_or(Reason::Encrypt)?;

Ok((a * b) % self.ek.nn())
}
Expand All @@ -130,6 +139,24 @@ impl<FastExp: utils::FactorizedExp> DecryptionKey<FastExp> {
Ok((ciphertext, nonce))
}

/// Homomorphic multiplication of scalar at ciphertext
///
/// It uses the fact that factorization of `N` is known to speed up an operation.
///
/// ```text
/// omul(a, Enc(c)) = Enc(a * c)
/// ```
pub fn omul(&self, scalar: &Integer, ciphertext: &Ciphertext) -> Result<Ciphertext, Error> {
if !utils::in_mult_group_abs(scalar, self.n())
|| !utils::in_mult_group(ciphertext, self.ek.nn())
{
return Err(Reason::Ops.into());
}

let e = self.crt_mod_nn.prepare_exponent(scalar);
Ok(self.crt_mod_nn.exp(ciphertext, &e).ok_or(Reason::Ops)?)
}

/// Returns a (public) encryption key corresponding to the (secret) decryption key
pub fn encryption_key(&self) -> &EncryptionKey {
&self.ek
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ impl AnyEncryptionKey for DecryptionKey {
}

fn omul(&self, scalar: &Integer, ciphertext: &Ciphertext) -> Result<Ciphertext, Error> {
self.encryption_key().omul(scalar, ciphertext)
self.omul(scalar, ciphertext)
}

fn oneg(&self, ciphertext: &Ciphertext) -> Result<Ciphertext, Error> {
Expand Down
148 changes: 94 additions & 54 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod small_primes;
use std::fmt;

use rand_core::RngCore;
use rug::{Assign, Complete, Integer};

mod small_primes;

/// Wraps any randomness source that implements [`rand_core::RngCore`] and makes
/// it compatible with [`rug::rand`].
pub fn external_rand(rng: &mut impl RngCore) -> rug::rand::ThreadRandState {
Expand Down Expand Up @@ -90,87 +92,125 @@ pub fn sieve_generate_safe_primes(rng: &mut impl RngCore, bits: u32, amount: usi
}
}

/// Faster exponentiation `x^e mod N^2` when factorization of `N = pq` is known and `e` is fixed
pub trait FactorizedExp: Sized {
/// Precomputes data for exponentiation
fn build(e: &Integer, p: &Integer, q: &Integer) -> Option<Self>;
/// Returns `x^e mod (p q)^2`
fn exp(&self, x: &Integer) -> Integer;
/// Faster algorithm for modular exponentiation based on Chinese remainder theorem when modulo factorization is known
///
/// `CrtExp` makes exponentation modulo `n` faster when factorization `n = n1 * n2` is known as well as `phi(n1)` and `phi(n2)`
/// (note that `n1` and `n2` don't need to be primes). In this case, you can [build](Self::build) a `CrtExp` and use provided
/// [exponentiation algorithm](Self::exp).
#[derive(Clone)]
pub struct CrtExp {
n: Integer,
n1: Integer,
phi_n1: Integer,
n2: Integer,
phi_n2: Integer,
beta: Integer,
}

/// Naive `x^e mod N` implementation without optimizations
/// Exponent for [modular exponentiation](CrtExp::exp) via [`CrtExp`]
#[derive(Clone)]
pub struct NaiveExp {
nn: Integer,
e: Integer,
pub struct Exponent {
e_mod_phi_pp: Integer,
e_mod_phi_qq: Integer,
is_negative: bool,
}

impl FactorizedExp for NaiveExp {
fn build(e: &Integer, p: &Integer, q: &Integer) -> Option<Self> {
if e.cmp0().is_lt() || p.cmp0().is_le() || q.cmp0().is_le() {
impl CrtExp {
/// Builds a `CrtExp` for exponentation modulo `n = n1 * n2`
///
/// `phi_n1 = phi(n1)` and `phi_n2 = phi(n2)` need to be known. For instance, if `p` is a prime,
/// then `phi(p) = p - 1` and `phi(p^2) = p * (p - 1)`.
///
/// [`CrtExp::build_n`] and [`CrtExp::build_nn`] can be used when `n1` and `n2` are primes or
/// square of primes.
pub fn build(n1: Integer, phi_n1: Integer, n2: Integer, phi_n2: Integer) -> Option<Self> {
if n1.cmp0().is_le()
|| n2.cmp0().is_le()
|| phi_n1.cmp0().is_le()
|| phi_n2.cmp0().is_le()
|| phi_n1 >= n1
|| phi_n2 >= n2
{
return None;
}
let n = (p * q).complete();

let beta = n1.invert_ref(&n2)?.into();
Some(Self {
e: e.clone(),
nn: n.square(),
n: (&n1 * &n2).complete(),
n1,
phi_n1,
n2,
phi_n2,
beta,
})
}

fn exp(&self, x: &Integer) -> Integer {
// We check that `e` is non-negative at the construction in `Self::build`
#[allow(clippy::expect_used)]
x.pow_mod_ref(&self.e, &self.nn)
.expect("`e` is checked to be non-negative")
.into()
/// Builds a `CrtExp` for exponentiation modulo `n = p * q` where `p`, `q` are primes
pub fn build_n(p: &Integer, q: &Integer) -> Option<Self> {
let phi_p = (p - 1u8).complete();
let phi_q = (q - 1u8).complete();
Self::build(p.clone(), phi_p, q.clone(), phi_q)
}
}

/// Faster algorithm for exponentiation based on Chinese remainder theorem
#[derive(Clone)]
pub struct CrtExp {
pp: Integer,
qq: Integer,
e_mod_phi_pp: Integer,
e_mod_phi_qq: Integer,
beta: Integer,
}

impl FactorizedExp for CrtExp {
fn build(e: &Integer, p: &Integer, q: &Integer) -> Option<Self> {
if e.cmp0().is_lt() || p.cmp0().is_le() || q.cmp0().is_le() {
return None;
}

/// Builds a `CrtExp` for exponentiation modulo `nn = (p * q)^2` where `p`, `q` are primes
pub fn build_nn(p: &Integer, q: &Integer) -> Option<Self> {
let pp = p.square_ref().complete();
let qq = q.square_ref().complete();
let e_mod_phi_pp = e % (&pp - p).complete();
let e_mod_phi_qq = e % (&qq - q).complete();
let beta = pp.invert_ref(&qq)?.into();
Some(Self {
let phi_pp = (&pp - p).complete();
let phi_qq = (&qq - q).complete();
Self::build(pp, phi_pp, qq, phi_qq)
}

/// Prepares exponent to perform [modular exponentiation](Self::exp)
pub fn prepare_exponent(&self, e: &Integer) -> Exponent {
let neg_e = (-e).complete();
let is_negative = e.cmp0().is_lt();
let e = if is_negative { &neg_e } else { e };
let e_mod_phi_pp = e.modulo_ref(&self.phi_n1).complete();
let e_mod_phi_qq = e.modulo_ref(&self.phi_n2).complete();
Exponent {
e_mod_phi_pp,
e_mod_phi_qq,
pp,
qq,
beta,
})
is_negative,
}
}

fn exp(&self, x: &Integer) -> Integer {
let s1 = (x % &self.pp).complete();
let s2 = (x % &self.qq).complete();
/// Performs exponentiation modulo `n`
///
/// Exponent needs to be output of [`CrtExp::prepare_exponent`]
pub fn exp(&self, x: &Integer, e: &Exponent) -> Option<Integer> {
let s1 = x.modulo_ref(&self.n1).complete();
let s2 = x.modulo_ref(&self.n2).complete();

// `e_mod_phi_pp` and `e_mod_phi_qq` are guaranteed to be non-negative by construction
#[allow(clippy::expect_used)]
let r1 = s1
.pow_mod(&self.e_mod_phi_pp, &self.pp)
.pow_mod(&e.e_mod_phi_pp, &self.n1)
.expect("exponent is guaranteed to be non-negative");
#[allow(clippy::expect_used)]
let r2 = s2
.pow_mod(&self.e_mod_phi_qq, &self.qq)
.pow_mod(&e.e_mod_phi_qq, &self.n2)
.expect("exponent is guaranteed to be non-negative");

((r2 - &r1) * &self.beta).modulo(&self.qq) * &self.pp + &r1
let result = ((r2 - &r1) * &self.beta).modulo(&self.n2) * &self.n1 + &r1;

if e.is_negative {
result.invert(&self.n).ok()
} else {
Some(result)
}
}
}

impl fmt::Debug for CrtExp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("CrtExp")
}
}

impl fmt::Debug for Exponent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("CrtExponent")
}
}

Expand Down
Loading

0 comments on commit c3ee636

Please sign in to comment.