"""
Comprehensive arithmetic tests for the integer module.
These tests validate integer module behavior with GCD operations and integer conversions,
specifically designed to catch Python 3.12+ compatibility issues like the Py_SIZE() vs lv_tag bug.
Tests cover:
1. Integer conversion correctness (Python int <-> integer)
2. GCD operations and isCoPrime() method
3. Modular arithmetic (modular inverse, modular operations)
4. Regression tests for Python 3.12+ compatibility
5. Integration tests that mirror real scheme usage
"""
import sys
import unittest
import pytest
from charm.core.math.integer import (
integer, gcd, random, randomPrime, isPrime, bitsize, serialize, deserialize
)
[docs]
class IntegerConversionTest(unittest.TestCase):
"""Test integer conversion correctness between Python int and integer objects."""
[docs]
def test_common_rsa_exponents(self):
"""Verify that common RSA exponents convert correctly."""
common_exponents = [65537, 3, 5, 17, 257, 641, 6700417]
for exp in common_exponents:
with self.subTest(exponent=exp):
result = integer(exp)
self.assertEqual(int(result), exp, f"integer({exp}) should equal {exp}")
self.assertEqual(str(result), str(exp), f"str(integer({exp})) should equal '{exp}'")
[docs]
def test_small_values(self):
"""Test edge cases with small values."""
small_values = [0, 1, 2, 10, 100, 255, 256, 1000]
for val in small_values:
with self.subTest(value=val):
result = integer(val)
self.assertEqual(int(result), val, f"integer({val}) should equal {val}")
[docs]
def test_large_values(self):
"""Test large values that require multiple digits in PyLongObject."""
# These values require multiple 30-bit digits in Python's internal representation
large_values = [
2**30, # Just over one digit
2**60, # Two digits
2**90, # Three digits
2**128, # Common cryptographic size
2**256, # 256-bit value
2**512, # 512-bit value
2**1024, # 1024-bit value (RSA key size)
]
for val in large_values:
with self.subTest(bits=val.bit_length()):
result = integer(val)
self.assertEqual(int(result), val, f"integer(2^{val.bit_length()-1}) conversion failed")
[docs]
def test_negative_values(self):
"""Test negative integer conversion."""
negative_values = [-1, -2, -10, -100, -65537, -2**30, -2**60, -2**128]
for val in negative_values:
with self.subTest(value=val):
result = integer(val)
self.assertEqual(int(result), val, f"integer({val}) should equal {val}")
[docs]
def test_round_trip_conversion(self):
"""Verify round-trip conversion: Python int -> integer -> Python int preserves value."""
test_values = [
0, 1, -1, 65537, -65537,
2**30 - 1, 2**30, 2**30 + 1, # Around digit boundary
2**60 - 1, 2**60, 2**60 + 1, # Two digit boundary
2**256, -2**256,
2**512 + 12345, -2**512 - 12345,
]
for val in test_values:
with self.subTest(value=val if abs(val) < 1000 else f"2^{val.bit_length()-1}"):
result = int(integer(val))
self.assertEqual(result, val, "Round-trip conversion failed")
[docs]
def test_integer_from_integer(self):
"""Test creating integer from another integer object."""
original = integer(65537)
copy = integer(original)
self.assertEqual(int(copy), 65537)
self.assertEqual(int(original), int(copy))
[docs]
class GCDOperationsTest(unittest.TestCase):
"""Test GCD operations with various integer types."""
[docs]
def test_gcd_python_ints(self):
"""Test gcd() with Python integers."""
test_cases = [
(12, 8, 4),
(17, 13, 1), # Coprime
(100, 25, 25),
(65537, 65536, 1), # Common RSA exponent vs power of 2
(2**128, 2**64, 2**64),
]
for a, b, expected in test_cases:
with self.subTest(a=a, b=b):
result = gcd(a, b)
self.assertEqual(int(result), expected)
[docs]
def test_gcd_integer_objects(self):
"""Test gcd() with integer objects."""
a = integer(48)
b = integer(18)
result = gcd(a, b)
self.assertEqual(int(result), 6)
[docs]
def test_gcd_mixed_types(self):
"""Test gcd() with mixed Python int and integer objects."""
a = integer(48)
result1 = gcd(a, 18)
result2 = gcd(48, integer(18))
self.assertEqual(int(result1), 6)
self.assertEqual(int(result2), 6)
[docs]
def test_gcd_edge_cases(self):
"""Test gcd edge cases."""
# gcd(0, n) = n
self.assertEqual(int(gcd(0, 5)), 5)
self.assertEqual(int(gcd(5, 0)), 5)
# gcd(1, n) = 1
self.assertEqual(int(gcd(1, 12345)), 1)
self.assertEqual(int(gcd(12345, 1)), 1)
# gcd(n, n) = n
self.assertEqual(int(gcd(42, 42)), 42)
[docs]
class IsCoPrimeTest(unittest.TestCase):
"""Test isCoPrime() method for coprimality checking."""
[docs]
def test_coprime_common_exponents(self):
"""Test isCoPrime() with common RSA exponents vs typical phi_N values."""
# Simulate phi_N = (p-1)(q-1) for small primes
p, q = 61, 53
phi_N = integer((p - 1) * (q - 1)) # 3120
# 65537 should be coprime to 3120 (gcd = 1)
self.assertTrue(phi_N.isCoPrime(65537))
# 3 should be coprime to 3120 (gcd = 3, not coprime!)
self.assertFalse(phi_N.isCoPrime(3))
# 17 should be coprime to 3120
self.assertTrue(phi_N.isCoPrime(17))
[docs]
def test_coprime_with_integer_objects(self):
"""Test isCoPrime() with integer objects as arguments."""
a = integer(35) # 5 * 7
self.assertTrue(a.isCoPrime(12)) # gcd(35, 12) = 1
self.assertFalse(a.isCoPrime(15)) # gcd(35, 15) = 5
self.assertTrue(a.isCoPrime(integer(12)))
[docs]
def test_coprime_edge_cases(self):
"""Test isCoPrime() edge cases."""
one = integer(1)
self.assertTrue(one.isCoPrime(12345)) # 1 is coprime to everything
# Any number is coprime to 1
n = integer(12345)
self.assertTrue(n.isCoPrime(1))
[docs]
class ModularArithmeticTest(unittest.TestCase):
"""Test modular arithmetic operations."""
[docs]
def test_modular_inverse_basic(self):
"""Test basic modular inverse computation."""
# e = 3, modulus = 11, inverse should be 4 (3*4 = 12 ≡ 1 mod 11)
e = integer(3, 11)
d = e ** -1
self.assertEqual(int(d), 4)
# Verify: e * d ≡ 1 (mod 11)
product = integer(int(e) * int(d), 11)
self.assertEqual(int(product), 1)
[docs]
def test_modular_inverse_rsa_exponent(self):
"""Test modular inverse with RSA-like parameters."""
# Small RSA example: p=61, q=53, phi_N=3120, e=17
phi_N = 3120
e = integer(17, phi_N)
d = e ** -1
# Verify: e * d ≡ 1 (mod phi_N)
product = (int(e) * int(d)) % phi_N
self.assertEqual(product, 1)
[docs]
def test_modular_operations_respect_modulus(self):
"""Test that modular operations respect the modulus."""
modulus = 17
a = integer(20, modulus) # 20 mod 17 = 3
self.assertEqual(int(a), 3)
b = integer(100, modulus) # 100 mod 17 = 15
self.assertEqual(int(b), 15)
[docs]
def test_modular_exponentiation(self):
"""Test modular exponentiation."""
base = integer(2, 13)
# 2^10 = 1024, 1024 mod 13 = 10
result = base ** 10
self.assertEqual(int(result), 1024 % 13)
[docs]
def test_integer_without_modulus(self):
"""Test integer behavior when modulus is not set."""
a = integer(65537)
b = integer(12345)
# Without modulus, operations should work as regular integers
product = a * b
self.assertEqual(int(product), 65537 * 12345)
[docs]
class Python312CompatibilityTest(unittest.TestCase):
"""Regression tests for Python 3.12+ compatibility.
These tests specifically target the Py_SIZE() vs lv_tag bug that was fixed.
The bug caused incorrect digit count extraction for multi-digit integers.
"""
[docs]
def test_65537_regression(self):
"""Test the specific value that exposed the Python 3.12+ bug.
In the buggy version, integer(65537) returned a huge incorrect value
like 12259964326940877255866161939725058870607969088809533441.
"""
result = integer(65537)
self.assertEqual(int(result), 65537)
# Also verify string representation
self.assertEqual(str(result), "65537")
[docs]
def test_multi_digit_integers(self):
"""Test integers that require multiple digits in PyLongObject.
Python uses 30-bit digits internally. Values >= 2^30 require multiple digits.
The bug was in extracting the digit count from lv_tag.
"""
# Single digit (< 2^30)
single_digit = 2**29
self.assertEqual(int(integer(single_digit)), single_digit)
# Two digits (2^30 to 2^60-1)
two_digits = 2**45
self.assertEqual(int(integer(two_digits)), two_digits)
# Three digits (2^60 to 2^90-1)
three_digits = 2**75
self.assertEqual(int(integer(three_digits)), three_digits)
# Many digits
many_digits = 2**300
self.assertEqual(int(integer(many_digits)), many_digits)
[docs]
def test_sign_handling(self):
"""Test sign handling for negative integers.
In Python 3.12+, sign is stored in lv_tag bits 0-1:
- 0 = positive
- 1 = zero
- 2 = negative
"""
# Positive
pos = integer(12345)
self.assertEqual(int(pos), 12345)
self.assertGreater(int(pos), 0)
# Zero
zero = integer(0)
self.assertEqual(int(zero), 0)
# Negative
neg = integer(-12345)
self.assertEqual(int(neg), -12345)
self.assertLess(int(neg), 0)
# Large negative
large_neg = integer(-2**100)
self.assertEqual(int(large_neg), -2**100)
[docs]
def test_digit_boundary_values(self):
"""Test values at digit boundaries (multiples of 2^30)."""
boundaries = [
2**30 - 1, 2**30, 2**30 + 1,
2**60 - 1, 2**60, 2**60 + 1,
2**90 - 1, 2**90, 2**90 + 1,
]
for val in boundaries:
with self.subTest(value=f"2^{val.bit_length()-1}"):
self.assertEqual(int(integer(val)), val)
self.assertEqual(int(integer(-val)), -val)
[docs]
def test_mpz_to_pylong_roundtrip(self):
"""Test that mpzToLongObj correctly creates Python integers.
This tests the reverse direction: GMP mpz_t -> Python int.
"""
# Create integer, perform operation, convert back
a = integer(2**100)
b = integer(2**50)
product = a * b
expected = 2**100 * 2**50
self.assertEqual(int(product), expected)
[docs]
class IntegrationSchemeTest(unittest.TestCase):
"""Integration tests that mirror real cryptographic scheme usage."""
[docs]
def test_rsa_coprime_search_pattern(self):
"""Test the RSA keygen coprime search pattern.
This mirrors the pattern used in pkenc_rsa.py to find e coprime to phi_N.
"""
# Simulate small RSA parameters
p, q = 61, 53
N = p * q # 3233
phi_N = integer((p - 1) * (q - 1)) # 3120
# Common RSA exponents to try
common_exponents = [65537, 3, 5, 17, 257, 641]
e_value = None
for candidate in common_exponents:
if phi_N.isCoPrime(candidate):
e_value = candidate
break
self.assertIsNotNone(e_value, "Should find a coprime exponent")
# Verify it's actually coprime
self.assertEqual(int(gcd(e_value, int(phi_N))), 1)
# Compute modular inverse
e = integer(e_value, int(phi_N))
d = e ** -1
# Verify: e * d ≡ 1 (mod phi_N)
product = (e_value * int(d)) % int(phi_N)
self.assertEqual(product, 1)
[docs]
def test_rsa_encryption_decryption_pattern(self):
"""Test RSA encryption/decryption with integer operations."""
# Small RSA parameters for testing
p, q = 61, 53
N = p * q # 3233
phi_N = (p - 1) * (q - 1) # 3120
e = 17
d = int(integer(e, phi_N) ** -1) # 2753
# Encrypt message m = 123
m = 123
c = pow(m, e, N) # c = 123^17 mod 3233 = 855
# Decrypt
m_decrypted = pow(c, d, N)
self.assertEqual(m_decrypted, m)
[docs]
def test_paillier_pattern(self):
"""Test Paillier-like integer encoding pattern."""
# Paillier uses n^2 as modulus
p, q = 17, 19
n = p * q # 323
n_squared = n * n # 104329
# Encode a message
m = 42
r = 7 # Random value coprime to n
# g = n + 1 is a common choice
g = n + 1
# Encrypt: c = g^m * r^n mod n^2
c = (pow(g, m, n_squared) * pow(r, n, n_squared)) % n_squared
# Verify the ciphertext is in the correct range
self.assertGreater(c, 0)
self.assertLess(c, n_squared)
[docs]
def test_serialization_roundtrip(self):
"""Test serialization and deserialization of integer objects."""
test_values = [0, 1, 65537, 2**128, 2**256, -12345, -2**100]
for val in test_values:
with self.subTest(value=val if abs(val) < 1000 else f"2^{abs(val).bit_length()-1}"):
original = integer(val)
serialized = serialize(original)
deserialized = deserialize(serialized)
self.assertEqual(int(deserialized), val)
[docs]
class ArithmeticOperationsTest(unittest.TestCase):
"""Test basic arithmetic operations on integer objects."""
[docs]
def test_addition(self):
"""Test integer addition."""
a = integer(100)
b = integer(200)
self.assertEqual(int(a + b), 300)
self.assertEqual(int(a + 50), 150)
[docs]
def test_subtraction(self):
"""Test integer subtraction."""
a = integer(200)
b = integer(100)
self.assertEqual(int(a - b), 100)
self.assertEqual(int(a - 50), 150)
[docs]
def test_multiplication(self):
"""Test integer multiplication."""
a = integer(12)
b = integer(34)
self.assertEqual(int(a * b), 408)
self.assertEqual(int(a * 10), 120)
[docs]
def test_division(self):
"""Test integer division."""
a = integer(100)
b = integer(25)
self.assertEqual(int(a / b), 4)
[docs]
def test_exponentiation(self):
"""Test integer exponentiation."""
a = integer(2)
self.assertEqual(int(a ** 10), 1024)
[docs]
def test_comparison(self):
"""Test integer comparison operations."""
a = integer(100)
b = integer(200)
c = integer(100)
self.assertTrue(a < b)
self.assertTrue(b > a)
self.assertTrue(a <= c)
self.assertTrue(a >= c)
self.assertTrue(a == c)
self.assertTrue(a != b)
if __name__ == "__main__":
unittest.main()