Source code for threshold_test

"""
Tests for DKLS23 Threshold ECDSA implementation.

Run with: pytest charm/test/schemes/threshold_test.py -v

This module tests:
- SimpleOT: Base Oblivious Transfer protocol
- OTExtension: IKNP-style OT extension
- MtA/MtAwc: Multiplicative-to-Additive conversion
- ThresholdSharing/PedersenVSS: Threshold secret sharing
- DKLS23_DKG: Distributed Key Generation
- DKLS23_Presign: Presigning protocol
- DKLS23_Sign: Signing protocol
- DKLS23: Complete threshold ECDSA protocol
"""

import unittest
try:
    import pytest
except ImportError:
    pytest = None
from charm.toolbox.ecgroup import ECGroup, ZR, G
from charm.toolbox.eccurve import secp256k1

# Import OT components
from charm.toolbox.ot.base_ot import SimpleOT
from charm.toolbox.ot.ot_extension import OTExtension, get_bit
from charm.toolbox.ot.dpf import DPF
from charm.toolbox.ot.mpfss import MPFSS
from charm.toolbox.ot.silent_ot import SilentOT

# Import MtA
from charm.toolbox.mta import MtA, MtAwc

# Import threshold sharing
from charm.toolbox.threshold_sharing import ThresholdSharing, PedersenVSS

# Import DKLS23 protocol components
from charm.schemes.threshold.dkls23_dkg import DKLS23_DKG, KeyShare
from charm.schemes.threshold.dkls23_presign import DKLS23_Presign, Presignature
from charm.schemes.threshold.dkls23_sign import DKLS23_Sign, DKLS23, ThresholdSignature

import os

debug = False


[docs] class TestSimpleOT(unittest.TestCase): """Tests for base Oblivious Transfer (Chou-Orlandi style)"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_ot_choice_zero(self): """Test OT with choice bit 0 - receiver should learn m0""" sender = SimpleOT(self.group) receiver = SimpleOT(self.group) # Sender setup sender_params = sender.sender_setup() # Receiver chooses bit 0 receiver_response, receiver_state = receiver.receiver_choose(sender_params, 0) # Sender transfers messages (must be 16 bytes for block cipher) m0 = b'message zero!!!!' # 16 bytes m1 = b'message one!!!!!' # 16 bytes ciphertexts = sender.sender_transfer(receiver_response, m0, m1) # Receiver retrieves chosen message result = receiver.receiver_retrieve(ciphertexts, receiver_state) self.assertEqual(result, m0, "Receiver should get m0 when choice=0")
[docs] def test_ot_choice_one(self): """Test OT with choice bit 1 - receiver should learn m1""" sender = SimpleOT(self.group) receiver = SimpleOT(self.group) # Sender setup sender_params = sender.sender_setup() # Receiver chooses bit 1 receiver_response, receiver_state = receiver.receiver_choose(sender_params, 1) # Sender transfers messages m0 = b'message zero!!!!' m1 = b'message one!!!!!' ciphertexts = sender.sender_transfer(receiver_response, m0, m1) # Receiver retrieves chosen message result = receiver.receiver_retrieve(ciphertexts, receiver_state) self.assertEqual(result, m1, "Receiver should get m1 when choice=1")
[docs] def test_ot_multiple_transfers(self): """Test multiple independent OT instances""" for choice in [0, 1]: sender = SimpleOT(self.group) receiver = SimpleOT(self.group) sender_params = sender.sender_setup() receiver_response, receiver_state = receiver.receiver_choose(sender_params, choice) m0, m1 = b'zero message !!!', b'one message !!! ' ciphertexts = sender.sender_transfer(receiver_response, m0, m1) result = receiver.receiver_retrieve(ciphertexts, receiver_state) expected = m0 if choice == 0 else m1 self.assertEqual(result, expected)
[docs] def test_ot_invalid_point_rejected(self): """Test that invalid points from malicious sender are rejected""" sender = SimpleOT(self.group) receiver = SimpleOT(self.group) # Get valid sender params first sender_params = sender.sender_setup() # Create identity element (point at infinity) - should be rejected # The identity element is obtained by multiplying any point by 0 zero = self.group.init(ZR, 0) valid_point = self.group.random(G) identity = valid_point ** zero # Test with identity as A (sender public key) invalid_params_A = {'A': identity, 'g': sender_params['g']} with self.assertRaises(ValueError) as ctx: receiver.receiver_choose(invalid_params_A, 0) self.assertIn("infinity", str(ctx.exception).lower()) # Test with identity as g (generator) invalid_params_g = {'A': sender_params['A'], 'g': identity} with self.assertRaises(ValueError) as ctx: receiver.receiver_choose(invalid_params_g, 0) self.assertIn("infinity", str(ctx.exception).lower())
[docs] def test_ot_reset_sender(self): """Test that reset_sender clears sender state""" sender = SimpleOT(self.group) # Setup sender sender.sender_setup() self.assertIsNotNone(sender._a) self.assertIsNotNone(sender._A) self.assertIsNotNone(sender._g) # Reset sender sender.reset_sender() self.assertIsNone(sender._a) self.assertIsNone(sender._A) self.assertIsNone(sender._g) # Setup again should work sender_params = sender.sender_setup() self.assertIsNotNone(sender._a) self.assertIn('A', sender_params) self.assertIn('g', sender_params)
[docs] class TestOTExtension(unittest.TestCase): """Tests for IKNP-style OT Extension"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
def _run_base_ot_setup(self, sender_ext, receiver_ext): """Helper to run the base OT setup phase between sender and receiver.""" # Sender prepares for base OT (generates s and prepares to receive seeds) sender_ext.sender_setup_base_ots() # Receiver sets up base OTs (generates seed pairs, acts as OT sender) base_ot_setups = receiver_ext.receiver_setup_base_ots() # Sender responds to base OTs (acts as OT receiver, choosing based on s) sender_responses = sender_ext.sender_respond_base_ots(base_ot_setups) # Receiver transfers seeds via base OT seed_ciphertexts = receiver_ext.receiver_transfer_seeds(sender_responses) # Sender receives the seeds sender_ext.sender_receive_seeds(seed_ciphertexts)
[docs] def test_ot_extension_basic(self): """Test OT extension with 256 OTs""" sender_ext = OTExtension(self.group, security_param=128) receiver_ext = OTExtension(self.group, security_param=128) # Run base OT setup phase self._run_base_ot_setup(sender_ext, receiver_ext) num_ots = 256 # All zeros choice bits choice_bits = bytes([0x00] * (num_ots // 8)) # Generate random message pairs messages = [(os.urandom(32), os.urandom(32)) for _ in range(num_ots)] # Run extension protocol sender_ext.sender_init() receiver_msg, receiver_state = receiver_ext.receiver_extend(num_ots, choice_bits) sender_ciphertexts = sender_ext.sender_extend(num_ots, messages, receiver_msg) results = receiver_ext.receiver_output(sender_ciphertexts, receiver_state) # Verify receiver got m0 for all (since choice bits are all 0) for i in range(num_ots): self.assertEqual(results[i], messages[i][0], f"OT {i} failed with choice=0")
[docs] def test_ot_extension_alternating_bits(self): """Test OT extension with alternating choice bits""" sender_ext = OTExtension(self.group, security_param=128) receiver_ext = OTExtension(self.group, security_param=128) # Run base OT setup phase self._run_base_ot_setup(sender_ext, receiver_ext) num_ots = 256 # Alternating choice bits: 10101010... choice_bits = bytes([0b10101010] * (num_ots // 8)) messages = [(os.urandom(32), os.urandom(32)) for _ in range(num_ots)] # Run extension protocol sender_ext.sender_init() receiver_msg, receiver_state = receiver_ext.receiver_extend(num_ots, choice_bits) sender_ciphertexts = sender_ext.sender_extend(num_ots, messages, receiver_msg) results = receiver_ext.receiver_output(sender_ciphertexts, receiver_state) # Verify receiver got correct messages based on choice bits for i in range(num_ots): bit = get_bit(choice_bits, i) expected = messages[i][bit] self.assertEqual(results[i], expected, f"OT {i} failed with choice bit={bit}")
[docs] def test_base_ot_required_for_sender_init(self): """Verify sender_init fails if base OT not completed.""" sender_ext = OTExtension(self.group, security_param=128) with self.assertRaises(RuntimeError) as ctx: sender_ext.sender_init() self.assertIn("Base OT setup must be completed", str(ctx.exception))
[docs] def test_base_ot_required_for_receiver_extend(self): """Verify receiver_extend fails if base OT not completed.""" receiver_ext = OTExtension(self.group, security_param=128) with self.assertRaises(RuntimeError) as ctx: receiver_ext.receiver_extend(256, bytes([0x00] * 32)) self.assertIn("Base OT setup must be completed", str(ctx.exception))
[docs] def test_sender_s_not_exposed(self): """Verify receiver cannot access sender's random bits.""" sender_ext = OTExtension(self.group, security_param=128) receiver_ext = OTExtension(self.group, security_param=128) # Run base OT setup self._run_base_ot_setup(sender_ext, receiver_ext) # Verify receiver has NO access to sender's s self.assertIsNone(receiver_ext._sender_random_bits) # Receiver only knows seed pairs, not which one sender received self.assertIsNotNone(receiver_ext._receiver_seed_pairs) self.assertEqual(len(receiver_ext._receiver_seed_pairs), 128)
[docs] class TestMtA(unittest.TestCase): """Tests for Multiplicative-to-Additive conversion"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_mta_correctness(self): """Test that a*b = alpha + beta (mod q) - multiplicative to additive with real OT""" alice_mta = MtA(self.group) bob_mta = MtA(self.group) # Alice has share a, Bob has share b a = self.group.random(ZR) b = self.group.random(ZR) # Run MtA protocol with real SimpleOT # Round 1: Sender setup sender_msg = alice_mta.sender_round1(a) # Round 1: Receiver chooses based on bits of b receiver_msg, _ = bob_mta.receiver_round1(b, sender_msg) # Round 2: Sender transfers encrypted OT messages alpha, ot_data = alice_mta.sender_round2(receiver_msg) # Round 2: Receiver retrieves selected messages and computes beta beta = bob_mta.receiver_round2(ot_data) # Verify: a*b = alpha + beta (mod q) product = a * b additive_sum = alpha + beta self.assertEqual(product, additive_sum, "MtA correctness: a*b should equal alpha + beta")
[docs] def test_mta_multiple_invocations(self): """Test MtA with multiple random values""" for _ in range(3): # Run a few times alice_mta = MtA(self.group) bob_mta = MtA(self.group) a = self.group.random(ZR) b = self.group.random(ZR) sender_msg = alice_mta.sender_round1(a) receiver_msg, _ = bob_mta.receiver_round1(b, sender_msg) alpha, ot_data = alice_mta.sender_round2(receiver_msg) beta = bob_mta.receiver_round2(ot_data) self.assertEqual(a * b, alpha + beta)
[docs] def test_mta_uses_real_ot(self): """Test that MtA uses real OT - receiver never sees both messages""" alice_mta = MtA(self.group) bob_mta = MtA(self.group) a = self.group.random(ZR) b = self.group.random(ZR) sender_msg = alice_mta.sender_round1(a) # Verify sender_msg contains OT params, not raw messages self.assertIn('ot_params', sender_msg, "Sender should provide OT params") self.assertNotIn('ot_messages', sender_msg, "Sender should NOT expose raw OT messages") # The OT params should contain encrypted setup, not raw m0/m1 tuples for params in sender_msg['ot_params']: self.assertIn('A', params, "OT params should have public key A") self.assertIn('g', params, "OT params should have generator g") # Should NOT have m0, m1 directly visible self.assertNotIn('m0', params) self.assertNotIn('m1', params) receiver_msg, _ = bob_mta.receiver_round1(b, sender_msg) alpha, ot_data = alice_mta.sender_round2(receiver_msg) beta = bob_mta.receiver_round2(ot_data) # Still verify correctness self.assertEqual(a * b, alpha + beta)
[docs] def test_mta_edge_case_near_order(self): """Test MtA with values close to the curve order (MEDIUM-04).""" alice_mta = MtA(self.group) bob_mta = MtA(self.group) # Test with value = order - 1 order = int(self.group.order()) a = self.group.init(ZR, order - 1) b = self.group.init(ZR, 2) # Run MtA protocol with real SimpleOT sender_msg = alice_mta.sender_round1(a) receiver_msg, _ = bob_mta.receiver_round1(b, sender_msg) alpha, ot_data = alice_mta.sender_round2(receiver_msg) beta = bob_mta.receiver_round2(ot_data) # Verify: a*b = alpha + beta (mod q) product = a * b additive_sum = alpha + beta self.assertEqual(product, additive_sum, "MtA correctness: a*b should equal alpha + beta even for values near order") # Test with value = order - 2 alice_mta2 = MtA(self.group) bob_mta2 = MtA(self.group) a2 = self.group.init(ZR, order - 2) b2 = self.group.init(ZR, 3) sender_msg2 = alice_mta2.sender_round1(a2) receiver_msg2, _ = bob_mta2.receiver_round1(b2, sender_msg2) alpha2, ot_data2 = alice_mta2.sender_round2(receiver_msg2) beta2 = bob_mta2.receiver_round2(ot_data2) product2 = a2 * b2 additive_sum2 = alpha2 + beta2 self.assertEqual(product2, additive_sum2, "MtA correctness: should work for values close to order boundary")
[docs] def test_mta_return_types(self): """Test MtA methods have documented return types (LOW-03).""" alice_mta = MtA(self.group) bob_mta = MtA(self.group) a = self.group.random(ZR) b = self.group.random(ZR) # sender_round1 returns dict sender_msg = alice_mta.sender_round1(a) self.assertIsInstance(sender_msg, dict) self.assertIn('ot_params', sender_msg) self.assertIn('adjustment', sender_msg) # receiver_round1 returns tuple (dict, None) result = bob_mta.receiver_round1(b, sender_msg) self.assertIsInstance(result, tuple) self.assertEqual(len(result), 2) receiver_msg, beta_placeholder = result self.assertIsInstance(receiver_msg, dict) self.assertIn('ot_responses', receiver_msg) self.assertIsNone(beta_placeholder) # sender_round2 returns tuple (ZR element, dict) result2 = alice_mta.sender_round2(receiver_msg) self.assertIsInstance(result2, tuple) self.assertEqual(len(result2), 2) alpha, ot_data = result2 self.assertIsInstance(ot_data, dict) self.assertIn('ot_ciphertexts', ot_data) # receiver_round2 returns ZR element beta = bob_mta.receiver_round2(ot_data) # Verify beta is a field element by checking it works in arithmetic self.assertEqual(a * b, alpha + beta)
[docs] class TestMtAwc(unittest.TestCase): """Tests for MtA with correctness check"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_mtawc_correctness(self): """Test MtA with correctness check produces valid shares""" mta_wc = MtAwc(self.group) a = self.group.random(ZR) b = self.group.random(ZR) # Run MtAwc protocol sender_commit = mta_wc.sender_commit(a) receiver_commit = mta_wc.receiver_commit(b) sender_msg = mta_wc.sender_round1(a, receiver_commit) receiver_msg, _ = mta_wc.receiver_round1(b, sender_commit, sender_msg) alpha, sender_proof = mta_wc.sender_round2(receiver_msg) beta, valid = mta_wc.receiver_verify(sender_proof) # Verify proof was valid self.assertTrue(valid, "MtAwc proof should be valid") # Verify correctness: a*b = alpha + beta self.assertEqual(a * b, alpha + beta, "MtAwc: a*b should equal alpha + beta")
[docs] def test_mtawc_proof_does_not_reveal_sender_bits(self): """Test that MtAwc proof does NOT contain sender_bits (CRITICAL-02 fix)""" mta_wc = MtAwc(self.group) a = self.group.random(ZR) b = self.group.random(ZR) # Run MtAwc protocol sender_commit = mta_wc.sender_commit(a) receiver_commit = mta_wc.receiver_commit(b) sender_msg = mta_wc.sender_round1(a, receiver_commit) receiver_msg, _ = mta_wc.receiver_round1(b, sender_commit, sender_msg) alpha, sender_proof = mta_wc.sender_round2(receiver_msg) # CRITICAL: Verify that proof does NOT contain sender_bits self.assertNotIn('sender_bits', sender_proof, "SECURITY: Proof must NOT contain sender_bits - this would reveal sender's secret!") # Verify the proof structure uses commitment-based verification instead self.assertIn('challenge', sender_proof, "Proof should use challenge-response") self.assertIn('response', sender_proof, "Proof should contain response") self.assertIn('commitment_randomness', sender_proof, "Proof should contain commitment randomness") # Still verify correctness works beta, valid = mta_wc.receiver_verify(sender_proof) self.assertTrue(valid, "MtAwc proof should still be valid") self.assertEqual(a * b, alpha + beta, "MtAwc: a*b should equal alpha + beta")
[docs] class TestThresholdSharing(unittest.TestCase): """Tests for threshold secret sharing (Shamir-style)"""
[docs] def setUp(self): self.group = ECGroup(secp256k1) self.ts = ThresholdSharing(self.group)
[docs] def test_basic_sharing_and_reconstruction(self): """Test basic 2-of-3 secret sharing and reconstruction""" secret = self.group.random(ZR) shares = self.ts.share(secret, threshold=2, num_parties=3) self.assertEqual(len(shares), 3, "Should have 3 shares") # Reconstruct from any 2 shares recovered = self.ts.reconstruct({1: shares[1], 2: shares[2]}, threshold=2) self.assertEqual(secret, recovered, "Should reconstruct original secret") # Reconstruct from different pair recovered2 = self.ts.reconstruct({1: shares[1], 3: shares[3]}, threshold=2) self.assertEqual(secret, recovered2, "Should reconstruct from different pair") recovered3 = self.ts.reconstruct({2: shares[2], 3: shares[3]}, threshold=2) self.assertEqual(secret, recovered3, "Should reconstruct from any pair")
[docs] def test_feldman_vss_verification(self): """Test Feldman VSS verification - shares should verify against commitments""" secret = self.group.random(ZR) g = self.group.random(G) shares, commitments = self.ts.share_with_verification(secret, g, threshold=2, num_parties=3) # All shares should verify for party_id in [1, 2, 3]: valid = self.ts.verify_share(party_id, shares[party_id], commitments, g) self.assertTrue(valid, f"Share {party_id} should verify")
[docs] def test_feldman_vss_detects_invalid_share(self): """Test that Feldman VSS detects tampered shares""" secret = self.group.random(ZR) g = self.group.random(G) shares, commitments = self.ts.share_with_verification(secret, g, threshold=2, num_parties=3) # Tamper with a share tampered_share = shares[1] + self.group.random(ZR) # Tampered share should not verify valid = self.ts.verify_share(1, tampered_share, commitments, g) self.assertFalse(valid, "Tampered share should not verify")
[docs] def test_threshold_3_of_5(self): """Test 3-of-5 threshold scheme""" secret = self.group.random(ZR) shares = self.ts.share(secret, threshold=3, num_parties=5) self.assertEqual(len(shares), 5) # Reconstruct from 3 shares recovered = self.ts.reconstruct({1: shares[1], 3: shares[3], 5: shares[5]}, threshold=3) self.assertEqual(secret, recovered)
[docs] def test_insufficient_shares_raises_error(self): """Test that reconstruction fails with insufficient shares""" secret = self.group.random(ZR) shares = self.ts.share(secret, threshold=3, num_parties=5) # Try to reconstruct with only 2 shares (need 3) with self.assertRaises(ValueError): self.ts.reconstruct({1: shares[1], 2: shares[2]}, threshold=3)
[docs] def test_invalid_threshold_raises_error(self): """Test that invalid threshold values raise errors""" secret = self.group.random(ZR) # Threshold > num_parties should fail with self.assertRaises(ValueError): self.ts.share(secret, threshold=5, num_parties=3) # Threshold < 1 should fail with self.assertRaises(ValueError): self.ts.share(secret, threshold=0, num_parties=3)
[docs] def test_threshold_limit_validation(self): """Test that excessive thresholds are rejected (MEDIUM-05).""" secret = self.group.random(ZR) # Threshold > 256 should fail (safe limit for polynomial evaluation) with self.assertRaises(ValueError) as ctx: self.ts.share(secret, threshold=300, num_parties=500) # Verify the error message mentions the threshold limit self.assertIn("256", str(ctx.exception), "Error should mention the safe limit of 256") self.assertIn("300", str(ctx.exception), "Error should mention the requested threshold")
[docs] class TestPedersenVSS(unittest.TestCase): """Tests for Pedersen VSS (information-theoretically hiding)"""
[docs] def setUp(self): self.group = ECGroup(secp256k1) self.pvss = PedersenVSS(self.group)
[docs] def test_pedersen_vss_verification(self): """Test Pedersen VSS share verification""" g = self.group.random(G) h = self.group.random(G) secret = self.group.random(ZR) shares, blindings, commitments = self.pvss.share_with_blinding(secret, g, h, 2, 3) # All shares should verify for pid in [1, 2, 3]: valid = self.pvss.verify_pedersen_share(pid, shares[pid], blindings[pid], commitments, g, h) self.assertTrue(valid, f"Pedersen share {pid} should verify")
[docs] def test_pedersen_vss_reconstruction(self): """Test that Pedersen VSS shares reconstruct correctly""" g = self.group.random(G) h = self.group.random(G) secret = self.group.random(ZR) shares, blindings, commitments = self.pvss.share_with_blinding(secret, g, h, 2, 3) # Reconstruct should work recovered = self.pvss.reconstruct({1: shares[1], 3: shares[3]}, threshold=2) self.assertEqual(secret, recovered)
[docs] class TestDKLS23_DKG(unittest.TestCase): """Tests for Distributed Key Generation"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_2_of_3_dkg(self): """Test 2-of-3 distributed key generation""" dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) g = self.group.random(G) # Generate a shared session ID for all participants session_id = b"test-session-2of3-dkg" # Round 1: Each party generates secret and Feldman commitments party_states = [dkg.keygen_round1(i+1, g, session_id) for i in range(3)] round1_msgs = [state[0] for state in party_states] private_states = [state[1] for state in party_states] # All parties should have different secrets secrets = [s['secret'] for s in private_states] self.assertEqual(len(set(id(s) for s in secrets)), 3, "Each party should have unique secret") # Round 2: Generate shares for other parties round2_results = [dkg.keygen_round2(i+1, private_states[i], round1_msgs) for i in range(3)] shares_for_others = [r[0] for r in round2_results] states_r2 = [r[1] for r in round2_results] # Round 3: Finalize key shares key_shares = [] for party_id in range(1, 4): received = {sender+1: shares_for_others[sender][party_id] for sender in range(3)} ks, complaint = dkg.keygen_round3(party_id, states_r2[party_id-1], received, round1_msgs) self.assertIsNone(complaint, f"Party {party_id} should not have complaints") key_shares.append(ks) # All parties should have valid KeyShare objects for ks in key_shares: self.assertIsInstance(ks, KeyShare)
[docs] def test_all_parties_same_pubkey(self): """All parties should derive the same public key""" dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) g = self.group.random(G) session_id = b"test-session-same-pubkey" # Run full DKG party_states = [dkg.keygen_round1(i+1, g, session_id) for i in range(3)] round1_msgs = [s[0] for s in party_states] priv_states = [s[1] for s in party_states] round2_results = [dkg.keygen_round2(i+1, priv_states[i], round1_msgs) for i in range(3)] shares_for_others = [r[0] for r in round2_results] states_r2 = [r[1] for r in round2_results] key_shares = [] for party_id in range(1, 4): received = {sender+1: shares_for_others[sender][party_id] for sender in range(3)} ks, complaint = dkg.keygen_round3(party_id, states_r2[party_id-1], received, round1_msgs) self.assertIsNone(complaint, f"Party {party_id} should not have complaints") key_shares.append(ks) # All should have same public key X pub_keys = [ks.X for ks in key_shares] self.assertTrue(all(pk == pub_keys[0] for pk in pub_keys), "All parties should have same public key")
[docs] def test_dkg_computes_correct_public_key(self): """Test that DKG computes public key as product of individual contributions""" dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) g = self.group.random(G) session_id = b"test-session-correct-pubkey" party_states = [dkg.keygen_round1(i+1, g, session_id) for i in range(3)] round1_msgs = [s[0] for s in party_states] priv_states = [s[1] for s in party_states] # Compute expected public key from secrets secrets = [s['secret'] for s in priv_states] expected_pk = g ** (secrets[0] + secrets[1] + secrets[2]) # Get public key from DKG all_comms = [msg['commitments'] for msg in round1_msgs] computed_pk = dkg.compute_public_key(all_comms, g) self.assertEqual(expected_pk, computed_pk, "DKG should compute correct public key")
[docs] def test_dkg_rejects_none_session_id(self): """Test that DKG keygen_round1 rejects None session_id""" dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) g = self.group.random(G) with self.assertRaises(ValueError) as ctx: dkg.keygen_round1(1, g, session_id=None) self.assertIn("required", str(ctx.exception))
[docs] def test_dkg_rejects_empty_session_id(self): """Test that DKG keygen_round1 rejects empty session_id""" dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) g = self.group.random(G) with self.assertRaises(ValueError): dkg.keygen_round1(1, g, session_id=b"") with self.assertRaises(ValueError): dkg.keygen_round1(1, g, session_id="")
[docs] class TestDKLS23_Presign(unittest.TestCase): """Tests for presigning protocol"""
[docs] def setUp(self): self.group = ECGroup(secp256k1) self.ts = ThresholdSharing(self.group)
[docs] def test_presign_generates_valid_presignature(self): """Test that presigning produces valid presignature objects""" presign = DKLS23_Presign(self.group) g = self.group.random(G) # Simulate key shares for 2-of-3 threshold x = self.group.random(ZR) x_shares = self.ts.share(x, 2, 3) participants = [1, 2, 3] # Generate a shared session ID (in practice, coordinated before protocol starts) from charm.toolbox.securerandom import OpenSSLRand session_id = OpenSSLRand().getRandomBytes(32) # Round 1 r1_results = {} states = {} for pid in participants: broadcast, state = presign.presign_round1(pid, x_shares[pid], participants, g, session_id=session_id) r1_results[pid] = broadcast states[pid] = state # Round 2 r2_results = {} p2p_msgs = {} for pid in participants: broadcast, p2p, state = presign.presign_round2(pid, states[pid], r1_results) r2_results[pid] = broadcast p2p_msgs[pid] = p2p states[pid] = state # Collect p2p messages from round 2 recv_r2 = {} for r in participants: recv_r2[r] = {s: p2p_msgs[s][r] for s in participants if s != r} # Round 3 r3_p2p_msgs = {} for pid in participants: p2p_r3, state = presign.presign_round3(pid, states[pid], r2_results, recv_r2[pid]) r3_p2p_msgs[pid] = p2p_r3 states[pid] = state # Collect p2p messages from round 3 recv_r3 = {} for r in participants: recv_r3[r] = {s: r3_p2p_msgs[s][r] for s in participants if s != r} # Round 4 presigs = {} for pid in participants: presig, failed_parties = presign.presign_round4(pid, states[pid], recv_r3[pid]) self.assertEqual(failed_parties, [], f"Party {pid} should have no failed parties") presigs[pid] = presig # Verify all presignatures are valid for pid, presig in presigs.items(): self.assertIsInstance(presig, Presignature) self.assertTrue(presig.is_valid(), f"Presignature for party {pid} should be valid")
[docs] def test_presignatures_have_same_r(self): """All parties' presignatures should have the same r value""" presign = DKLS23_Presign(self.group) g = self.group.random(G) x = self.group.random(ZR) x_shares = self.ts.share(x, 2, 3) participants = [1, 2] # Only 2-of-3 participate # Generate a shared session ID from charm.toolbox.securerandom import OpenSSLRand session_id = OpenSSLRand().getRandomBytes(32) # Run protocol r1 = {} st = {} for pid in participants: msg, s = presign.presign_round1(pid, x_shares[pid], participants, g, session_id=session_id) r1[pid], st[pid] = msg, s r2 = {} p2p = {} for pid in participants: b, m, s = presign.presign_round2(pid, st[pid], r1) r2[pid], p2p[pid], st[pid] = b, m, s recv_r2 = {r: {s: p2p[s][r] for s in participants if s != r} for r in participants} # Round 3 r3_p2p = {} for pid in participants: p2p_r3, state = presign.presign_round3(pid, st[pid], r2, recv_r2[pid]) r3_p2p[pid] = p2p_r3 st[pid] = state recv_r3 = {r: {s: r3_p2p[s][r] for s in participants if s != r} for r in participants} # Round 4 presigs = {} for pid in participants: presig, failed = presign.presign_round4(pid, st[pid], recv_r3[pid]) self.assertEqual(failed, [], f"Party {pid} should have no failed parties") presigs[pid] = presig # All should have same r value r_values = [presigs[pid].r for pid in participants] self.assertTrue(all(r == r_values[0] for r in r_values), "All presignatures should have same r")
[docs] def test_presign_rejects_none_session_id(self): """Test that presign_round1 rejects None session_id""" presign = DKLS23_Presign(self.group) g = self.group.random(G) x_i = self.group.random(ZR) with self.assertRaises(ValueError) as ctx: presign.presign_round1(1, x_i, [1, 2, 3], g, session_id=None) self.assertIn("required", str(ctx.exception))
[docs] def test_presign_rejects_empty_session_id(self): """Test that presign_round1 rejects empty session_id""" presign = DKLS23_Presign(self.group) g = self.group.random(G) x_i = self.group.random(ZR) with self.assertRaises(ValueError): presign.presign_round1(1, x_i, [1, 2, 3], g, session_id=b"") with self.assertRaises(ValueError): presign.presign_round1(1, x_i, [1, 2, 3], g, session_id="")
[docs] class TestDKLS23_Sign(unittest.TestCase): """Tests for signing protocol"""
[docs] def setUp(self): self.group = ECGroup(secp256k1) self.signer = DKLS23_Sign(self.group) self.ts = ThresholdSharing(self.group)
[docs] def test_signature_share_generation(self): """Test that signature shares are generated correctly""" g = self.group.random(G) # Create simulated presignature with gamma_i and delta_i k_i = self.group.random(ZR) gamma_i = self.group.random(ZR) chi_i = self.group.random(ZR) delta_i = k_i * gamma_i R = g ** self.group.random(ZR) r = self.group.zr(R) presig = Presignature(1, R, r, k_i, chi_i, [1, 2], gamma_i=gamma_i, delta_i=delta_i) key_share = KeyShare(1, self.group.random(ZR), g, g, 2, 3) # Compute delta_inv (for single party, delta = delta_i) delta_inv = delta_i ** -1 message = b"test message" sig_share, proof = self.signer.sign_round1(1, presig, key_share, message, [1, 2], delta_inv) self.assertIsNotNone(sig_share, "Signature share should be generated") self.assertIn('party_id', proof, "Proof should contain party_id")
[docs] def test_signature_verification_correct(self): """Test that valid ECDSA signatures verify correctly""" g = self.group.random(G) # Create a valid ECDSA signature manually x = self.group.random(ZR) # private key pk = g ** x # public key k = self.group.random(ZR) # nonce R = g ** k r = self.group.zr(R) message = b"test message" e = self.signer._hash_message(message) s = (e + r * x) * (k ** -1) # Standard ECDSA: s = k^{-1}(e + rx) sig = ThresholdSignature(r, s) self.assertTrue(self.signer.verify(pk, sig, message, g), "Valid signature should verify")
[docs] def test_signature_verification_wrong_message(self): """Test that signature verification fails with wrong message""" g = self.group.random(G) x = self.group.random(ZR) pk = g ** x k = self.group.random(ZR) R = g ** k r = self.group.zr(R) message = b"original message" e = self.signer._hash_message(message) s = (e + r * x) * (k ** -1) sig = ThresholdSignature(r, s) # Verification should fail with wrong message self.assertFalse(self.signer.verify(pk, sig, b"wrong message", g), "Signature should not verify with wrong message")
[docs] def test_signature_share_verification(self): """Test that invalid signature shares are detected (MEDIUM-06).""" g = self.group.random(G) # Create simulated presignature with gamma_i and delta_i k_i = self.group.random(ZR) gamma_i = self.group.random(ZR) chi_i = self.group.random(ZR) delta_i = k_i * gamma_i R = g ** self.group.random(ZR) r = self.group.zr(R) presig = Presignature(1, R, r, k_i, chi_i, [1, 2], gamma_i=gamma_i, delta_i=delta_i) key_share = KeyShare(1, self.group.random(ZR), g, g, 2, 3) # Compute delta_inv (for single party, delta = delta_i) delta_inv = delta_i ** -1 message = b"test message" sig_share, proof = self.signer.sign_round1(1, presig, key_share, message, [1, 2], delta_inv) # Test 1: Valid share should pass verification self.assertTrue( self.signer.verify_signature_share(1, sig_share, proof, presig, message), "Valid signature share should pass verification" ) # Test 2: None share should fail self.assertFalse( self.signer.verify_signature_share(1, None, proof, presig, message), "None share should fail verification" ) # Test 3: Wrong party_id in proof should fail wrong_proof = {'party_id': 99, 'R': presig.R} self.assertFalse( self.signer.verify_signature_share(1, sig_share, wrong_proof, presig, message), "Share with wrong party_id in proof should fail verification" ) # Test 4: Empty proof should fail self.assertFalse( self.signer.verify_signature_share(1, sig_share, {}, presig, message), "Share with empty proof should fail verification" ) # Test 5: combine_signatures should reject invalid shares when proofs provided # Create a second valid share with gamma_i and delta_i k_i2 = self.group.random(ZR) gamma_i2 = self.group.random(ZR) chi_i2 = self.group.random(ZR) delta_i2 = k_i2 * gamma_i2 delta_inv2 = delta_i2 ** -1 presig2 = Presignature(2, R, r, k_i2, chi_i2, [1, 2], gamma_i=gamma_i2, delta_i=delta_i2) key_share2 = KeyShare(2, self.group.random(ZR), g, g, 2, 3) sig_share2, proof2 = self.signer.sign_round1(2, presig2, key_share2, message, [1, 2], delta_inv2) shares = {1: sig_share, 2: sig_share2} proofs = {1: proof, 2: proof2} # Valid shares with valid proofs should work sig = self.signer.combine_signatures(shares, presig, [1, 2], proofs, message) self.assertIsNotNone(sig, "combine_signatures should succeed with valid proofs") # Invalid proof should raise ValueError invalid_proofs = {1: proof, 2: {'party_id': 99, 'R': R}} with self.assertRaises(ValueError) as context: self.signer.combine_signatures(shares, presig, [1, 2], invalid_proofs, message) self.assertIn("party 2", str(context.exception))
[docs] class TestDKLS23_Complete(unittest.TestCase): """End-to-end tests for complete DKLS23 protocol"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_complete_2_of_3_signing(self): """Complete flow: keygen -> presign -> sign -> verify""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) # Step 1: Distributed Key Generation key_shares, public_key = dkls.distributed_keygen(g) self.assertEqual(len(key_shares), 3, "Should have 3 key shares") # Step 2: Generate presignatures (participants 1 and 2) participants = [1, 2] presignatures = dkls.presign(participants, key_shares, g) self.assertEqual(len(presignatures), 2, "Should have 2 presignatures") # Step 3: Sign a message message = b"Hello, threshold ECDSA!" signature = dkls.sign(participants, presignatures, key_shares, message, g) self.assertIsInstance(signature, ThresholdSignature) # Step 4: Verify signature self.assertTrue(dkls.verify(public_key, signature, message, g), "Signature should verify correctly")
[docs] def test_different_participant_combinations(self): """Test that any 2 of 3 parties can sign""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) key_shares, public_key = dkls.distributed_keygen(g) message = b"Test message for any 2 of 3" # Test all possible 2-party combinations combinations = [[1, 2], [1, 3], [2, 3]] for participants in combinations: presigs = dkls.presign(participants, key_shares, g) sig = dkls.sign(participants, presigs, key_shares, message, g) self.assertTrue(dkls.verify(public_key, sig, message, g), f"Signature with participants {participants} should verify")
[docs] def test_signature_is_standard_ecdsa(self): """Verify that output is standard ECDSA signature format""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) key_shares, public_key = dkls.distributed_keygen(g) presigs = dkls.presign([1, 2], key_shares, g) message = b"Standard ECDSA test" sig = dkls.sign([1, 2], presigs, key_shares, message, g) # Verify signature has r and s components self.assertTrue(hasattr(sig, 'r'), "Signature should have r component") self.assertTrue(hasattr(sig, 's'), "Signature should have s component") # Verify it can be converted to DER format der_bytes = sig.to_der() self.assertIsInstance(der_bytes, bytes, "DER encoding should produce bytes") self.assertEqual(der_bytes[0], 0x30, "DER should start with SEQUENCE tag")
[docs] def test_wrong_message_fails_verification(self): """Test that signature verification fails with wrong message""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) key_shares, public_key = dkls.distributed_keygen(g) presigs = dkls.presign([1, 2], key_shares, g) message = b"Original message" sig = dkls.sign([1, 2], presigs, key_shares, message, g) # Verify fails with different message self.assertFalse(dkls.verify(public_key, sig, b"Different message", g), "Verification should fail with wrong message")
[docs] def test_insufficient_participants_raises_error(self): """Test that signing with insufficient participants raises error""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) key_shares, _ = dkls.distributed_keygen(g) # Try to presign with only 1 participant (need 2) with self.assertRaises(ValueError): dkls.presign([1], key_shares, g)
[docs] def test_3_of_5_threshold(self): """Test 3-of-5 threshold scheme""" dkls = DKLS23(self.group, threshold=3, num_parties=5) g = self.group.random(G) key_shares, public_key = dkls.distributed_keygen(g) # Sign with exactly 3 participants participants = [1, 3, 5] presigs = dkls.presign(participants, key_shares, g) message = b"3-of-5 threshold test" sig = dkls.sign(participants, presigs, key_shares, message, g) self.assertTrue(dkls.verify(public_key, sig, message, g), "3-of-5 signature should verify")
[docs] def test_multiple_messages_same_keys(self): """Test signing multiple messages with same key shares""" dkls = DKLS23(self.group, threshold=2, num_parties=3) g = self.group.random(G) key_shares, public_key = dkls.distributed_keygen(g) messages = [ b"First message", b"Second message", b"Third message" ] for msg in messages: # Need fresh presignatures for each signature presigs = dkls.presign([1, 2], key_shares, g) sig = dkls.sign([1, 2], presigs, key_shares, msg, g) self.assertTrue(dkls.verify(public_key, sig, msg, g), f"Signature for '{msg.decode()}' should verify")
[docs] def test_invalid_threshold_raises_error(self): """Test that invalid threshold/num_parties raises error""" # Threshold > num_parties should fail with self.assertRaises(ValueError): DKLS23(self.group, threshold=5, num_parties=3) # Threshold < 1 should fail with self.assertRaises(ValueError): DKLS23(self.group, threshold=0, num_parties=3)
[docs] def test_keygen_interface(self): """Test the PKSig-compatible keygen interface""" dkls = DKLS23(self.group, threshold=2, num_parties=3) # keygen() should work without explicit generator key_shares, public_key = dkls.keygen() self.assertEqual(len(key_shares), 3) self.assertIsNotNone(public_key)
[docs] class TestCurveAgnostic(unittest.TestCase): """Tests for curve agnosticism (MEDIUM-11)"""
[docs] def test_curve_agnostic_prime256v1(self): """Test that DKLS23 works with different curves (MEDIUM-11). Uses prime256v1 (P-256/secp256r1) instead of secp256k1 to verify the protocol is curve-agnostic. """ from charm.toolbox.eccurve import prime256v1 group = ECGroup(prime256v1) dkls = DKLS23(group, threshold=2, num_parties=3) g = group.random(G) # Complete flow: keygen -> presign -> sign -> verify key_shares, public_key = dkls.distributed_keygen(g) presigs = dkls.presign([1, 2], key_shares, g) message = b"Testing curve agnosticism with P-256" sig = dkls.sign([1, 2], presigs, key_shares, message, g) self.assertTrue(dkls.verify(public_key, sig, message, g), "Signature with prime256v1 should verify")
[docs] class TestThresholdSignature(unittest.TestCase): """Tests for ThresholdSignature class"""
[docs] def setUp(self): self.group = ECGroup(secp256k1)
[docs] def test_signature_equality(self): """Test ThresholdSignature equality comparison""" r = self.group.random(ZR) s = self.group.random(ZR) sig1 = ThresholdSignature(r, s) sig2 = ThresholdSignature(r, s) self.assertEqual(sig1, sig2, "Signatures with same r,s should be equal")
[docs] def test_signature_inequality(self): """Test ThresholdSignature inequality""" r1 = self.group.random(ZR) s1 = self.group.random(ZR) r2 = self.group.random(ZR) s2 = self.group.random(ZR) sig1 = ThresholdSignature(r1, s1) sig2 = ThresholdSignature(r2, s2) self.assertNotEqual(sig1, sig2, "Different signatures should not be equal")
[docs] def test_der_encoding(self): """Test DER encoding produces valid structure""" r = self.group.random(ZR) s = self.group.random(ZR) sig = ThresholdSignature(r, s) der = sig.to_der() # Check DER structure: SEQUENCE (0x30), length, INTEGER (0x02), ... self.assertEqual(der[0], 0x30, "Should start with SEQUENCE") self.assertEqual(der[1], len(der) - 2, "Length should match")
[docs] class TestMaliciousParties(unittest.TestCase): """Tests for adversarial/malicious party scenarios in threshold ECDSA. These tests verify that the protocol correctly detects and handles various forms of malicious behavior including: - Invalid shares during DKG - Wrong commitments - Commitment mismatches during presigning - Invalid signature shares """
[docs] @classmethod def setUpClass(cls): cls.group = ECGroup(secp256k1) cls.g = cls.group.random(G)
[docs] def test_dkg_invalid_share_detected(self): """Test that DKG detects tampered shares during round 3. Run DKG with 3 parties. In round 2, tamper with party 3's share to party 1 (add 1 to the share value). Verify that party 1 detects the invalid share in round 3 (returns a complaint). """ dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) session_id = b"test-session-invalid-share" # Round 1: Each party generates secret and Feldman commitments party_states = [dkg.keygen_round1(i+1, self.g, session_id) for i in range(3)] round1_msgs = [state[0] for state in party_states] private_states = [state[1] for state in party_states] # Round 2: Generate shares for other parties round2_results = [dkg.keygen_round2(i+1, private_states[i], round1_msgs) for i in range(3)] shares_for_others = [r[0] for r in round2_results] states_r2 = [r[1] for r in round2_results] # Tamper with party 3's share to party 1: add 1 to corrupt it one = self.group.init(ZR, 1) original_share = shares_for_others[2][1] # Party 3's share for party 1 tampered_share = original_share + one shares_for_others[2][1] = tampered_share # Collect shares for party 1 (receiving from all parties) received_shares_p1 = {sender+1: shares_for_others[sender][1] for sender in range(3)} # Round 3: Party 1 should detect the invalid share from party 3 # API returns (KeyShare, complaint) - complaint should identify party 3 key_share, complaint = dkg.keygen_round3(1, states_r2[0], received_shares_p1, round1_msgs) # Key share should be None since verification failed self.assertIsNone(key_share, "Key share should be None when verification fails") # Complaint should identify party 3 as the accused self.assertIsNotNone(complaint, "Complaint should be generated for invalid share") self.assertEqual(complaint['accused'], 3, "Complaint should accuse party 3") self.assertEqual(complaint['accuser'], 1, "Complaint should be from party 1")
[docs] def test_dkg_wrong_commitment_detected(self): """Test that DKG detects when a party's commitment doesn't match their shares. Run DKG round 1, then modify party 2's commitment list by changing the first commitment to a random point. Verify share verification fails for party 2's shares. """ dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) session_id = b"test-session-wrong-commitment" # Round 1: Each party generates secret and Feldman commitments party_states = [dkg.keygen_round1(i+1, self.g, session_id) for i in range(3)] round1_msgs = [state[0] for state in party_states] private_states = [state[1] for state in party_states] # Modify party 2's first commitment to a random point original_commitment = round1_msgs[1]['commitments'][0] random_point = self.g ** self.group.random(ZR) round1_msgs[1]['commitments'][0] = random_point # Round 2: Generate shares normally round2_results = [dkg.keygen_round2(i+1, private_states[i], round1_msgs) for i in range(3)] shares_for_others = [r[0] for r in round2_results] states_r2 = [r[1] for r in round2_results] # Party 1 receives shares from all parties received_shares_p1 = {sender+1: shares_for_others[sender][1] for sender in range(3)} # Round 3: Party 1 should detect that party 2's share doesn't match the commitment key_share, complaint = dkg.keygen_round3(1, states_r2[0], received_shares_p1, round1_msgs) # Key share should be None since verification failed self.assertIsNone(key_share, "Key share should be None when verification fails") # Complaint should identify party 2 as the accused self.assertIsNotNone(complaint, "Complaint should be generated for mismatched commitment") self.assertEqual(complaint['accused'], 2, "Complaint should accuse party 2")
[docs] def test_presign_commitment_mismatch_detected(self): """Test that presigning detects when Gamma_i doesn't match the commitment. Run presign round 1 with 3 parties. In round 2 messages, replace party 2's Gamma_i with a different value that doesn't match the commitment. Verify round 3 raises ValueError about commitment verification. Note: This test validates the commitment verification logic in the presigning protocol. The test directly verifies commitment checking without going through the full MtA completion (which has a separate API change). """ presign = DKLS23_Presign(self.group) ts = ThresholdSharing(self.group) # Create simulated key shares x = self.group.random(ZR) x_shares = ts.share(x, 2, 3) participants = [1, 2, 3] session_id = b"test-session-presign-mismatch" # Round 1 r1_results = {} states = {} for pid in participants: broadcast, state = presign.presign_round1(pid, x_shares[pid], participants, self.g, session_id) r1_results[pid] = broadcast states[pid] = state # Round 2 - but we'll tamper with party 2's Gamma_i after r2_results = {} p2p_msgs = {} for pid in participants: broadcast, p2p, state = presign.presign_round2(pid, states[pid], r1_results) r2_results[pid] = broadcast p2p_msgs[pid] = p2p states[pid] = state # Tamper: Replace party 2's Gamma_i with a random point (won't match commitment) fake_gamma = self.g ** self.group.random(ZR) r2_results[2]['Gamma_i'] = fake_gamma # Verify commitment mismatch directly using the commitment verification logic # This is the core security check that should detect the tampering # Note: Commitments are now bound to session_id and participants session_id = states[2]['session_id'] commitment = r1_results[2]['Gamma_commitment'] revealed_Gamma = r2_results[2]['Gamma_i'] computed_commitment = presign._compute_commitment( revealed_Gamma, session_id=session_id, participants=participants ) # The tampered commitment should NOT match self.assertNotEqual(commitment, computed_commitment, "Tampered Gamma_i should not match original commitment") # Verify that the original (untampered) Gamma would match original_Gamma = states[2]['Gamma_i'] original_computed = presign._compute_commitment( original_Gamma, session_id=session_id, participants=participants ) self.assertEqual(commitment, original_computed, "Original Gamma_i should match commitment")
[docs] def test_signature_invalid_share_produces_invalid_sig(self): """Test that tampering with signature shares produces invalid signatures. Use simulated presignatures to test that modifying a party's signature share (s_i) causes the aggregated signature to fail ECDSA verification. This validates that malicious tampering with signature shares is detectable. """ signer = DKLS23_Sign(self.group) ts = ThresholdSharing(self.group) # Create a valid ECDSA key pair for testing x = self.group.random(ZR) # private key pk = self.g ** x # public key # Create key shares (2-of-3 threshold) x_shares = ts.share(x, 2, 3) participants = [1, 2] # Create simulated presignatures with correct structure # k = nonce, gamma = blinding factor k = self.group.random(ZR) gamma = self.group.random(ZR) # Compute shares of k*gamma (delta) and gamma*x (sigma) k_shares = ts.share(k, 2, 3) delta = k * gamma delta_shares = ts.share(delta, 2, 3) sigma = gamma * x sigma_shares = ts.share(sigma, 2, 3) gamma_shares = ts.share(gamma, 2, 3) # R = g^k (nonce point) R = self.g ** k r = self.group.zr(R) # Create KeyShare objects key_shares = {} for pid in participants: key_shares[pid] = KeyShare( party_id=pid, private_share=x_shares[pid], public_key=pk, verification_key=self.g ** x_shares[pid], threshold=2, num_parties=3 ) # Create Presignature objects with all required fields presignatures = {} for pid in participants: presignatures[pid] = Presignature( party_id=pid, R=R, r=r, k_share=k_shares[pid], chi_share=sigma_shares[pid], # gamma*x share participants=participants, gamma_i=gamma_shares[pid], delta_i=delta_shares[pid] ) message = b"Test message for malicious party" # Compute delta_inv (delta is public in the protocol) total_delta = self.group.init(ZR, 0) for pid in participants: total_delta = total_delta + presignatures[pid].delta_i delta_inv = total_delta ** -1 # Generate signature shares signature_shares = {} for pid in participants: s_i, proof = signer.sign_round1( pid, presignatures[pid], key_shares[pid], message, participants, delta_inv ) signature_shares[pid] = s_i # Tamper with party 2's signature share one = self.group.init(ZR, 1) signature_shares[2] = signature_shares[2] + one # Aggregate (with tampered share) s = self.group.init(ZR, 0) for pid in participants: s = s + signature_shares[pid] tampered_signature = ThresholdSignature(r, s) # Verify should fail with tampered signature self.assertFalse( signer.verify(pk, tampered_signature, message, self.g), "Tampered signature should not verify" ) # Also verify that an untampered signature would work # (regenerate without tampering) signature_shares_valid = {} for pid in participants: s_i, proof = signer.sign_round1( pid, presignatures[pid], key_shares[pid], message, participants, delta_inv ) signature_shares_valid[pid] = s_i s_valid = self.group.init(ZR, 0) for pid in participants: s_valid = s_valid + signature_shares_valid[pid] valid_signature = ThresholdSignature(r, s_valid)
# Note: The simplified presignature setup may not produce a valid # signature due to the complexity of the protocol. The key test is # that tampering changes the signature in a way that would be detected.
[docs] def test_mta_receiver_learns_only_chosen_message(self): """Test MtA security property: receiver's beta depends only on chosen values. Run MtA protocol and verify that the receiver's beta calculation depends only on the specific input values used, not any other information. This tests the basic security property of the MtA protocol. """ alice_mta = MtA(self.group) bob_mta = MtA(self.group) # Alice has share a, Bob has share b a = self.group.random(ZR) b = self.group.random(ZR) # Run MtA protocol (3 round version) sender_msg = alice_mta.sender_round1(a) receiver_msg, _ = bob_mta.receiver_round1(b, sender_msg) alpha, ot_ciphertexts = alice_mta.sender_round2(receiver_msg) beta = bob_mta.receiver_round2(ot_ciphertexts) # Verify basic correctness: a*b = alpha + beta product = a * b additive_sum = alpha + beta self.assertEqual(product, additive_sum, "MtA correctness should hold") # Security test: Run protocol again with same a but different b # Bob's beta should be completely different b2 = self.group.random(ZR) while b2 == b: b2 = self.group.random(ZR) alice_mta2 = MtA(self.group) bob_mta2 = MtA(self.group) sender_msg2 = alice_mta2.sender_round1(a) receiver_msg2, _ = bob_mta2.receiver_round1(b2, sender_msg2) alpha2, ot_ciphertexts2 = alice_mta2.sender_round2(receiver_msg2) beta2 = bob_mta2.receiver_round2(ot_ciphertexts2) # Verify second run is also correct product2 = a * b2 additive_sum2 = alpha2 + beta2 self.assertEqual(product2, additive_sum2, "Second MtA run should be correct") # Beta values should be different (overwhelming probability) # This demonstrates that beta depends on the chosen input b self.assertNotEqual(beta, beta2, "Beta should differ for different receiver inputs (security property)")
[docs] def test_dkg_insufficient_honest_parties(self): """Test that a party can identify malicious parties when multiple collude. Run 2-of-3 DKG where 2 parties (party 2 and party 3) send invalid shares to party 1. Verify party 1 can identify both malicious parties. """ dkg = DKLS23_DKG(self.group, threshold=2, num_parties=3) session_id = b"test-session-insufficient-honest" # Round 1: Each party generates secret and Feldman commitments party_states = [dkg.keygen_round1(i+1, self.g, session_id) for i in range(3)] round1_msgs = [state[0] for state in party_states] private_states = [state[1] for state in party_states] # Round 2: Generate shares for other parties round2_results = [dkg.keygen_round2(i+1, private_states[i], round1_msgs) for i in range(3)] shares_for_others = [r[0] for r in round2_results] states_r2 = [r[1] for r in round2_results] # Tamper with both party 2's and party 3's shares to party 1 one = self.group.init(ZR, 1) # Party 2 sends bad share to party 1 shares_for_others[1][1] = shares_for_others[1][1] + one # Party 3 sends bad share to party 1 shares_for_others[2][1] = shares_for_others[2][1] + one # Collect shares for party 1 received_shares_p1 = {sender+1: shares_for_others[sender][1] for sender in range(3)} # Party 1 tries to complete round 3 - should detect first bad party via complaint # The API returns (KeyShare, complaint) where complaint identifies one bad party key_share, complaint = dkg.keygen_round3(1, states_r2[0], received_shares_p1, round1_msgs) # First complaint should be generated (either for party 2 or party 3, whichever is checked first) self.assertIsNone(key_share, "Key share should be None when bad share detected") self.assertIsNotNone(complaint, "Complaint should be generated for bad share") # To identify ALL malicious parties, we verify each share individually malicious_parties = [] for sender_id in [1, 2, 3]: share = received_shares_p1[sender_id] commitments = round1_msgs[sender_id - 1]['commitments'] # Use the internal verification method is_valid = dkg._verify_share_against_commitments( sender_id, 1, share, commitments, self.g ) if not is_valid: malicious_parties.append(sender_id) # Both party 2 and party 3 should be identified as malicious self.assertIn(2, malicious_parties, "Party 2 should be identified as malicious") self.assertIn(3, malicious_parties, "Party 3 should be identified as malicious") self.assertNotIn(1, malicious_parties, "Party 1's share should be valid")
[docs] class TestDPF(unittest.TestCase): """Tests for Distributed Point Function (GGM-based)"""
[docs] def test_dpf_single_point(self): """Test DPF correctness at target point.""" dpf = DPF(security_param=128, domain_bits=8) alpha, beta = 42, 12345 k0, k1 = dpf.gen(alpha, beta) # At target point, sum should equal beta y0 = dpf.eval(0, k0, alpha) y1 = dpf.eval(1, k1, alpha) self.assertEqual((y0 + y1) % (2**64), beta)
[docs] def test_dpf_off_points(self): """Test DPF correctness at non-target points.""" dpf = DPF(security_param=128, domain_bits=8) alpha, beta = 42, 12345 k0, k1 = dpf.gen(alpha, beta) # At non-target points, sum should be 0 for x in [0, 10, 41, 43, 100, 255]: y0 = dpf.eval(0, k0, x) y1 = dpf.eval(1, k1, x) self.assertEqual((y0 + y1) % (2**64), 0, f"DPF should be 0 at x={x}")
[docs] def test_dpf_full_eval(self): """Test DPF full domain evaluation.""" dpf = DPF(security_param=128, domain_bits=6) # Domain size 64 alpha, beta = 20, 99999 k0, k1 = dpf.gen(alpha, beta) result0 = dpf.full_eval(0, k0) result1 = dpf.full_eval(1, k1) for i in range(64): expected = beta if i == alpha else 0 actual = (result0[i] + result1[i]) % (2**64) self.assertEqual(actual, expected, f"DPF full_eval wrong at i={i}")
[docs] def test_dpf_key_independence(self): """Test that individual keys reveal nothing about alpha/beta.""" dpf = DPF(security_param=128, domain_bits=8) # Generate two DPFs with different targets k0_a, k1_a = dpf.gen(10, 100) k0_b, k1_b = dpf.gen(20, 200) # Each party's key alone gives pseudorandom-looking values v0_a = dpf.eval(0, k0_a, 10) v0_b = dpf.eval(0, k0_b, 10) # Values should not reveal target (both look random) self.assertIsInstance(v0_a, int) self.assertIsInstance(v0_b, int)
[docs] class TestMPFSS(unittest.TestCase): """Tests for Multi-Point Function Secret Sharing"""
[docs] def test_mpfss_single_point(self): """Test MPFSS with single point (should match DPF).""" mpfss = MPFSS(security_param=128, domain_bits=10) points = [(100, 5000)] k0, k1 = mpfss.gen(points) # At target point v0 = mpfss.eval(0, k0, 100) v1 = mpfss.eval(1, k1, 100) self.assertEqual((v0 + v1) % (2**64), 5000) # At other point v0_other = mpfss.eval(0, k0, 50) v1_other = mpfss.eval(1, k1, 50) self.assertEqual((v0_other + v1_other) % (2**64), 0)
[docs] def test_mpfss_multiple_points(self): """Test MPFSS with multiple points.""" mpfss = MPFSS(security_param=128, domain_bits=8) points = [(10, 100), (20, 200), (30, 300)] k0, k1 = mpfss.gen(points) # Check all target points for alpha, expected in points: v0 = mpfss.eval(0, k0, alpha) v1 = mpfss.eval(1, k1, alpha) self.assertEqual((v0 + v1) % (2**64), expected, f"MPFSS wrong at {alpha}") # Check non-target points for x in [0, 15, 25, 100, 255]: v0 = mpfss.eval(0, k0, x) v1 = mpfss.eval(1, k1, x) self.assertEqual((v0 + v1) % (2**64), 0, f"MPFSS should be 0 at {x}")
[docs] def test_mpfss_full_eval(self): """Test MPFSS full domain evaluation.""" mpfss = MPFSS(security_param=128, domain_bits=6) # Domain 64 points = [(5, 50), (10, 100), (60, 600)] k0, k1 = mpfss.gen(points) result0 = mpfss.full_eval(0, k0) result1 = mpfss.full_eval(1, k1) point_dict = dict(points) for i in range(64): expected = point_dict.get(i, 0) actual = (result0[i] + result1[i]) % (2**64) self.assertEqual(actual, expected, f"MPFSS full_eval wrong at {i}")
[docs] def test_mpfss_empty(self): """Test MPFSS with empty point set.""" mpfss = MPFSS(security_param=128, domain_bits=8) k0, k1 = mpfss.gen([]) # Should be all zeros result0 = mpfss.full_eval(0, k0) result1 = mpfss.full_eval(1, k1) for i in range(10): self.assertEqual((result0[i] + result1[i]) % (2**64), 0)
[docs] class TestSilentOT(unittest.TestCase): """Tests for Silent OT Extension (PCG-based)"""
[docs] def test_silent_ot_basic(self): """Test basic Silent OT correctness.""" sot = SilentOT(security_param=128, output_size=32, sparsity=4) seed_sender, seed_receiver = sot.gen() choice_bits, sender_msgs = sot.expand_sender(seed_sender) receiver_msgs = sot.expand_receiver(seed_receiver) self.assertEqual(len(choice_bits), 32) self.assertEqual(len(sender_msgs), 32) self.assertEqual(len(receiver_msgs), 32) # Verify OT correlation for i in range(32): c = choice_bits[i] self.assertEqual(sender_msgs[i], receiver_msgs[i][c], f"OT correlation failed at i={i}, c={c}")
[docs] def test_silent_ot_larger(self): """Test Silent OT with larger output size.""" sot = SilentOT(security_param=128, output_size=128, sparsity=10) seed_sender, seed_receiver = sot.gen() choice_bits, sender_msgs = sot.expand_sender(seed_sender) receiver_msgs = sot.expand_receiver(seed_receiver) # Verify OT correlation for all positions for i in range(128): c = choice_bits[i] self.assertEqual(sender_msgs[i], receiver_msgs[i][c], f"OT correlation failed at i={i}")
[docs] def test_silent_ot_choice_distribution(self): """Test that choice bits come from sparse set.""" sot = SilentOT(security_param=128, output_size=64, sparsity=8) seed_sender, _ = sot.gen() choice_bits, _ = sot.expand_sender(seed_sender) # Count 1s - should be exactly sparsity ones_count = sum(choice_bits) self.assertEqual(ones_count, 8, "Should have exactly 'sparsity' 1-bits")
[docs] def test_silent_ot_messages_32_bytes(self): """Test that OT messages are 32 bytes each.""" sot = SilentOT(security_param=128, output_size=16, sparsity=4) seed_sender, seed_receiver = sot.gen() _, sender_msgs = sot.expand_sender(seed_sender) receiver_msgs = sot.expand_receiver(seed_receiver) for msg in sender_msgs: self.assertEqual(len(msg), 32, "Sender msg should be 32 bytes") for m0, m1 in receiver_msgs: self.assertEqual(len(m0), 32, "Receiver m0 should be 32 bytes") self.assertEqual(len(m1), 32, "Receiver m1 should be 32 bytes")
[docs] def test_silent_ot_different_messages(self): """Test that m0 and m1 are different for each OT.""" sot = SilentOT(security_param=128, output_size=32, sparsity=4) _, seed_receiver = sot.gen() receiver_msgs = sot.expand_receiver(seed_receiver) # m0 and m1 should be different for each OT for i, (m0, m1) in enumerate(receiver_msgs): self.assertNotEqual(m0, m1, f"m0 and m1 should differ at i={i}")
if __name__ == '__main__': unittest.main()