Source code for pkenc_rabin

'''
| From: "Digitalized Signatures and Public-Key Functions as Intractable as Factorization".
| Published in: 1979
| Security Assumption: Integer Factorization

* type:           public-key encryption
* setting:        Integer

:Authors: Christina Garman
:Date:            09/2011
'''

from charm.core.math.integer import integer
from charm.toolbox.PKEnc import PKEnc
from charm.toolbox.PKSig import PKSig
from charm.toolbox.paddingschemes import OAEPEncryptionPadding,SAEPEncryptionPadding
from charm.toolbox.redundancyschemes import InMessageRedundancy
from charm.toolbox.conversion import Conversion
from charm.toolbox.bitstring import Bytes
from charm.toolbox.specialprimes import BlumWilliamsInteger
from math import ceil 

debug = False
[docs]class Rabin(): def __init__(self, modulus=BlumWilliamsInteger()): self.modulustype = modulus # generate p,q and n
[docs] def paramgen(self, secparam): (p, q, N) = self.modulustype.generateBlumWilliamsInteger(secparam) yp = (p % q) ** -1 yq = (q % p) ** -1 return (p, yp, q, yq, N)
[docs] def keygen(self, s0, secparam=1024, params=None): if params: (N, p, q, yp, yq) = self.convert(params) pk = { 'N':N, 'n':secparam, 's0':s0 } sk = { 'p':p, 'q':q, 'N':N , 'yp':yp, 'yq':yq } return (pk, sk) (p, yp, q, yq, N) = self.paramgen(secparam) pk = { 'N':N, 'n':secparam, 's0':s0 } sk = { 'p':p, 'q':q, 'N':N , 'yp':yp, 'yq':yq } return (pk, sk)
[docs] def convert(self, N, p, q, yp, yq): return (integer(N), integer(p), integer(q), integer(yp), integer(yq))
[docs]class Rabin_Enc(Rabin,PKEnc): """ >>> rabin = Rabin_Enc() >>> (public_key, secret_key) = rabin.keygen(128, 1024) >>> msg = b'This is a test' >>> cipher_text = rabin.encrypt(public_key, msg) >>> decrypted_msg = rabin.decrypt(public_key, secret_key, cipher_text) >>> decrypted_msg == msg True """ def __init__(self, padding=SAEPEncryptionPadding(), redundancy=InMessageRedundancy(), params=None): Rabin.__init__(self) PKEnc.__init__(self) self.paddingscheme = padding self.redundancyscheme = redundancy # m : Bytes
[docs] def encrypt(self, pk, m, salt=None): if(self.paddingscheme.name == "SAEPEncryptionPadding"): EM = self.paddingscheme.encode(m, pk['n'], pk['s0']) else: m = self.redundancyscheme.encode(m) octetlen = int(ceil(int(pk['N']).bit_length() / 8.0)) EM = self.paddingscheme.encode(m, octetlen, "", salt) if debug: print("EM == >", EM) i = Conversion.OS2IP(EM) ip = integer(i) % pk['N'] #Convert to modular integer return (ip ** 2) % pk['N']
[docs] def decrypt(self, pk, sk, c): p = sk['p'] q = sk['q'] yp = sk['yp'] yq = sk['yq'] mp = (c ** ((p+1)/4)) % p mq = (c ** ((q+1)/4)) % q if(not(((c % p) == (mp ** 2)) and ((c % q) == (mq ** 2)))): assert False, "invalid ciphertext" r1 = ((int(yp)*int(p)*int(mq)) + ((int(yq)*int(q)*int(mp)))) % int(sk['N']) r2 = int(sk['N']) - int(r1) s1 = (int(yp)*int(p)*int(mq) - int(yq)*int(q)*int(mp)) % int(sk['N']) s2 = int(sk['N']) - int(s1) m1 = r1 % int(sk['N']) m2 = r2 % int(sk['N']) m3 = s1 % int(sk['N']) m4 = s2 % int(sk['N']) if(self.paddingscheme.name == "SAEPEncryptionPadding"): if(m1 < integer(int(sk['N'])//2)): os1 = Conversion.IP2OS(int(m1)) if(m2 < integer(int(sk['N'])//2)): os2 = Conversion.IP2OS(int(m2)) else: if(m3 < integer(int(sk['N'])//2)): os2 = Conversion.IP2OS(int(m3)) else: os2 = Conversion.IP2OS(int(m4)) else: if(m2 < integer(int(sk['N'])//2)): os1 = Conversion.IP2OS(int(m2)) if(m3 < integer(int(sk['N'])//2)): os2 = Conversion.IP2OS(int(m3)) else: os2 = Conversion.IP2OS(int(m4)) else: os1 = Conversion.IP2OS(int(m3)) os2 = Conversion.IP2OS(int(m4)) if debug: print("OS1 =>", os1) print("OS2 =>", os2) (m1, t1) = self.paddingscheme.decode(os1, pk['n'], pk['s0']) (m2, t2) = self.paddingscheme.decode(os2, pk['n'], pk['s0']) if((t1 == Bytes.fill(b'\x00', pk['s0']/8)) and (t2 == Bytes.fill(b'\x00', pk['s0']/8))): assert False, "invalid ciphertext" if(t1 == Bytes.fill(b'\x00', pk['s0']/8)): return m1 else: if(t2 == Bytes.fill(b'\x00', pk['s0']/8)): return m2 else: assert False, "invalid ciphertext" else: octetlen = int(ceil(int(pk['N']).bit_length() / 8.0)) os1 = Conversion.IP2OS(int(m1), octetlen) os2 = Conversion.IP2OS(int(m2), octetlen) os3 = Conversion.IP2OS(int(m3), octetlen) os4 = Conversion.IP2OS(int(m4), octetlen) if debug: print("OS1 =>", os1) print("OS2 =>", os2) print("OS3 =>", os3) print("OS4 =>", os4) for i in [os1, os2, os3, os4]: (isMessage, message) = self.redundancyscheme.decode(self.paddingscheme.decode(i)) if(isMessage): return message
[docs]class Rabin_Sig(Rabin, PKSig): """ RSASSA-PSS >>> msg = b'This is a test message.' >>> rabin = Rabin_Sig() >>> (public_key, secret_key) = rabin.keygen(1024) >>> signature = rabin.sign(secret_key, msg) >>> rabin.verify(public_key, msg, signature) True """ def __init__(self, padding=OAEPEncryptionPadding()): Rabin.__init__(self) PKSig.__init__(self) self.paddingscheme = padding
[docs] def sign(self,sk, M, salt=None): #apply encoding while True: octetlen = int(ceil(int(sk['N']).bit_length() / 8.0)) em = self.paddingscheme.encode(M, octetlen, "", salt) m = Conversion.OS2IP(em) m = integer(m) % sk['N'] #ERRROR m is larger than N p = sk['p'] q = sk['q'] yp = sk['yp'] yq = sk['yq'] mp = (m ** ((p+1)/4)) % p mq = (m ** ((q+1)/4)) % q r1 = ((int(yp)*int(p)*int(mq)) + ((int(yq)*int(q)*int(mp)))) % int(sk['N']) r2 = int(sk['N']) - int(r1) s1 = (int(yp)*int(p)*int(mq) - int(yq)*int(q)*int(mp)) % int(sk['N']) s2 = int(sk['N']) - int(s1) if(((int((integer(r1) ** 2) % sk['N'] - m)) == 0) or ((int((integer(r2) ** 2) % sk['N'] - m)) == 0) or ((int((integer(s1) ** 2) % sk['N'] - m)) == 0) or ((int((integer(s2) ** 2) % sk['N'] - m)) == 0)): break S = { 's1':r1, 's2':r2, 's3':s1, 's4':s2 } if debug: print("Signing") print("m =>", m) print("em =>", em) print("S =>", S) return S
[docs] def verify(self, pk, M, S, salt=None): #M = b'This is a malicious message' octetlen = int(ceil(int(pk['N']).bit_length() / 8.0)) sig_mess = (integer(S['s1']) ** 2) % pk['N'] sig_mess = Conversion.IP2OS(int(sig_mess), octetlen) if debug: print("OS1 =>", sig_mess) dec_mess = self.paddingscheme.decode(sig_mess) if debug: print("Verifying") print("sig_mess =>", sig_mess) print("dec_mess =>", dec_mess) print("S =>", S) return (dec_mess == M)
[docs]def main(): rabin = Rabin_Enc() (public_key, secret_key) = rabin.keygen(128, 1024) msg = b'This is a test' cipher_text = rabin.encrypt(public_key, msg) decrypted_msg = rabin.decrypt(public_key, secret_key, cipher_text) print(decrypted_msg == msg)
if __name__ == "__main__": main()