From 0a6006989f6215a45e982cd696339c503ddfc325 Mon Sep 17 00:00:00 2001 From: Andrew Poelstra Date: Tue, 12 Jul 2022 18:51:48 +0000 Subject: [PATCH] Revert "Remove unused scalar_sqr" This reverts commit 5437e7bdfbffddf69fdf7b4af7e997c78f5dafbf. --- src/bench_internal.c | 10 +++ src/scalar.h | 3 + src/scalar_4x64_impl.h | 166 +++++++++++++++++++++++++++++++++++++++++ src/scalar_8x32_impl.h | 89 ++++++++++++++++++++++ src/scalar_low_impl.h | 4 + src/tests.c | 14 ++++ 6 files changed, 286 insertions(+) diff --git a/src/bench_internal.c b/src/bench_internal.c index 61400519..898c7801 100644 --- a/src/bench_internal.c +++ b/src/bench_internal.c @@ -98,6 +98,15 @@ void bench_scalar_negate(void* arg, int iters) { } } +void bench_scalar_sqr(void* arg, int iters) { + int i; + bench_inv *data = (bench_inv*)arg; + + for (i = 0; i < iters; i++) { + secp256k1_scalar_sqr(&data->scalar[0], &data->scalar[0]); + } +} + void bench_scalar_mul(void* arg, int iters) { int i; bench_inv *data = (bench_inv*)arg; @@ -376,6 +385,7 @@ int main(int argc, char **argv) { if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "add")) run_benchmark("scalar_add", bench_scalar_add, bench_setup, NULL, &data, 10, iters*100); if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "negate")) run_benchmark("scalar_negate", bench_scalar_negate, bench_setup, NULL, &data, 10, iters*100); + if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "sqr")) run_benchmark("scalar_sqr", bench_scalar_sqr, bench_setup, NULL, &data, 10, iters*10); if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "mul")) run_benchmark("scalar_mul", bench_scalar_mul, bench_setup, NULL, &data, 10, iters*10); if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "split")) run_benchmark("scalar_split", bench_scalar_split, bench_setup, NULL, &data, 10, iters); if (d || have_flag(argc, argv, "scalar") || have_flag(argc, argv, "inverse")) run_benchmark("scalar_inverse", bench_scalar_inverse, bench_setup, NULL, &data, 10, iters); diff --git a/src/scalar.h b/src/scalar.h index 36eb0db8..227913cb 100644 --- a/src/scalar.h +++ b/src/scalar.h @@ -65,6 +65,9 @@ static void secp256k1_scalar_mul(secp256k1_scalar *r, const secp256k1_scalar *a, * the low bits that were shifted off */ static int secp256k1_scalar_shr_int(secp256k1_scalar *r, int n); +/** Compute the square of a scalar (modulo the group order). */ +static void secp256k1_scalar_sqr(secp256k1_scalar *r, const secp256k1_scalar *a); + /** Compute the inverse of a scalar (modulo the group order). */ static void secp256k1_scalar_inverse(secp256k1_scalar *r, const secp256k1_scalar *a); diff --git a/src/scalar_4x64_impl.h b/src/scalar_4x64_impl.h index 6b0b44ed..585a4b63 100644 --- a/src/scalar_4x64_impl.h +++ b/src/scalar_4x64_impl.h @@ -224,6 +224,28 @@ static int secp256k1_scalar_cond_negate(secp256k1_scalar *r, int flag) { VERIFY_CHECK(c1 >= th); \ } +/** Add 2*a*b to the number defined by (c0,c1,c2). c2 must never overflow. */ +#define muladd2(a,b) { \ + uint64_t tl, th, th2, tl2; \ + { \ + uint128_t t = (uint128_t)a * b; \ + th = t >> 64; /* at most 0xFFFFFFFFFFFFFFFE */ \ + tl = t; \ + } \ + th2 = th + th; /* at most 0xFFFFFFFFFFFFFFFE (in case th was 0x7FFFFFFFFFFFFFFF) */ \ + c2 += (th2 < th); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((th2 >= th) || (c2 != 0)); \ + tl2 = tl + tl; /* at most 0xFFFFFFFFFFFFFFFE (in case the lowest 63 bits of tl were 0x7FFFFFFFFFFFFFFF) */ \ + th2 += (tl2 < tl); /* at most 0xFFFFFFFFFFFFFFFF */ \ + c0 += tl2; /* overflow is handled on the next line */ \ + th2 += (c0 < tl2); /* second overflow is handled on the next line */ \ + c2 += (c0 < tl2) & (th2 == 0); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((c0 >= tl2) || (th2 != 0) || (c2 != 0)); \ + c1 += th2; /* overflow is handled on the next line */ \ + c2 += (c1 < th2); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((c1 >= th2) || (c2 != 0)); \ +} + /** Add a to the number defined by (c0,c1,c2). c2 must never overflow. */ #define sumadd(a) { \ unsigned int over; \ @@ -733,10 +755,148 @@ static void secp256k1_scalar_mul_512(uint64_t l[8], const secp256k1_scalar *a, c #endif } +static void secp256k1_scalar_sqr_512(uint64_t l[8], const secp256k1_scalar *a) { +#ifdef USE_ASM_X86_64 + __asm__ __volatile__( + /* Preload */ + "movq 0(%%rdi), %%r11\n" + "movq 8(%%rdi), %%r12\n" + "movq 16(%%rdi), %%r13\n" + "movq 24(%%rdi), %%r14\n" + /* (rax,rdx) = a0 * a0 */ + "movq %%r11, %%rax\n" + "mulq %%r11\n" + /* Extract l0 */ + "movq %%rax, 0(%%rsi)\n" + /* (r8,r9,r10) = (rdx,0) */ + "movq %%rdx, %%r8\n" + "xorq %%r9, %%r9\n" + "xorq %%r10, %%r10\n" + /* (r8,r9,r10) += 2 * a0 * a1 */ + "movq %%r11, %%rax\n" + "mulq %%r12\n" + "addq %%rax, %%r8\n" + "adcq %%rdx, %%r9\n" + "adcq $0, %%r10\n" + "addq %%rax, %%r8\n" + "adcq %%rdx, %%r9\n" + "adcq $0, %%r10\n" + /* Extract l1 */ + "movq %%r8, 8(%%rsi)\n" + "xorq %%r8, %%r8\n" + /* (r9,r10,r8) += 2 * a0 * a2 */ + "movq %%r11, %%rax\n" + "mulq %%r13\n" + "addq %%rax, %%r9\n" + "adcq %%rdx, %%r10\n" + "adcq $0, %%r8\n" + "addq %%rax, %%r9\n" + "adcq %%rdx, %%r10\n" + "adcq $0, %%r8\n" + /* (r9,r10,r8) += a1 * a1 */ + "movq %%r12, %%rax\n" + "mulq %%r12\n" + "addq %%rax, %%r9\n" + "adcq %%rdx, %%r10\n" + "adcq $0, %%r8\n" + /* Extract l2 */ + "movq %%r9, 16(%%rsi)\n" + "xorq %%r9, %%r9\n" + /* (r10,r8,r9) += 2 * a0 * a3 */ + "movq %%r11, %%rax\n" + "mulq %%r14\n" + "addq %%rax, %%r10\n" + "adcq %%rdx, %%r8\n" + "adcq $0, %%r9\n" + "addq %%rax, %%r10\n" + "adcq %%rdx, %%r8\n" + "adcq $0, %%r9\n" + /* (r10,r8,r9) += 2 * a1 * a2 */ + "movq %%r12, %%rax\n" + "mulq %%r13\n" + "addq %%rax, %%r10\n" + "adcq %%rdx, %%r8\n" + "adcq $0, %%r9\n" + "addq %%rax, %%r10\n" + "adcq %%rdx, %%r8\n" + "adcq $0, %%r9\n" + /* Extract l3 */ + "movq %%r10, 24(%%rsi)\n" + "xorq %%r10, %%r10\n" + /* (r8,r9,r10) += 2 * a1 * a3 */ + "movq %%r12, %%rax\n" + "mulq %%r14\n" + "addq %%rax, %%r8\n" + "adcq %%rdx, %%r9\n" + "adcq $0, %%r10\n" + "addq %%rax, %%r8\n" + "adcq %%rdx, %%r9\n" + "adcq $0, %%r10\n" + /* (r8,r9,r10) += a2 * a2 */ + "movq %%r13, %%rax\n" + "mulq %%r13\n" + "addq %%rax, %%r8\n" + "adcq %%rdx, %%r9\n" + "adcq $0, %%r10\n" + /* Extract l4 */ + "movq %%r8, 32(%%rsi)\n" + "xorq %%r8, %%r8\n" + /* (r9,r10,r8) += 2 * a2 * a3 */ + "movq %%r13, %%rax\n" + "mulq %%r14\n" + "addq %%rax, %%r9\n" + "adcq %%rdx, %%r10\n" + "adcq $0, %%r8\n" + "addq %%rax, %%r9\n" + "adcq %%rdx, %%r10\n" + "adcq $0, %%r8\n" + /* Extract l5 */ + "movq %%r9, 40(%%rsi)\n" + /* (r10,r8) += a3 * a3 */ + "movq %%r14, %%rax\n" + "mulq %%r14\n" + "addq %%rax, %%r10\n" + "adcq %%rdx, %%r8\n" + /* Extract l6 */ + "movq %%r10, 48(%%rsi)\n" + /* Extract l7 */ + "movq %%r8, 56(%%rsi)\n" + : + : "S"(l), "D"(a->d) + : "rax", "rdx", "r8", "r9", "r10", "r11", "r12", "r13", "r14", "cc", "memory"); +#else + /* 160 bit accumulator. */ + uint64_t c0 = 0, c1 = 0; + uint32_t c2 = 0; + + /* l[0..7] = a[0..3] * b[0..3]. */ + muladd_fast(a->d[0], a->d[0]); + extract_fast(l[0]); + muladd2(a->d[0], a->d[1]); + extract(l[1]); + muladd2(a->d[0], a->d[2]); + muladd(a->d[1], a->d[1]); + extract(l[2]); + muladd2(a->d[0], a->d[3]); + muladd2(a->d[1], a->d[2]); + extract(l[3]); + muladd2(a->d[1], a->d[3]); + muladd(a->d[2], a->d[2]); + extract(l[4]); + muladd2(a->d[2], a->d[3]); + extract(l[5]); + muladd_fast(a->d[3], a->d[3]); + extract_fast(l[6]); + VERIFY_CHECK(c1 == 0); + l[7] = c0; +#endif +} + #undef sumadd #undef sumadd_fast #undef muladd #undef muladd_fast +#undef muladd2 #undef extract #undef extract_fast @@ -758,6 +918,12 @@ static int secp256k1_scalar_shr_int(secp256k1_scalar *r, int n) { return ret; } +static void secp256k1_scalar_sqr(secp256k1_scalar *r, const secp256k1_scalar *a) { + uint64_t l[8]; + secp256k1_scalar_sqr_512(l, a); + secp256k1_scalar_reduce_512(r, l); +} + static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r2, const secp256k1_scalar *k) { r1->d[0] = k->d[0]; r1->d[1] = k->d[1]; diff --git a/src/scalar_8x32_impl.h b/src/scalar_8x32_impl.h index acd5ef7d..6086f1ec 100644 --- a/src/scalar_8x32_impl.h +++ b/src/scalar_8x32_impl.h @@ -306,6 +306,28 @@ static int secp256k1_scalar_cond_negate(secp256k1_scalar *r, int flag) { VERIFY_CHECK(c1 >= th); \ } +/** Add 2*a*b to the number defined by (c0,c1,c2). c2 must never overflow. */ +#define muladd2(a,b) { \ + uint32_t tl, th, th2, tl2; \ + { \ + uint64_t t = (uint64_t)a * b; \ + th = t >> 32; /* at most 0xFFFFFFFE */ \ + tl = t; \ + } \ + th2 = th + th; /* at most 0xFFFFFFFE (in case th was 0x7FFFFFFF) */ \ + c2 += (th2 < th); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((th2 >= th) || (c2 != 0)); \ + tl2 = tl + tl; /* at most 0xFFFFFFFE (in case the lowest 63 bits of tl were 0x7FFFFFFF) */ \ + th2 += (tl2 < tl); /* at most 0xFFFFFFFF */ \ + c0 += tl2; /* overflow is handled on the next line */ \ + th2 += (c0 < tl2); /* second overflow is handled on the next line */ \ + c2 += (c0 < tl2) & (th2 == 0); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((c0 >= tl2) || (th2 != 0) || (c2 != 0)); \ + c1 += th2; /* overflow is handled on the next line */ \ + c2 += (c1 < th2); /* never overflows by contract (verified the next line) */ \ + VERIFY_CHECK((c1 >= th2) || (c2 != 0)); \ +} + /** Add a to the number defined by (c0,c1,c2). c2 must never overflow. */ #define sumadd(a) { \ unsigned int over; \ @@ -569,10 +591,71 @@ static void secp256k1_scalar_mul_512(uint32_t *l, const secp256k1_scalar *a, con l[15] = c0; } +static void secp256k1_scalar_sqr_512(uint32_t *l, const secp256k1_scalar *a) { + /* 96 bit accumulator. */ + uint32_t c0 = 0, c1 = 0, c2 = 0; + + /* l[0..15] = a[0..7]^2. */ + muladd_fast(a->d[0], a->d[0]); + extract_fast(l[0]); + muladd2(a->d[0], a->d[1]); + extract(l[1]); + muladd2(a->d[0], a->d[2]); + muladd(a->d[1], a->d[1]); + extract(l[2]); + muladd2(a->d[0], a->d[3]); + muladd2(a->d[1], a->d[2]); + extract(l[3]); + muladd2(a->d[0], a->d[4]); + muladd2(a->d[1], a->d[3]); + muladd(a->d[2], a->d[2]); + extract(l[4]); + muladd2(a->d[0], a->d[5]); + muladd2(a->d[1], a->d[4]); + muladd2(a->d[2], a->d[3]); + extract(l[5]); + muladd2(a->d[0], a->d[6]); + muladd2(a->d[1], a->d[5]); + muladd2(a->d[2], a->d[4]); + muladd(a->d[3], a->d[3]); + extract(l[6]); + muladd2(a->d[0], a->d[7]); + muladd2(a->d[1], a->d[6]); + muladd2(a->d[2], a->d[5]); + muladd2(a->d[3], a->d[4]); + extract(l[7]); + muladd2(a->d[1], a->d[7]); + muladd2(a->d[2], a->d[6]); + muladd2(a->d[3], a->d[5]); + muladd(a->d[4], a->d[4]); + extract(l[8]); + muladd2(a->d[2], a->d[7]); + muladd2(a->d[3], a->d[6]); + muladd2(a->d[4], a->d[5]); + extract(l[9]); + muladd2(a->d[3], a->d[7]); + muladd2(a->d[4], a->d[6]); + muladd(a->d[5], a->d[5]); + extract(l[10]); + muladd2(a->d[4], a->d[7]); + muladd2(a->d[5], a->d[6]); + extract(l[11]); + muladd2(a->d[5], a->d[7]); + muladd(a->d[6], a->d[6]); + extract(l[12]); + muladd2(a->d[6], a->d[7]); + extract(l[13]); + muladd_fast(a->d[7], a->d[7]); + extract_fast(l[14]); + VERIFY_CHECK(c1 == 0); + l[15] = c0; +} + #undef sumadd #undef sumadd_fast #undef muladd #undef muladd_fast +#undef muladd2 #undef extract #undef extract_fast @@ -598,6 +681,12 @@ static int secp256k1_scalar_shr_int(secp256k1_scalar *r, int n) { return ret; } +static void secp256k1_scalar_sqr(secp256k1_scalar *r, const secp256k1_scalar *a) { + uint32_t l[16]; + secp256k1_scalar_sqr_512(l, a); + secp256k1_scalar_reduce_512(r, l); +} + static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r2, const secp256k1_scalar *k) { r1->d[0] = k->d[0]; r1->d[1] = k->d[1]; diff --git a/src/scalar_low_impl.h b/src/scalar_low_impl.h index 47001fcc..aa75f8b0 100644 --- a/src/scalar_low_impl.h +++ b/src/scalar_low_impl.h @@ -105,6 +105,10 @@ static int secp256k1_scalar_shr_int(secp256k1_scalar *r, int n) { return ret; } +static void secp256k1_scalar_sqr(secp256k1_scalar *r, const secp256k1_scalar *a) { + *r = (*a * *a) % EXHAUSTIVE_TEST_ORDER; +} + static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r2, const secp256k1_scalar *a) { *r1 = *a; *r2 = 0; diff --git a/src/tests.c b/src/tests.c index ca1ded47..89f47432 100644 --- a/src/tests.c +++ b/src/tests.c @@ -1898,6 +1898,14 @@ void scalar_test(void) { CHECK(secp256k1_scalar_eq(&r1, &r2)); } + { + /* Test square. */ + secp256k1_scalar r1, r2; + secp256k1_scalar_sqr(&r1, &s1); + secp256k1_scalar_mul(&r2, &s1, &s1); + CHECK(secp256k1_scalar_eq(&r1, &r2)); + } + { /* Test multiplicative identity. */ secp256k1_scalar r1, v1; @@ -2653,6 +2661,12 @@ void run_scalar_tests(void) { CHECK(!secp256k1_scalar_check_overflow(&zz)); CHECK(secp256k1_scalar_eq(&one, &zz)); } + secp256k1_scalar_mul(&z, &x, &x); + CHECK(!secp256k1_scalar_check_overflow(&z)); + secp256k1_scalar_sqr(&zz, &x); + CHECK(!secp256k1_scalar_check_overflow(&zz)); + CHECK(secp256k1_scalar_eq(&zz, &z)); + CHECK(secp256k1_scalar_eq(&r2, &zz)); } } }