Skip to content

Commit d0706cd

Browse files
committed
ellswift: introduce ElligatorSwift encoding and decoding funcs
The BIP324 ElligatorSwift test vectors are also included.
1 parent 67b8efd commit d0706cd

File tree

2 files changed

+1211
-0
lines changed

2 files changed

+1211
-0
lines changed

btcec/ellswift.go

+392
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,392 @@
1+
package btcec
2+
3+
import (
4+
"crypto/rand"
5+
"fmt"
6+
7+
"github.com/btcsuite/btcd/chaincfg/chainhash"
8+
)
9+
10+
var (
11+
// c is sqrt(-3) (mod p)
12+
c FieldVal
13+
14+
cBytes = [32]byte{
15+
0x0a, 0x2d, 0x2b, 0xa9, 0x35, 0x07, 0xf1, 0xdf,
16+
0x23, 0x37, 0x70, 0xc2, 0xa7, 0x97, 0x96, 0x2c,
17+
0xc6, 0x1f, 0x6d, 0x15, 0xda, 0x14, 0xec, 0xd4,
18+
0x7d, 0x8d, 0x27, 0xae, 0x1c, 0xd5, 0xf8, 0x52,
19+
}
20+
21+
ellswiftTag = []byte("bip324_ellswift_xonly_ecdh")
22+
23+
// ErrPointNotOnCurve is returned when we're unable to find a point on the
24+
// curve.
25+
ErrPointNotOnCurve = fmt.Errorf("point does not exist on secp256k1 curve")
26+
)
27+
28+
func init() {
29+
c.SetByteSlice(cBytes[:])
30+
}
31+
32+
// XSwiftEC() takes two field elements (u, t) and gives us an x-coordinate that
33+
// is on the secp256k1 curve. This is used to take an ElligatorSwift-encoded
34+
// public key (u, t) and return the point on the curve it maps to.
35+
// TODO: Rewrite these so to avoid new(FieldVal).Add(...) usage?
36+
// NOTE: u, t MUST be normalized. The result x is normalized.
37+
func XSwiftEC(u, t *FieldVal) *FieldVal {
38+
// 1. Let u' = u if u != 0, else = 1
39+
if u.IsZero() {
40+
u.SetInt(1)
41+
}
42+
43+
// 2. Let t' = t if t != 0, else 1
44+
if t.IsZero() {
45+
t.SetInt(1)
46+
}
47+
48+
// 3. Let t'' = t' if g(u') != -(t'^2); t'' = 2t' otherwise
49+
// g(x) = x^3 + ax + b, a = 0, b = 7
50+
51+
// Calculate g(u').
52+
gu := new(FieldVal).SquareVal(u).Mul(u).AddInt(7).Normalize()
53+
54+
// Calculate the right-hand side of the equation (-t'^2)
55+
rhs := new(FieldVal).SquareVal(t).Negate(1).Normalize()
56+
57+
if gu.Equals(rhs) {
58+
// t'' = 2t'
59+
t = t.Add(t)
60+
}
61+
62+
// 4. X = (u'^3 + b - t''^2) / (2t'')
63+
tSquared := new(FieldVal).SquareVal(t).Negate(1)
64+
xNum := new(FieldVal).SquareVal(u).Mul(u).AddInt(7).Add(tSquared)
65+
xDenom := new(FieldVal).Add2(t, t).Inverse()
66+
x := xNum.Mul(xDenom)
67+
68+
// 5. Y = (X+t'') / (u' * c)
69+
yNum := new(FieldVal).Add2(x, t)
70+
yDenom := new(FieldVal).Mul2(u, &c).Inverse()
71+
y := yNum.Mul(yDenom)
72+
73+
// 6. Return the first x in (u'+4Y^2, -X/2Y - u'/2, X/2Y - u'/2) for which
74+
// x^3 + b is square.
75+
76+
// 6a. Calculate u' +4Y^2 and determine if x^3+7 is square.
77+
ySqr := new(FieldVal).Add(y).Mul(y)
78+
quadYSqr := new(FieldVal).Add(ySqr).MulInt(4)
79+
firstX := new(FieldVal).Add(u).Add(quadYSqr)
80+
81+
firstXCurve := new(FieldVal).Add(firstX).Square().Mul(firstX).AddInt(7)
82+
83+
// Now determine if firstXCurve is square (on the curve).
84+
if new(FieldVal).SquareRootVal(firstXCurve) {
85+
return firstX.Normalize()
86+
}
87+
88+
// 6b. Calculate -X/2Y - u'/2 and determine if x^3 + 7 is square
89+
doubleYInv := new(FieldVal).Add(y).Add(y).Inverse()
90+
xDivDoubleYInv := new(FieldVal).Add(x).Mul(doubleYInv)
91+
negXDivDoubleYInv := new(FieldVal).Add(xDivDoubleYInv).Negate(1)
92+
invTwo := new(FieldVal).AddInt(2).Inverse()
93+
negUDivTwo := new(FieldVal).Add(u).Mul(invTwo).Negate(1)
94+
secondX := new(FieldVal).Add(negXDivDoubleYInv).Add(negUDivTwo)
95+
96+
secondXCurve := new(FieldVal).Add(secondX).Square().Mul(secondX).AddInt(7)
97+
98+
// Now determine if secondXCurve is square.
99+
if new(FieldVal).SquareRootVal(secondXCurve) {
100+
return secondX.Normalize()
101+
}
102+
103+
// 6c. Calculate X/2Y -u'/2 and determine if x^3 + 7 is square
104+
thirdX := new(FieldVal).Add(xDivDoubleYInv).Add(negUDivTwo)
105+
106+
thirdXCurve := new(FieldVal).Add(thirdX).Square().Mul(thirdX).AddInt(7)
107+
108+
// Now determine if thirdXCurve is square.
109+
if new(FieldVal).SquareRootVal(thirdXCurve) {
110+
return thirdX.Normalize()
111+
}
112+
113+
// Should have found a square above.
114+
panic("unreachable - no calculated x-values were square")
115+
}
116+
117+
// XSwiftECInv takes two field elements (u, x) (where x is on the curve) and
118+
// returns a field element t. This is used to take a random field element u and
119+
// a point on the curve and return a field element t where (u, t) forms the
120+
// ElligatorSwift encoding.
121+
// TODO: Rewrite these so to avoid new(FieldVal).Add(...) usage?
122+
// NOTE: u, x MUST be normalized. The result `t` is normalized.
123+
func XSwiftECInv(u, x *FieldVal, caseNum int) *FieldVal {
124+
v := new(FieldVal)
125+
s := new(FieldVal)
126+
twoInv := new(FieldVal).AddInt(2).Inverse()
127+
128+
if caseNum&2 == 0 {
129+
// If lift_x(-x-u) succeeds, return None
130+
if _, found := liftX(new(FieldVal).Add(x).Add(u).Negate(2)); found {
131+
return nil
132+
}
133+
134+
// Let v = x
135+
v.Add(x)
136+
137+
// Let s = -(u^3+7)/(u^2 + uv + v^2)
138+
uSqr := new(FieldVal).Add(u).Square()
139+
vSqr := new(FieldVal).Add(v).Square()
140+
sDenom := new(FieldVal).Add(u).Mul(v).Add(uSqr).Add(vSqr)
141+
sNum := new(FieldVal).Add(uSqr).Mul(u).AddInt(7)
142+
143+
s = sDenom.Inverse().Mul(sNum).Negate(1)
144+
} else {
145+
// Let s = x - u
146+
negU := new(FieldVal).Add(u).Negate(1)
147+
s.Add(x).Add(negU).Normalize()
148+
149+
// If s = 0, return None
150+
if s.IsZero() {
151+
return nil
152+
}
153+
154+
// Let r be the square root of -s(4(u^3 + 7) + 3u^2s)
155+
uSqr := new(FieldVal).Add(u).Square()
156+
lhs := new(FieldVal).Add(uSqr).Mul(u).AddInt(7).MulInt(4)
157+
rhs := new(FieldVal).Add(uSqr).MulInt(3).Mul(s)
158+
159+
// Add the two terms together and multiply by -s.
160+
lhs.Add(rhs).Normalize().Mul(s).Negate(1)
161+
162+
r := new(FieldVal)
163+
if !r.SquareRootVal(lhs) {
164+
// If no square root was found, return None.
165+
return nil
166+
}
167+
168+
if caseNum&1 == 1 && r.Normalize().IsZero() {
169+
// If case & 1 = 1 and r = 0, return None.
170+
return nil
171+
}
172+
173+
// Let v = (r/s - u)/2
174+
sInv := new(FieldVal).Add(s).Inverse()
175+
uNeg := new(FieldVal).Add(u).Negate(1)
176+
177+
v.Add(r).Mul(sInv).Add(uNeg).Mul(twoInv)
178+
}
179+
180+
w := new(FieldVal)
181+
182+
if !w.SquareRootVal(s) {
183+
// If no square root was found, return None.
184+
return nil
185+
}
186+
187+
switch caseNum & 5 {
188+
case 0:
189+
// If case & 5 = 0, return -w(u(1-c)/2 + v)
190+
oneMinusC := new(FieldVal).Add(&c).Negate(1).AddInt(1)
191+
t := new(FieldVal).Add(u).Mul(oneMinusC).Mul(twoInv).Add(v).Mul(w).
192+
Negate(1).Normalize()
193+
194+
return t
195+
196+
case 1:
197+
// If case & 5 = 1, return w(u(1+c)/2 + v)
198+
onePlusC := new(FieldVal).Add(&c).AddInt(1)
199+
t := new(FieldVal).Add(u).Mul(onePlusC).Mul(twoInv).Add(v).Mul(w).
200+
Normalize()
201+
202+
return t
203+
204+
case 4:
205+
// If case & 5 = 4, return w(u(1-c)/2 + v)
206+
oneMinusC := new(FieldVal).Add(&c).Negate(1).AddInt(1)
207+
t := new(FieldVal).Add(u).Mul(oneMinusC).Mul(twoInv).Add(v).Mul(w).
208+
Normalize()
209+
210+
return t
211+
212+
case 5:
213+
// If case & 5 = 5, return -w(u(1+c)/2 + v)
214+
onePlusC := new(FieldVal).Add(&c).AddInt(1)
215+
t := new(FieldVal).Add(u).Mul(onePlusC).Mul(twoInv).Add(v).Mul(w).
216+
Negate(1).Normalize()
217+
218+
return t
219+
}
220+
221+
panic("should not reach here")
222+
}
223+
224+
// XElligatorSwift takes the x-coordinate of a point on secp256k1 and generates
225+
// ElligatorSwift encoding of that point composed of two field elements (u, t).
226+
// NOTE: x MUST be normalized. The return values u, t are normalized.
227+
func XElligatorSwift(x *FieldVal) (*FieldVal, *FieldVal, error) {
228+
// We'll choose a random `u` value and a random case so that we can
229+
// generate a `t` value.
230+
for {
231+
// Choose random u value.
232+
var randUBytes [32]byte
233+
_, err := rand.Read(randUBytes[:])
234+
if err != nil {
235+
return nil, nil, err
236+
}
237+
238+
u := new(FieldVal)
239+
overflow := u.SetBytes(&randUBytes)
240+
if overflow == 1 {
241+
u.Normalize()
242+
}
243+
244+
// Choose a random case in the interval [0, 7]
245+
var randCaseByte [1]byte
246+
_, err = rand.Read(randCaseByte[:])
247+
if err != nil {
248+
return nil, nil, err
249+
}
250+
251+
caseNum := randCaseByte[0] & 7
252+
253+
// Find t, if none is found, continue with the loop.
254+
t := XSwiftECInv(u, x, int(caseNum))
255+
if t != nil {
256+
return u, t, nil
257+
}
258+
}
259+
}
260+
261+
// EllswiftCreate generates a random private key and returns that along with
262+
// the ElligatorSwift encoding of its corresponding public key.
263+
func EllswiftCreate() (*PrivateKey, [64]byte, error) {
264+
var randPrivKeyBytes [64]byte
265+
266+
// Generate a random private key
267+
_, err := rand.Read(randPrivKeyBytes[:])
268+
if err != nil {
269+
return nil, [64]byte{}, err
270+
}
271+
272+
privKey, _ := PrivKeyFromBytes(randPrivKeyBytes[:])
273+
274+
// Fetch the x-coordinate of the public key.
275+
x := getXCoord(privKey)
276+
277+
// Get the ElligatorSwift encoding of the public key.
278+
u, t, err := XElligatorSwift(x)
279+
if err != nil {
280+
return nil, [64]byte{}, err
281+
}
282+
283+
uBytes := u.Bytes()
284+
tBytes := t.Bytes()
285+
286+
// ellswift_pub = bytes(u) || bytes(t), its encoding as 64 bytes
287+
var ellswiftPub [64]byte
288+
copy(ellswiftPub[0:32], (*uBytes)[:])
289+
copy(ellswiftPub[32:64], (*tBytes)[:])
290+
291+
// Return (priv, ellswift_pub)
292+
return privKey, ellswiftPub, nil
293+
}
294+
295+
// EllswiftECDHXOnly takes the ElligatorSwift-encoded public key of a
296+
// counter-party and performs ECDH with our private key.
297+
func EllswiftECDHXOnly(ellswiftTheirs [64]byte, privKey *PrivateKey) ([32]byte,
298+
error) {
299+
300+
// Let u = int(ellswift_theirs[:32]) mod p.
301+
// Let t = int(ellswift_theirs[32:]) mod p.
302+
uBytesTheirs := ellswiftTheirs[0:32]
303+
tBytesTheirs := ellswiftTheirs[32:64]
304+
305+
var uTheirs FieldVal
306+
overflow := uTheirs.SetByteSlice(uBytesTheirs[:])
307+
if overflow {
308+
uTheirs.Normalize()
309+
}
310+
311+
var tTheirs FieldVal
312+
overflow = tTheirs.SetByteSlice(tBytesTheirs[:])
313+
if overflow {
314+
tTheirs.Normalize()
315+
}
316+
317+
// Calculate bytes(x(priv⋅lift_x(XSwiftEC(u, t))))
318+
xTheirs := XSwiftEC(&uTheirs, &tTheirs)
319+
pubKey, found := liftX(xTheirs)
320+
if !found {
321+
return [32]byte{}, ErrPointNotOnCurve
322+
}
323+
324+
var pubJacobian JacobianPoint
325+
pubKey.AsJacobian(&pubJacobian)
326+
327+
var sharedPoint JacobianPoint
328+
ScalarMultNonConst(&privKey.Key, &pubJacobian, &sharedPoint)
329+
sharedPoint.ToAffine()
330+
331+
return *sharedPoint.X.Bytes(), nil
332+
}
333+
334+
// getXCoord fetches the corresponding public key's x-coordinate given a
335+
// private key.
336+
func getXCoord(privKey *PrivateKey) *FieldVal {
337+
var result JacobianPoint
338+
ScalarBaseMultNonConst(&privKey.Key, &result)
339+
result.ToAffine()
340+
return &result.X
341+
}
342+
343+
// liftX returns the point P with x-coordinate `x` and even y-coordinate. If a
344+
// point exists on the curve, it returns true and false otherwise.
345+
// TODO: Use quadratic residue formula instead (see: BIP340)?
346+
func liftX(x *FieldVal) (*PublicKey, bool) {
347+
ySqr := new(FieldVal).Add(x).Square().Mul(x).AddInt(7)
348+
349+
y := new(FieldVal)
350+
if !y.SquareRootVal(ySqr) {
351+
// If we've reached here, the point does not exist on the curve.
352+
return nil, false
353+
}
354+
355+
if !y.Normalize().IsOdd() {
356+
return NewPublicKey(x, y), true
357+
}
358+
359+
// Negate y if it's odd.
360+
if !y.Negate(1).Normalize().IsOdd() {
361+
return NewPublicKey(x, y), true
362+
}
363+
364+
return nil, false
365+
}
366+
367+
// V2Ecdh performs x-only ecdh and returns a shared secret composed of a tagged
368+
// hash which itself is composed of two ElligatorSwift-encoded public keys and
369+
// the x-only ecdh point.
370+
func V2Ecdh(priv *PrivateKey, ellswiftTheirs, ellswiftOurs [64]byte,
371+
initiating bool) (*chainhash.Hash, error) {
372+
373+
ecdhPoint, err := EllswiftECDHXOnly(ellswiftTheirs, priv)
374+
if err != nil {
375+
return nil, err
376+
}
377+
378+
if initiating {
379+
// Initiating, place our public key encoding first.
380+
var msg []byte
381+
msg = append(msg, ellswiftOurs[:]...)
382+
msg = append(msg, ellswiftTheirs[:]...)
383+
msg = append(msg, ecdhPoint[:]...)
384+
return chainhash.TaggedHash(ellswiftTag, msg), nil
385+
}
386+
387+
var msg []byte
388+
msg = append(msg, ellswiftTheirs[:]...)
389+
msg = append(msg, ellswiftOurs[:]...)
390+
msg = append(msg, ecdhPoint[:]...)
391+
return chainhash.TaggedHash(ellswiftTag, msg), nil
392+
}

0 commit comments

Comments
 (0)