Skip to content

Commit 20aa98d

Browse files
committed
ellswift: introduce ElligatorSwift encoding and decoding funcs
The BIP324 ElligatorSwift test vectors are also included.
1 parent 67b8efd commit 20aa98d

File tree

2 files changed

+1210
-0
lines changed

2 files changed

+1210
-0
lines changed

btcec/ellswift.go

+391
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
package btcec
2+
3+
import (
4+
"crypto/rand"
5+
"fmt"
6+
7+
"github.com/btcsuite/btcd/chaincfg/chainhash"
8+
)
9+
10+
var (
11+
// c is sqrt(-3) (mod p)
12+
c FieldVal
13+
14+
cBytes = [32]byte{
15+
0x0a, 0x2d, 0x2b, 0xa9, 0x35, 0x07, 0xf1, 0xdf,
16+
0x23, 0x37, 0x70, 0xc2, 0xa7, 0x97, 0x96, 0x2c,
17+
0xc6, 0x1f, 0x6d, 0x15, 0xda, 0x14, 0xec, 0xd4,
18+
0x7d, 0x8d, 0x27, 0xae, 0x1c, 0xd5, 0xf8, 0x52,
19+
}
20+
21+
// ErrPointNotOnCurve is returned when we're unable to find a point on the
22+
// curve.
23+
ErrPointNotOnCurve = fmt.Errorf("point does not exist on secp256k1 curve")
24+
)
25+
26+
func init() {
27+
c.SetByteSlice(cBytes[:])
28+
}
29+
30+
// XSwiftEC() takes two field elements (u, t) and gives us an x-coordinate that
31+
// is on the secp256k1 curve. This is used to take an ElligatorSwift-encoded
32+
// public key (u, t) and return the point on the curve it maps to.
33+
// TODO: Rewrite these so to avoid new(FieldVal).Add(...) usage?
34+
// NOTE: u, t MUST be normalized. The result x is normalized.
35+
func XSwiftEC(u, t *FieldVal) *FieldVal {
36+
// 1. Let u' = u if u != 0, else = 1
37+
if u.IsZero() {
38+
u.SetInt(1)
39+
}
40+
41+
// 2. Let t' = t if t != 0, else 1
42+
if t.IsZero() {
43+
t.SetInt(1)
44+
}
45+
46+
// 3. Let t'' = t' if g(u') != -(t'^2); t'' = 2t' otherwise
47+
// g(x) = x^3 + ax + b, a = 0, b = 7
48+
49+
// Calculate g(u').
50+
gu := new(FieldVal).SquareVal(u).Mul(u).AddInt(7).Normalize()
51+
52+
// Calculate the right-hand side of the equation (-t'^2)
53+
rhs := new(FieldVal).SquareVal(t).Negate(1).Normalize()
54+
55+
if gu.Equals(rhs) {
56+
// t'' = 2t'
57+
t = t.Add(t)
58+
}
59+
60+
// 4. X = (u'^3 + b - t''^2) / (2t'')
61+
tSquared := new(FieldVal).SquareVal(t).Negate(1)
62+
xNum := new(FieldVal).SquareVal(u).Mul(u).AddInt(7).Add(tSquared)
63+
xDenom := new(FieldVal).Add2(t, t).Inverse()
64+
x := xNum.Mul(xDenom)
65+
66+
// 5. Y = (X+t'') / (u' * c)
67+
yNum := new(FieldVal).Add2(x, t)
68+
yDenom := new(FieldVal).Mul2(u, &c).Inverse()
69+
y := yNum.Mul(yDenom)
70+
71+
// 6. Return the first x in (u'+4Y^2, -X/2Y - u'/2, X/2Y - u'/2) for which
72+
// x^3 + b is square.
73+
74+
// 6a. Calculate u' +4Y^2 and determine if x^3+7 is square.
75+
ySqr := new(FieldVal).Add(y).Mul(y)
76+
quadYSqr := new(FieldVal).Add(ySqr).MulInt(4)
77+
firstX := new(FieldVal).Add(u).Add(quadYSqr)
78+
79+
firstXCurve := new(FieldVal).Add(firstX).Square().Mul(firstX).AddInt(7)
80+
81+
// Now determine if firstXCurve is square (on the curve).
82+
if new(FieldVal).SquareRootVal(firstXCurve) {
83+
return firstX.Normalize()
84+
}
85+
86+
// 6b. Calculate -X/2Y - u'/2 and determine if x^3 + 7 is square
87+
doubleYInv := new(FieldVal).Add(y).Add(y).Inverse()
88+
xDivDoubleYInv := new(FieldVal).Add(x).Mul(doubleYInv)
89+
negXDivDoubleYInv := new(FieldVal).Add(xDivDoubleYInv).Negate(1)
90+
invTwo := new(FieldVal).AddInt(2).Inverse()
91+
negUDivTwo := new(FieldVal).Add(u).Mul(invTwo).Negate(1)
92+
secondX := new(FieldVal).Add(negXDivDoubleYInv).Add(negUDivTwo)
93+
94+
secondXCurve := new(FieldVal).Add(secondX).Square().Mul(secondX).AddInt(7)
95+
96+
// Now determine if secondXCurve is square.
97+
if new(FieldVal).SquareRootVal(secondXCurve) {
98+
return secondX.Normalize()
99+
}
100+
101+
// 6c. Calculate X/2Y -u'/2 and determine if x^3 + 7 is square
102+
thirdX := new(FieldVal).Add(xDivDoubleYInv).Add(negUDivTwo)
103+
104+
thirdXCurve := new(FieldVal).Add(thirdX).Square().Mul(thirdX).AddInt(7)
105+
106+
// Now determine if thirdXCurve is square.
107+
if new(FieldVal).SquareRootVal(thirdXCurve) {
108+
return thirdX.Normalize()
109+
}
110+
111+
// Should have found a square above.
112+
panic("unreachable - no calculated x-values were square")
113+
}
114+
115+
// XSwiftECInv takes two field elements (u, x) (where x is on the curve) and
116+
// returns a field element t. This is used to take a random field element u and
117+
// a point on the curve and return a field element t where (u, t) forms the
118+
// ElligatorSwift encoding.
119+
// TODO: Rewrite these so to avoid new(FieldVal).Add(...) usage?
120+
// NOTE: u, x MUST be normalized. The result `t` is normalized.
121+
func XSwiftECInv(u, x *FieldVal, caseNum int) *FieldVal {
122+
v := new(FieldVal)
123+
s := new(FieldVal)
124+
twoInv := new(FieldVal).AddInt(2).Inverse()
125+
126+
if caseNum&2 == 0 {
127+
// If lift_x(-x-u) succeeds, return None
128+
if _, found := liftX(new(FieldVal).Add(x).Add(u).Negate(2)); found {
129+
return nil
130+
}
131+
132+
// Let v = x
133+
v.Add(x)
134+
135+
// Let s = -(u^3+7)/(u^2 + uv + v^2)
136+
uSqr := new(FieldVal).Add(u).Square()
137+
vSqr := new(FieldVal).Add(v).Square()
138+
sDenom := new(FieldVal).Add(u).Mul(v).Add(uSqr).Add(vSqr)
139+
sNum := new(FieldVal).Add(uSqr).Mul(u).AddInt(7)
140+
141+
s = sDenom.Inverse().Mul(sNum).Negate(1)
142+
} else {
143+
// Let s = x - u
144+
negU := new(FieldVal).Add(u).Negate(1)
145+
s.Add(x).Add(negU).Normalize()
146+
147+
// If s = 0, return None
148+
if s.IsZero() {
149+
return nil
150+
}
151+
152+
// Let r be the square root of -s(4(u^3 + 7) + 3u^2s)
153+
uSqr := new(FieldVal).Add(u).Square()
154+
lhs := new(FieldVal).Add(uSqr).Mul(u).AddInt(7).MulInt(4)
155+
rhs := new(FieldVal).Add(uSqr).MulInt(3).Mul(s)
156+
157+
// Add the two terms together and multiply by -s.
158+
lhs.Add(rhs).Normalize().Mul(s).Negate(1)
159+
160+
r := new(FieldVal)
161+
if !r.SquareRootVal(lhs) {
162+
// If no square root was found, return None.
163+
return nil
164+
}
165+
166+
if caseNum&1 == 1 && r.Normalize().IsZero() {
167+
// If case & 1 = 1 and r = 0, return None.
168+
return nil
169+
}
170+
171+
// Let v = (r/s - u)/2
172+
sInv := new(FieldVal).Add(s).Inverse()
173+
uNeg := new(FieldVal).Add(u).Negate(1)
174+
175+
v.Add(r).Mul(sInv).Add(uNeg).Mul(twoInv)
176+
}
177+
178+
w := new(FieldVal)
179+
180+
if !w.SquareRootVal(s) {
181+
// If no square root was found, return None.
182+
return nil
183+
}
184+
185+
switch caseNum & 5 {
186+
case 0:
187+
// If case & 5 = 0, return -w(u(1-c)/2 + v)
188+
oneMinusC := new(FieldVal).Add(&c).Negate(1).AddInt(1)
189+
t := new(FieldVal).Add(u).Mul(oneMinusC).Mul(twoInv).Add(v).Mul(w).
190+
Negate(1).Normalize()
191+
192+
return t
193+
194+
case 1:
195+
// If case & 5 = 1, return w(u(1+c)/2 + v)
196+
onePlusC := new(FieldVal).Add(&c).AddInt(1)
197+
t := new(FieldVal).Add(u).Mul(onePlusC).Mul(twoInv).Add(v).Mul(w).
198+
Normalize()
199+
200+
return t
201+
202+
case 4:
203+
// If case & 5 = 4, return w(u(1-c)/2 + v)
204+
oneMinusC := new(FieldVal).Add(&c).Negate(1).AddInt(1)
205+
t := new(FieldVal).Add(u).Mul(oneMinusC).Mul(twoInv).Add(v).Mul(w).
206+
Normalize()
207+
208+
return t
209+
210+
case 5:
211+
// If case & 5 = 5, return -w(u(1+c)/2 + v)
212+
onePlusC := new(FieldVal).Add(&c).AddInt(1)
213+
t := new(FieldVal).Add(u).Mul(onePlusC).Mul(twoInv).Add(v).Mul(w).
214+
Negate(1).Normalize()
215+
216+
return t
217+
}
218+
219+
panic("should not reach here")
220+
}
221+
222+
// XElligatorSwift takes the x-coordinate of a point on secp256k1 and generates
223+
// ElligatorSwift encoding of that point composed of two field elements (u, t).
224+
// NOTE: x MUST be normalized. The return values u, t are normalized.
225+
func XElligatorSwift(x *FieldVal) (*FieldVal, *FieldVal, error) {
226+
// We'll choose a random `u` value and a random case so that we can
227+
// generate a `t` value.
228+
for {
229+
// Choose random u value.
230+
var randUBytes [32]byte
231+
_, err := rand.Read(randUBytes[:])
232+
if err != nil {
233+
return nil, nil, err
234+
}
235+
236+
u := new(FieldVal)
237+
overflow := u.SetBytes(&randUBytes)
238+
if overflow == 1 {
239+
u.Normalize()
240+
}
241+
242+
// Choose a random case in the interval [0, 7]
243+
var randCaseByte [1]byte
244+
_, err = rand.Read(randCaseByte[:])
245+
if err != nil {
246+
return nil, nil, err
247+
}
248+
249+
caseNum := randCaseByte[0] & 7
250+
251+
// Find t, if none is found, continue with the loop.
252+
t := XSwiftECInv(u, x, int(caseNum))
253+
if t != nil {
254+
return u, t, nil
255+
}
256+
}
257+
}
258+
259+
// EllswiftCreate generates a random private key and returns that along with
260+
// the ElligatorSwift encoding of its corresponding public key.
261+
func EllswiftCreate() (*PrivateKey, [64]byte, error) {
262+
var randPrivKeyBytes [64]byte
263+
264+
// Generate a random private key
265+
_, err := rand.Read(randPrivKeyBytes[:])
266+
if err != nil {
267+
return nil, [64]byte{}, err
268+
}
269+
270+
privKey, _ := PrivKeyFromBytes(randPrivKeyBytes[:])
271+
272+
// Fetch the x-coordinate of the public key.
273+
x := getXCoord(privKey)
274+
275+
// Get the ElligatorSwift encoding of the public key.
276+
u, t, err := XElligatorSwift(x)
277+
if err != nil {
278+
return nil, [64]byte{}, err
279+
}
280+
281+
uBytes := u.Bytes()
282+
tBytes := t.Bytes()
283+
284+
// ellswift_pub = bytes(u) || bytes(t), its encoding as 64 bytes
285+
var ellswiftPub [64]byte
286+
copy(ellswiftPub[0:32], (*uBytes)[:])
287+
copy(ellswiftPub[32:64], (*tBytes)[:])
288+
289+
// Return (priv, ellswift_pub)
290+
return privKey, ellswiftPub, nil
291+
}
292+
293+
// EllswiftECDHXOnly takes the ElligatorSwift-encoded public key of a
294+
// counter-party and performs ECDH with our private key.
295+
func EllswiftECDHXOnly(ellswiftTheirs [64]byte, privKey *PrivateKey) ([32]byte,
296+
error) {
297+
298+
// Let u = int(ellswift_theirs[:32]) mod p.
299+
// Let t = int(ellswift_theirs[32:]) mod p.
300+
uBytesTheirs := ellswiftTheirs[0:32]
301+
tBytesTheirs := ellswiftTheirs[32:64]
302+
303+
var uTheirs FieldVal
304+
overflow := uTheirs.SetByteSlice(uBytesTheirs[:])
305+
if overflow {
306+
uTheirs.Normalize()
307+
}
308+
309+
var tTheirs FieldVal
310+
overflow = tTheirs.SetByteSlice(tBytesTheirs[:])
311+
if overflow {
312+
tTheirs.Normalize()
313+
}
314+
315+
// Calculate bytes(x(priv⋅lift_x(XSwiftEC(u, t))))
316+
xTheirs := XSwiftEC(&uTheirs, &tTheirs)
317+
pubKey, found := liftX(xTheirs)
318+
if !found {
319+
return [32]byte{}, ErrPointNotOnCurve
320+
}
321+
322+
var pubJacobian JacobianPoint
323+
pubKey.AsJacobian(&pubJacobian)
324+
325+
var sharedPoint JacobianPoint
326+
ScalarMultNonConst(&privKey.Key, &pubJacobian, &sharedPoint)
327+
sharedPoint.ToAffine()
328+
329+
return *sharedPoint.X.Bytes(), nil
330+
}
331+
332+
// getXCoord fetches the corresponding public key's x-coordinate given a
333+
// private key.
334+
func getXCoord(privKey *PrivateKey) *FieldVal {
335+
var result JacobianPoint
336+
ScalarBaseMultNonConst(&privKey.Key, &result)
337+
result.ToAffine()
338+
return &result.X
339+
}
340+
341+
// liftX returns the point P with x-coordinate `x` and even y-coordinate. If a
342+
// point exists on the curve, it returns true and false otherwise.
343+
// TODO: Use quadratic residue formula instead (see: BIP340)?
344+
func liftX(x *FieldVal) (*PublicKey, bool) {
345+
ySqr := new(FieldVal).Add(x).Square().Mul(x).AddInt(7)
346+
347+
y := new(FieldVal)
348+
if !y.SquareRootVal(ySqr) {
349+
// If we've reached here, the point does not exist on the curve.
350+
return nil, false
351+
}
352+
353+
if !y.Normalize().IsOdd() {
354+
return NewPublicKey(x, y), true
355+
}
356+
357+
// Negate y if it's odd.
358+
if !y.Negate(1).Normalize().IsOdd() {
359+
return NewPublicKey(x, y), true
360+
}
361+
362+
return nil, false
363+
}
364+
365+
// V2Ecdh performs x-only ecdh and returns a shared secret composed of a tagged
366+
// hash which itself is composed of two ElligatorSwift-encoded public keys and
367+
// the x-only ecdh point.
368+
func V2Ecdh(priv *PrivateKey, ellswiftTheirs, ellswiftOurs [64]byte,
369+
initiating bool) (*chainhash.Hash, error) {
370+
371+
ecdhPoint, err := EllswiftECDHXOnly(ellswiftTheirs, priv)
372+
if err != nil {
373+
return nil, err
374+
}
375+
376+
if initiating {
377+
// Initiating, place our public key encoding first.
378+
var msg []byte
379+
msg = append(msg, ellswiftOurs[:]...)
380+
msg = append(msg, ellswiftTheirs[:]...)
381+
msg = append(msg, ecdhPoint[:]...)
382+
return chainhash.TaggedHash([]byte("bip324_ellswift_xonly_ecdh"), msg),
383+
nil
384+
}
385+
386+
var msg []byte
387+
msg = append(msg, ellswiftTheirs[:]...)
388+
msg = append(msg, ellswiftOurs[:]...)
389+
msg = append(msg, ecdhPoint[:]...)
390+
return chainhash.TaggedHash([]byte("bip324_ellswift_xonly_ecdh"), msg), nil
391+
}

0 commit comments

Comments
 (0)