-rw-r--r-- 2483 mceliece-sage-20221023/controlbits.py raw
#!/usr/bin/env python3 # copied from https://eprint.iacr.org/2020/1493 # see proofs there def permutation(c): m = 1 while (2*m-1)<<(m-1) < len(c): m += 1 assert (2*m-1)<<(m-1) == len(c) n = 1<<m pi = list(range(n)) for i in range(2*m-1): gap = 1<<min(i,2*m-2-i) for j in range(n//2): if c[i*n//2+j]: pos = (j%gap)+2*gap*(j//gap) pi[pos],pi[pos+gap] = pi[pos+gap],pi[pos] return pi def composeinv(c,pi): return [y for x,y in sorted(zip(pi,c))] def controlbits(pi): n = len(pi) m = 1 while 1<<m < n: m += 1 assert 1<<m == n if m == 1: return [pi[0]] p = [pi[x^1] for x in range(n)] q = [pi[x]^1 for x in range(n)] piinv = composeinv(range(n),pi) p,q = composeinv(p,q),composeinv(q,p) c = [min(x,p[x]) for x in range(n)] p,q = composeinv(p,q),composeinv(q,p) for i in range(1,m-1): cp,p,q = composeinv(c,q),composeinv(p,q),composeinv(q,p) c = [min(c[x],cp[x]) for x in range(n)] f = [c[2*j]%2 for j in range(n//2)] F = [x^f[x//2] for x in range(n)] Fpi = composeinv(F,piinv) l = [Fpi[2*k]%2 for k in range(n//2)] L = [y^l[y//2] for y in range(n)] M = composeinv(Fpi,L) subM = [[M[2*j+e]//2 for j in range(n//2)] for e in range(2)] subz = map(controlbits,subM) z = [s for s0s1 in zip(*subz) for s in s0s1] return f+z+l # ----- miscellaneous tests import sys def test_onepermutation(pi): pi = list(pi) n = len(pi) if n < 2: raise Exception('testing only permutations of length at least 2') m = 1 while 1<<m < n: m += 1 if 1<<m != n: raise Exception('testing only permutations of power-of-2 length') assert sorted(pi) == list(range(n)) c = controlbits(pi) assert pi == permutation(c) def test_small(): def doit(pi,pos): if pos == 1: test_onepermutation(pi) else: for i in range(pos): doit(pi,pos-1) if pos%2: pi[0],pi[pos-1] = pi[pos-1],pi[0] else: pi[i],pi[pos-1] = pi[pos-1],pi[i] for m in range(2,4): n = 1<<m print('controlbits scan %d %d' % (m,n)) sys.stdout.flush() pi = list(range(n)) doit(pi,n) import random def test_random(): for m in range(1,14): n = 2**m print('controlbits random %d %d' % (m,n)) sys.stdout.flush() for loop in range(10): r = [random.randrange(2**64)*n+j for j in range(n)] r.sort() pi = [x%n for x in r] test_onepermutation(pi) if __name__ == '__main__': test_small() test_random()