Source code for lhe.lhe

# pylint: disable=C0103# to allow capital letters in method signatures
"""
Python homomorphic encryption library supporting up
to three multiplications and unlimited additions.
"""
from __future__ import annotations
from typing import NamedTuple, Union, Optional, Tuple
import doctest
from mclbn256 import Fr, G1, G2, GT


[docs]class CTG1(NamedTuple): """Ciphertext in strictly `G1 x G1` only.""" g1r: G1 g1m_pr: G1 def __add__(self: CTG1, other: CTG1) -> CTG1: return add_G1(self, other) def __mul__(self: CTG1, other: CTG2) -> CTGT: return multiply_G1_G2(self, other)
[docs]class CTG2(NamedTuple): """Ciphertext in strictly `G2 x G2` only.""" g2r: G2 g2m_pr: G2 def __add__(self: CTG2, other: CTG2) -> CTG2: return add_G2(self, other) def __mul__(self: CTG2, other: CTG1) -> CTGT: return multiply_G1_G2(other, self)
[docs]class CTGT(NamedTuple): """Level-2 ciphertext in $\textsf{GT}^{4}$.""" z_r1_r2: GT z_m2_s2_r2__r1: GT z_m1_s1_r1__r2: GT z_m1_s1_r1__m2_s2_r2: GT def __add__(self: CTGT, other: CTGT) -> CTGT: return add_GT(self, other) # def __mul__(self: CTG1, other: Fr): # return multiply_constant_GT_Fr(self, other) def __mul__(self: CTG1, other: Fr): return CTGT( self.z_r1_r2 ** other, self.z_m2_s2_r2__r1 ** other, self.z_m1_s1_r1__r2 ** other, self.z_m1_s1_r1__m2_s2_r2 ** other )
[docs]class CT1(NamedTuple): """All-purpose (dual) level-1 ciphertext making use of both G1 and G2.""" ctg1: CTG1 ctg2: CTG2 def __add__(self: CT1, other: CT1) -> CT1: return CT1( self.ctg1 + other.ctg1, self.ctg2 + other.ctg2 ) def __mul__(self: CT1, other: Union[CT1, Fr, int]) -> CT1 | CT2: if type(other) == int: other = Fr(other) if type(other) == Fr: return CT1( CTG1( self.ctg1.g1r * other, self.ctg1.g1m_pr * other ), CTG2( self.ctg2.g2r * other, self.ctg2.g2m_pr * other ) ) else: ct = multiply_G1_G2(other.ctg1, self.ctg2) return CT2( ct or multiply_G1_G2(self.ctg1, other.ctg2) ) # `or` just in case first product is corrupted def __rmul__(self: CT1, other: Union[Fr, int]) -> CT1: if type(other) == int: other = Fr(other) if type(other) == Fr: return CT1( CTG1( self.ctg1.g1r * other, self.ctg1.g1m_pr * other ), CTG2( self.ctg2.g2r * other, self.ctg2.g2m_pr * other ) ) def __neg__(self: CT1) -> CT1: return CT1( CTG1( -self.ctg1.g1r, -self.ctg1.g1m_pr ), CTG2( -self.ctg2.g2r, -self.ctg2.g2m_pr ) )
[docs]class CT2(NamedTuple): """Level-2 ciphertext (wrapper around GT^4).""" ctgt: CTGT def __add__(self: CT2, other: CT2) -> CT2: return CT2(self.ctgt + other.ctgt) def __mul__(self: CT2, other: Union[Fr, int]) -> CT2: if type(other) == int: other = Fr(other) if type(other) == Fr: return CT2( self.ctgt * other ) raise TypeError("Cannot perform constant multiplication with these operands.") def __rmul__(self: CT1, other: Union[Fr, int]) -> CT1: return self.__mul__(other)
[docs]class KPG1(NamedTuple): """Keypair for G1-based encryption.""" sk: Fr # secret scalar pk: G1 # public point (in Group 1)
[docs]class KPG2(NamedTuple): """Keypair for G2-based encryption.""" sk: Fr # secret scalar pk: G2 # public point (in Group 2)
[docs]class SK(NamedTuple): """Dual secret key for decryption""" s1: Fr # secret scalar s2: Fr
[docs]class PK(NamedTuple): """Dual public key for encryption (either 'dumb' group-agnostic, or optimal)""" p1: G1 # public point (in Group 1) p2: G2 # public point (in Group 2)
# z2: GT# = g1 @ p2 # z3: GT# = p1 @ g2 # z4: GT# = p1 @ p2 g1 = G1().hash("Fixed public point in Group 1") g2 = G2().hash("Fixed public point in Group 2") z1 = g1 @ g2 # z a.k.a. z1 is the pairing of the two generators # and is also a generator in its own right, for GT.
[docs]def keygen_G1() -> Tuple[Fr, G1]: """Generate a G1 keypair.""" s = Fr() p = g1 * s return KPG1(s, p)
[docs]def keygen_G2() -> KPG2: """Generate a G2 keypair.""" s = Fr() p = g2 * s return KPG2(s, p)
[docs]def keygen() -> Tuple[SK, PK]: """Generate a dual keypair.""" s1, p1 = keygen_G1() s2, p2 = keygen_G2() return SK(s1, s2), PK(p1, p2)
[docs]def encrypt_G1(p: G1, m: int) -> CTG1: """Encrypt a plaintext to be a G1 ciphertext.""" r = Fr() return CTG1( g1 * r, (g1 * Fr(m)) + (p * r) )
[docs]def encrypt_G2(p: G2, m: int) -> CTG2: """Encrypt a plaintext to be a G2 ciphertext.""" r = Fr() return CTG2( g2 * r, (g2 * Fr(m)) + (p * r) )
[docs]def encrypt_GT(p1: G1, p2: G2, m: int) -> CTGT: """ Encrypt a plaintext to be a GT ciphertext. Each such ciphertext is made of the components z1 ** (-t + r + s) z1 ** (r * s2) z1 ** (s * s1) z1 ** (m + (t * s1 * s2)) from which (given the secrets, s1 and s2) you can compute z1 ** s and z1 ** r, then subtract from the first to get z1 ** -t which can then yield z1 ** (-t * s1 * s2), the inverse of everything in z1 ** (m + (t * s1 * s2)) except z1 ** m. m is finally extracted by a d.log. computation. """ r = Fr() # random scalar s = Fr() # random scalar t = Fr() # random scalar # z1 = g1 @ g2 # z1 := e(g1, g2); constant/static global/public scalar z1_s2 = g1 @ p2 # e(g1, p2) = e(g1, g2 * s2) = e(g1, g2) ** s2 = z1 ** s2 z1_s1 = p1 @ g2 # e(p1, g2) = e(g1 * s1, g2) = e(g1, g2) ** s1 = z1 ** s1 z1_s1_s2 = p1 @ p2 # e(p1, p2) = e(g1 * s1, g2 * s2) = e(g1, g2) ** s1 ** s2 = z1 ** s1 ** s2 return CTGT( z1 ** (r + s - t), # z1 ** (r + s - t) z1_s2 ** r, # z1 ** (r * s2) z1_s1 ** s, # z1 ** (s * s1) z1_s1_s2 ** t * (z1 ** Fr(m)) # z1 ** (m + (t * s1 * s2)) )
[docs]def encrypt_lvl_1(pk: PK, m: int) -> CT1: """Encrypt a plaintext to be a dual ('dumb') ciphertext.""" ct1 = encrypt_G1(pk.p1, m) ct2 = encrypt_G2(pk.p2, m) return CT1(ct1, ct2)
[docs]def encrypt_lvl_2(pk: PK, m: int) -> CT2: """Encrypt a level-2 ciphertext.""" ct = encrypt_GT(pk.p1, pk.p2, m) return CT2(ct)
[docs]def add_G1(ct1: CTG1, ct2: CTG1) -> CTG1: """Homomorphically add two G1 ciphertexts in level 1.""" return CTG1( ct1.g1r + ct2.g1r, ct1.g1m_pr + ct2.g1m_pr )
[docs]def add_G2(ct1: CTG2, ct2: CTG2) -> CTG2: """Homomorphically add two G2 ciphertexts in level 1.""" return CTG2( ct1.g2r + ct2.g2r, ct1.g2m_pr + ct2.g2m_pr )
[docs]def add_GT(ct1: CTGT, ct2: CTGT) -> CTGT: """Homomorphically add two GT ciphertexts in level 2.""" return CTGT( ct1.z_r1_r2 * ct2.z_r1_r2, # multiply because the m's we want to add are up in powers of z1 ct1.z_m2_s2_r2__r1 * ct2.z_m2_s2_r2__r1, ct1.z_m1_s1_r1__r2 * ct2.z_m1_s1_r1__r2, ct1.z_m1_s1_r1__m2_s2_r2 * ct2.z_m1_s1_r1__m2_s2_r2 )
[docs]def multiply_G1_G2(ct1: CTG1, ct2: CTG2) -> CTGT: """ Homomorphically multiply two complementary level-1 ciphertexts and return a level-2 ciphertext of their product. """ return CTGT( ct1.g1r @ ct2.g2r, # z1 ** (r1 * r2) ct1.g1r @ ct2.g2m_pr, # z1 ** (r1 * (m2 + (r2 * s2))) = z1 ** ((m2 + (s2 * r2)) * r1) ct1.g1m_pr @ ct2.g2r, # z1 ** ((m1 + (s1 * r1)) * r2) ct1.g1m_pr @ ct2.g2m_pr # e((g1 * m1) + (p1 * r1), (g2 * m2) + (p2 * r2)) # = e((g1 * m1) + (g1 * s1 * r1), (g2 * m2) + (g2 * s2 * r2)) # = e((g1 * m1) + (g1 * (s1 * r1)), (g2 * m2) + (g2 * (s2 * r2))) # = e(g1 * (m1 + (s1 * r1))), g2 * (m2 + (s2 * r2))) # = e(g1 * (m1 + (s1 * r1))), g2) ** (m2 + (s2 * r2)) # = e(g1, g2) ** (m1 + (s1 * r1)) ** (m2 + (s2 * r2)) # = z1 ** ((m1 + (s1 * r1)) * (m2 + (s2 * r2))) # = z1 ** ( # (m2 * r1 * s1) + # (m1 * r2 * s2) + # (r1 * r2 * s1 * s2) + # (m1 * m2) # ) # = z1 ** (m2 * r1 * s1) # * z1 ** (m1 * r2 * s2) # * z1 ** (r1 * r2 * s1 * s2) # * z1 ** (m1 * m2) # dec: m1m2 = (r1r2)(s1s2) + r1(m2+s2r2)(-s1) + r2(m1+s1r1)(-s2) + (m1+s1r1)(m2+s2r2) )
[docs]def decrypt_G1(s1: Fr, ct: CTG1) -> Optional[int]: """ Decrypt a G1 ciphertext to a plaintext. >>> sk, pk = keygen_G1() >>> ct = encrypt_G1(pk, 737) >>> print(decrypt_G1(sk, ct)) 737 """ g1m = ct.g1m_pr - (ct.g1r * s1) # remember, p = g^s return dlog(g1, g1m)
[docs]def decrypt_G2(s2: Fr, ct: CTG2) -> Optional[int]: """ Decrypt a G2 ciphertext to a plaintext. >>> sk, pk = keygen_G2() >>> ct = encrypt_G2(pk, 747) >>> int(decrypt_G2(sk, ct)) 747 """ g2m = ct.g2m_pr - (ct.g2r * s2) # remember, p = g^s return dlog(g2, g2m)
[docs]def decrypt_GT(s1: Fr, s2: Fr, ct: CTGT): """ Decrypt a level-2 ciphertext. >>> sk1, pk1 = keygen_G1() >>> sk2, pk2 = keygen_G2() >>> ct11 = encrypt_G1(pk1, 1) >>> ct12 = encrypt_G1(pk1, 2) >>> ct21 = encrypt_G2(pk2, 200) >>> ct22 = encrypt_G2(pk2, 22) >>> ct1 = ct11 + ct12 >>> ct2 = ct21 + ct22 >>> ct3 = ct1 * ct2 >>> pt = decrypt_GT(sk1, sk2, ct3) >>> int(pt) 666 >>> sk, pk = keygen() >>> ct_1 = encrypt_lvl_1(pk, 1) >>> ct_2 = encrypt_lvl_1(pk, 2) >>> ct_200 = encrypt_lvl_1(pk, 200) >>> ct_22 = encrypt_lvl_1(pk, 22) >>> ct_3 = ct_1 + ct_2 >>> ct_222 = ct_200 + ct_22 >>> ct_666 = ct_3 * ct_222 >>> pt = decrypt(sk, ct_666) >>> int(pt) 666 The goal is to unmask the last ciphertext component and get z1 ** (m1 * m2). Note that that component, z1 ** ((m1 + (s1 * r1)) * (m2 + (s2 * r2))), expands to equal = z1 ** (m2 * r1 * s1) * z1 ** (m1 * r2 * s2) * z1 ** (r1 * r2 * s1 * s2) * z1 ** (m1 * m2) for whose terms we already have the ingredients to construct. The z1 ** (r1 * r2 * s1 * s2) specifically cancels the last negative term in the ct.z_m2_s2_r2__r1 by ct.z_m1_s1_r1__r2 product. We have z1 to the power of, (m2 + r1 s1)(m1 + r2 s2) = (m1 m2) + (m1 r1 s1) + (m2 r2 s2) + (r1 r2 s1 s2). And z1 to the power of, (r1 r2)(s1 s2) + r1 (m1 + r2 s2)(-s1) + r2 (m2 + r1 s1)(-s2) = -(m1 r1 s1) + -(m2 r2 s2) + -(r1 r2 s2 s1). Thus, we may decrypt by add these exponents (by multiplying powers) to get m1*m2 which can be extracted by a discrete log. """ z1_m1_m2 = \ (ct.z_r1_r2 ** (s1 * s2)) * \ (ct.z_m2_s2_r2__r1 ** (-s1)) * \ (ct.z_m1_s1_r1__r2 ** (-s2)) * \ ct.z_m1_s1_r1__m2_s2_r2 return dlog(z1, z1_m1_m2)
[docs]def decrypt(sk: SK, ct: Union[CT1, CT2, CTG1, CTG2, CTGT]) -> Fr: """ Type-generic decryption helper >>> sk, pk = keygen() >>> pt_m = Fr() % (2 ** 12) >>> m = int(pt_m) >>> 0 <= m < 2 ** 12 True >>> decrypt(sk, encrypt_G1(pk.p1, m)) == pt_m True >>> decrypt(sk, encrypt_G2(pk.p2, m)) == pt_m True >>> decrypt(sk, encrypt_GT(pk.p1, pk.p2, m)) == pt_m True >>> decrypt(sk, encrypt_lvl_1(pk, m)) == pt_m True >>> decrypt(sk, encrypt_lvl_2(pk, m)) == pt_m True """ if type(ct) is CT2: return decrypt_GT(sk.s1, sk.s2, ct.ctgt) if type(ct) is CT1: pt = decrypt_G1(sk.s1, ct.ctg1) return pt or decrypt_G2(sk.s2, ct.ctg2) # `or` in case maybe one of them got corrupted? if type(ct) is CTGT: return decrypt_GT(sk.s1, sk.s2, ct) if type(ct) is CTG1: return decrypt_G1(sk.s1, ct) if type(ct) is CTG2: return decrypt_G2(sk.s2, ct)
[docs]def dlog(base: Union[Fr, G1, G2, GT], power: GT, unsigned=False) -> Optional[Fr]: """ Discrete logarithm on any group, either Fr, G1, G2, or GT. Can work with up to 20-bits before giving up. The example below tests 16-bit exponents of each type (for efficiency). This helper may be replaced with Pollard's Kangaroo method for a big boost (~2x) in performance. That optimization is unimplemented. Alternatively, we may use a lookup table. >>> x = Fr() >>> a = Fr() % (2 ** 16) >>> dlog(x, x ** a) == a True >>> x = G1().randomize() >>> a = Fr() % (2 ** 16) >>> y = x * a >>> dlog(x, y) == a True >>> x = G2().randomize() >>> a = Fr() % (2 ** 16) >>> y = x * a >>> dlog(x, y) == a True >>> x = G1().randomize() @ G2().randomize() >>> a = Fr() % (2 ** 16) >>> y = x ** a >>> dlog(x, y) == a True """ domain = range(pow(2, 20)) if unsigned else \ [e for i in range(pow(2, 20)) for e in (i, -i)] try: for exponent in map(Fr, domain): if base ** exponent == power: return exponent except TypeError: for exponent in map(Fr, domain): if base * exponent == power: return exponent # raise ValueError("No such exponent.") return None
[docs]def main(): sk1, pk1 = keygen_G1() sk2, pk2 = keygen_G2() # ct1 = encrypt_G1(pk1, 5005) # ct2 = encrypt_G2(pk2, 111) ct1 = encrypt_G1(pk1, 3) ct2 = encrypt_G2(pk2, 222) ct3 = multiply_G1_G2(ct1, ct2) ct4 = add_GT(ct3, ct3) pt = decrypt_GT(sk1, sk2, ct3) print("This may take a bit of time for large plaintexts...") print(pt) pt = decrypt_GT(sk1, sk2, ct4) print("This may take a bit of time for large plaintexts...") print(pt)
# if __name__ == "__main__": # main() if __name__ == "__main__": doctest.testmod() # pragma: no cover # alias for 'dumb' API encrypt = encrypt_lvl_1