git » sdk » commit f1fed24

OMEMO: Fixes for OMEMO carbons and session building

author Matthew Wild
2025-05-24 14:56:37 UTC
committer Stephen Paul Weber
2025-09-29 13:49:58 UTC
parent c5a8197946cbbda397808f967a118cd8fc3a5aa8

OMEMO: Fixes for OMEMO carbons and session building

doc/OMEMO.md +1 -0
snikket/Client.hx +7 -205
snikket/MessageSync.hx +1 -1
snikket/OMEMO.hx +70 -21

diff --git a/doc/OMEMO.md b/doc/OMEMO.md
index c130206..4862263 100644
--- a/doc/OMEMO.md
+++ b/doc/OMEMO.md
@@ -14,6 +14,7 @@ compile with the NO_OMEMO flag.
 - [x] Remove and replace consumed prekeys
 - [x] Allow non-OMEMO messages to recipients with no published keys when policy allows
 - [x] Encrypt outgoing messages to the sending account's other devices
+- [x] OMEMO carbons working
 - [x] Persistence: IndexedDB (for web)
 - [x] Use cache for remote contact devices
 - [x] Fix that encryption status reported by the API can be forged by sender
diff --git a/snikket/Client.hx b/snikket/Client.hx
index cd42046..56ad44d 100644
--- a/snikket/Client.hx
+++ b/snikket/Client.hx
@@ -172,13 +172,6 @@ class Client extends EventEmitter {
 			}
 
 			final from = stanza.attr.get("from") == null ? null : JID.parse(stanza.attr.get("from"));
-
-			if (stanza.attr.get("type") == "error" && from != null) {
-				final chat = getChat(from.asBare().asString());
-				final channel = Std.downcast(chat, Channel);
-				if (channel != null) channel.selfPing(true);
-			}
-
 			var fwd = null;
 			if (from != null && from.asBare().asString() == accountId()) {
 				var carbon = stanza.getChild("received", "urn:xmpp:carbons:2");
@@ -188,198 +181,17 @@ class Client extends EventEmitter {
 				}
 			}
 
-			final message = Message.fromStanza(stanza, this.jid, (builder, stanza) -> {
-				var chat = getChat(builder.chatId());
-				if (chat == null && stanza.attr.get("type") != "groupchat") chat = getDirectChat(builder.chatId());
-				if (chat == null) return builder;
-				return chat.prepareIncomingMessage(builder, stanza);
-			});
-			switch (message.parsed) {
-				case ChatMessageStanza(chatMessage):
-					for (hash in chatMessage.inlineHashReferences()) {
-						fetchMediaByHash([hash], [chatMessage.from]);
-					}
-					final chat = getChat(chatMessage.chatId());
-					if (chat != null) {
-						final updateChat = (chatMessage) -> {
-							notifyMessageHandlers(chatMessage, chatMessage.versions.length > 1 ? CorrectionEvent : DeliveryEvent);
-							if (chatMessage.versions.length < 1 || chat.lastMessageId() == chatMessage.serverId || chat.lastMessageId() == chatMessage.localId) {
-								chat.setLastMessage(chatMessage);
-								if (chatMessage.versions.length < 1) chat.setUnreadCount(chatMessage.isIncoming() ? chat.unreadCount() + 1 : 0);
-								chatActivity(chat);
-							}
-						};
-						if (chatMessage.serverId == null) {
-							updateChat(chatMessage);
-						} else {
-							storeMessages([chatMessage]).then((stored) -> updateChat(stored[0]));
-						}
-					}
-				case ReactionUpdateStanza(update):
-					for (hash in update.inlineHashReferences()) {
-						fetchMediaByHash([hash], [from]);
-					}
-					persistence.storeReaction(accountId(), update).then((stored) -> if (stored != null) notifyMessageHandlers(stored, ReactionEvent));
-				case ModerateMessageStanza(action):
-					moderateMessage(action).then((stored) -> if (stored != null) notifyMessageHandlers(stored, CorrectionEvent));
-				default:
-					// ignore
-			}
-
-#if !NO_JINGLE
-			final jmiP = stanza.getChild("propose", "urn:xmpp:jingle-message:0");
-			if (jmiP != null && jmiP.attr.get("id") != null) {
-				final session = new IncomingProposedSession(this, from, jmiP.attr.get("id"));
-				final chat = getDirectChat(from.asBare().asString());
-				if (!chat.jingleSessions.exists(session.sid)) {
-					chat.jingleSessions.set(session.sid, session);
-					chatActivity(chat);
-					session.ring();
-				}
-			}
-
-			final jmiR = stanza.getChild("retract", "urn:xmpp:jingle-message:0");
-			if (jmiR != null && jmiR.attr.get("id") != null) {
-				final chat = getDirectChat(from.asBare().asString());
-				final session = chat.jingleSessions.get(jmiR.attr.get("id"));
-				if (session != null) {
-					session.retract();
-					chat.jingleSessions.remove(session.sid);
-				}
-			}
-
-			// Another resource picked this up
-			final jmiProFwd = fwd?.getChild("proceed", "urn:xmpp:jingle-message:0");
-			if (jmiProFwd != null && jmiProFwd.attr.get("id") != null) {
-				final chat = getDirectChat(JID.parse(fwd.attr.get("to")).asBare().asString());
-				final session = chat.jingleSessions.get(jmiProFwd.attr.get("id"));
-				if (session != null) {
-					session.retract();
-					chat.jingleSessions.remove(session.sid);
-				}
-			}
-
-			final jmiPro = stanza.getChild("proceed", "urn:xmpp:jingle-message:0");
-			if (jmiPro != null && jmiPro.attr.get("id") != null) {
-				final chat = getDirectChat(from.asBare().asString());
-				final session = chat.jingleSessions.get(jmiPro.attr.get("id"));
-				if (session != null) {
-					try {
-						chat.jingleSessions.set(session.sid, session.initiate(stanza));
-					} catch (e) {
-						trace("JMI proceed failed", e);
-					}
-				}
-			}
-
-			final jmiRej = stanza.getChild("reject", "urn:xmpp:jingle-message:0");
-			if (jmiRej != null && jmiRej.attr.get("id") != null) {
-				final chat = getDirectChat(from.asBare().asString());
-				final session = chat.jingleSessions.get(jmiRej.attr.get("id"));
-				if (session != null) {
-					session.retract();
-					chat.jingleSessions.remove(session.sid);
-				}
-			}
-#end
-
-			if (stanza.attr.get("type") != "error") {
-				final chatState = stanza.getChild(null, "http://jabber.org/protocol/chatstates");
-				final userState = switch (chatState?.name) {
-					case "active": UserState.Active;
-					case "inactive": UserState.Inactive;
-					case "gone": UserState.Gone;
-					case "composing": UserState.Composing;
-					case "paused": UserState.Paused;
-					default: null;
-				};
-				if (userState != null) {
-					final chat = getChat(from.asBare().asString());
-					if (chat == null || !chat.getParticipantDetails(message.senderId).isSelf) {
-						for (handler in chatStateHandlers) {
-							handler(message.senderId, message.chatId, message.threadId, userState);
-						}
-					}
-				}
-			}
-
-			final pubsubEvent = PubsubEvent.fromStanza(stanza);
-			if (pubsubEvent != null && pubsubEvent.getFrom() != null && pubsubEvent.getNode() == "urn:xmpp:avatar:metadata" && pubsubEvent.getItems().length > 0) {
-				final item = pubsubEvent.getItems()[0];
-				final avatarSha1Hex = pubsubEvent.getItems()[0].attr.get("id");
-				final avatarSha1 = Hash.fromHex("sha-1", avatarSha1Hex)?.hash;
-				final metadata = item.getChild("metadata", "urn:xmpp:avatar:metadata");
-				var mime = "image/png";
-				if (metadata != null) {
-					final info = metadata.getChild("info"); // should have xmlns matching metadata
-					if (info != null && info.attr.get("type") != null) {
-						mime = info.attr.get("type");
-					}
-				}
-				if (avatarSha1 != null) {
-					final chat = this.getDirectChat(JID.parse(pubsubEvent.getFrom()).asBare().asString(), false);
-					chat.setAvatarSha1(avatarSha1);
-					persistence.storeChats(accountId(), [chat]);
-					persistence.hasMedia("sha-1", avatarSha1).then((has) -> {
-						if (has) {
-							this.trigger("chats/update", [chat]);
-						} else {
-							final pubsubGet = new PubsubGet(pubsubEvent.getFrom(), "urn:xmpp:avatar:data", avatarSha1Hex);
-							pubsubGet.onFinished(() -> {
-								final item = pubsubGet.getResult()[0];
-								if (item == null) return;
-								final dataNode = item.getChild("data", "urn:xmpp:avatar:data");
-								if (dataNode == null) return;
-								persistence.storeMedia(mime, Base64.decode(StringTools.replace(dataNode.getText(), "\n", "")).getData()).then(_ -> {
-									this.trigger("chats/update", [chat]);
-								});
-							});
-							sendQuery(pubsubGet);
-						}
-					});
-				}
-			}
-
-			trace("pubsubEvent "+Std.string(pubsubEvent!=null));
-			if (pubsubEvent != null && pubsubEvent.getFrom() != null) {
-				final fromBare = JID.parse(pubsubEvent.getFrom()).asBare();
-				final isOwnAccount = fromBare.asString() == accountId();
-				final pubsubNode = pubsubEvent.getNode();
-
-				if(isOwnAccount && pubsubEvent.getNode() == "urn:xmpp:mds:displayed:0" && pubsubEvent.getItems().length > 0) {
-					for (item in pubsubEvent.getItems()) {
-						if (item.attr.get("id") != null) {
-							final upTo = item.getChild("displayed", "urn:xmpp:mds:displayed:0")?.getChild("stanza-id", "urn:xmpp:sid:0");
-							final chat = getChat(item.attr.get("id"));
-							if (chat == null) {
-								startChatWith(item.attr.get("id"), (caps) -> Closed, (chat) -> chat.markReadUpToId(upTo.attr.get("id"), upTo.attr.get("by")));
-							} else {
-								chat.markReadUpToId(upTo.attr.get("id"), upTo.attr.get("by"), () -> {
-									persistence.storeChats(accountId(), [chat]);
-									this.trigger("chats/update", [chat]);
-								});
-							}
-						}
-					}
-				}
-						
-				if (isOwnAccount && pubsubNode == "http://jabber.org/protocol/nick" && pubsubEvent.getItems().length > 0) {
-					updateDisplayName(pubsubEvent.getItems()[0].getChildText("nick", "http://jabber.org/protocol/nick"));
-				}
-
-				trace("pubsubNode == "+pubsubNode);
-			}
 #if !NO_OMEMO
-			if(stanza.hasChild("encrypted", NS.OMEMO)) {
-				omemo.decryptMessage(stanza).then((decryptionResult) -> {
+			if((fwd??stanza).hasChild("encrypted", NS.OMEMO)) {
+				omemo.decryptMessage(stanza, fwd).then((decryptionResult) -> {
 					trace("OMEMO: Decrypted message, now processing...");
-					processLiveMessage(decryptionResult.stanza, decryptionResult.encryptionInfo);
+					processLiveMessage(decryptionResult.stanza, fwd, decryptionResult.encryptionInfo);
 					return true;
 				});
 				return EventHandled;
 			}
 #end
-			processLiveMessage(stanza);
+			processLiveMessage(stanza, fwd);
 			return EventHandled;
 		});
 
@@ -596,7 +408,7 @@ class Client extends EventEmitter {
 	}
 
 	@:allow(snikket)
-	private function processLiveMessage(stanza:Stanza, ?encryptionInfo:EncryptionInfo):Void {
+	private function processLiveMessage(stanza:Stanza, fwd:Null<Stanza>, ?encryptionInfo:EncryptionInfo):Void {
 		final from = stanza.attr.get("from") == null ? null : JID.parse(stanza.attr.get("from"));
 
 		if (stanza.attr.get("type") == "error" && from != null) {
@@ -605,16 +417,6 @@ class Client extends EventEmitter {
 			if (channel != null) channel.selfPing(true);
 		}
 
-		var fwd = null;
-		if (from != null && from.asBare().asString() == accountId()) {
-			var carbon = stanza.getChild("received", "urn:xmpp:carbons:2");
-			if (carbon == null) carbon = stanza.getChild("sent", "urn:xmpp:carbons:2");
-			if (carbon != null) {
-				fwd = carbon.getChild("forwarded", "urn:xmpp:forward:0")?.getFirstChild();
-			}
-		}
-
-
 		final message = Message.fromStanza(stanza, this.jid, (builder, stanza) -> {
 			var chat = getChat(builder.chatId());
 			if (chat == null && stanza.attr.get("type") != "groupchat") chat = getDirectChat(builder.chatId());
@@ -640,14 +442,14 @@ class Client extends EventEmitter {
 					if (chatMessage.serverId == null) {
 						updateChat(chatMessage);
 					} else {
-						storeMessages([chatMessage], (stored) -> updateChat(stored[0]));
+						storeMessages([chatMessage]).then((stored) -> updateChat(stored[0]));
 					}
 				}
 			case ReactionUpdateStanza(update):
 				for (hash in update.inlineHashReferences()) {
 					fetchMediaByHash([hash], [from]);
 				}
-				persistence.storeReaction(accountId(), update, (stored) -> if (stored != null) notifyMessageHandlers(stored, ReactionEvent));
+				persistence.storeReaction(accountId(), update).then((stored) -> if (stored != null) notifyMessageHandlers(stored, ReactionEvent));
 			case ModerateMessageStanza(action):
 				moderateMessage(action).then((stored) -> if (stored != null) notifyMessageHandlers(stored, CorrectionEvent));
 			default:
diff --git a/snikket/MessageSync.hx b/snikket/MessageSync.hx
index e01b795..52360b5 100644
--- a/snikket/MessageSync.hx
+++ b/snikket/MessageSync.hx
@@ -93,7 +93,7 @@ class MessageSync {
 			if (originalMessage.hasChild("encrypted", NS.OMEMO)) {
 #if !NO_OMEMO
 				trace("MAM: Processing OMEMO message from " + originalMessage.attr.get("from"));
-				promisedMessages.push(client.omemo.decryptMessage(originalMessage).then((decryptionResult) -> {
+				promisedMessages.push(client.omemo.decryptMessage(originalMessage, null).then((decryptionResult) -> {
 					final decryptedStanza = decryptionResult.stanza;
 					trace("MAM: Decrypted stanza: "+decryptedStanza);
 
diff --git a/snikket/OMEMO.hx b/snikket/OMEMO.hx
index d750829..f2098ce 100644
--- a/snikket/OMEMO.hx
+++ b/snikket/OMEMO.hx
@@ -814,7 +814,7 @@ class OMEMO {
 				// Incoming message used a prekey - build a new session between
 				// us and the sender
 				trace("OMEMO: Received an encrypted message using a prekey. Creating session...");
-				final promSession = buildSession(deviceId, fromBare, payload.sid);
+				final promSession = buildSession(deviceId, fromBare, payload.sid, "prekey");
 				promSession.then((session) -> {
 					getSessionCipher(deviceId, fromBare, payload.sid).then((cipher) -> {
 						resolve(cipher);
@@ -879,13 +879,38 @@ class OMEMO {
 		});
 	}
 
-	public function decryptMessage(stanza: Stanza):Promise<OMEMODecryptionResult> {
-		final header = OMEMOPayload.fromMessageStanza(stanza);
-		return client.omemo.getDeviceId().then((deviceId:Int) -> {
-			final deviceKey = header.findKey(deviceId);
+	public function decryptMessage(stanza: Stanza, fwd: Null<Stanza>):Promise<OMEMODecryptionResult> {
+		// Check for carbon-forwarded message
+		final from = stanza.attr.get("from") == null ? null : JID.parse(stanza.attr.get("from")).asBare();
+		final header = OMEMOPayload.fromMessageStanza(fwd??stanza);
+		final senderAddress = new SignalProtocolAddress(from.asString(), header.sid);
+		final sessionMeta = new Promise<OMEMOSessionMetadata>((resolve, reject) -> {
+			persistence.getOmemoMetadata(client.accountId(), senderAddress.toString(), resolve);
+		});
+		final promDeviceId = client.omemo.getDeviceId();
+		var deviceKey:Null<OMEMOPayloadKey>;
+		final promResult = promDeviceId.then((deviceId:Int) -> {
+			if(deviceId == header.sid) {
+				// Message was sent by us (it was probably fetched from MAM)
+				// We're not going to build a session with ourself (that won't
+				// work!). We either have the original message locally, or we
+				// don't, but we can't decrypt this copy.
+				return Promise.resolve(
+					new OMEMODecryptionResult(
+						stanza,
+						new EncryptionInfo(
+							DecryptionFailure,
+							NS.OMEMO,
+							"own-message",
+							"Past message sent from this device (cannot be decrypted)"
+						)
+					)
+				);
+			}
+			deviceKey = header.findKey(deviceId);
 			if(deviceKey == null) {
 				trace("OMEMO: Message not encrypted for our device (looked for "+deviceId+")");
-				stanza.removeChildren("encrypted", NS.OMEMO);
+				(fwd??stanza).removeChildren("encrypted", NS.OMEMO);
 				return Promise.resolve(
 					new OMEMODecryptionResult(
 						stanza,
@@ -900,7 +925,6 @@ class OMEMO {
 			}
 			// FIXME: Identify correct JID for group chats
 			trace("OMEMO: Decrypting payload...");
-			final from = JID.parse(stanza.attr.get("from")).asBare();
 			final promPayload = decryptPayload(deviceId, deviceKey, from.asString(), header);
 			return promPayload.then((decryptedPayload:BytesData) -> {
 				if(decryptedPayload == null) {
@@ -916,9 +940,9 @@ class OMEMO {
 					));
 				}
 
-				stanza.removeChildren("body");
+				(fwd??stanza).removeChildren("body");
 				// FIXME: Verify valid UTF-8, etc.
-				stanza.textTag("body", Bytes.ofData(decryptedPayload).toString());
+				(fwd??stanza).textTag("body", Bytes.ofData(decryptedPayload).toString());
 				trace("OMEMO: Payload decrypted OK!");
 				return Promise.resolve(new OMEMODecryptionResult(
 					stanza,
@@ -929,13 +953,6 @@ class OMEMO {
 				));
 			}, (err:Any) -> {
 				trace("OMEMO: Failed to decrypt message: " + err);
-				// FIXME: Rebuilding the session should not be unconditional, as this
-				// can be triggered by a MITM and effectively bypass the double ratchet
-				// part of the protocol.
-				buildSession(deviceId, from.asString(), header.sid).then((session) -> {
-					// Broken session? Send key to start new session...
-					sendKeyExchange(deviceId, from.asString(), header.sid);
-				});
 				return Promise.resolve(new OMEMODecryptionResult(
 					stanza,
 					new EncryptionInfo(
@@ -947,6 +964,38 @@ class OMEMO {
 				));
 			});
 		});
+
+		// Some post-decryption tasks, such as updating the session metadata
+		// and sending a key exchange if necessary
+		promResult.then((decryptionResult) -> {
+			sessionMeta.then((metadata) -> {
+				promDeviceId.then((deviceId) -> {
+					if(metadata == null) {
+						// No metadata in storage, so create a default
+						metadata = new OMEMOSessionMetadata(false, false, false);
+					}
+					final decryptedOk = decryptionResult.encryptionInfo.status == DecryptionSuccess;
+					var needUpdate = metadata.lastMessageDecryptedOk != decryptedOk;
+					final receivedSessionMessage = deviceKey != null && !deviceKey.prekey;
+					needUpdate = needUpdate || receivedSessionMessage != metadata.receivedSessionMessageOk;
+					// Send a key exchange if decryption failed, this wasn't a prekey message, and
+					// if we haven't already sent a key exchange
+					final shouldSendKeyExchange = !decryptedOk && receivedSessionMessage && !metadata.sentKeyExchange;
+					if(shouldSendKeyExchange) {
+						needUpdate = true;
+						trace("OMEMO: Possible broken session with <"+senderAddress.toString()+">, sending key exchange...");
+						buildSession(deviceId, from.asString(), header.sid, "replacement").then((session) -> {
+							sendKeyExchange(deviceId, from.asString(), header.sid);
+						});
+					}
+					if(needUpdate) {
+						persistence.storeOmemoMetadata(client.accountId(), senderAddress.toString(), new OMEMOSessionMetadata(receivedSessionMessage||metadata.receivedSessionMessageOk, decryptedOk, shouldSendKeyExchange));
+					}
+				});
+			});
+		});
+
+		return promResult;
 	}
 
 	private function decryptPayloadWithKey(rawPayload:BytesData, rawKeyWithTag:BytesData, rawIv:BytesData):Promise<BytesData> {
@@ -1119,10 +1168,10 @@ class OMEMO {
 		return promStanza;
 	}
 
-	private function buildSession(sid:Int, jid:String, rid:Int):Promise<SignalSession> {
+	private function buildSession(sid:Int, jid:String, rid:Int, reason:String):Promise<SignalSession> {
 		final address = new SignalProtocolAddress(jid, rid);
 		final promBundle = getContactBundle(jid, rid);
-		trace("OMEMO: Building session (fetching bundle)...");
+		trace("OMEMO: Building session for <"+address.toString()+"> for "+reason+" (fetching bundle)...");
 		final promSession = promBundle.then((bundle:OMEMOBundle) -> {
 			trace("OMEMO: Fetched bundle");
 			final contactPreKey = bundle.getRandomPreKey();
@@ -1140,10 +1189,10 @@ class OMEMO {
 				},
 			});
 		}).then((_) -> {
-			trace("OMEMO: Built session!");
+			trace("OMEMO: Built session! ("+address.toString()+" for "+reason+")");
 			return signalStore.loadSession(address);
 		}, (err:Any) -> {
-			trace("OMEMO: Failed to build session: "+err);
+			trace("OMEMO: Failed to build "+reason+" session for <"+address.toString()+">: "+err);
 			return signalStore.loadSession(address);
 		});
 
@@ -1158,7 +1207,7 @@ class OMEMO {
 		final promReadySession = promSession.then((session) -> {
 			if(session == null) {
 				trace("OMEMO: No session for "+address.toString());
-				return buildSession(sid, jid, rid);
+				return buildSession(sid, jid, rid, "new");
 			}
 			return session;
 		});