SM2

大创撸了个sm2(水一下博客

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from gmpy2 import invert
import sm3, key
from random import choice

random_hex = lambda x: ''.join([choice('0123456789abcdef') for _ in range(x)])

#默认的ECC参数
default_ecc_paras = {
'p' : 0x8542D69E4C044F18E8B92435BF6FF7DE457283915C45517D722EDB8B08F1DFC3,
'a' : 0x787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E498,
'b' : 0x63E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A,
'G' : (0x421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D,
0x0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2),
'n' : 0x8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7,
}
#曲线的零点
zero_point = (0, 0)

class SM2(object):
def __init__(self, private_key, public_key, ecc_paras = default_ecc_paras):
self.private_key = private_key
self.public_key = public_key
self.ecc_paras = ecc_paras
self.para_len = 64

def add_point(self, P1, P2):

if P1 == zero_point:
return P2
if P2 == zero_point:
return P1

(p1x, p1y), (p2x, p2y) = P1, P2
if p1x == p2x and (p1y != p2y or p1y == 0):
return zero_point

#计算斜率
if p1x == p2x:
lam = (3 * p1x * p1x + self.ecc_paras['a']) * invert(2 * p1y , self.ecc_paras['p']) % self.ecc_paras['p']
else:
lam = (p2y - p1y) * invert(p2x - p1x , self.ecc_paras['p']) % self.ecc_paras['p']

x = (lam**2 - p1x - p2x) % self.ecc_paras['p']
y = (lam * (p1x - x) - p1y) % self.ecc_paras['p']

return (int(x),int(y))

def get_kg(self, k, P):
tmp = zero_point
while(k):
if(k & 1):
tmp = self.add_point(tmp, P)
P = self.add_point(P, P)
k >>= 1
form = '%%0%dx' % self.para_len #格式化坐标
form = form * 2
return form % (tmp[0], tmp[1])

def sign(self, data):
#data:待签名的信息,private_key:私钥
e = int(data.hex(),16)
s = 0
r = 0
while s == 0:
while r == 0 or r + k == self.ecc_paras['n']:
k = int(random_hex(self.para_len), 16)
xy = self.get_kg(k, self.ecc_paras['G'])
x = int(xy[:self.para_len], 16)
y = int(xy[self.para_len:], 16)
r = (e + x) % self.ecc_paras['n']
s = (k - r * self.private_key) * invert(self.private_key + 1, self.ecc_paras['n']) % self.ecc_paras['n']
return '%064x%064x' % (r, s)

def verify(self, sign, data):
#sign = (r, s), data:被签名的信息,public_key:公钥
r = int(sign[:self.para_len], 16)
s = int(sign[self.para_len:], 16)
assert r >= 1 and r <= self.ecc_paras['n'] - 1
assert s >= 1 and s <= self.ecc_paras['n'] - 1
e = int(data.hex(), 16)
t = (r + s) % self.ecc_paras['n']
if t == 0:
print('验证失败!')
return None

P1 = self.get_kg(s, self.ecc_paras['G'])
P2 = self.get_kg(t, self.public_key)

x, y = self.add_point((int(P1[:self.para_len], 16),int(P1[self.para_len:], 16)),
(int(P2[:self.para_len], 16),int(P2[self.para_len:], 16)))
R = (e + x) % self.ecc_paras['n']
return R == r

def encrypt(self, data):
#加密
msg = data.hex()
k = int(random_hex(self.para_len),16)
C1 = self.get_kg(k, self.ecc_paras['G'])
xy = self.get_kg(k, self.public_key)
x2 = xy[0:self.para_len]
y2 = xy[self.para_len:2*self.para_len]
mlen = len(msg)
t = sm3.sm3_kdf(xy, mlen/2)
if int(t, 16) == 0:
return None
form = '%%0%dx' % mlen
C2 = form % (int(msg, 16) ^ int(t, 16))
C3 = sm3.sm3_hash([i for i in bytes.fromhex('%s%s%s' % (x2, msg, y2))])
return bytes.fromhex('%s%s%s' % (C1, C3, C2))

def decrypt(self, data):
#解密
data = data.hex()
len_2 = 2 * self.para_len
len_3 = len_2 + 64
C1 = data[0:len_2]
C1x = int(C1[:self.para_len],16)
C1y = int(C1[self.para_len:],16)
assert C1y**2 % self.ecc_paras['p'] == (C1x**3 + self.ecc_paras['a'] * C1x + self.ecc_paras['b']) % self.ecc_paras['p']

C3 = data[len_2:len_3]
C2 = data[len_3:]
xy = self.get_kg(self.private_key,(C1x, C1y))
x2 = xy[0:self.para_len]
y2 = xy[self.para_len:len_2]
cl = len(C2)
t = sm3.sm3_kdf(xy, cl/2)
if int(t, 16) == 0:
return None
form = '%%0%dx' % cl
M = form % (int(C2,16) ^ int(t,16))
u = sm3.sm3_hash([i for i in bytes.fromhex('%s%s%s'% (x2,M,y2))])
return bytes.fromhex(M)


def sign_and_verify():
sk1, pk1 = key.gen_key()
sm2_user = SM2(sk1, pk1)
data = b'sh1kaku'*100
sign_data = sm2_user.sign(data)
print(sm2_user.verify(sign_data, data))
print()

def encrypt_and_decrypt():
sk1, pk1 = key.gen_key()
sm2_user = SM2(sk1, pk1)
data = b'sh1kaku'*100
en = sm2_user.encrypt(data)
print('加密后:', en)
de = sm2_user.decrypt(en)
print('解密:', de)

sign_and_verify()
encrypt_and_decrypt()