diff --git a/src/modules/surjection/main_impl.h b/src/modules/surjection/main_impl.h index e76ebbf9..9614e5f7 100644 --- a/src/modules/surjection/main_impl.h +++ b/src/modules/surjection/main_impl.h @@ -373,6 +373,11 @@ int secp256k1_surjectionproof_verify(const secp256k1_context* ctx, const secp256 return 0; } + /* Reject proofs with too many used inputs in USE_REDUCED_SURJECTION_PROOF_SIZE mode */ + if (n_used_pubkeys > SECP256K1_SURJECTIONPROOF_MAX_USED_INPUTS) { + return 0; + } + if (secp256k1_surjection_compute_public_keys(ring_pubkeys, n_used_pubkeys, ephemeral_input_tags, n_total_pubkeys, proof->used_inputs, ephemeral_output_tag, 0, NULL) == 0) { return 0; } diff --git a/src/modules/surjection/surjection_impl.h b/src/modules/surjection/surjection_impl.h index 1a839ff8..f3652567 100644 --- a/src/modules/surjection/surjection_impl.h +++ b/src/modules/surjection/surjection_impl.h @@ -69,6 +69,9 @@ SECP256K1_INLINE static int secp256k1_surjection_compute_public_keys(secp256k1_g secp256k1_ge tmpge; secp256k1_generator_load(&tmpge, &input_tags[i]); secp256k1_ge_neg(&tmpge, &tmpge); + + VERIFY_CHECK(j < SECP256K1_SURJECTIONPROOF_MAX_USED_INPUTS); + VERIFY_CHECK(j < n_pubkeys); secp256k1_gej_set_ge(&pubkeys[j], &tmpge); secp256k1_generator_load(&tmpge, output_tag); @@ -77,11 +80,10 @@ SECP256K1_INLINE static int secp256k1_surjection_compute_public_keys(secp256k1_g *ring_input_index = j; } j++; - if (j > n_pubkeys || j > SECP256K1_SURJECTIONPROOF_MAX_USED_INPUTS) { - return 0; - } } } + /* Caller needs to ensure that the number of set bits in used_tags (which we counted in j) equals n_pubkeys. */ + VERIFY_CHECK(j == n_pubkeys); return 1; } diff --git a/src/modules/surjection/tests_impl.h b/src/modules/surjection/tests_impl.h index 4885a8e8..ca0b09a0 100644 --- a/src/modules/surjection/tests_impl.h +++ b/src/modules/surjection/tests_impl.h @@ -427,6 +427,7 @@ static void test_gen_verify(size_t n_inputs, size_t n_used) { CHECK(secp256k1_surjectionproof_parse(ctx, &proof, serialized_proof, serialized_len)); result = secp256k1_surjectionproof_verify(ctx, &proof, ephemeral_input_tags, n_inputs, &ephemeral_input_tags[n_inputs]); CHECK(result == 1); + /* various fail cases */ if (n_inputs > 1) { result = secp256k1_surjectionproof_verify(ctx, &proof, ephemeral_input_tags, n_inputs, &ephemeral_input_tags[n_inputs - 1]); @@ -441,6 +442,15 @@ static void test_gen_verify(size_t n_inputs, size_t n_used) { n_inputs += 1; } + for (i = 0; i < n_inputs; i++) { + /* flip bit */ + proof.used_inputs[i / 8] ^= (1 << (i % 8)); + result = secp256k1_surjectionproof_verify(ctx, &proof, ephemeral_input_tags, n_inputs, &ephemeral_input_tags[n_inputs]); + CHECK(result == 0); + /* reset the bit */ + proof.used_inputs[i / 8] ^= (1 << (i % 8)); + } + /* cleanup */ for (i = 0; i < n_inputs + 1; i++) { free(input_blinding_key[i]);