diff --git a/src/modules/surjection/main_impl.h b/src/modules/surjection/main_impl.h index dcd4d6e0..a214f90b 100644 --- a/src/modules/surjection/main_impl.h +++ b/src/modules/surjection/main_impl.h @@ -298,6 +298,10 @@ int secp256k1_surjectionproof_generate(const secp256k1_context* ctx, secp256k1_s CHECK(proof->initialized == 1); #endif + n_used_pubkeys = secp256k1_surjectionproof_n_used_inputs(ctx, proof); + /* This must be true if the proof was created with surjectionproof_initialize */ + ARG_CHECK(n_used_pubkeys > 0); + /* Compute secret key */ secp256k1_scalar_set_b32(&tmps, input_blinding_key, &overflow); if (overflow) { @@ -321,7 +325,7 @@ int secp256k1_surjectionproof_generate(const secp256k1_context* ctx, secp256k1_s /* Compute public keys */ n_total_pubkeys = secp256k1_surjectionproof_n_total_inputs(ctx, proof); - n_used_pubkeys = secp256k1_surjectionproof_n_used_inputs(ctx, proof); + if (n_used_pubkeys > n_total_pubkeys || n_total_pubkeys != n_ephemeral_input_tags) { return 0; } diff --git a/src/modules/surjection/tests_impl.h b/src/modules/surjection/tests_impl.h index a00f6ad2..a792bb5f 100644 --- a/src/modules/surjection/tests_impl.h +++ b/src/modules/surjection/tests_impl.h @@ -173,31 +173,45 @@ static void test_surjectionproof_api(void) { CHECK(secp256k1_surjectionproof_verify(vrfy, &proof, ephemeral_input_tags, n_inputs, NULL) == 0); CHECK(ecount == 16); + /* Test how surjectionproof_generate fails when the proof was not created + * with surjectionproof_initialize */ + ecount = 0; + CHECK(secp256k1_surjectionproof_generate(sign, &proof, ephemeral_input_tags, n_inputs, &ephemeral_output_tag, 0, input_blinding_key[0], output_blinding_key) == 1); + { + secp256k1_surjectionproof tmp_proof = proof; + tmp_proof.n_inputs = 0; + CHECK(secp256k1_surjectionproof_generate(sign, &tmp_proof, ephemeral_input_tags, n_inputs, &ephemeral_output_tag, 0, input_blinding_key[0], output_blinding_key) == 0); + } + CHECK(ecount == 1); + + CHECK(secp256k1_surjectionproof_generate(sign, &proof, ephemeral_input_tags, n_inputs, &ephemeral_output_tag, 0, input_blinding_key[0], output_blinding_key) == 1); + /* Check serialize */ + ecount = 0; serialized_len = sizeof(serialized_proof); CHECK(secp256k1_surjectionproof_serialize(none, serialized_proof, &serialized_len, &proof) != 0); - CHECK(ecount == 16); + CHECK(ecount == 0); serialized_len = sizeof(serialized_proof); CHECK(secp256k1_surjectionproof_serialize(none, NULL, &serialized_len, &proof) == 0); - CHECK(ecount == 17); + CHECK(ecount == 1); serialized_len = sizeof(serialized_proof); CHECK(secp256k1_surjectionproof_serialize(none, serialized_proof, NULL, &proof) == 0); - CHECK(ecount == 18); + CHECK(ecount == 2); serialized_len = sizeof(serialized_proof); CHECK(secp256k1_surjectionproof_serialize(none, serialized_proof, &serialized_len, NULL) == 0); - CHECK(ecount == 19); + CHECK(ecount == 3); serialized_len = sizeof(serialized_proof); CHECK(secp256k1_surjectionproof_serialize(none, serialized_proof, &serialized_len, &proof) != 0); /* Check parse */ CHECK(secp256k1_surjectionproof_parse(none, &proof, serialized_proof, serialized_len) != 0); - CHECK(ecount == 19); + CHECK(ecount == 3); CHECK(secp256k1_surjectionproof_parse(none, NULL, serialized_proof, serialized_len) == 0); - CHECK(ecount == 20); + CHECK(ecount == 4); CHECK(secp256k1_surjectionproof_parse(none, &proof, NULL, serialized_len) == 0); - CHECK(ecount == 21); + CHECK(ecount == 5); CHECK(secp256k1_surjectionproof_parse(none, &proof, serialized_proof, 0) == 0); - CHECK(ecount == 21); + CHECK(ecount == 5); secp256k1_context_destroy(none); secp256k1_context_destroy(sign);