git » sdk » commit 00ec9bb

Callback to override TLS checks by the app

author Stephen Paul Weber
2025-07-08 20:13:51 UTC
committer Stephen Paul Weber
2025-07-08 20:27:59 UTC
parent e0c3381340be84396d4fd2bd8bde82415f05a01c

Callback to override TLS checks by the app

snikket/Client.hx +11 -0
snikket/streams/XmppStropheStream.hx +43 -0

diff --git a/snikket/Client.hx b/snikket/Client.hx
index eeead42..241fe20 100644
--- a/snikket/Client.hx
+++ b/snikket/Client.hx
@@ -1081,6 +1081,17 @@ class Client extends EventEmitter {
 		});
 	}
 
+	/**
+		Event fired when TLS checks fail, to give client the chance to override
+
+		@param handler takes two arguments, the PEM of the cert and an array of DNS names, and must return true to accept or false to reject
+	**/
+	public function addTlsCheckListener(handler:(String, Array<String>)->Bool): Void {
+		stream.on("tls/check", (data) -> {
+			return EventValue(handler(data.pem, data.dnsNames));
+		});
+	}
+
 	#if !cpp
 	// TODO: haxe cpp erases enum into int, so using it as a callback arg is hard
 	// could just use int in C bindings, or need to come up with a good strategy
diff --git a/snikket/streams/XmppStropheStream.hx b/snikket/streams/XmppStropheStream.hx
index 4da6348..52ad562 100644
--- a/snikket/streams/XmppStropheStream.hx
+++ b/snikket/streams/XmppStropheStream.hx
@@ -55,8 +55,25 @@ extern class StropheCtx {
 	static function stop(ctx:StropheCtx):Void;
 }
 
+@:native("const xmpp_tlscert_t*")
+@:unreflective
+extern class StropheTlsCert {
+	@:native("xmpp_tlscert_get_conn")
+	static function get_conn(cert: StropheTlsCert):StropheConn;
+
+	@:native("xmpp_tlscert_get_pem")
+	static function get_pem(cert: StropheTlsCert):ConstPointer<Char>;
+
+	@:native("xmpp_tlscert_get_dnsname")
+	static function get_dnsname(cert: StropheTlsCert, n: cpp.SizeT):ConstPointer<Char>;
+
+	@:native("xmpp_tlscert_get_userdata")
+	static function get_userdata(cert: StropheTlsCert):RawPointer<Void>;
+}
+
 @:include("strophe.h")
 @:native("xmpp_conn_t*")
+@:unreflective
 extern class StropheConn {
 	@:native("xmpp_conn_new")
 	static function create(ctx:StropheCtx):StropheConn;
@@ -89,6 +106,13 @@ extern class StropheConn {
 		userdata:RawPointer<Void>
 	):cpp.Int32;
 
+	@:native("xmpp_conn_set_certfail_handler")
+	static function set_certfail_handler(
+		conn:StropheConn,
+		handler:cpp.Callable<StropheTlsCert->RawConstPointer<Char>->Int>
+	):Void;
+
+
 	@:native("xmpp_send")
 	static function send(conn:StropheConn, stanza:StropheStanza):Void;
 
@@ -178,6 +202,7 @@ class XmppStropheStream extends GenericStream {
 			null,
 			untyped __cpp__("(void*)this")
 		);
+		StropheConn.set_certfail_handler(conn, cpp.Callable.fromStaticFunction(strophe_certfail_handler));
 		NativeGc.addFinalizable(this, false);
 	}
 
@@ -185,6 +210,24 @@ class XmppStropheStream extends GenericStream {
 		return ID.long();
 	}
 
+	public static function strophe_certfail_handler(cert:StropheTlsCert, err:RawConstPointer<Char>): Int {
+		final userdata = StropheTlsCert.get_userdata(cert);
+		final stream: XmppStropheStream = untyped __cpp__("static_cast<hx::Object*>(userdata)");
+		final dnsNames: Array<String> = [];
+		var dnsName = null;
+		var dnsNameN = 0;
+		while ((dnsName = StropheTlsCert.get_dnsname(cert, dnsNameN++)) != null) {
+			dnsNames.push(NativeString.fromPointer(dnsName));
+		}
+		final pem = NativeString.fromPointer(StropheTlsCert.get_pem(cert));
+		switch (stream.trigger("tls/check", { pem: pem, dnsNames: dnsNames })) {
+		case EventValue(result):
+			return result ? 1 : 0;
+		default:
+			return 0;
+		}
+	}
+
 	public static function strophe_stanza(conn:StropheConn, sstanza:StropheStanza, userdata:RawPointer<Void>):Int {
 		final stream: XmppStropheStream = untyped __cpp__("static_cast<hx::Object*>(userdata)");
 		stream.stanzaThisPoll = true;