git » sdk » commit c88a354

Support some generics in swift

author Stephen Paul Weber
2026-04-29 03:01:55 UTC
committer Stephen Paul Weber
2026-04-29 03:11:23 UTC
parent 368b47c8d425ba37a82d054056838c1da00107c9

Support some generics in swift

HaxeSwiftBridge.hx +102 -21

diff --git a/HaxeSwiftBridge.hx b/HaxeSwiftBridge.hx
index ec05c0c..bb781a1 100644
--- a/HaxeSwiftBridge.hx
+++ b/HaxeSwiftBridge.hx
@@ -131,16 +131,17 @@ class HaxeSwiftBridge {
 		return nativeName;
 	}
 
-	static function getSwiftType(type, arg = false) {
+	static function getSwiftType(type, arg = false, erase = null) {
 		return switch type {
 		case TInst(_.get().name => "String", params):
 			return "String";
 		case TInst(_.get().name => "Array", [param]):
-			return "Array<" + getSwiftType(param, arg) + ">";
+			return "Array<" + getSwiftType(param, arg, erase) + ">";
 		case TInst(_.get() => t, params):
+			if ((erase ?? []).contains(t.name)) return "AnyObject";
 			return t.name;
 		case TAbstract(_.get().name => "Null", [param]):
-			return getSwiftType(param) + "?";
+			return getSwiftType(param, false, erase) + "?";
 		case TAbstract(_.get().name => "Int", []):
 			return "Int32";
 		case TAbstract(_.get() => t, []):
@@ -154,24 +155,67 @@ class HaxeSwiftBridge {
 			}
 			return t.name;
 		case TAbstract(_.get() => t, params):
-			return getSwiftType(TypeTools.followWithAbstracts(type, false), arg);
+			return getSwiftType(TypeTools.followWithAbstracts(type, false), arg, erase);
 		case TFun(args, ret):
 			final builder = new hx.strings.StringBuilder(arg ? "@Sendable @escaping (" : "@Sendable (");
 			for (i => arg in args) {
 				if (i > 0) builder.add(", ");
-				builder.add(getSwiftType(arg.t));
+				builder.add(getSwiftType(arg.t, false, erase));
 			}
 			builder.add(")->");
-			builder.add(getSwiftType(ret));
-
+			builder.add(getSwiftType(ret, false, erase));
 			return builder.toString();
 		case TType(_.get() => t, params):
-			return getSwiftType(TypeTools.follow(type, true), arg);
+			return getSwiftType(TypeTools.follow(type, true), arg, erase);
 		default:
 			Context.fatalError("No implemented Swift type conversion for: " + type, Context.currentPos());
 		}
 	}
 
+	static function makeErased(name, args: Array<{t: Type}>, ret, params) {
+		var erasedAny = false;
+
+		final builder = new hx.strings.StringBuilder("let __erased_"+name+": @Sendable (");
+		for (i => arg in args) {
+			if (i > 0) builder.add(", ");
+			var st = getSwiftType(arg.t);
+			var est = getSwiftType(arg.t, true, params);
+			erasedAny = erasedAny || st != est;
+			builder.add(est);
+		}
+		builder.add(")->");
+		var st = getSwiftType(ret, false);
+		var est = getSwiftType(ret, false, params);
+		erasedAny = erasedAny || st != est;
+		builder.add(est);
+
+		if (!erasedAny) return null;
+
+		builder.add(" = { (");
+		for (i => _ in args) {
+			if (i != 0) builder.add(", ");
+			builder.add("a" + i);
+		}
+		builder.add(") in\n");
+		for (i => arg in args) {
+			builder.add("\t\t\tlet c" + i + " = ");
+			builder.add(castToGeneric("a" + i, arg.t, params));
+			builder.add("\n");
+		}
+		builder.add("\t\t\treturn ");
+		builder.add(name);
+		builder.add("(");
+		for (i => arg in args) {
+			if (i != 0) builder.add(", ");
+			builder.add("c" + i);
+		}
+		builder.add(") as! ");
+		builder.add(est);
+		builder.add("\n\t\t}\n\t\t");
+
+		return builder.toString();
+	}
+
 	static function convertArgs(builder: hx.strings.StringBuilder, args: Array<{ name: String, opt: Bool, t: haxe.macro.Type }>, ?kind: FieldType) {
 		for (i => arg in args) {
 			if (i > 0) builder.add(", ");
@@ -204,7 +248,7 @@ class HaxeSwiftBridge {
 		}
 	}
 
-	static function castToSwift(item: String, type: haxe.macro.Type, canNull = false, isRet = false) {
+	static function castToSwift(item: String, type: haxe.macro.Type, canNull = false, isRet = false, forRet = false) {
 		return switch type {
 		case TInst(_.get().name => "String", params):
 			return "useString(" + item + ")" + (canNull ? "" : "!");
@@ -219,13 +263,16 @@ class HaxeSwiftBridge {
 					"{" +
 					"var __ret: UnsafeMutablePointer<" + ptrType + ">? = nil;" +
 					"let __ret_length = " + ~/\)$/.replace(item, ", &__ret);") +
-					"return " + castToSwift("__ret", type, canNull, false) + ";" +
+					"return " + castToSwift("__ret", type, canNull, false, forRet) + ";" +
 					"}()";
 			} else {
 				return
-					"{" +
+					(canNull ?
+						"{ if(" + item + "_length < 0) { return nil; }; " :
+						"{ "
+					) +
 					"let __r = UnsafeMutableBufferPointer<" + ptrType + ">(start: " + item + ", count: " + item + "_length).map({" +
-					castToSwift("$0", param) +
+					castToSwift("$0", param, false, false, forRet) +
 					"});" +
 					"c_" + libName + "." + libName + "_release(" + item + ");" +
 					"return __r;" +
@@ -234,7 +281,7 @@ class HaxeSwiftBridge {
 		case TInst(_.get() => t, params):
 			final wrapper = switch (Context.follow(type)) {
 				case TInst(_.get() => {kind: KTypeParameter(_)}, _):
-					"";
+					return "Unmanaged<AnyObject>.fromOpaque(" + item + "!).takeRetainedValue()" + (forRet ? " as! " + t.name : "");
 				default:
 					t.isInterface ? 'Any${t.name}' : t.name;
 			};
@@ -244,11 +291,11 @@ class HaxeSwiftBridge {
 				return wrapper + "(" + item + "!)";
 			}
 		case TAbstract(_.get().name => "Null", [param]):
-			return castToSwift(item, param, true, isRet);
+			return castToSwift(item, param, true, isRet, forRet);
 		case TAbstract(_.get() => t, []):
 			return item;
 		case TAbstract(_.get() => t, params):
-			return castToSwift(item, TypeTools.followWithAbstracts(type, false), canNull, isRet);
+			return castToSwift(item, TypeTools.followWithAbstracts(type, false), canNull, isRet, forRet);
 		case TType(_.get() => t, params):
 			return castToSwift(item, TypeTools.follow(type, true), canNull);
 		default:
@@ -263,11 +310,18 @@ class HaxeSwiftBridge {
 		case TInst(_.get().name => "Array", [param = TInst(_.get().name => "String", _)]):
 			return "__" + item;
 		case TInst(_.get().name => "Array", [param = TInst(_)]):
-			return item + ".map { " + castToC("$0", param, canNull) + " }";
+			final toC = castToC("$0", param, canNull);
+			if (toC == "$0") return item;
+			return item + ".map { " + toC + " }";
 		case TInst(_.get().name => "Array", [param]):
 			return item;
 		case TInst(_.get() => t, []):
-			return item + (canNull ? "?" : "") + ".o";
+			switch (Context.follow(type)) {
+			case TInst(_.get() => {kind: KTypeParameter(_)}, _):
+				return "Unmanaged.passRetained(" + item + ").toOpaque()";
+			default:
+				return item + (canNull ? "?" : "") + ".o";
+			}
 		case TAbstract(_.get().name => "Null", [param]):
 			return castToC(item, param, true);
 		case TAbstract(_.get() => t, []):
@@ -279,6 +333,27 @@ class HaxeSwiftBridge {
 		}
 	}
 
+	static function castToGeneric(item: String, type: haxe.macro.Type, params: Array<String>) {
+		return switch type {
+		case TInst(_.get().name => "Array", [param = TInst(_)]):
+			final casted = castToGeneric("$0", param, params);
+			if (casted == "$0") return item;
+			return item + "?.map { " + casted + " }";
+		case TInst(_.get().name => "Array", [param]):
+			return item;
+		case TInst(_.get().name => name, []):
+			if (params.contains(name)) {
+				return item + " as! " + name;
+			} else {
+				return item;
+			}
+		case TAbstract(_.get().name => "Null", [param]):
+			return castToGeneric(item, param, params);
+		default:
+			return item;
+		}
+	}
+
 	static function identToStr(expr: Expr) {
 		switch (expr.expr) {
 		case EConst(CIdent(s)):
@@ -386,11 +461,13 @@ class HaxeSwiftBridge {
 		if (genAccess) builder.add("public ");
 		builder.add("func ");
 		builder.add(funcName);
+		var fparams = [];
 		switch (fld?.kind) {
 		case FFun(func):
 			if (func.params.length > 0) {
+				fparams = func.params.map(p -> p.name);
 				builder.add("<");
-				builder.add(func.params.map(p -> p.name).join(","));
+				builder.add(fparams.join(","));
 				builder.add(">");
 			}
 		default:
@@ -421,10 +498,12 @@ class HaxeSwiftBridge {
 			for (arg in targs) {
 				switch (arg.t) {
 				case TFun(fargs, fret):
+					final erased = makeErased(arg.name, fargs, fret, fparams);
+					if (erased != null) builder.add(erased);
 					builder.add("let __");
 					builder.add(arg.name);
 					builder.add("_ptr = UnsafeMutableRawPointer(Unmanaged.passRetained(");
-					builder.add(arg.name);
+					builder.add(erased == null ? arg.name : "__erased_" + arg.name);
 					builder.add(" as AnyObject).toOpaque())\n\t\t");
 				default:
 				}
@@ -460,6 +539,8 @@ class HaxeSwiftBridge {
 						switch (farg.t) {
 						case TInst(_.get().name => "Array", params):
 							ibuilder.add(", a" + i + "_length");
+						case TAbstract(_.get().name => "Null", [TInst(_.get().name => "Array", _)]):
+							ibuilder.add(", a" + i + "_length");
 						default:
 						}
 					}
@@ -468,7 +549,7 @@ class HaxeSwiftBridge {
 					ibuilder.add(") in\n\t\t\t\tlet ");
 					ibuilder.add(arg.name);
 					ibuilder.add(" = Unmanaged<AnyObject>.fromOpaque(ctx!).takeUnretainedValue() as! ");
-					ibuilder.add(getSwiftType(arg.t));
+					ibuilder.add(getSwiftType(arg.t, false, fparams));
 					ibuilder.add("\n\t\t\t\t");
 					final cbuilder = new hx.strings.StringBuilder(arg.name);
 					cbuilder.add("(");
@@ -499,7 +580,7 @@ class HaxeSwiftBridge {
 			}
 			ibuilder.add("\n\t\t)");
 			builder.add("let __result = ");
-			builder.add(castToSwift(ibuilder.toString(), finalTret, false, true));
+			builder.add(castToSwift(ibuilder.toString(), finalTret, false, true, true));
 			for (arg in targs) {
 				switch TypeTools.followWithAbstracts(arg.t, false) {
 				case TFun(fargs, fret):