git » sdk » commit 34d27d2

Allowing NULL in primary key is bad news

author Stephen Paul Weber
2025-03-24 17:46:45 UTC
committer Stephen Paul Weber
2025-03-24 17:46:45 UTC
parent fe3a7112ee0d6e1737992c62f57330c90f3d83a3

Allowing NULL in primary key is bad news

It allows duplicates to be inserted even if only one of the columns is
NULL.  So use empty string internally instead for null values.

snikket/persistence/Sqlite.hx +15 -16

diff --git a/snikket/persistence/Sqlite.hx b/snikket/persistence/Sqlite.hx
index f0685d2..d80ea90 100644
--- a/snikket/persistence/Sqlite.hx
+++ b/snikket/persistence/Sqlite.hx
@@ -39,12 +39,11 @@ class Sqlite implements Persistence implements KeyValueStore {
 		final version = db.exec("PRAGMA user_version;").then(iter -> {
 			final version = Std.parseInt(iter.next()?.user_version) ?? 0;
 			return if (version < 1) {
-				// messages cannot be STRICT because mam_id may be NULL
 				db.exec("CREATE TABLE messages (
 					account_id TEXT NOT NULL,
-					mam_id TEXT,
-					mam_by TEXT,
-					stanza_id TEXT,
+					mam_id TEXT NOT NULL,
+					mam_by TEXT NOT NULL,
+					stanza_id TEXT NOT NULL,
 					correction_id TEXT NOT NULL,
 					sync_point INTEGER NOT NULL,
 					chat_id TEXT NOT NULL,
@@ -55,7 +54,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 					type INTEGER NOT NULL,
 					stanza TEXT NOT NULL,
 					PRIMARY KEY (account_id, mam_id, mam_by, stanza_id)
-				);
+				) STRICT;
 				CREATE INDEX messages_created_at ON messages (account_id, chat_id, created_at);
 				CREATE INDEX messages_correction_id ON messages (correction_id);
 				CREATE TABLE chats (
@@ -300,8 +299,8 @@ class Sqlite implements Persistence implements KeyValueStore {
 					final correctable = m;
 					final message = m.versions.length == 1 ? m.versions[0] : m; // TODO: storing multiple versions at once? We never do that right now
 					([
-						accountId, message.serverId, message.serverIdBy,
-						message.localId, correctable.localId ?? correctable.serverId, correctable.syncPoint,
+						accountId, message.serverId ?? "", message.serverIdBy ?? "",
+						message.localId ?? "", correctable.localId ?? correctable.serverId, correctable.syncPoint,
 						correctable.chatId(), correctable.senderId,
 						message.timestamp, message.status, message.direction, message.type,
 						message.asStanza().toString()
@@ -348,8 +347,8 @@ class Sqlite implements Persistence implements KeyValueStore {
 		var q = "SELECT
 			correction_id AS stanza_id,
 			versions.stanza,
-			json_group_object(COALESCE(versions.mam_id, versions.stanza_id), strftime('%FT%H:%M:%fZ', versions.created_at / 1000.0, 'unixepoch')) AS version_times,
-			json_group_object(COALESCE(versions.mam_id, versions.stanza_id), versions.stanza) AS versions,
+			json_group_object(CASE WHEN versions.mam_id IS NULL OR versions.mam_id='' THEN versions.stanza_id ELSE versions.mam_id END, strftime('%FT%H:%M:%fZ', versions.created_at / 1000.0, 'unixepoch')) AS version_times,
+			json_group_object(CASE WHEN versions.mam_id IS NULL OR versions.mam_id='' THEN versions.stanza_id ELSE versions.mam_id END, versions.stanza) AS versions,
 			messages.direction,
 			messages.type,
 			strftime('%FT%H:%M:%fZ', messages.created_at / 1000.0, 'unixepoch') AS timestamp,
@@ -358,7 +357,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 			messages.mam_by,
 			messages.sync_point,
 			MAX(versions.created_at)
-			FROM messages INNER JOIN messages versions USING (correction_id) WHERE (messages.stanza_id IS NULL OR messages.stanza_id=correction_id) AND messages.account_id=? AND messages.chat_id=?";
+			FROM messages INNER JOIN messages versions USING (correction_id) WHERE (messages.stanza_id IS NULL OR messages.stanza_id='' OR messages.stanza_id=correction_id) AND messages.account_id=? AND messages.chat_id=?";
 		final params = [accountId, chatId];
 		if (time != null) {
 			q += " AND messages.created_at " + op + "CAST(unixepoch(?, 'subsec') * 1000 AS INTEGER)";
@@ -427,7 +426,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 		final params: Array<Dynamic> = [accountId]; // subq is first in final q, so subq params first
 
 		final subq = new StringBuf();
-		subq.add("SELECT chat_id, MAX(created_at) AS created_at FROM messages WHERE account_id=?");
+		subq.add("SELECT chat_id, ROWID as row, MAX(created_at) AS created_at FROM messages WHERE account_id=?");
 		subq.add(" AND chat_id IN (");
 		for (i => chat in chats) {
 			if (i != 0) subq.add(",");
@@ -450,14 +449,14 @@ class Sqlite implements Persistence implements KeyValueStore {
 		final q = new StringBuf();
 		q.add("SELECT chat_id AS chatId, stanza, direction, type, sender_id, mam_id, mam_by, sync_point, CASE WHEN subq.created_at IS NULL THEN COUNT(*) ELSE COUNT(*) - 1 END AS unreadCount, strftime('%FT%H:%M:%fZ', MAX(messages.created_at) / 1000.0, 'unixepoch') AS timestamp FROM messages LEFT JOIN (");
 		q.add(subq.toString());
-		q.add(") subq USING (chat_id) WHERE account_id=? AND (stanza_id IS NULL OR stanza_id=correction_id) AND chat_id IN (");
+		q.add(") subq USING (chat_id) WHERE account_id=? AND (stanza_id IS NULL OR stanza_id='' OR stanza_id=correction_id) AND chat_id IN (");
 		params.push(accountId);
 		for (i => chat in chats) {
 			if (i != 0) q.add(",");
 			q.add("?");
 			params.push(chat.chatId);
 		}
-		q.add(") AND (subq.created_at IS NULL OR messages.created_at >= subq.created_at) GROUP BY chat_id;");
+		q.add(") AND (subq.created_at IS NULL OR messages.created_at > subq.created_at OR (messages.created_at=subq.created_at AND messages.ROWID >= subq.row)) GROUP BY chat_id;");
 		db.exec(q.toString(), params).then(result -> {
 			final details = [];
 			final rows: Array<Dynamic> = { iterator: () -> result }.array();
@@ -701,7 +700,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 			final agg: Map<String, Map<String, Array<Dynamic>>> = [];
 			for (row in rows) {
 				final reactions: Array<Dynamic> = Json.parse(row.reactions);
-				final mapId = (row.mam_id == null ? row.stanza_id : row.mam_id + "\n" + row.mam_by) + "\n" + row.chat_id;
+				final mapId = (row.mam_id == null || row.mam_id == "" ? row.stanza_id : row.mam_id + "\n" + row.mam_by) + "\n" + row.chat_id;
 				if (!agg.exists(mapId)) agg.set(mapId, []);
 				final map = agg[mapId];
 				if (!map.exists(row.sender_id)) map[row.sender_id] = [];
@@ -776,7 +775,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 			builder.timestamp = row.timestamp;
 			builder.type = row.type;
 			builder.senderId = row.sender_id;
-			builder.serverId = row.mam_id;
+			builder.serverId = row.mam_id == "" ? null : row.mam_id;
 			builder.serverIdBy = row.mam_by;
 			if (builder.direction != row.direction) {
 				builder.direction = row.direction;
@@ -784,7 +783,7 @@ class Sqlite implements Persistence implements KeyValueStore {
 				builder.replyTo = builder.recipients;
 				builder.recipients = replyTo;
 			}
-			if (row.stanza_id != null) builder.localId = row.stanza_id;
+			if (row.stanza_id != null && row.stanza_id != "") builder.localId = row.stanza_id;
 			if (row.versions != null) {
 				final versionTimes: DynamicAccess<String> = Json.parse(row.version_times);
 				final versions: DynamicAccess<String> =  Json.parse(row.versions);