-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
143 lines (116 loc) · 2.99 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from Crypto.Util.number import getPrime, isPrime, bytes_to_long
from Crypto.Random.random import randint
import math
N = 2048
ML = 256
flag = b"tjctf{hay_gato_encerrado28f23}"
#print(len(flag))
#https://www.esat.kuleuven.be/cosic/publications/article-54.pdf
#solve
def get_prime(bits,exclude = []):
out = getPrime((bits+1))
while (out in exclude):
bits += 1
out = getPrime((bits+1))
return out
def next_prime(a):
if (a % 2 == 0):
a += 1
while (not isPrime(a)):
a += 2
return a
#https://ctftime.org/writeup/32914
def getSmooth(numBits,exclude=[],smoothness = 15):
out = 2
outL = [2]
ctr = smoothness
while len(bin(out)) - 2 < numBits:
r = get_prime(ctr,exclude+outL)
out *= r
outL.append(r)
#ctr += 1
bitcnt = (numBits - (len(bin(out)) - 2)) // 2
#print("done part 1",bitcnt,outL,out)
i = 0
prime = 23
while True:
while prime in outL or prime in exclude:
prime = next_prime(prime+1)
#print(prime)
if isPrime(1 + (out * prime)):
outL.append(prime)
out = 1+(out * prime)
break
prime = next_prime(prime+1)
outL.sort()
return out, outL
def getSmoothComp(s):
out = 1
fac = []
while (len(bin(out)) - 2 < s):
cand = getPrime(15)
if (cand in fac):
continue
out *= cand
fac.append(cand)
return out, fac
m, mfac = getSmoothComp(ML)
q = get_prime(N - ML) * 2
p = (q * m) + 1
i = 0
while (not isPrime(p)):
i += 1
m, mfac = getSmoothComp(ML)
q = randint(0, 2**(N-ML)) * 2 #
p = (q * m) + 1
g = 3
x = randint(1, p - 1)
print("ElGamal...")
print("p = m * q + 1")
print("m =",m)
#print("mfac =", mfac)
print("p =", p)
#print("q =", q)
print("g =",g)
y = pow(g, x, p)
print("y =", y)
def sign(M, k = p - 1):
while math.gcd(k, p - 1) != 1:
k = randint(1, p - 1)
r = pow(g, k, p)
s = ((M - (x * r)) * pow(k, -1, p - 1) ) % (p - 1)
return (r, s)
def verify(r, s, M):
if (0 < r and r < p and 0 < s and s < p - 1):
return pow(g, M, p) == (pow(y, r, p) * pow(r, s, p)) % p
else:
return False
def create_covert(M, c):
k = p - 1
while math.gcd(k, p - 1) != 1:
k_p = randint(1, 2**10)
k = c + k_p * m
#print("#k =",k)
#print("k_p =",k_p)
r = pow(g, k, p)
s = ((M - (x * r)) * pow(k, -1, p - 1) ) % (p - 1)
return r, s
def verify_covert(r, s, M):
if (0 < r and r < p and 0 < s and s < p - 1):
base = pow(g, q, p)
val = pow(r, q, p)
assert pow(base, m, p) == 1
#print("gq =", base)
#print("rq =", val)
return pow(g, M, p) == (pow(y, r, p) * pow(r, s, p)) % p
else:
return False
M = bytes_to_long(b"tjctf{fake_flag}")
print("M =", M)
#print(len(flag))
snip = bytes_to_long(flag)
#print("#snip =",snip)
r, s = create_covert(M, snip)
print("r =", r)
print("s =", s)
verify_covert(r, s, M)