diff --git a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c index b471380..27d68f5 100644 --- a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c +++ b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c @@ -36,15 +36,6 @@ void JNI_ThrowByName(JNIEnv *penv, const char *name, const char *msg) secp256k1_context_set_error_callback(ctx, my_error_callback_fn, &error_callback_message); \ secp256k1_context_set_illegal_callback(ctx, my_illegal_callback_fn, &illegal_callback_message); -#define CHECKRESULT(errorcheck, message) \ - { \ - if (errorcheck) \ - { \ - JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \ - return 0; \ - } \ - } - #define CHECKRESULT(errorcheck, message) \ { \ if (error_callback_message) \ @@ -727,7 +718,9 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 SETUP_ERROR_CALLBACKS + // we do not check that recid is valid, which should trigger our illegal callback handler to throw a Secp256k1IllegalCallbackException // CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3") + sigSize = (*penv)->GetArrayLength(penv, jsig); int sigFormat = GetSignatureFormat(sigSize); CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size"); @@ -774,7 +767,6 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 secp256k1_context *ctx = (secp256k1_context *)jctx; jbyte *sig; secp256k1_ecdsa_signature signature; - ; unsigned char der[73]; size_t size; int result = 0; diff --git a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt index 8fa6460..ebc34eb 100644 --- a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt +++ b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt @@ -4,62 +4,65 @@ import kotlinx.cinterop.* import platform.posix.size_tVar import secp256k1.* -private typealias MyHandler = (String) -> Unit +private typealias Secp256k1CallbackHandler = (String) -> Unit -private object CallbackHandler { +@OptIn(ExperimentalStdlibApi::class) +private class CallbackHandler(ctx: CPointer) : AutoCloseable { var illegalCallBackMessage: String? = null - val illegalHandler: MyHandler = { x: String -> illegalCallBackMessage = x } + val illegalHandler: Secp256k1CallbackHandler = { x: String -> illegalCallBackMessage = x } val illegalCallbackRef = StableRef.create(illegalHandler) var errorCallBackMessage: String? = null - val errorHandler: MyHandler = { x: String -> errorCallBackMessage = x } + val errorHandler: Secp256k1CallbackHandler = { x: String -> errorCallBackMessage = x } val errorCallbackRef = StableRef.create(errorHandler) - fun checkForErrors() { - if (errorCallBackMessage != null) { - val message = errorCallBackMessage - errorCallBackMessage = null - throw Secp256k1ErrorCallbackException(message) - } - if (illegalCallBackMessage != null) { - val message = illegalCallBackMessage - illegalCallBackMessage = null - throw Secp256k1IllegalCallbackException(message) - } - } -} - -@OptIn(ExperimentalUnsignedTypes::class) -public object Secp256k1Native : Secp256k1 { - - private val ctx: CPointer by lazy { - - val ctx = secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt()) - ?: error("Could not create secp256k1 context") - + init { secp256k1_context_set_error_callback( ctx, staticCFunction { buffer: CPointer?, data: COpaquePointer? -> if (data != null) { - val callback = data.asStableRef().get() + val callback = data.asStableRef().get() callback(buffer?.toKString() ?: "error callback triggered") } }, - CallbackHandler.errorCallbackRef.asCPointer() + errorCallbackRef.asCPointer() ) secp256k1_context_set_illegal_callback( ctx, staticCFunction { buffer: CPointer?, data: COpaquePointer? -> if (data != null) { - val callback = data.asStableRef().get() + val callback = data.asStableRef().get() callback(buffer?.toKString() ?: "illegal callback triggered") } }, - CallbackHandler.illegalCallbackRef.asCPointer() + illegalCallbackRef.asCPointer() ) + } - ctx + fun checkForErrors() { + errorCallBackMessage?.let { throw Secp256k1ErrorCallbackException(it) } + illegalCallBackMessage?.let { throw Secp256k1IllegalCallbackException(it) } + } + + override fun close() { + // StableRef instances have to be disposed of manually + illegalCallbackRef.dispose() + errorCallbackRef.dispose() + } +} + +@OptIn(ExperimentalUnsignedTypes::class, ExperimentalStdlibApi::class) +public object Secp256k1Native : Secp256k1 { + + private val ctx: CPointer by lazy { + + secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt()) + ?: error("Could not create secp256k1 context") } private fun Int.requireSuccess(message: String): Int { - CallbackHandler.checkForErrors() + return if (this != 1) throw Secp256k1Exception(message) else this + } + + private fun Int.requireSuccess(callbackHandler: CallbackHandler, message: String): Int { + callbackHandler.checkForErrors() return if (this != 1) throw Secp256k1Exception(message) else this } @@ -107,168 +110,209 @@ public object Secp256k1Native : Secp256k1 { public override fun verify(signature: ByteArray, message: ByteArray, pubkey: ByteArray): Boolean { require(message.size == 32) require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - val nMessage = toNat(message) - val nSig = allocSignature(signature) - return secp256k1_ecdsa_verify(ctx, nSig.ptr, nMessage, nPubkey.ptr) == 1 + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + val nMessage = toNat(message) + val nSig = allocSignature(signature) + val verify = secp256k1_ecdsa_verify(ctx, nSig.ptr, nMessage, nPubkey.ptr) + callbackHandler.checkForErrors() + return verify == 1 + } } } public override fun sign(message: ByteArray, privkey: ByteArray): ByteArray { require(privkey.size == 32) require(message.size == 32) - memScoped { - val nPrivkey = toNat(privkey) - val nMessage = toNat(message) - val nSig = alloc() - secp256k1_ecdsa_sign(ctx, nSig.ptr, nMessage, nPrivkey, null, null).requireSuccess("secp256k1_ecdsa_sign() failed") - return serializeSignature(nSig) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPrivkey = toNat(privkey) + val nMessage = toNat(message) + val nSig = alloc() + secp256k1_ecdsa_sign(ctx, nSig.ptr, nMessage, nPrivkey, null, null).requireSuccess(callbackHandler, "secp256k1_ecdsa_sign() failed") + return serializeSignature(nSig) + } } } public override fun signatureNormalize(sig: ByteArray): Pair { - require(sig.size >= 64){ "invalid signature ${Hex.encode(sig)}" } - memScoped { - val nSig = allocSignature(sig) - val isHighS = secp256k1_ecdsa_signature_normalize(ctx, nSig.ptr, nSig.ptr) - return Pair(serializeSignature(nSig), isHighS == 1) + require(sig.size >= 64) { "invalid signature ${Hex.encode(sig)}" } + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nSig = allocSignature(sig) + val isHighS = secp256k1_ecdsa_signature_normalize(ctx, nSig.ptr, nSig.ptr) + callbackHandler.checkForErrors() + return Pair(serializeSignature(nSig), isHighS == 1) + } } } public override fun secKeyVerify(privkey: ByteArray): Boolean { if (privkey.size != 32) return false - memScoped { - val nPrivkey = toNat(privkey) - return secp256k1_ec_seckey_verify(ctx, nPrivkey) == 1 + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPrivkey = toNat(privkey) + val result = secp256k1_ec_seckey_verify(ctx, nPrivkey) == 1 + callbackHandler.checkForErrors() + return result + } } } public override fun pubkeyCreate(privkey: ByteArray): ByteArray { require(privkey.size == 32) - memScoped { - val nPrivkey = toNat(privkey) - val nPubkey = alloc() - secp256k1_ec_pubkey_create(ctx, nPubkey.ptr, nPrivkey).requireSuccess("secp256k1_ec_pubkey_create() failed") - return serializePubkey(nPubkey) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPrivkey = toNat(privkey) + val nPubkey = alloc() + secp256k1_ec_pubkey_create(ctx, nPubkey.ptr, nPrivkey).requireSuccess(callbackHandler, "secp256k1_ec_pubkey_create() failed") + return serializePubkey(nPubkey) + } } } public override fun pubkeyParse(pubkey: ByteArray): ByteArray { require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - return serializePubkey(nPubkey) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + val result = serializePubkey(nPubkey) + callbackHandler.checkForErrors() + return result + } } } public override fun privKeyNegate(privkey: ByteArray): ByteArray { require(privkey.size == 32) - memScoped { - val negated = privkey.copyOf() - val negPriv = toNat(negated) - secp256k1_ec_seckey_negate(ctx, negPriv).requireSuccess("secp256k1_ec_seckey_negate() failed") - return negated + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val negated = privkey.copyOf() + val negPriv = toNat(negated) + secp256k1_ec_seckey_negate(ctx, negPriv).requireSuccess(callbackHandler, "secp256k1_ec_seckey_negate() failed") + return negated + } } } public override fun privKeyTweakAdd(privkey: ByteArray, tweak: ByteArray): ByteArray { require(privkey.size == 32) - memScoped { - val added = privkey.copyOf() - val natAdd = toNat(added) - val natTweak = toNat(tweak) - secp256k1_ec_seckey_tweak_add(ctx, natAdd, natTweak).requireSuccess("secp256k1_ec_seckey_tweak_add() failed") - return added + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val added = privkey.copyOf() + val natAdd = toNat(added) + val natTweak = toNat(tweak) + secp256k1_ec_seckey_tweak_add(ctx, natAdd, natTweak).requireSuccess(callbackHandler, "secp256k1_ec_seckey_tweak_add() failed") + return added + } } } public override fun privKeyTweakMul(privkey: ByteArray, tweak: ByteArray): ByteArray { require(privkey.size == 32) - memScoped { - val multiplied = privkey.copyOf() - val natMul = toNat(multiplied) - val natTweak = toNat(tweak) - secp256k1_ec_privkey_tweak_mul(ctx, natMul, natTweak).requireSuccess("secp256k1_ec_privkey_tweak_mul() failed") - return multiplied + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val multiplied = privkey.copyOf() + val natMul = toNat(multiplied) + val natTweak = toNat(tweak) + secp256k1_ec_privkey_tweak_mul(ctx, natMul, natTweak).requireSuccess(callbackHandler, "secp256k1_ec_privkey_tweak_mul() failed") + return multiplied + } } } public override fun pubKeyNegate(pubkey: ByteArray): ByteArray { require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - secp256k1_ec_pubkey_negate(ctx, nPubkey.ptr).requireSuccess("secp256k1_ec_pubkey_negate() failed") - return serializePubkey(nPubkey) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + secp256k1_ec_pubkey_negate(ctx, nPubkey.ptr).requireSuccess(callbackHandler, "secp256k1_ec_pubkey_negate() failed") + return serializePubkey(nPubkey) + } } } public override fun pubKeyTweakAdd(pubkey: ByteArray, tweak: ByteArray): ByteArray { require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - val nTweak = toNat(tweak) - secp256k1_ec_pubkey_tweak_add(ctx, nPubkey.ptr, nTweak).requireSuccess("secp256k1_ec_pubkey_tweak_add() failed") - return serializePubkey(nPubkey) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + val nTweak = toNat(tweak) + secp256k1_ec_pubkey_tweak_add(ctx, nPubkey.ptr, nTweak).requireSuccess(callbackHandler, "secp256k1_ec_pubkey_tweak_add() failed") + return serializePubkey(nPubkey) + } } } public override fun pubKeyTweakMul(pubkey: ByteArray, tweak: ByteArray): ByteArray { require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - val nTweak = toNat(tweak) - secp256k1_ec_pubkey_tweak_mul(ctx, nPubkey.ptr, nTweak).requireSuccess("secp256k1_ec_pubkey_tweak_mul() failed") - return serializePubkey(nPubkey) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + val nTweak = toNat(tweak) + secp256k1_ec_pubkey_tweak_mul(ctx, nPubkey.ptr, nTweak).requireSuccess(callbackHandler, "secp256k1_ec_pubkey_tweak_mul() failed") + return serializePubkey(nPubkey) + } } } public override fun pubKeyCombine(pubkeys: Array): ByteArray { pubkeys.forEach { require(it.size == 33 || it.size == 65) } - memScoped { - val nPubkeys = pubkeys.map { allocPublicKey(it).ptr } - val combined = alloc() - secp256k1_ec_pubkey_combine(ctx, combined.ptr, nPubkeys.toCValues(), pubkeys.size.convert()).requireSuccess("secp256k1_ec_pubkey_combine() failed") - return serializePubkey(combined) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkeys = pubkeys.map { allocPublicKey(it).ptr } + val combined = alloc() + secp256k1_ec_pubkey_combine(ctx, combined.ptr, nPubkeys.toCValues(), pubkeys.size.convert()).requireSuccess(callbackHandler, "secp256k1_ec_pubkey_combine() failed") + return serializePubkey(combined) + } } } public override fun ecdh(privkey: ByteArray, pubkey: ByteArray): ByteArray { require(privkey.size == 32) require(pubkey.size == 33 || pubkey.size == 65) - memScoped { - val nPubkey = allocPublicKey(pubkey) - val nPrivkey = toNat(privkey) - val output = allocArray(32) - secp256k1_ecdh(ctx, output, nPubkey.ptr, nPrivkey, null, null).requireSuccess("secp256k1_ecdh() failed") - return output.readBytes(32) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPubkey = allocPublicKey(pubkey) + val nPrivkey = toNat(privkey) + val output = allocArray(32) + secp256k1_ecdh(ctx, output, nPubkey.ptr, nPrivkey, null, null).requireSuccess(callbackHandler, "secp256k1_ecdh() failed") + return output.readBytes(32) + } } } public override fun ecdsaRecover(sig: ByteArray, message: ByteArray, recid: Int): ByteArray { require(sig.size == 64) require(message.size == 32) - require(recid in 0..3) - memScoped { - val nSig = toNat(sig) - val rSig = alloc() - secp256k1_ecdsa_recoverable_signature_parse_compact(ctx, rSig.ptr, nSig, recid).requireSuccess("secp256k1_ecdsa_recoverable_signature_parse_compact() failed") - val nMessage = toNat(message) - val pubkey = alloc() - secp256k1_ecdsa_recover(ctx, pubkey.ptr, rSig.ptr, nMessage).requireSuccess("secp256k1_ecdsa_recover() failed") - return serializePubkey(pubkey) + // we do not check that recid is valid, which should trigger our illegal callback handler to throw a Secp256k1IllegalCallbackException + // require(recid in 0..3) + + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nSig = toNat(sig) + val rSig = alloc() + secp256k1_ecdsa_recoverable_signature_parse_compact(ctx, rSig.ptr, nSig, recid).requireSuccess(callbackHandler, "secp256k1_ecdsa_recoverable_signature_parse_compact() failed") + val nMessage = toNat(message) + val pubkey = alloc() + secp256k1_ecdsa_recover(ctx, pubkey.ptr, rSig.ptr, nMessage).requireSuccess(callbackHandler, "secp256k1_ecdsa_recover() failed") + return serializePubkey(pubkey) + } } } public override fun compact2der(sig: ByteArray): ByteArray { require(sig.size == 64) - memScoped { - val nSig = allocSignature(sig) - val natOutput = allocArray(73) - val len = alloc() - len.value = 73.convert() - secp256k1_ecdsa_signature_serialize_der(ctx, natOutput, len.ptr, nSig.ptr).requireSuccess("secp256k1_ecdsa_signature_serialize_der() failed") - return natOutput.readBytes(len.value.toInt()) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nSig = allocSignature(sig) + val natOutput = allocArray(73) + val len = alloc() + len.value = 73.convert() + secp256k1_ecdsa_signature_serialize_der(ctx, natOutput, len.ptr, nSig.ptr).requireSuccess(callbackHandler, "secp256k1_ecdsa_signature_serialize_der() failed") + return natOutput.readBytes(len.value.toInt()) + } } } @@ -276,13 +320,15 @@ public object Secp256k1Native : Secp256k1 { require(signature.size == 64) require(data.size == 32) require(pub.size == 32) - memScoped { - val nPub = toNat(pub) - val pubkey = alloc() - secp256k1_xonly_pubkey_parse(ctx, pubkey.ptr, nPub).requireSuccess("secp256k1_xonly_pubkey_parse() failed") - val nData = toNat(data) - val nSig = toNat(signature) - return secp256k1_schnorrsig_verify(ctx, nSig, nData, 32, pubkey.ptr) == 1 + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nPub = toNat(pub) + val pubkey = alloc() + secp256k1_xonly_pubkey_parse(ctx, pubkey.ptr, nPub).requireSuccess(callbackHandler, "secp256k1_xonly_pubkey_parse() failed") + val nData = toNat(data) + val nSig = toNat(signature) + return secp256k1_schnorrsig_verify(ctx, nSig, nData, 32u, pubkey.ptr) == 1 + } } } @@ -290,15 +336,17 @@ public object Secp256k1Native : Secp256k1 { require(sec.size == 32) require(data.size == 32) auxrand32?.let { require(it.size == 32) } - memScoped { - val nSec = toNat(sec) - val nData = toNat(data) - val nAuxrand32 = auxrand32?.let { toNat(it) } - val nSig = allocArray(64) - val keypair = alloc() - secp256k1_keypair_create(ctx, keypair.ptr, nSec).requireSuccess("secp256k1_keypair_create() failed") - secp256k1_schnorrsig_sign32(ctx, nSig, nData, keypair.ptr, nAuxrand32).requireSuccess("secp256k1_ecdsa_sign() failed") - return nSig.readBytes(64) + CallbackHandler(ctx).use { callbackHandler -> + memScoped { + val nSec = toNat(sec) + val nData = toNat(data) + val nAuxrand32 = auxrand32?.let { toNat(it) } + val nSig = allocArray(64) + val keypair = alloc() + secp256k1_keypair_create(ctx, keypair.ptr, nSec).requireSuccess(callbackHandler, "secp256k1_keypair_create() failed") + secp256k1_schnorrsig_sign32(ctx, nSig, nData, keypair.ptr, nAuxrand32).requireSuccess(callbackHandler, "secp256k1_ecdsa_sign() failed") + return nSig.readBytes(64) + } } } diff --git a/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt b/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt index a54cba8..66a5d00 100644 --- a/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt +++ b/tests/src/commonTest/kotlin/fr/acinq/secp256k1/Secp256k1Test.kt @@ -276,7 +276,8 @@ class Secp256k1Test { val pub1 = Secp256k1.ecdsaRecover(sig, message, 1) assertTrue(pub.contentEquals(pub0) || pub.contentEquals(pub1)) - assertFails { + // this is a special case, ecdsaRecover explicitly does not check that recid is valid, which triggers our illegal callback handler + assertFailsWith(Secp256k1IllegalCallbackException::class) { Secp256k1.ecdsaRecover(sig, message, 4) } }