feat: add built-in backup decryption

This commit is contained in:
Samuel 2024-12-31 16:16:18 +01:00
parent df218b9a56
commit bd951310d1
17 changed files with 641 additions and 437 deletions

View file

@ -1,5 +1,5 @@
import { type Component } from "solid-js";
import { Route } from "@solidjs/router";
import { type Component } from "solid-js";
import { DmId, GroupId, Home, Overview, preloadDmId } from "./pages";
import "./app.css";

View file

@ -0,0 +1,34 @@
import type { Component, JSX, ValidComponent } from "solid-js"
import { splitProps } from "solid-js"
import type { PolymorphicProps } from "@kobalte/core/polymorphic"
import * as ProgressPrimitive from "@kobalte/core/progress"
import { Label } from "~/components/ui/label"
type ProgressRootProps<T extends ValidComponent = "div"> =
ProgressPrimitive.ProgressRootProps<T> & { children?: JSX.Element }
const Progress = <T extends ValidComponent = "div">(
props: PolymorphicProps<T, ProgressRootProps<T>>
) => {
const [local, others] = splitProps(props as ProgressRootProps, ["children"])
return (
<ProgressPrimitive.Root {...others}>
{local.children}
<ProgressPrimitive.Track class="relative h-2 w-full overflow-hidden rounded-full bg-secondary">
<ProgressPrimitive.Fill class="h-full w-[var(--kb-progress-fill-width)] flex-1 bg-primary transition-all" />
</ProgressPrimitive.Track>
</ProgressPrimitive.Root>
)
}
const ProgressLabel: Component<ProgressPrimitive.ProgressLabelProps> = (props) => {
return <ProgressPrimitive.Label as={Label} {...props} />
}
const ProgressValueLabel: Component<ProgressPrimitive.ProgressValueLabelProps> = (props) => {
return <ProgressPrimitive.ValueLabel as={Label} {...props} />
}
export { Progress, ProgressLabel, ProgressValueLabel }

View file

@ -1,17 +1,52 @@
import { sql, type NotNull } from "kysely";
import { db, kyselyDb, SELF_ID, setDbHash } from "./db";
import { cached } from "./lib/db-cache";
import { kyselyDb, SELF_ID } from "./db";
import { hashString } from "./lib/hash";
export const loadDb = (
statements: string[],
progressCallback?: (percentage: number) => void,
) => {
const length = statements.length;
let percentage = 0;
for (let i = 0; i < length; i++) {
const statement = statements[i];
const newPercentage = Math.round((i / length) * 100);
try {
db.exec(statement);
if (newPercentage !== percentage) {
progressCallback?.(newPercentage);
percentage = newPercentage;
}
} catch (e) {
throw new Error(`statement failed: ${statement}`, {
cause: e,
});
}
}
setDbHash(hashString(statements.join()));
};
const allThreadsOverviewQueryRaw = () =>
kyselyDb()
?.selectFrom("thread")
kyselyDb
.selectFrom("thread")
.innerJoin(
(eb) =>
eb
.selectFrom("message")
.select((eb) => ["message.thread_id", eb.fn.countAll().as("message_count")])
.select((eb) => [
"message.thread_id",
eb.fn.countAll().as("message_count"),
])
.where((eb) => {
return eb.and([eb("message.body", "is not", null), eb("message.body", "is not", "")]);
return eb.and([
eb("message.body", "is not", null),
eb("message.body", "is not", ""),
]);
})
.groupBy("message.thread_id")
.as("message"),
@ -41,8 +76,8 @@ const allThreadsOverviewQueryRaw = () =>
export const allThreadsOverviewQuery = cached(allThreadsOverviewQueryRaw);
const overallSentMessagesQueryRaw = (recipientId: number) =>
kyselyDb()
?.selectFrom("message")
kyselyDb
.selectFrom("message")
.select((eb) => eb.fn.countAll().as("messageCount"))
.where((eb) =>
eb.and([
@ -56,8 +91,8 @@ const overallSentMessagesQueryRaw = (recipientId: number) =>
export const overallSentMessagesQuery = cached(overallSentMessagesQueryRaw);
const dmPartnerRecipientQueryRaw = (dmId: number) =>
kyselyDb()
?.selectFrom("recipient")
kyselyDb
.selectFrom("recipient")
.select([
"recipient._id",
"recipient.system_joined_name",
@ -65,7 +100,9 @@ const dmPartnerRecipientQueryRaw = (dmId: number) =>
"recipient.nickname_joined_name",
])
.innerJoin("thread", "recipient._id", "thread.recipient_id")
.where((eb) => eb.and([eb("thread._id", "=", dmId), eb("recipient._id", "!=", SELF_ID)]))
.where((eb) =>
eb.and([eb("thread._id", "=", dmId), eb("recipient._id", "!=", SELF_ID)]),
)
.$narrowType<{
_id: number;
}>()
@ -74,30 +111,47 @@ const dmPartnerRecipientQueryRaw = (dmId: number) =>
export const dmPartnerRecipientQuery = cached(dmPartnerRecipientQueryRaw);
const threadSentMessagesOverviewQueryRaw = (threadId: number) =>
kyselyDb()
?.selectFrom("message")
.select(["from_recipient_id", sql<string>`datetime(date_sent / 1000, 'unixepoch')`.as("message_datetime")])
kyselyDb
.selectFrom("message")
.select([
"from_recipient_id",
sql<string>`datetime(date_sent / 1000, 'unixepoch')`.as(
"message_datetime",
),
])
.orderBy(["message_datetime"])
.where((eb) => eb.and([eb("body", "is not", null), eb("body", "!=", ""), eb("thread_id", "=", threadId)]))
.where((eb) =>
eb.and([
eb("body", "is not", null),
eb("body", "!=", ""),
eb("thread_id", "=", threadId),
]),
)
.execute();
export const threadSentMessagesOverviewQuery = cached(threadSentMessagesOverviewQueryRaw);
export const threadSentMessagesOverviewQuery = cached(
threadSentMessagesOverviewQueryRaw,
);
const threadMostUsedWordsQueryRaw = (threadId: number, limit = 10) =>
kyselyDb()
?.withRecursive("words", (eb) => {
kyselyDb
.withRecursive("words", (eb) => {
return eb
.selectFrom("message")
.select([
sql`LOWER(substr(body, 1, instr(body || " ", " ") - 1))`.as("word"),
sql`(substr(body, instr(body || " ", " ") + 1))`.as("rest"),
])
.where((eb) => eb.and([eb("body", "is not", null), eb("thread_id", "=", threadId)]))
.where((eb) =>
eb.and([eb("body", "is not", null), eb("thread_id", "=", threadId)]),
)
.unionAll((ebInner) => {
return ebInner
.selectFrom("words")
.select([
sql`LOWER(substr(rest, 1, instr(rest || " ", " ") - 1))`.as("word"),
sql`LOWER(substr(rest, 1, instr(rest || " ", " ") - 1))`.as(
"word",
),
sql`(substr(rest, instr(rest || " ", " ") + 1))`.as("rest"),
])
.where("rest", "<>", "");

View file

@ -1,50 +1,30 @@
import { createEffect, createMemo, createRoot, createSignal } from "solid-js";
import { makePersisted } from "@solid-primitives/storage";
import sqlite3InitModule from "@sqlite.org/sqlite-wasm";
import { Kysely } from "kysely";
import type { DB } from "kysely-codegen";
import { SqlJsDialect } from "kysely-wasm";
import initSqlJS, { type Database } from "sql.js";
import wasmURL from "./assets/sql-wasm.wasm?url";
import { OfficialWasmDialect } from "kysely-wasm";
import { createSignal } from "solid-js";
import workerUrl from "./lib/kysely-official-wasm-worker/worker?url";
export const SELF_ID = 2;
export const SQL = await initSqlJS({
locateFile: () => wasmURL,
const sqlite3 = await sqlite3InitModule({
print: console.log,
printErr: console.error,
});
export const [db, setDb] = createSignal<Database | undefined>();
export const db = new sqlite3.oo1.DB("signal");
const sqlJsDialect = () => {
const currentDb = db();
if (currentDb) {
return new SqlJsDialect({
database: currentDb,
});
}
};
export const kyselyDb = createRoot(() => {
createEffect(() => {
const currentDb = db();
if (currentDb) {
currentDb.create_function("is_not_empty", (str: string | null) => {
return str !== null && str !== "";
});
}
});
return createMemo(() => {
const currentSqlJsDialect = sqlJsDialect();
if (!currentSqlJsDialect) {
return;
}
return new Kysely<DB>({
dialect: currentSqlJsDialect,
});
});
export const worker = new Worker(workerUrl, {
type: "module",
});
const dialect = new OfficialWasmDialect({
database: db,
});
export const kyselyDb = new Kysely<DB>({
dialect,
});
export const [dbHash, setDbHash] = makePersisted(createSignal<number>());

View file

@ -1,11 +1,9 @@
/* @refresh reload */
import { render } from "solid-js/web";
import { Router, useNavigate } from "@solidjs/router";
import { MetaProvider } from "@solidjs/meta";
import { Router } from "@solidjs/router";
import { render } from "solid-js/web";
import App from "./App";
import { createEffect } from "solid-js";
import { db } from "./db";
const root = document.getElementById("root");
@ -21,18 +19,18 @@ if (root) {
<div class="mx-auto max-w-screen-2xl">
<MetaProvider>
<Router
root={(props) => {
const navigate = useNavigate();
const { pathname } = props.location;
// root={(props) => {
// const navigate = useNavigate();
// const { pathname } = props.location;
createEffect(() => {
if (!db() && pathname !== "/") {
navigate("/");
}
});
// createEffect(() => {
// if (!db() && pathname !== "/") {
// navigate("/");
// }
// });
return props.children;
}}
// return props.children;
// }}
>
<App />
</Router>

View file

@ -1,9 +1,9 @@
import { on, createSignal, createEffect, createRoot, createMemo } from "solid-js";
import { serialize, deserialize } from "seroval";
import { createSignaledWorker } from "@solid-primitives/workers";
import { db } from "~/db";
import { deserialize, serialize } from "seroval";
import { createEffect, createMemo, createRoot, on } from "solid-js";
import { dbHash } from "~/db";
import { hashString } from "./hash";
const DATABASE_HASH_PREFIX = "database";
export const DATABASE_HASH_PREFIX = "database";
// clear the cache on new session so that selecting a different database does not result in wrong cache entries
const clearDbCache = () => {
@ -16,64 +16,32 @@ const clearDbCache = () => {
}
};
// https://stackoverflow.com/a/7616484
const hashString = (str: string) => {
let hash = 0,
i,
chr;
if (str.length === 0) return hash;
for (i = 0; i < str.length; i++) {
chr = str.charCodeAt(i);
hash = (hash << 5) - hash + chr;
hash |= 0; // Convert to 32bit integer
}
return hash;
};
const HASH_STORE_KEY = `${DATABASE_HASH_PREFIX}_hash`;
let prevDbHash = dbHash();
createRoot(() => {
const [dbHash, setDbHash] = createSignal(localStorage.getItem(HASH_STORE_KEY));
// offloaded because this can take a long time (>1s easily) and would block the mainthread
createSignaledWorker({
input: db,
output: setDbHash,
func: function hashDb(currentDb: ReturnType<typeof db>) {
const hashString = (str: string) => {
let hash = 0,
i,
chr;
if (str.length === 0) return hash;
for (i = 0; i < str.length; i++) {
chr = str.charCodeAt(i);
hash = (hash << 5) - hash + chr;
hash |= 0; // Convert to 32bit integer
}
return hash;
};
if (currentDb?.export) {
return hashString(new TextDecoder().decode(currentDb.export())).toString();
}
},
});
createEffect(() => {
on(dbHash, (currentDbHash) => {
if (currentDbHash) {
clearDbCache();
localStorage.setItem(HASH_STORE_KEY, currentDbHash);
}
});
on(
dbHash,
(currentDbHash) => {
if (currentDbHash && currentDbHash !== prevDbHash) {
prevDbHash = currentDbHash;
clearDbCache();
}
},
{
defer: true,
},
);
});
});
class LocalStorageCacheAdapter {
keys = new Set<string>(Object.keys(localStorage).filter((key) => key.startsWith(this.prefix)));
keys = new Set<string>(
Object.keys(localStorage).filter((key) => key.startsWith(this.prefix)),
);
prefix = "database";
#dbLoaded = createMemo(() => !!db());
// TODO: real way of detecting if the db is loaded, on loading the db and opfs (if persisted db?)
#dbLoaded = createMemo(() => !!dbHash());
#createKey(cacheName: string, key: string): string {
return `${this.prefix}-${cacheName}-${key}`;
@ -86,7 +54,10 @@ class LocalStorageCacheAdapter {
try {
localStorage.setItem(fullKey, serialize(value));
} catch (error: unknown) {
if (error instanceof DOMException && error.name === "QUOTA_EXCEEDED_ERR") {
if (
error instanceof DOMException &&
error.name === "QUOTA_EXCEEDED_ERR"
) {
console.error("Storage quota exceeded, not caching new function calls");
} else {
console.error(error);
@ -146,7 +117,10 @@ const createHashKey = (...args: unknown[]) => {
return hashString(stringToHash);
};
export const cached = <T extends unknown[], R, TT>(fn: (...args: T) => R, self?: ThisType<TT>): ((...args: T) => R) => {
export const cached = <T extends unknown[], R, TT>(
fn: (...args: T) => R,
self?: ThisType<TT>,
): ((...args: T) => R) => {
const cacheName = hashString(fn.toString()).toString();
// important to return a promise on follow-up calls even if the data is immediately available

47
src/lib/decryptor.ts Normal file
View file

@ -0,0 +1,47 @@
import {
BackupDecryptor,
type DecryptionResult,
} from "@duskflower/signal-decrypt-backup-wasm";
const CHUNK_SIZE = 1024 * 1024 * 40; // 40MB chunks
export async function decryptBackup(
file: File,
passphrase: string,
progressCallback: (progress: number) => void,
): Promise<DecryptionResult> {
const fileSize = file.size;
const decryptor = new BackupDecryptor();
decryptor.set_progress_callback(fileSize, progressCallback);
let offset = 0;
try {
while (offset < file.size) {
const chunk = file.slice(offset, offset + CHUNK_SIZE);
const arrayBuffer = await chunk.arrayBuffer();
const uint8Array = new Uint8Array(arrayBuffer);
decryptor.feed_data(uint8Array);
let done = false;
while (!done) {
try {
done = decryptor.process_chunk(passphrase);
} catch (e) {
console.error("Error processing chunk:", e);
throw e;
}
}
offset += CHUNK_SIZE;
}
const result = decryptor.finish();
return result;
} catch (e) {
console.error("Decryption failed:", e);
throw e;
}
}

13
src/lib/hash.ts Normal file
View file

@ -0,0 +1,13 @@
// https://stackoverflow.com/a/7616484
export const hashString = (str: string) => {
let hash = 0,
i,
chr;
if (str.length === 0) return hash;
for (i = 0; i < str.length; i++) {
chr = str.charCodeAt(i);
hash = (hash << 5) - hash + chr;
hash |= 0; // Convert to 32bit integer
}
return hash;
};

View file

@ -1,47 +1,90 @@
import { useNavigate, type RouteSectionProps } from "@solidjs/router";
import { createSignal, Show, type Component, type JSX } from "solid-js";
import { type RouteSectionProps, useNavigate } from "@solidjs/router";
import { Title } from "@solidjs/meta";
import { Portal } from "solid-js/web";
import { Flex } from "~/components/ui/flex";
import { Title } from "@solidjs/meta";
import { setDb, SQL } from "~/db";
import {
Progress,
ProgressLabel,
ProgressValueLabel,
} from "~/components/ui/progress";
// import { db } from "~/db";
import { loadDb } from "~/db-queries";
import { decryptBackup } from "~/lib/decryptor";
export const Home: Component<RouteSectionProps> = () => {
const [isLoadingDb, setIsLoadingDb] = createSignal(false);
const [decryptionProgress, setDecryptionProgress] = createSignal<number>();
const [isLoadingDatabase, setIsLoadingDatabase] = createSignal(false);
const [passphrase, setPassphrase] = createSignal("");
const navigate = useNavigate();
const onFileChange: JSX.ChangeEventHandler<HTMLInputElement, Event> = (event) => {
const onFileChange: JSX.ChangeEventHandler<HTMLInputElement, Event> = (
event,
) => {
const file = event.currentTarget.files?.[0];
if (file) {
const reader = new FileReader();
const currentPassphrase = passphrase();
reader.addEventListener("load", () => {
setIsLoadingDb(true);
if (file && currentPassphrase) {
decryptBackup(file, currentPassphrase, setDecryptionProgress)
.then((result) => {
setDecryptionProgress(undefined);
setIsLoadingDatabase(true);
setTimeout(() => {
const Uints = new Uint8Array(reader.result as ArrayBuffer);
setDb(new SQL.Database(Uints));
setIsLoadingDb(false);
navigate("/overview");
}, 10);
});
setTimeout(() => {
loadDb(result.database_statements);
reader.readAsArrayBuffer(file);
setIsLoadingDatabase(false);
navigate("/overview");
}, 0);
})
.catch((error) => {
console.error("Decryption failed:", error);
});
}
};
return (
<>
<Portal>
<Show when={isLoadingDb()}>
<Flex alignItems="center" justifyContent="center" class="fixed inset-0 backdrop-blur-lg backdrop-filter">
<Flex
flexDirection="col"
alignItems="center"
justifyContent="center"
class="fixed inset-0 backdrop-blur-lg backdrop-filter gap-y-8"
classList={{
hidden: decryptionProgress() === undefined && !isLoadingDatabase(),
}}
>
<Show when={decryptionProgress() !== undefined}>
<p class="font-bold text-2xl">Decrypting database</p>
<Progress
value={decryptionProgress()}
minValue={0}
maxValue={100}
getValueLabel={({ value }) => `${value}%`}
class="w-[300px] space-y-1"
>
<div class="flex justify-between">
<ProgressLabel>Processing...</ProgressLabel>
<ProgressValueLabel />
</div>
</Progress>
</Show>
<Show when={isLoadingDatabase()}>
<p class="font-bold text-2xl">Loading database</p>
</Flex>
</Show>
</Show>
</Flex>
</Portal>
<Title>Signal stats</Title>
<div>
<input type="file" accept=".sqlite" onChange={onFileChange}></input>
<input
type="password"
onChange={(event) => setPassphrase(event.currentTarget.value)}
/>
<input type="file" accept=".backup" onChange={onFileChange} />
</div>
</>
);

View file

@ -1,48 +1,70 @@
import { type Component, createResource, Show } from "solid-js";
import type { RouteSectionProps } from "@solidjs/router";
import { type Component, createResource, Show } from "solid-js";
import { allThreadsOverviewQuery, overallSentMessagesQuery } from "~/db-queries";
import {
allThreadsOverviewQuery,
overallSentMessagesQuery,
} from "~/db-queries";
import { OverviewTable, type RoomOverview } from "./overview-table";
import { getNameFromRecipient } from "~/lib/get-name-from-recipient";
import { Title } from "@solidjs/meta";
import { SELF_ID } from "~/db";
import { getNameFromRecipient } from "~/lib/get-name-from-recipient";
import { OverviewTable, type RoomOverview } from "./overview-table";
export const Overview: Component<RouteSectionProps> = () => {
const [allSelfSentMessagesCount] = createResource(() => overallSentMessagesQuery(SELF_ID));
console.log(overallSentMessagesQuery(SELF_ID));
const [roomOverview] = createResource<RoomOverview[] | undefined>(async () => {
return (await allThreadsOverviewQuery())?.map((row) => {
const isGroup = row.title !== null;
const [allSelfSentMessagesCount] = createResource(() =>
overallSentMessagesQuery(SELF_ID),
);
let name = "";
const [roomOverview] = createResource<RoomOverview[] | undefined>(
async () => {
return (await allThreadsOverviewQuery())?.map((row) => {
const isGroup = row.title !== null;
if (row.title !== null) {
name = row.title;
} else {
name = getNameFromRecipient(row.nickname_joined_name, row.system_joined_name, row.profile_joined_name);
}
let name = "";
return {
threadId: row.thread_id,
recipientId: row.recipient_id,
archived: Boolean(row.archived),
messageCount: row.message_count,
lastMessageDate: row.last_message_date ? new Date(row.last_message_date) : undefined,
name,
isGroup,
};
});
});
if (row.title !== null) {
name = row.title;
} else {
name = getNameFromRecipient(
row.nickname_joined_name,
row.system_joined_name,
row.profile_joined_name,
);
}
return {
threadId: row.thread_id,
recipientId: row.recipient_id,
archived: Boolean(row.archived),
messageCount: row.message_count,
lastMessageDate: row.last_message_date
? new Date(row.last_message_date)
: undefined,
name,
isGroup,
};
});
},
);
return (
<>
<Title>Signal statistics overview</Title>
<div>
<p>All messages: {allSelfSentMessagesCount()?.messageCount as number}</p>
<Show when={!roomOverview.loading && roomOverview()} fallback="Loading...">
{(currentRoomOverview) => <OverviewTable data={currentRoomOverview()} />}
<p>
All messages: {allSelfSentMessagesCount()?.messageCount as number}
</p>
<Show
when={!roomOverview.loading && roomOverview()}
fallback="Loading..."
>
{(currentRoomOverview) => (
<OverviewTable data={currentRoomOverview()} />
)}
</Show>
</div>
</>