1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
| from gmpy2 import invert import sm3, key from random import choice
random_hex = lambda x: ''.join([choice('0123456789abcdef') for _ in range(x)])
default_ecc_paras = { 'p' : 0x8542D69E4C044F18E8B92435BF6FF7DE457283915C45517D722EDB8B08F1DFC3, 'a' : 0x787968B4FA32C3FD2417842E73BBFEFF2F3C848B6831D7E0EC65228B3937E498, 'b' : 0x63E4C6D3B23B0C849CF84241484BFE48F61D59A5B16BA06E6E12D1DA27C5249A, 'G' : (0x421DEBD61B62EAB6746434EBC3CC315E32220B3BADD50BDC4C4E6C147FEDD43D, 0x0680512BCBB42C07D47349D2153B70C4E5D7FDFCBFA36EA1A85841B9E46E09A2), 'n' : 0x8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7, }
zero_point = (0, 0)
class SM2(object): def __init__(self, private_key, public_key, ecc_paras = default_ecc_paras): self.private_key = private_key self.public_key = public_key self.ecc_paras = ecc_paras self.para_len = 64
def add_point(self, P1, P2):
if P1 == zero_point: return P2 if P2 == zero_point: return P1
(p1x, p1y), (p2x, p2y) = P1, P2 if p1x == p2x and (p1y != p2y or p1y == 0): return zero_point if p1x == p2x: lam = (3 * p1x * p1x + self.ecc_paras['a']) * invert(2 * p1y , self.ecc_paras['p']) % self.ecc_paras['p'] else: lam = (p2y - p1y) * invert(p2x - p1x , self.ecc_paras['p']) % self.ecc_paras['p']
x = (lam**2 - p1x - p2x) % self.ecc_paras['p'] y = (lam * (p1x - x) - p1y) % self.ecc_paras['p']
return (int(x),int(y))
def get_kg(self, k, P): tmp = zero_point while(k): if(k & 1): tmp = self.add_point(tmp, P) P = self.add_point(P, P) k >>= 1 form = '%%0%dx' % self.para_len form = form * 2 return form % (tmp[0], tmp[1])
def sign(self, data): e = int(data.hex(),16) s = 0 r = 0 while s == 0: while r == 0 or r + k == self.ecc_paras['n']: k = int(random_hex(self.para_len), 16) xy = self.get_kg(k, self.ecc_paras['G']) x = int(xy[:self.para_len], 16) y = int(xy[self.para_len:], 16) r = (e + x) % self.ecc_paras['n'] s = (k - r * self.private_key) * invert(self.private_key + 1, self.ecc_paras['n']) % self.ecc_paras['n'] return '%064x%064x' % (r, s)
def verify(self, sign, data): r = int(sign[:self.para_len], 16) s = int(sign[self.para_len:], 16) assert r >= 1 and r <= self.ecc_paras['n'] - 1 assert s >= 1 and s <= self.ecc_paras['n'] - 1 e = int(data.hex(), 16) t = (r + s) % self.ecc_paras['n'] if t == 0: print('验证失败!') return None
P1 = self.get_kg(s, self.ecc_paras['G']) P2 = self.get_kg(t, self.public_key) x, y = self.add_point((int(P1[:self.para_len], 16),int(P1[self.para_len:], 16)), (int(P2[:self.para_len], 16),int(P2[self.para_len:], 16))) R = (e + x) % self.ecc_paras['n'] return R == r
def encrypt(self, data): msg = data.hex() k = int(random_hex(self.para_len),16) C1 = self.get_kg(k, self.ecc_paras['G']) xy = self.get_kg(k, self.public_key) x2 = xy[0:self.para_len] y2 = xy[self.para_len:2*self.para_len] mlen = len(msg) t = sm3.sm3_kdf(xy, mlen/2) if int(t, 16) == 0: return None form = '%%0%dx' % mlen C2 = form % (int(msg, 16) ^ int(t, 16)) C3 = sm3.sm3_hash([i for i in bytes.fromhex('%s%s%s' % (x2, msg, y2))]) return bytes.fromhex('%s%s%s' % (C1, C3, C2))
def decrypt(self, data): data = data.hex() len_2 = 2 * self.para_len len_3 = len_2 + 64 C1 = data[0:len_2] C1x = int(C1[:self.para_len],16) C1y = int(C1[self.para_len:],16) assert C1y**2 % self.ecc_paras['p'] == (C1x**3 + self.ecc_paras['a'] * C1x + self.ecc_paras['b']) % self.ecc_paras['p'] C3 = data[len_2:len_3] C2 = data[len_3:] xy = self.get_kg(self.private_key,(C1x, C1y)) x2 = xy[0:self.para_len] y2 = xy[self.para_len:len_2] cl = len(C2) t = sm3.sm3_kdf(xy, cl/2) if int(t, 16) == 0: return None form = '%%0%dx' % cl M = form % (int(C2,16) ^ int(t,16)) u = sm3.sm3_hash([i for i in bytes.fromhex('%s%s%s'% (x2,M,y2))]) return bytes.fromhex(M)
def sign_and_verify(): sk1, pk1 = key.gen_key() sm2_user = SM2(sk1, pk1) data = b'sh1kaku'*100 sign_data = sm2_user.sign(data) print(sm2_user.verify(sign_data, data)) print()
def encrypt_and_decrypt(): sk1, pk1 = key.gen_key() sm2_user = SM2(sk1, pk1) data = b'sh1kaku'*100 en = sm2_user.encrypt(data) print('加密后:', en) de = sm2_user.decrypt(en) print('解密:', de)
sign_and_verify() encrypt_and_decrypt()
|