diff --git a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c index 5187159..b424403 100644 --- a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c +++ b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c @@ -637,7 +637,7 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 return NULL; if (jmsg == NULL) return NULL; - + CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3") sigSize = (*penv)->GetArrayLength(penv, jsig); int sigFormat = GetSignatureFormat(sigSize); CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size"); diff --git a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt index 7ffe516..f77b6e6 100644 --- a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt +++ b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt @@ -199,6 +199,7 @@ public object Secp256k1Native : Secp256k1 { public override fun ecdsaRecover(sig: ByteArray, message: ByteArray, recid: Int): ByteArray { require(sig.size == 64) require(message.size == 32) + require(recid in 0..3) memScoped { val nSig = toNat(sig) val rSig = alloc() diff --git a/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt b/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt index c36ef50..a54cba8 100644 --- a/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt +++ b/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt @@ -275,6 +275,10 @@ class Secp256k1Test { val pub0 = Secp256k1.ecdsaRecover(sig, message, 0) val pub1 = Secp256k1.ecdsaRecover(sig, message, 1) assertTrue(pub.contentEquals(pub0) || pub.contentEquals(pub1)) + + assertFails { + Secp256k1.ecdsaRecover(sig, message, 4) + } } @Test