diff --git a/Cargo.toml b/Cargo.toml index cbb7140..f331444 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ thiserror = "1" serde = { version = "1", optional = true } [dev-dependencies] +rand = "0.8" rand_dev = "0.1" criterion = { version = "0.5", features = ["html_reports"] } libpaillier = { version = "0.5", default-features = false, features = ["gmp"] } diff --git a/src/decryption_key.rs b/src/decryption_key.rs index 080fa06..cbaf27a 100644 --- a/src/decryption_key.rs +++ b/src/decryption_key.rs @@ -15,10 +15,11 @@ pub struct DecryptionKey { p: Integer, q: Integer, + crt_mod_nn: utils::CrtExp, /// Calculates `x ^ N mod N^2`. It's used for faster encryption - exp_to_n_mod_nn: utils::CrtExp, + exp_n: utils::Exponent, /// Calculates `x ^ lambda mod N^2`. It's used for faster decryption - exp_to_lambda_mod_nn: utils::CrtExp, + exp_lambda: utils::Exponent, } impl DecryptionKey { @@ -53,9 +54,9 @@ impl DecryptionKey { // u = lambda^-1 mod N let u = lambda.invert_ref(ek.n()).ok_or(Reason::InvalidPQ)?.into(); - let exp_to_n_mod_nn = utils::CrtExp::build(ek.n(), &p, &q).ok_or(Reason::BuildFastExp)?; - let exp_to_lambda_mod_nn = - utils::CrtExp::build(&lambda, &p, &q).ok_or(Reason::BuildFastExp)?; + let crt_mod_nn = utils::CrtExp::build_nn(&p, &q).ok_or(Reason::BuildFastExp)?; + let exp_n = crt_mod_nn.prepare_exponent(ek.n()); + let exp_lambda = crt_mod_nn.prepare_exponent(&lambda); Ok(Self { ek, @@ -63,8 +64,9 @@ impl DecryptionKey { mu: u, p, q, - exp_to_n_mod_nn, - exp_to_lambda_mod_nn, + crt_mod_nn, + exp_n, + exp_lambda, }) } @@ -75,7 +77,10 @@ impl DecryptionKey { } // a = c^\lambda mod n^2 - let a = self.exp_to_lambda_mod_nn.exp(c); + let a = self + .crt_mod_nn + .exp(c, &self.exp_lambda) + .ok_or(Reason::Decrypt)?; // ell = L(a, N) let l = self.ek.l(&a).ok_or(Reason::Decrypt)?; @@ -109,7 +114,10 @@ impl DecryptionKey { // a = (1 + N)^x mod N^2 = (1 + xN) mod N^2 let a = (Integer::ONE + x * self.ek.n()) % self.ek.nn(); // b = nonce^N mod N^2 - let b = self.exp_to_n_mod_nn.exp(nonce); + let b = self + .crt_mod_nn + .exp(nonce, &self.exp_n) + .ok_or(Reason::Encrypt)?; Ok((a * b) % self.ek.nn()) } diff --git a/src/utils.rs b/src/utils.rs index 045940c..63ef5f7 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -90,51 +90,113 @@ pub fn sieve_generate_safe_primes(rng: &mut impl RngCore, bits: u32, amount: usi } } -/// Faster algorithm for exponentiation based on Chinese remainder theorem +/// Faster algorithm for modular exponentiation based on Chinese remainder theorem when modulo factorization is known +/// +/// `CrtExp` makes exponentation modulo `n` faster when factorization `n = n1 * n2` is known as well as `phi(n1)` and `phi(n2)` +/// (note that `n1` and `n2` don't need to be primes). In this case, you can [build](Self::build) a `CrtExp` and use provided +/// [exponentiation algorithm](Self::exp). #[derive(Clone)] pub struct CrtExp { - pp: Integer, - qq: Integer, + n: Integer, + n1: Integer, + phi_n1: Integer, + n2: Integer, + phi_n2: Integer, + beta: Integer, +} + +/// Exponent for [modular exponentiation](CrtExp::exp) via [`CrtExp`] +#[derive(Clone)] +pub struct Exponent { e_mod_phi_pp: Integer, e_mod_phi_qq: Integer, - beta: Integer, + is_negative: bool, } impl CrtExp { - pub fn build(e: &Integer, p: &Integer, q: &Integer) -> Option { - if e.cmp0().is_lt() || p.cmp0().is_le() || q.cmp0().is_le() { + /// Builds a `CrtExp` for exponentation modulo `n = n1 * n2` + /// + /// `phi_n1 = phi(n1)` and `phi_n2 = phi(n2)` need to be known. For instance, if `p` is a prime, + /// then `phi(p) = p - 1` and `phi(p^2) = p * (p - 1)`. + /// + /// [`CrtExp::build_n`] and [`CrtExp::build_nn`] can be used when `n1` and `n2` are primes or + /// square of primes. + pub fn build(n1: Integer, phi_n1: Integer, n2: Integer, phi_n2: Integer) -> Option { + if n1.cmp0().is_le() + || n2.cmp0().is_le() + || phi_n1.cmp0().is_le() + || phi_n2.cmp0().is_le() + || phi_n1 >= n1 + || phi_n2 >= n2 + { return None; } + let beta = n1.invert_ref(&n2)?.into(); + Some(Self { + n: (&n1 * &n2).complete(), + n1, + phi_n1, + n2, + phi_n2, + beta, + }) + } + + /// Builds a `CrtExp` for exponentiation modulo `n = p * q` where `p`, `q` are primes + pub fn build_n(p: &Integer, q: &Integer) -> Option { + let phi_p = (p - 1u8).complete(); + let phi_q = (q - 1u8).complete(); + Self::build(p.clone(), phi_p, q.clone(), phi_q) + } + + /// Builds a `CrtExp` for exponentiation modulo `nn = (p * q)^2` where `p`, `q` are primes + pub fn build_nn(p: &Integer, q: &Integer) -> Option { let pp = p.square_ref().complete(); let qq = q.square_ref().complete(); - let e_mod_phi_pp = e % (&pp - p).complete(); - let e_mod_phi_qq = e % (&qq - q).complete(); - let beta = pp.invert_ref(&qq)?.into(); - Some(Self { + let phi_pp = (&pp - p).complete(); + let phi_qq = (&qq - q).complete(); + Self::build(pp, phi_pp, qq, phi_qq) + } + + /// Prepares exponent to perform [modular exponentiation](Self::exp) + pub fn prepare_exponent(&self, e: &Integer) -> Exponent { + let neg_e = (-e).complete(); + let is_negative = e.cmp0().is_lt(); + let e = if is_negative { &neg_e } else { e }; + let e_mod_phi_pp = e.modulo_ref(&self.phi_n1).complete(); + let e_mod_phi_qq = e.modulo_ref(&self.phi_n2).complete(); + Exponent { e_mod_phi_pp, e_mod_phi_qq, - pp, - qq, - beta, - }) + is_negative, + } } - pub fn exp(&self, x: &Integer) -> Integer { - let s1 = (x % &self.pp).complete(); - let s2 = (x % &self.qq).complete(); + /// Performs exponentiation modulo `n` + /// + /// Exponent needs to be output of [`CrtExp::prepare_exponent`] + pub fn exp(&self, x: &Integer, e: &Exponent) -> Option { + let s1 = x.modulo_ref(&self.n1).complete(); + let s2 = x.modulo_ref(&self.n2).complete(); // `e_mod_phi_pp` and `e_mod_phi_qq` are guaranteed to be non-negative by construction #[allow(clippy::expect_used)] let r1 = s1 - .pow_mod(&self.e_mod_phi_pp, &self.pp) + .pow_mod(&e.e_mod_phi_pp, &self.n1) .expect("exponent is guaranteed to be non-negative"); #[allow(clippy::expect_used)] let r2 = s2 - .pow_mod(&self.e_mod_phi_qq, &self.qq) + .pow_mod(&e.e_mod_phi_qq, &self.n2) .expect("exponent is guaranteed to be non-negative"); - ((r2 - &r1) * &self.beta).modulo(&self.qq) * &self.pp + &r1 + let result = ((r2 - &r1) * &self.beta).modulo(&self.n2) * &self.n1 + &r1; + + if e.is_negative { + result.invert(&self.n).ok() + } else { + Some(result) + } } } diff --git a/tests/integration.rs b/tests/integration.rs index 9ab2070..94625f8 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,4 +1,5 @@ use fast_paillier::{utils, DecryptionKey}; +use rand::Rng; use rug::{Complete, Integer}; #[test] @@ -145,24 +146,63 @@ fn encryption_with_known_factorization() { } #[test] -fn factorized_exp() { +fn factorized_exp_mod_n() { let mut rng = rand_dev::DevRng::new(); let p = utils::generate_safe_prime(&mut rng, 512); let q = utils::generate_safe_prime(&mut rng, 512); let n = (&p * &q).complete(); + println!("n: {n}"); - let e = Integer::random_bits(1024, &mut utils::external_rand(&mut rng)).into(); + let crt = utils::CrtExp::build_n(&p, &q).unwrap(); - let crt = utils::CrtExp::build(&e, &p, &q).unwrap(); + for _ in 0..100 { + let x: Integer = n + .random_below_ref(&mut utils::external_rand(&mut rng)) + .into(); + let mut e: Integer = Integer::random_bits(1024, &mut utils::external_rand(&mut rng)).into(); + if rng.gen::() { + e = -e + } + let crt_e = crt.prepare_exponent(&e); + + println!(); + println!("x: {x}"); + println!("e: {e}"); + let expected: Integer = x.pow_mod_ref(&e, &n).unwrap().into(); + let actual = crt.exp(&x, &crt_e).unwrap(); + assert_eq!(expected, actual); + } +} + +#[test] +fn factorized_exp_mod_nn() { + let mut rng = rand_dev::DevRng::new(); + + let p = utils::generate_safe_prime(&mut rng, 512); + let q = utils::generate_safe_prime(&mut rng, 512); let nn = (&p * &q).complete().square(); + println!("nn: {nn}"); + + let crt = utils::CrtExp::build_nn(&p, &q).unwrap(); + for _ in 0..100 { let x: Integer = nn .random_below_ref(&mut utils::external_rand(&mut rng)) .into(); - let expected: Integer = x.pow_mod_ref(&e, &n).unwrap().into(); - let actual = crt.exp(&x); + let mut e: Integer = Integer::random_bits(1024, &mut utils::external_rand(&mut rng)).into(); + if rng.gen::() { + e = -e + } + let crt_e = crt.prepare_exponent(&e); + + println!(); + println!("x: {x}"); + println!("e: {e}"); + + let expected: Integer = x.pow_mod_ref(&e, &nn).unwrap().into(); + let actual = crt.exp(&x, &crt_e).unwrap(); assert_eq!(expected, actual); } }