diff --git a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c index b424403..b471380 100644 --- a/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c +++ b/jni/c/src/fr_acinq_secp256k1_Secp256k1CFunctions.c @@ -23,6 +23,18 @@ void JNI_ThrowByName(JNIEnv *penv, const char *name, const char *msg) (*penv)->DeleteLocalRef(penv, cls); } } +/** + * secp256k1 uses callbacks for errors that are either hw pbs or bugs in the calling library, for example + * passing parameters with values that are explicitly defined as illegal in the API, and should never be called for normal operations + * But if they are, default behaviour is to print an error to stderr and abort which is not what we want especially in mobile apps + * => we set up string pointers in every method, and custom callback that will set them to the message passed in by sec256k1's callbacks, which + * we turn into specific Sec256k1 exceptions +*/ +#define SETUP_ERROR_CALLBACKS \ + char *error_callback_message = NULL; \ + char *illegal_callback_message = NULL; \ + secp256k1_context_set_error_callback(ctx, my_error_callback_fn, &error_callback_message); \ + secp256k1_context_set_illegal_callback(ctx, my_illegal_callback_fn, &illegal_callback_message); #define CHECKRESULT(errorcheck, message) \ { \ @@ -33,16 +45,62 @@ void JNI_ThrowByName(JNIEnv *penv, const char *name, const char *msg) } \ } -#define CHECKRESULT1(errorcheck, message, dosomething) \ - { \ - if (errorcheck) \ - { \ - dosomething; \ - JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \ - return 0; \ - } \ +#define CHECKRESULT(errorcheck, message) \ + { \ + if (error_callback_message) \ + { \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1ErrorCallbackException", error_callback_message); \ + return 0; \ + } \ + if (illegal_callback_message) \ + { \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1IllegalCallbackException", illegal_callback_message); \ + return 0; \ + } \ + if (errorcheck) \ + { \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \ + return 0; \ + } \ } +#define CHECKRESULT1(errorcheck, message, dosomething) \ + { \ + if (error_callback_message) \ + { \ + dosomething; \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1ErrorCallbackException", error_callback_message); \ + return 0; \ + } \ + if (illegal_callback_message) \ + { \ + dosomething; \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1IllegalCallbackException", illegal_callback_message); \ + return 0; \ + } \ + if (errorcheck) \ + { \ + JNI_ThrowByName(penv, "fr/acinq/secp256k1/Secp256k1Exception", message); \ + return 0; \ + } \ + } + +void my_illegal_callback_fn(const char *str, void *data) +{ + if (data != NULL) + { + *(char **)data = str; + } +} + +void my_error_callback_fn(const char *str, void *data) +{ + if (data != NULL) + { + *(char **)data = str; + } +} + /* * Class: fr_acinq_bitcoin_Secp256k1Bindings * Method: secp256k1_context_create @@ -84,6 +142,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec if ((*penv)->GetArrayLength(penv, jseckey) != 32) return 0; + SETUP_ERROR_CALLBACKS + seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); result = secp256k1_ec_seckey_verify(ctx, (unsigned char *)seckey); (*penv)->ReleaseByteArrayElements(penv, jseckey, seckey, 0); @@ -108,6 +168,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jpubkey == NULL) return 0; + SETUP_ERROR_CALLBACKS + size = (*penv)->GetArrayLength(penv, jpubkey); CHECKRESULT((size != 33) && (size != 65), "invalid public key size"); @@ -144,6 +206,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jctx == 0) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); result = secp256k1_ec_pubkey_create(ctx, &pub, (unsigned char *)seckey); @@ -178,6 +242,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jseckey == NULL) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message key must be 32 bytes"); seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); @@ -228,6 +294,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec if (jpubkey == NULL) return 0; + SETUP_ERROR_CALLBACKS + sigSize = (*penv)->GetArrayLength(penv, jsig); int sigFormat = GetSignatureFormat(sigSize); CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size"); @@ -285,6 +353,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1ec if (jsigout == NULL) return 0; + SETUP_ERROR_CALLBACKS + size = (*penv)->GetArrayLength(penv, jsigin); sigFormat = GetSignatureFormat(size); CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size"); @@ -328,6 +398,9 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 return 0; if (jseckey == NULL) return 0; + + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); result = secp256k1_ec_seckey_negate(ctx, (unsigned char *)seckey); @@ -354,6 +427,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jpubkey == NULL) return 0; + SETUP_ERROR_CALLBACKS + size = (*penv)->GetArrayLength(penv, jpubkey); CHECKRESULT((size != 33) && (size != 65), "invalid public key size"); pub = (*penv)->GetByteArrayElements(penv, jpubkey, 0); @@ -391,6 +466,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jtweak == NULL) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes"); seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); @@ -422,6 +499,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jtweak == NULL) return NULL; + SETUP_ERROR_CALLBACKS + size = (*penv)->GetArrayLength(penv, jpubkey); CHECKRESULT((size != 33) && (size != 65), "invalid public key size"); CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes"); @@ -463,6 +542,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jtweak == NULL) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes"); seckey = (*penv)->GetByteArrayElements(penv, jseckey, 0); @@ -494,6 +575,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jtweak == NULL) return NULL; + SETUP_ERROR_CALLBACKS + size = (*penv)->GetArrayLength(penv, jpubkey); CHECKRESULT((size != 33) && (size != 65), "invalid public key size"); CHECKRESULT((*penv)->GetArrayLength(penv, jtweak) != 32, "tweak must be 32 bytes"); @@ -548,6 +631,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jpubkeys == NULL) return NULL; + SETUP_ERROR_CALLBACKS + count = (*penv)->GetArrayLength(penv, jpubkeys); pubkeys = calloc(count, sizeof(secp256k1_pubkey *)); @@ -596,6 +681,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jpubkey == NULL) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "invalid private key size"); size = (*penv)->GetArrayLength(penv, jpubkey); @@ -637,7 +724,10 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 return NULL; if (jmsg == NULL) return NULL; - CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3") + + SETUP_ERROR_CALLBACKS + + // CHECKRESULT(recid < 0 || recid > 3, "recid must be 0, 1, 2 or 3") sigSize = (*penv)->GetArrayLength(penv, jsig); int sigFormat = GetSignatureFormat(sigSize); CHECKRESULT(sigFormat == SIG_FORMAT_UNKNOWN, "invalid signature size"); @@ -693,6 +783,9 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 return 0; if (jsig == NULL) return 0; + + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jsig) != 64, "invalid signature size"); size = (*penv)->GetArrayLength(penv, jsig); @@ -732,6 +825,8 @@ JNIEXPORT jbyteArray JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256 if (jseckey == NULL) return NULL; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jseckey) != 32, "secret key must be 32 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message must be 32 bytes"); if (jauxrand32 != 0) @@ -785,6 +880,8 @@ JNIEXPORT jint JNICALL Java_fr_acinq_secp256k1_Secp256k1CFunctions_secp256k1_1sc if (jpubkey == NULL) return 0; + SETUP_ERROR_CALLBACKS + CHECKRESULT((*penv)->GetArrayLength(penv, jsig) != 64, "signature must be 64 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jpubkey) != 32, "public key must be 32 bytes"); CHECKRESULT((*penv)->GetArrayLength(penv, jmsg) != 32, "message must be 32 bytes"); diff --git a/src/commonMain/kotlin/fr/acinq/secp256k1/Secp256k1.kt b/src/commonMain/kotlin/fr/acinq/secp256k1/Secp256k1.kt index 8911051..4aef8f4 100644 --- a/src/commonMain/kotlin/fr/acinq/secp256k1/Secp256k1.kt +++ b/src/commonMain/kotlin/fr/acinq/secp256k1/Secp256k1.kt @@ -166,7 +166,17 @@ public interface Secp256k1 { internal expect fun getSecpk256k1(): Secp256k1 -public class Secp256k1Exception : RuntimeException { +public open class Secp256k1Exception : RuntimeException { + public constructor() : super() + public constructor(message: String?) : super(message) +} + +public class Secp256k1ErrorCallbackException : Secp256k1Exception { + public constructor() : super() + public constructor(message: String?) : super(message) +} + +public class Secp256k1IllegalCallbackException : Secp256k1Exception { public constructor() : super() public constructor(message: String?) : super(message) } \ No newline at end of file diff --git a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt index f77b6e6..8fa6460 100644 --- a/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt +++ b/src/nativeMain/kotlin/fr/acinq/secp256k1/Secp256k1Native.kt @@ -4,15 +4,64 @@ import kotlinx.cinterop.* import platform.posix.size_tVar import secp256k1.* +private typealias MyHandler = (String) -> Unit + +private object CallbackHandler { + var illegalCallBackMessage: String? = null + val illegalHandler: MyHandler = { x: String -> illegalCallBackMessage = x } + val illegalCallbackRef = StableRef.create(illegalHandler) + var errorCallBackMessage: String? = null + val errorHandler: MyHandler = { x: String -> errorCallBackMessage = x } + val errorCallbackRef = StableRef.create(errorHandler) + + fun checkForErrors() { + if (errorCallBackMessage != null) { + val message = errorCallBackMessage + errorCallBackMessage = null + throw Secp256k1ErrorCallbackException(message) + } + if (illegalCallBackMessage != null) { + val message = illegalCallBackMessage + illegalCallBackMessage = null + throw Secp256k1IllegalCallbackException(message) + } + } +} + @OptIn(ExperimentalUnsignedTypes::class) public object Secp256k1Native : Secp256k1 { private val ctx: CPointer by lazy { - secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt()) + + val ctx = secp256k1_context_create((SECP256K1_FLAGS_TYPE_CONTEXT or SECP256K1_FLAGS_BIT_CONTEXT_SIGN or SECP256K1_FLAGS_BIT_CONTEXT_VERIFY).toUInt()) ?: error("Could not create secp256k1 context") + + secp256k1_context_set_error_callback( + ctx, staticCFunction { buffer: CPointer?, data: COpaquePointer? -> + if (data != null) { + val callback = data.asStableRef().get() + callback(buffer?.toKString() ?: "error callback triggered") + } + }, + CallbackHandler.errorCallbackRef.asCPointer() + ) + secp256k1_context_set_illegal_callback( + ctx, staticCFunction { buffer: CPointer?, data: COpaquePointer? -> + if (data != null) { + val callback = data.asStableRef().get() + callback(buffer?.toKString() ?: "illegal callback triggered") + } + }, + CallbackHandler.illegalCallbackRef.asCPointer() + ) + + ctx } - private fun Int.requireSuccess(message: String): Int = if (this != 1) throw Secp256k1Exception(message) else this + private fun Int.requireSuccess(message: String): Int { + CallbackHandler.checkForErrors() + return if (this != 1) throw Secp256k1Exception(message) else this + } private fun MemScope.allocSignature(input: ByteArray): secp256k1_ecdsa_signature { val sig = alloc()