feat: move progress reporting to rust side

This commit is contained in:
Samuel 2024-12-25 13:30:23 +01:00
parent 6f477f9cbc
commit 5e354fc7ae
2 changed files with 158 additions and 105 deletions

View file

@ -271,8 +271,6 @@ fn get_frame_length(
header_version: Option<u32>,
) -> Result<Option<u32>, JsValue> {
if reader.remaining_length() < 4 {
web_sys::console::log_1(&"too less data to decrypt frame length".into());
return Ok(None); // Not enough data to read the frame length
}
@ -354,7 +352,10 @@ pub struct BackupDecryptor {
database_bytes: Vec<u8>,
ciphertext_buf: Vec<u8>,
plaintext_buf: Vec<u8>,
total_bytes_received: usize,
total_file_size: usize,
total_bytes_processed: usize,
processed_percentage: usize,
progress_callback: Option<js_sys::Function>,
is_initialized: bool,
// this is stored if the frame has been decrypted but it is an attachment for which we don't have enough data available
// so we don't need to decrypt the whole frame again
@ -375,7 +376,10 @@ impl BackupDecryptor {
database_bytes: Vec::new(),
ciphertext_buf: Vec::new(),
plaintext_buf: Vec::new(),
total_bytes_received: 0,
total_file_size: 0,
total_bytes_processed: 0,
processed_percentage: 0,
progress_callback: None,
is_initialized: false,
current_backup_frame: None,
}
@ -389,10 +393,41 @@ impl BackupDecryptor {
new_data.extend_from_slice(self.reader.remaining_data());
new_data.extend_from_slice(chunk);
self.total_bytes_received += chunk.len();
// self.total_bytes_received += chunk.len();
self.reader = ByteReader::new(new_data);
}
#[wasm_bindgen]
pub fn set_progress_callback(
&mut self,
total_file_size: usize,
progress_callback: js_sys::Function,
) {
self.total_file_size = total_file_size;
self.progress_callback = Some(progress_callback);
}
pub fn call_progress_callback(&self) -> usize {
let prev_percentage = self.processed_percentage;
if let Some(ref progress_callback) = self.progress_callback {
let percentage =
(self.total_bytes_processed as f32 / self.total_file_size as f32) * 100.0;
let rounded = percentage.round() as usize;
if rounded != prev_percentage {
progress_callback
.call1(&JsValue::NULL, &JsValue::from(rounded))
.unwrap();
return rounded;
}
}
return prev_percentage;
}
// process available data
// returns Ok if the decryption of the current frame was successful
// Ok(false) if there is enough data left
@ -431,6 +466,10 @@ impl BackupDecryptor {
if self.reader.remaining_length() < length as usize {
return Ok(true);
} else {
self.total_bytes_processed += (length + 10) as usize;
let new_percentage = self.call_progress_callback();
self.processed_percentage = new_percentage;
// attachments are encoded as length, which would have to be read using decode_frame_payload
// +10 because in decrypt_frame_payload we would read `their_mac` from reader which is 10 bytes long
self.reader.increment_position((length + 10) as usize);
@ -442,19 +481,22 @@ impl BackupDecryptor {
return Ok(false);
}
}
} else {
// we need to do this here so that during get_frame_length and decrypt_frame we use the same hmac and ctr
let mut hmac = <HmacSha256 as Mac>::new_from_slice(&keys.hmac_key)
.map_err(|_| JsValue::from_str("Invalid HMAC key"))?;
// we need to do this here so that during get_frame_length and decrypt_frame we use the same hmac and ctr
let mut hmac = <HmacSha256 as Mac>::new_from_slice(&keys.hmac_key)
.map_err(|_| JsValue::from_str("Invalid HMAC key"))?;
let mut ctr = <Ctr32BE<Aes256> as KeyIvInit>::new_from_slices(&keys.cipher_key, iv)
.map_err(|_| JsValue::from_str("Invalid CTR parameters"))?;
let mut ctr = <Ctr32BE<Aes256> as KeyIvInit>::new_from_slices(&keys.cipher_key, iv)
.map_err(|_| JsValue::from_str("Invalid CTR parameters"))?;
let initial_reader_position = self.reader.get_position();
let initial_reader_position = self.reader.get_position();
let frame_length =
match get_frame_length(&mut self.reader, &mut hmac, &mut ctr, header_data.version) {
let frame_length = match get_frame_length(
&mut self.reader,
&mut hmac,
&mut ctr,
header_data.version,
) {
Ok(None) => {
// need to reset the position here because getting the length and decrypting the frame rely on
// the same hmac / ctr and if we don't read the position first they won't be correct
@ -465,102 +507,113 @@ impl BackupDecryptor {
Err(e) => return Err(e),
};
// if we got to an attachment, but there we demand more data, it will be faulty, because we try to decrypt the frame although we would need
// to decrypt the attachment
match decrypt_frame(
&mut self.reader,
hmac,
&mut ctr,
&mut self.ciphertext_buf,
&mut self.plaintext_buf,
frame_length,
) {
Ok(None) => {
return Ok(true);
}
Ok(Some(backup_frame)) => {
// can not assign right here because of borrowing issues
let mut new_iv = increment_initialisation_vector(iv);
if backup_frame.end.unwrap_or(false) {
self.initialisation_vector = Some(new_iv);
// if we got to an attachment, but there we demand more data, it will be faulty, because we try to decrypt the frame although we would need
// to decrypt the attachment
match decrypt_frame(
&mut self.reader,
hmac,
&mut ctr,
&mut self.ciphertext_buf,
&mut self.plaintext_buf,
frame_length,
) {
Ok(None) => {
return Ok(true);
}
Ok(Some(backup_frame)) => {
// +4 because the length which was read is 4 bytes long
self.total_bytes_processed += (frame_length + 4) as usize;
let new_percentage = self.call_progress_callback();
self.processed_percentage = new_percentage;
// Handle all frame types
if let Some(version) = backup_frame.version {
if let Some(ver_num) = version.version {
let pragma_sql = format!("PRAGMA user_version = {}", ver_num);
self.database_bytes.extend_from_slice(pragma_sql.as_bytes());
self.database_bytes.push(b';');
// can not assign right here because of borrowing issues
let mut new_iv = increment_initialisation_vector(iv);
if backup_frame.end.unwrap_or(false) {
self.initialisation_vector = Some(new_iv);
return Ok(true);
}
} else if let Some(statement) = backup_frame.statement {
if let Some(sql) = statement.statement {
if !sql.to_lowercase().starts_with("create table sqlite_")
&& !sql.contains("sms_fts_")
&& !sql.contains("mms_fts_")
{
let processed_sql = if !statement.parameters.is_empty() {
let params: Vec<String> = statement
.parameters
.iter()
.map(|param| sql_parameter_to_string(param))
.collect::<Result<_, _>>()?;
process_parameter_placeholders(&sql, &params)?
} else {
sql
};
// Add to concatenated string
self.database_bytes
.extend_from_slice(processed_sql.as_bytes());
// Handle all frame types
if let Some(version) = backup_frame.version {
if let Some(ver_num) = version.version {
let pragma_sql = format!("PRAGMA user_version = {}", ver_num);
self.database_bytes.extend_from_slice(pragma_sql.as_bytes());
self.database_bytes.push(b';');
}
}
} else if backup_frame.preference.is_some() || backup_frame.key_value.is_some() {
} else {
// we just skip these types here
let backup_frame_cloned = backup_frame.clone();
} else if let Some(statement) = backup_frame.statement {
if let Some(sql) = statement.statement {
if !sql.to_lowercase().starts_with("create table sqlite_")
&& !sql.contains("sms_fts_")
&& !sql.contains("mms_fts_")
{
let processed_sql = if !statement.parameters.is_empty() {
let params: Vec<String> = statement
.parameters
.iter()
.map(|param| sql_parameter_to_string(param))
.collect::<Result<_, _>>()?;
let length = if let Some(attachment) = backup_frame_cloned.attachment {
attachment.length.unwrap_or(0)
} else if let Some(sticker) = backup_frame_cloned.sticker {
sticker.length.unwrap_or(0)
} else if let Some(avatar) = backup_frame_cloned.avatar {
avatar.length.unwrap_or(0)
process_parameter_placeholders(&sql, &params)?
} else {
sql
};
// Add to concatenated string
self.database_bytes
.extend_from_slice(processed_sql.as_bytes());
self.database_bytes.push(b';');
}
}
} else if backup_frame.preference.is_some() || backup_frame.key_value.is_some()
{
} else {
return Err(JsValue::from_str("Invalid field type found"));
};
// we just skip these types here
let backup_frame_cloned = backup_frame.clone();
if self.reader.remaining_length() < length as usize {
// important: we need to apply the first new_iv here, else it won't be correct when resuming payload decryption
// as we return, we don't get to the final assignment below
self.initialisation_vector = Some(new_iv);
let length = if let Some(attachment) = backup_frame_cloned.attachment {
attachment.length.unwrap_or(0)
} else if let Some(sticker) = backup_frame_cloned.sticker {
sticker.length.unwrap_or(0)
} else if let Some(avatar) = backup_frame_cloned.avatar {
avatar.length.unwrap_or(0)
} else {
return Err(JsValue::from_str("Invalid field type found"));
};
self.current_backup_frame = Some(backup_frame.clone());
if self.reader.remaining_length() < length as usize {
// important: we need to apply the first new_iv here, else it won't be correct when resuming payload decryption
// as we return, we don't get to the final assignment below
self.initialisation_vector = Some(new_iv);
return Ok(true);
} else {
// attachments are encoded as length, which would have to be read using decode_frame_payload
// +10 because in decrypt_frame_payload we would read `their_mac` from reader which is 10 bytes long
self.reader.increment_position((length + 10) as usize);
self.current_backup_frame = Some(backup_frame.clone());
new_iv = increment_initialisation_vector(&new_iv);
return Ok(true);
} else {
self.total_bytes_processed += (length + 10) as usize;
let new_percentage = self.call_progress_callback();
self.processed_percentage = new_percentage;
// attachments are encoded as length, which would have to be read using decode_frame_payload
// +10 because in decrypt_frame_payload we would read `their_mac` from reader which is 10 bytes long
self.reader.increment_position((length + 10) as usize);
new_iv = increment_initialisation_vector(&new_iv);
}
}
}
// here we can finally assign
self.initialisation_vector = Some(new_iv);
Ok(false)
}
Err(e) => {
if e.as_string()
.map_or(false, |s| s.contains("unexpected end of file"))
{
// here we can finally assign
self.initialisation_vector = Some(new_iv);
Ok(false)
} else {
Err(e)
}
Err(e) => {
if e.as_string()
.map_or(false, |s| s.contains("unexpected end of file"))
{
Ok(false)
} else {
Err(e)
}
}
}
}

20
test/dist/index.js vendored
View file

@ -14,21 +14,17 @@ async function initialize() {
export async function decryptBackup(file, passphrase, progressCallback) {
await initialize();
console.log("Starting decryption of file size:", file.size);
const fileSize = file.size;
console.log("Starting decryption of file size:", fileSize);
const decryptor = new BackupDecryptor();
decryptor.set_progress_callback(fileSize, (percentage) =>
console.info(`${percentage}% done`),
);
const chunkSize = 1024 * 1024 * 40; // 40MB chunks
let offset = 0;
let percent;
try {
while (offset < file.size) {
const newPercent = Math.round((100 / file.size) * offset);
if (newPercent !== percent) {
percent = newPercent;
console.info(`${percent}% done`);
}
// console.log(`Processing chunk at offset ${offset}`);
const chunk = file.slice(offset, offset + chunkSize);
const arrayBuffer = await chunk.arrayBuffer();
@ -59,7 +55,11 @@ export async function decryptBackup(file, passphrase, progressCallback) {
}
// console.log("All chunks processed, finishing up");
return decryptor.finish();
const result = decryptor.finish();
decryptor.free();
return result;
} catch (e) {
console.error("Decryption failed:", e);
throw e;