-rw-r--r-- 5623 mceliece-sage-20221023/test-checksums.sage raw
import parameters import encap import decap import keygen def randomuint8(): global randombytes_key global randombytes_buf global randombytes_pos if randombytes_pos == len(randombytes_buf): h = [] def quarterround(a,b,c,d): a += b a &= 0xffffffff d = ZZ(d).__xor__(a) d = (d<<16)|(d>>16) d &= 0xffffffff c += d c &= 0xffffffff b = ZZ(b).__xor__(c) b = (b<<12)|(b>>20) b &= 0xffffffff a += b a &= 0xffffffff d = ZZ(d).__xor__(a) d = (d<<8)|(d>>24) d &= 0xffffffff c += d c &= 0xffffffff b = ZZ(b).__xor__(c) b = (b<<7)|(b>>25) b &= 0xffffffff return a,b,c,d for pos in range(12): x = [1634760805,857760878,2036477234,1797285236] k = randombytes_key for i in range(0,32,4): x += [k[i]+(k[i+1]<<8)+(k[i+2]<<16)+(k[i+3]<<24)] x += [pos,0,0,0] y = copy(x) for i in range(10): x[0],x[4],x[8],x[12] = quarterround(x[0],x[4],x[8],x[12]) x[1],x[5],x[9],x[13] = quarterround(x[1],x[5],x[9],x[13]) x[2],x[6],x[10],x[14] = quarterround(x[2],x[6],x[10],x[14]) x[3],x[7],x[11],x[15] = quarterround(x[3],x[7],x[11],x[15]) x[0],x[5],x[10],x[15] = quarterround(x[0],x[5],x[10],x[15]) x[1],x[6],x[11],x[12] = quarterround(x[1],x[6],x[11],x[12]) x[2],x[7],x[8],x[13] = quarterround(x[2],x[7],x[8],x[13]) x[3],x[4],x[9],x[14] = quarterround(x[3],x[4],x[9],x[14]) for i in range(16): z = x[i]+y[i] z &= 0xffffffff h += [z&255]; z >>= 8 h += [z&255]; z >>= 8 h += [z&255]; z >>= 8 h += [z&255]; z >>= 8 randombytes_key = h[:32] randombytes_buf = h[32:] randombytes_pos = 0 c = randombytes_buf[randombytes_pos] randombytes_buf[randombytes_pos] = None randombytes_pos += 1 return c def randombytes(r): return bytes(bytearray([randomuint8() for j in range(r)])) def L32(x,c): x &= 0xffffffff result = (x << c) | (x >> (32 - c)) result &= 0xffffffff return result def ld32(x): assert len(x) == 4 u = x[3] u = (u<<8)|x[2] u = (u<<8)|x[1] return (u<<8)|x[0] def st32(x): result = [] for i in range(4): result += [255&x] x >>= 8 return result def core(block,k): assert len(block) == 16 x = [None]*16 sigma = b'expand 32-byte k' sigma = list(bytearray(sigma)) for i in range(4): x[5*i] = ld32(sigma[4*i:4*i+4]) x[1+i] = ld32(k[4*i:4*i+4]) x[6+i] = ld32(block[4*i:4*i+4]) x[11+i] = ld32(k[4*i+16:4*i+20]) y = copy(x) for i in range(20): w = [None]*16 for j in range(4): t = [None]*4 for m in range(4): t[m] = ZZ(x[(5*j+4*m)%16]) t[1] = t[1].__xor__(L32(t[0]+t[3], 7)); t[2] = t[2].__xor__(L32(t[1]+t[0], 9)); t[3] = t[3].__xor__(L32(t[2]+t[1],13)); t[0] = t[0].__xor__(L32(t[3]+t[2],18)); for m in range(4): w[4*j+((j+m)%4)] = t[m] for m in range(16): x[m] = w[m] result = [] for i in range(16): z = 0xffffffff & (x[i]+y[i]) result += st32(z) return result def checksum(x): global checksum_state if type(x) == type(b'123'): x = list(bytearray(x)) info = 'checksum %s' % ''.join('%02x'%xi for xi in x) while len(x) >= 16: checksum_state = core(x[:16],checksum_state) x = x[16:] info += ' %s' % ''.join('%02x'%ci for ci in checksum_state) block = copy(x) + [1] + [0]*(15-len(x)) checksum_state[0] = checksum_state[0].__xor__(1) checksum_state = core(block,checksum_state) info += ' %s' % ''.join('%02x'%ci for ci in checksum_state) # print(info) def salsa20(outlen,n,k): assert outlen >= 0 assert len(n) == 8 assert len(k) == 32 result = [] if outlen == 0: return result z = n + [0]*8 while outlen >= 64: result += core(z,k) outlen -= 64 for i in range(8,16): z[i] = 255&(z[i]+1) if z[i]: break if outlen > 0: result += core(z,k)[:outlen] return result def testvector(outlen): k = b'generate inputs for test vectors' k = list(bytearray(k)) result = salsa20(outlen,testvector_n,k) for i in range(8): testvector_n[i] = 255&(testvector_n[i]+1) if testvector_n[i]: break return result def myrandom(): x = testvector(8) return sum(x[i]<<(8*i) for i in range(8)) systems = parameters.alltests if len(sys.argv) > 1: systems = sys.argv[1:] for system in systems: randombytes_key = [0]*32 randombytes_buf = [0]*736 randombytes_pos = len(randombytes_buf) checksum_state = [0]*64 testvector_n = [0]*8 params = parameters.parameters(system,allowtestparams=True) result = system print(result) sys.stdout.flush() for loop in range(64): print(loop) sys.stdout.flush() pk,sk = keygen.keygen(randombytes,params) checksum(pk) checksum(sk) C,k = encap.encap(pk,randombytes,params) checksum(C) checksum(k) assert decap.decap(C,sk,params) == k checksum(k) for loop2 in range(3): Clen = len(C) C = list(bytearray(C)) offset = 1 + (myrandom() % 255) pos = myrandom() % Clen C[pos] = 255&(C[pos]+offset) C = bytes(bytearray(C)) k2 = decap.decap(C,sk,params) if k2 == False: checksum(C) else: checksum(k2) if loop in [7,63]: checksumhex = '' for i in range(32): checksumhex += '%x' % (15&(checksum_state[i]>>4)) checksumhex += '%x' % (15&checksum_state[i]) result += ' ' + checksumhex print(result) sys.stdout.flush() print(result) sys.stdout.flush()