'''A collection of encryption and signature padding schemes'''
from charm.toolbox.bitstring import Bytes,py3
from charm.toolbox.securerandom import SecureRandomFactory
import charm.core.crypto.cryptobase
import hashlib
import math
import struct
import sys
debug = False
[docs]class OAEPEncryptionPadding:
'''
:Authors: Gary Belvin
OAEPEncryptionPadding
Implements the OAEP padding scheme. Appropriate for RSA-OAEP encryption.
Implemented according to PKCS#1 v2.1 Section 7 ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1.pdf
'''
def __init__(self, _hash_type ='sha1'):
self.name = "OAEPEncryptionPadding"
self.hashFn = hashFunc(_hash_type)
self.hashFnOutputBytes = len(hashlib.new(_hash_type).digest())
# outputBytes - the length in octets of the RSA modulus used
# - the intended length of the encoded message
# emLen = the length of the rsa modulus in bits
[docs] def encode(self, message, emLen, label="", seed=None):
''':Return: a Bytes object'''
# Skipped: label input length checking. (L must be less than 2^61 octets for SHA1)
# First, make sure the message isn't too long. emLen
hLen = self.hashFnOutputBytes
if (len(message) > (emLen - (2 * hLen) - 2)):
assert False, "message too long"
if py3: lHash = self.hashFn(Bytes(label, 'utf8'))
else: lHash = self.hashFn(Bytes(label))
# Let PS be a string of length (emLen - mLen - 2hLen - 2) containing only zero octets.
# Compute DB = lHash || PS || 0x01 || M.
PS = Bytes.fill(b'\x00', emLen - len(message) - (2 * hLen) - 2)
DB = lHash + PS + b'\x01' + bytes(message)
# Generate a random octet string seed of length hLen and compute
# maskedDB = MGF1(seed, emLen - self.hashFnOutputBytes - 1)
if (seed is None):
rand = SecureRandomFactory.getInstance()
seed = rand.getRandomBytes(hLen)
dbMask = MGF1(seed, len(DB), self.hashFn, hLen)
maskedDB = DB ^ dbMask
# Let seedMask = MGF(maskedDB, self.hashFnOutputBytes) and
# maskedSeed = seedMask XOR seed
seedMask = MGF1(maskedDB, len(seed), self.hashFn, hLen)
maskedSeed = seedMask ^ seed
if(debug):
print("Encoding")
print("label =>", label)
print("lhash =>", lHash)
print("seed =>", seed)
print("db =>", DB)
print("db len =>", len(DB))
print("db mask =>", dbMask)
print("maskedDB =>", maskedDB)
print("seedMask =>", seedMask)
print("maskedSed=>", maskedSeed)
return Bytes(b'\x00') + maskedSeed + maskedDB
[docs] def decode(self, encMessage, label=""):
hLen = self.hashFnOutputBytes
# Make sure the encoded string is at least L bytes long
if len(encMessage) < (2 * hLen + 2):
assert False, "encoded string not long enough."
if py3: lHash = self.hashFn(Bytes(label, 'utf-8'))
else: lHash = self.hashFn(Bytes(label))
# Parse the encoded string as (0x00 || maskedSeed || maskedDB)
#Y = encMessage[0]
maskedSeed = Bytes(encMessage[1:(1+hLen)])
maskedDB = Bytes(encMessage[(1+hLen):])
# Set seedMask = MGF1(maskedDB, hashFnOutputSize)
seedMask = MGF1(maskedDB, len(maskedSeed), self.hashFn, hLen)
seed = maskedSeed ^ seedMask
# Set dbMask = MGF(seed, k - hLen - 1) and
# DB = maskedDB \xor dbMask.
dbMask = MGF1(seed, len(maskedDB), self.hashFn, hLen)
DB = dbMask ^ maskedDB
if(debug):
print("decoding:")
print("MaskedSeed => ", maskedSeed)
print("maskedDB => ", maskedDB)
print("r seed =>", seed)
print("r DB =>", DB)
# Parse DB as:
# DB = lHash' || PS || 0x01 || M.
# Check that lHash' == lHash, Y == 0x00 and there is an 0x01 after PS
lHashPrime = DB[0 : hLen]
M = DB[DB.find(b'\x01')+1 : ]
return M
#def MGF1(seed:Bytes, maskBytes:int, hashFn, hLen:int):
[docs]def MGF1(seed, maskBytes, hashFn, hLen):
''' MGF1 Mask Generation Function
Implemented according to PKCS #1 specification, see appendix B.2.1:
:Parameters:
- ``hLen``: is the output length of the hash function
- ``maskBytes``: the number of mask bytes to return
'''
debug = False
# Skipped output size checking. Must be less than 2^32 * hLen
ran = range(int(math.ceil(maskBytes / float(hLen))))
if debug:
print("calc =>", math.ceil(maskBytes / float(hLen)))
print("Range =>", ran)
test = [hashFn(struct.pack(">%dsI" % (len(seed)), seed, i)) for i in ran]
if debug:
print("test =>", test)
result = b''.join(test)
return Bytes(result[0:maskBytes])
[docs]class hashFunc:
def __init__(self, _hash_type=None):
if _hash_type == None:
self.hashObj = hashlib.new('sha1')
else:
self.hashObj = hashlib.new(_hash_type)
#message must be a binary string
def __call__(self, message):
h = self.hashObj.copy()
if type(message) == str:
h.update(bytes(message))
elif type(message) in [bytes, Bytes]:
h.update(bytes(message)) # bytes or custom Bytes
return Bytes(h.digest())
[docs]class PSSPadding:
'''
:Authors: Gary Belvin
PSSSignaturePadding
Implements the PSS signature padding scheme. Appropriate for RSA-PSS signing.
Implemented according to section 8 of ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1.pdf.
'''
def __init__(self, _hash_type ='sha1'):
self.hashFn = hashFunc(_hash_type)
self.hLen = len(hashlib.new(_hash_type).digest())
self.sLen = self.hLen # The length of the default salt
[docs] def encode(self, M, emBits=None, salt=None):
'''Encodes a message with PSS padding
emLen will be set to the minimum allowed length if not explicitly set
'''
# assert len(M) < (2^61 -1), Message too long
#Let H' = Hash (M'), an octet string of length hLen.
#Max length of output message
if emBits is None:
emBits = 8*self.hLen + 8 * self.sLen + 9
#Round to the next byte
emBits = int(math.ceil(emBits / 8.0)) * 8
assert emBits >= 8*self.hLen + 8 * self.sLen + 9, "Not enough emBits"
#Make sure the the message is long enough to be valid
emLen = int(math.ceil(emBits / 8.0))
assert emLen >= self.hLen + self.sLen + 2, "emLen too small"
if salt is None:
if self.sLen > 0:
salt = SecureRandomFactory.getInstance().getRandomBytes(self.sLen)
else:
salt = b''
assert len(salt) == self.sLen, "Salt wrong size"
#Let M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt;
eightzerobytes = Bytes.fill(b'\x00', 8)
mHash = self.hashFn(M)
Mprime = eightzerobytes + mHash + salt
#Let H = Hash (M'), an octet string of length hLen.
H = self.hashFn(Mprime)
#Generate an octet string PS consisting of emLen - sLen - hLen - 2 zero octets.
#The length of PS may be 0.
pslen = emLen - self.sLen - self.hLen - 2
ps = Bytes.fill(b'\x00', pslen)
#Let DB = PS || 0x01 || salt; DB is an octet string of length emLen - hLen - 1.
DB = ps + Bytes(b'\x01') + salt
#Let dbMask = MGF (H, emLen - hLen - 1).
masklen = emLen - self.hLen - 1
dbMask = MGF1(H, masklen, self.hashFn, self.hLen)
#Let maskedDB = DB ^ dbMask.
maskedDB = DB ^ dbMask
#Set the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB to zero
numzeros = 8 * emLen - emBits
bitmask = int('0'*numzeros + '1'*(8-numzeros), 2)
ba = bytearray(maskedDB)
ba[0] &= bitmask
maskedDB = Bytes(ba)
EM = maskedDB + H + Bytes(b'\xbc')
if debug:
print("PSS Encoding:")
print("M =>", M)
print("mHash =>", mHash)
print("salt =>", salt)
print("M' =>", Mprime)
print("H =>", H)
print("DB =>", DB)
print("dbmask=>", dbMask)
print("masked=>", maskedDB)
print("EM =>", EM)
return EM
[docs] def verify(self, M, EM, emBits=None):
'''
Verifies that EM is a correct encoding for M
:Parameters:
- M - the message to verify
- EM - the encoded message
:Return: true for 'consistent' or false for 'inconsistent'
'''
if debug: print("PSS Decoding:")
#Preconditions
if emBits == None:
emBits = 8 * len(EM)
assert emBits >= 8* self.hLen + 8* self.sLen + 9, "Not enough emBits"
emLen = int(math.ceil(emBits / 8.0))
assert len(EM) == emLen, "EM length not equivalent to bits provided"
# assert len(M) < (2^61 -1), Message too long
#Let mHash = Hash (M), an octet string of length hLen
mHash = self.hashFn(M)
#if emLen < hLen + sLen + 2, output 'inconsistent' and stop.
if emLen < self.hLen + self.sLen + 2:
if debug: print("emLen too short")
return False
#If the rightmost octet of EM does not have hexadecimal value 0xbc, output
#'inconsistent' and stop.
if EM[len(EM)-1:] != b'\xbc':
if debug: print("0xbc not found")
return False
#Let maskedDB be the leftmost emLen - hLen - 1 octets of EM, and let H be the
#next hLen octets.
maskeDBlen = emLen - self.hLen - 1
maskedDB = Bytes(EM[:maskeDBlen])
H = EM[maskeDBlen:maskeDBlen+self.hLen]
#If the leftmost 8emLen - emBits bits of the leftmost octet in maskedDB are not all
#equal to zero, output 'inconsistent' and stop.
numzeros = 8 * emLen - emBits
bitmask = int('1'*numzeros + '0'*(8-numzeros), 2)
_mask_check = maskedDB[0]
if not py3: _mask_check = ord(_mask_check)
if (_mask_check & bitmask != 0):
if debug: print("right % bits of masked db not zero, found %" % (numzeros, bin(maskedDB[0])))
return False
#Let dbMask = MGF (H, emLen - hLen - 1).
masklen = emLen - self.hLen - 1
dbMask = MGF1(H, masklen, self.hashFn, self.hLen)
#Let DB = maskedDB ^ dbMask.
DB = maskedDB ^ dbMask
#Set the leftmost 8emLen - emBits bits of the leftmost octet in DB to zero.
numzeros = 8 * emLen - emBits
bitmask = int('0'*numzeros + '1'*(8-numzeros), 2)
ba = bytearray(DB)
ba[0] &= bitmask
DB = Bytes(ba)
#If the emLen - hLen - sLen - 2 leftmost octets of DB are not zero
zerolen = emLen - self.hLen - self.sLen - 2
if DB[:zerolen] != Bytes.fill(b'\x00', zerolen):
if debug: print("DB did not start with % zero octets" % zerolen)
return False
#or if the octet at position emLen - hLen - sLen - 1 (the leftmost position is 'position 1') does not
#have hexadecimal value 0x01, output 'inconsistent' and stop.
_db_check = DB[zerolen]
if not py3: _db_check = ord(_db_check)
if _db_check != 0x01:
if debug: print("DB did not have 0x01 at %s, found %s instead" % (zerolen,DB[zerolen]))
return False
#Let salt be the last sLen octets of DB.
salt = DB[len(DB)-self.sLen:]
#Let M' = (0x)00 00 00 00 00 00 00 00 || mHash || salt ;
mPrime = Bytes.fill(b'\x00', 8) + mHash + salt
#Let H' = Hash (M'), an octet string of length hLen.
HPrime = self.hashFn(mPrime)
if debug:
print("M =>", M)
print("mHash =>", mHash)
print("salt =>", salt)
print("M' =>", mPrime)
print("H =>", H)
print("DB =>", DB)
print("dbmask=>", dbMask)
print("masked=>", maskedDB)
print("EM =>", EM)
#If H = H', output 'consistent'. Otherwise, output 'inconsistent'.
return H == HPrime
[docs]class SAEPEncryptionPadding:
'''
:Authors: Christina Garman
SAEPEncryptionPadding
'''
def __init__(self, _hash_type ='sha384'):
self.name = "SAEPEncryptionPadding"
self.hashFn = hashFunc(_hash_type)
self.hashFnOutputBytes = len(hashlib.new(_hash_type).digest())
[docs] def encode(self, message, n, s0):
#n = m + s0 + s1
m = int(n/4) #usually 256 bits
if(len(message) > (m/8)):
assert False, "message too long"
if(len(message) != m):
message_ext = bytes(message) + Bytes.fill(b'\x80', 1)
if(len(message_ext) != m):
message_ext = bytes(message_ext) + Bytes.fill(b'\x00', ((m/8)-2)-len(message))
message_ext = bytes(message_ext) + Bytes.fill(b'\x80', 1)
s1 = n - m - s0
t = Bytes.fill(b'\x00', s0/8)
rand = SecureRandomFactory.getInstance()
r = rand.getRandomBytes(int(s1/8))
v = Bytes(bytes(message_ext) + t)
x = v ^ self.hashFn(r)
y = x + r
if(debug):
print("Encoding")
print("m =>", m)
print("s0 =>", s0)
print("s1 =>", s1)
print("t =>", t, len(t))
print("r =>", r, len(r))
print("v =>", v, len(v))
print("x =>", x)
print("y =>", y, len(y))
return y
[docs] def decode(self, encMessage, n, s0):
m = int(n/4)
x = encMessage[:int((m+s0)/8)]
r = encMessage[int((m+s0)/8):int(n-m-s0)]
v = Bytes(x) ^ self.hashFn(r)
M = v[:int(m/8)]
t = v[int(m/8):int(m+s0/8)]
if(M[-1] == 128 and (M[-2] == 0 or M[-2] == 128)):
index = M[:(len(M)-1)].rindex(b'\x80')
M = M[:index]
else:
M = M[:len(M)-1]
if(debug):
print("decoding:")
print("x => ", x)
print("r => ", r)
print("v => ", v)
print("M => ", M)
print("t => ", t)
print("r =>" , r)
return (M, t)
[docs]class PKCS7Padding(object):
def __init__(self,block_size = 16):
self.block_size = block_size
[docs] def encode(self,_bytes,block_size = 16):
pad = self._padlength(_bytes)
return _bytes.ljust(pad+len(_bytes),bytes([pad]))
[docs] def decode(self,_bytes):
return _bytes[:-(_bytes[-1])]
def _padlength(self,_bytes):
ln=len(_bytes)
pad_bytes_needed = self.block_size - (ln % self.block_size)
if pad_bytes_needed == 0:
pad_bytes_needed = self.block_size
return pad_bytes_needed