diff --git a/.gitignore b/.gitignore index a5e75f4b54..a46d50a104 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ website/build -target/ +target Cargo.lock +credentials.json diff --git a/Cargo.toml b/Cargo.toml index bdff826e0c..44051a1ff1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ "boxlocker", - "fxa-rust-client" + "sync15-adapter", + "sync15-adapter/ffi" ] diff --git a/sync15-adapter/Cargo.toml b/sync15-adapter/Cargo.toml new file mode 100644 index 0000000000..fe59b6d711 --- /dev/null +++ b/sync15-adapter/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "sync15-adapter" +version = "0.1.0" +authors = ["Thom Chiovoloni "] + +[dependencies] +base64 = "0.9.0" +serde = "1.0" +serde_derive = "1.0" +serde_json = "1.0" +url = "1.6.0" +reqwest = "0.8.2" +error-chain = "0.11" +openssl = "0.10.7" +hawk = { git = "https://github.com/eoger/rust-hawk", branch = "use-openssl" } +hyper = "0.11" +log = "0.4" +lazy_static = "1.0" + +[dev-dependencies] +env_logger = "0.5" diff --git a/sync15-adapter/examples/boxlocker-parity.rs b/sync15-adapter/examples/boxlocker-parity.rs new file mode 100644 index 0000000000..73f3ac2dc8 --- /dev/null +++ b/sync15-adapter/examples/boxlocker-parity.rs @@ -0,0 +1,188 @@ + +extern crate sync15_adapter as sync; +extern crate error_chain; +extern crate url; +extern crate base64; +extern crate reqwest; + +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; + +extern crate env_logger; + +use std::io::{self, Read, Write}; +use std::error::Error; +use std::fs; +use std::process; +use std::time; +use std::collections::HashMap; +use std::time::{SystemTime, UNIX_EPOCH}; + +#[derive(Debug, Deserialize)] +struct OAuthCredentials { + access_token: String, + refresh_token: String, + keys: HashMap, + expires_in: u64, + auth_at: u64, +} + +#[derive(Debug, Deserialize)] +struct ScopedKeyData { + k: String, + kid: String, + scope: String, +} + +fn do_auth(recur: bool) -> Result> { + match fs::File::open("./credentials.json") { + Err(_) => { + if recur { + panic!("Failed to open credentials 2nd time"); + } + println!("No credentials found, invoking boxlocker.py..."); + process::Command::new("python") + .arg("../boxlocker/boxlocker.py").output() + .expect("Failed to run boxlocker.py"); + return do_auth(true); + }, + Ok(mut file) => { + let mut s = String::new(); + file.read_to_string(&mut s)?; + let creds: OAuthCredentials = serde_json::from_str(&s)?; + let time = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); + if creds.expires_in + creds.auth_at < time { + println!("Warning, credentials may be stale."); + } + Ok(creds) + } + } +} + +fn prompt_string>(prompt: S) -> Option { + print!("{}: ", prompt.as_ref()); + let _ = io::stdout().flush(); // Don't care if flush fails really. + let mut s = String::new(); + io::stdin().read_line(&mut s).expect("Failed to read line..."); + if let Some('\n') = s.chars().next_back() { s.pop(); } + if let Some('\r') = s.chars().next_back() { s.pop(); } + if s.len() == 0 { + None + } else { + Some(s) + } +} + +fn read_login() -> sync::record_types::PasswordRecord { + let username = prompt_string("username").unwrap_or(String::new()); + let password = prompt_string("password").unwrap_or(String::new()); + let form_submit_url = prompt_string("form_submit_url"); + let hostname = prompt_string("hostname"); + let http_realm = prompt_string("http_realm"); + let username_field = prompt_string("username_field").unwrap_or(String::new()); + let password_field = prompt_string("password_field").unwrap_or(String::new()); + let since_unix_epoch = time::SystemTime::now().duration_since(time::UNIX_EPOCH).unwrap(); + let dur_ms = since_unix_epoch.as_secs() * 1000 + ((since_unix_epoch.subsec_nanos() / 1_000_000) as u64); + let ms_i64 = dur_ms as i64; + sync::record_types::PasswordRecord { + id: sync::util::random_guid().unwrap(), + username, + password, + username_field, + password_field, + form_submit_url, + http_realm, + hostname, + time_created: ms_i64, + time_password_changed: ms_i64, + times_used: None, + time_last_used: Some(ms_i64), + } +} + +fn prompt_bool(msg: &str) -> Option { + let result = prompt_string(msg); + result.and_then(|r| match r.chars().next().unwrap() { + 'y' | 'Y' | 't' | 'T' => Some(true), + 'n' | 'N' | 'f' | 'F' => Some(false), + _ => None + }) +} + +fn prompt_chars(msg: &str) -> Option { + prompt_string(msg).and_then(|r| r.chars().next()) +} + +fn start() -> Result<(), Box> { + let oauth_data = do_auth(false)?; + + let scope = &oauth_data.keys["https://identity.mozilla.com/apps/oldsync"]; + + let mut svc = sync::Sync15Service::new( + sync::Sync15ServiceInit { + key_id: scope.kid.clone(), + sync_key: scope.k.clone(), + access_token: oauth_data.access_token.clone(), + tokenserver_base_url: "https://oauth-sync.dev.lcip.org/syncserver/token".into(), + } + )?; + + svc.remote_setup()?; + let passwords = svc.all_records::("passwords")? + .into_iter() + .filter_map(|r| r.record()) + .collect::>(); + + println!("Found {} passwords", passwords.len()); + + for pw in passwords.iter() { + println!("{:?}", pw.payload); + } + + if !prompt_bool("Would you like to make changes? [y/N]").unwrap_or(false) { + return Ok(()); + } + + let mut ids: Vec = passwords.iter().map(|p| p.id.clone()).collect(); + + let mut upd = sync::CollectionUpdate::new(&svc, false); + loop { + match prompt_chars("Add, delete, or commit [adc]:").unwrap_or('s') { + 'A' | 'a' => { + let record = read_login(); + upd.add_record(record); + }, + 'D' | 'd' => { + for (i, id) in ids.iter().enumerate() { + println!("{}: {}", i, id); + } + if let Some(index) = prompt_string("Index to delete (enter index)").and_then(|x| x.parse::().ok()) { + let result = ids.swap_remove(index); + upd.add_tombstone(result); + } else { + println!("???"); + } + }, + 'C' | 'c' => { + println!("committing!"); + let (good, bad) = upd.upload()?; + println!("Uploded {} ids successfully, and {} unsuccessfully", + good.len(), bad.len()); + break; + }, + c => { + println!("Unknown action '{}', exiting.", c); + break; + } + } + } + + Ok(()) +} + +fn main() { + env_logger::init(); + start().unwrap(); +} diff --git a/sync15-adapter/ffi/Cargo.toml b/sync15-adapter/ffi/Cargo.toml new file mode 100644 index 0000000000..76c03d590e --- /dev/null +++ b/sync15-adapter/ffi/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "sync-adapter-ffi" +version = "0.1.0" +authors = ["Thom Chiovoloni "] + +[lib] +name = "sync_adapter" +crate-type = ["staticlib"] + +[dependencies] +libc = "0.2" + +[dependencies.sync15-adapter] +path = "../" diff --git a/sync15-adapter/ffi/README.md b/sync15-adapter/ffi/README.md new file mode 100644 index 0000000000..9afa06533a --- /dev/null +++ b/sync15-adapter/ffi/README.md @@ -0,0 +1,9 @@ +# Sync 1.5 Client FFI + +This README is shamelessly stolen from the one in the fxa client directory + +## iOS build + +- Make sure you have the nightly compiler in order to get LLVM Bitcode generation. +- Install [cargo-lipo](https://github.com/TimNN/cargo-lipo/#installation). +- Build with: `OPENSSL_DIR=/usr/local/opt/openssl cargo +nightly lipo --release` diff --git a/sync15-adapter/ffi/src/lib.rs b/sync15-adapter/ffi/src/lib.rs new file mode 100644 index 0000000000..812d7b0b14 --- /dev/null +++ b/sync15-adapter/ffi/src/lib.rs @@ -0,0 +1,206 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +extern crate sync15_adapter as sync; +extern crate libc; + +use sync::record_types::{PasswordRecord}; +use std::ffi::{CStr, CString}; +use libc::c_char; +use std::ptr; + +fn c_char_to_string(cchar: *const c_char) -> String { + let c_str = unsafe { CStr::from_ptr(cchar) }; + let r_str = c_str.to_str().unwrap_or(""); + r_str.to_string() +} + +fn string_to_c_char(s: String) -> *mut c_char { + CString::new(s).unwrap().into_raw() +} + +fn opt_string_to_c_char(os: Option) -> *mut c_char { + match os { + Some(s) => string_to_c_char(s), + _ => ptr::null_mut(), + } +} + +#[repr(C)] +pub struct PasswordRecordC { + pub id: *mut c_char, + + /// Might be null! + pub hostname: *mut c_char, + + /// Might be null! + pub form_submit_url: *mut c_char, + pub http_realm: *mut c_char, + + pub username: *mut c_char, + pub password: *mut c_char, + + pub username_field: *mut c_char, + pub password_field: *mut c_char, + + /// In ms since unix epoch + pub time_created: i64, + + /// In ms since unix epoch + pub time_password_changed: i64, + + /// -1 for missing, otherwise in ms_since_unix_epoch + pub time_last_used: i64, + + /// -1 for missing + pub times_used: i64, +} + +unsafe fn drop_cchar_ptr(s: *mut c_char) { + if !s.is_null() { + let _ = CString::from_raw(s); + } +} + +impl Drop for PasswordRecordC { + fn drop(&mut self) { + unsafe { + drop_cchar_ptr(self.id); + drop_cchar_ptr(self.hostname); + drop_cchar_ptr(self.form_submit_url); + drop_cchar_ptr(self.http_realm); + drop_cchar_ptr(self.username); + drop_cchar_ptr(self.password); + drop_cchar_ptr(self.username_field); + drop_cchar_ptr(self.password_field); + } + } +} + +impl From for PasswordRecordC { + fn from(record: PasswordRecord) -> PasswordRecordC { + PasswordRecordC { + id: string_to_c_char(record.id), + hostname: opt_string_to_c_char(record.hostname), + form_submit_url: opt_string_to_c_char(record.form_submit_url), + http_realm: opt_string_to_c_char(record.http_realm), + username: string_to_c_char(record.username), + password: string_to_c_char(record.password), + username_field: string_to_c_char(record.username_field), + password_field: string_to_c_char(record.password_field), + time_created: record.time_created, + time_password_changed: record.time_password_changed, + time_last_used: record.time_last_used.unwrap_or(-1), + times_used: record.times_used.unwrap_or(-1), + } + } +} + +#[no_mangle] +pub extern "C" fn sync15_service_create( + key_id: *const c_char, + access_token: *const c_char, + sync_key: *const c_char, + tokenserver_base_url: *const c_char +) -> *mut sync::Sync15Service { + let params = sync::Sync15ServiceInit { + key_id: c_char_to_string(key_id), + access_token: c_char_to_string(access_token), + sync_key: c_char_to_string(sync_key), + tokenserver_base_url: c_char_to_string(tokenserver_base_url), + }; + let mut boxed = match sync::Sync15Service::new(params) { + Ok(svc) => Box::new(svc), + Err(e) => { + println!("Unexpected error initializing Sync15Service: {}", e); + // TODO: have thoughts about error handling. + return ptr::null_mut(); + } + }; + if let Err(e) = boxed.remote_setup() { + println!("Unexpected error performing remote sync setup: {}", e); + // TODO: have thoughts about error handling here too. + return ptr::null_mut(); + } + Box::into_raw(boxed) +} + +#[no_mangle] +pub extern "C" fn sync15_service_destroy(svc: *mut sync::Sync15Service) { + let _ = unsafe { Box::from_raw(svc) }; +} + +// This is opaque to C +pub struct PasswordCollection { + pub records: Vec, + pub tombstones: Vec, +} + +#[no_mangle] +pub extern "C" fn sync15_service_request_passwords( + svc: *mut sync::Sync15Service +) -> *mut PasswordCollection { + let service = unsafe { &mut *svc }; + let passwords = match service.all_records::("passwords") { + Ok(pws) => pws, + Err(e) => { + // TODO: error handling... + println!("Unexpected error downloading passwords {}", e); + return ptr::null_mut(); + } + }; + let mut tombstones = vec![]; + let mut records = vec![]; + for obj in passwords { + match obj.payload { + sync::Tombstone { id, .. } => tombstones.push(id), + sync::NonTombstone(record) => records.push(record), + } + } + Box::into_raw(Box::new(PasswordCollection { records, tombstones })) +} + +#[no_mangle] +pub extern "C" fn sync15_passwords_destroy(coll: *mut PasswordCollection) { + let _ = unsafe { Box::from_raw(coll) }; +} + +#[no_mangle] +pub extern "C" fn sync15_passwords_tombstone_count(coll: *const PasswordCollection) -> libc::size_t { + let coll = unsafe { &*coll }; + coll.tombstones.len() as libc::size_t +} + +#[no_mangle] +pub extern "C" fn sync15_passwords_record_count(coll: *const PasswordCollection) -> libc::size_t { + let coll = unsafe { &*coll }; + coll.records.len() as libc::size_t +} + +#[no_mangle] +pub extern "C" fn sync15_passwords_get_tombstone_at( + coll: *const PasswordCollection, + index: libc::size_t +) -> *mut c_char { + let coll = unsafe { &*coll }; + opt_string_to_c_char(coll.tombstones.get(index as usize).cloned()) +} + +#[no_mangle] +pub extern "C" fn sync15_passwords_get_record_at( + coll: *const PasswordCollection, + index: libc::size_t +) -> *mut PasswordRecordC { + let coll = unsafe { &*coll }; + match coll.records.get(index as usize) { + Some(r) => Box::into_raw(Box::new(r.clone().into())), + None => ptr::null_mut(), + } +} + +#[no_mangle] +pub extern "C" fn sync15_password_record_destroy(pw: *mut PasswordRecordC) { + // Our drop impl takes care of our strings. + let _ = unsafe { Box::from_raw(pw) }; +} diff --git a/sync15-adapter/ffi/sync_adapter.h b/sync15-adapter/ffi/sync_adapter.h new file mode 100644 index 0000000000..db3f4a9b8f --- /dev/null +++ b/sync15-adapter/ffi/sync_adapter.h @@ -0,0 +1,63 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ +#ifndef SYNC_ADAPTER_15_H +#define SYNC_ADAPTER_15_H +// size_t +#include +// int64_t +#include + +typedef struct sync15_PasswordRecord sync15_PasswordRecord; +typedef struct sync15_PasswordCollection sync15_PasswordCollection; +typedef struct sync15_Service sync15_Service; + +struct sync15_PasswordRecord { + const char* id; + // Might be null! + const char* hostname; + // Might be null! + const char* form_submit_url; + const char* http_realm; + + const char* username; + const char* password; + + const char* username_field; + const char* password_field; + + // In ms since unix epoch + int64_t time_created; + + // In ms since unix epoch + int64_t time_password_changed; + + // -1 for missing, otherwise in ms_since_unix_epoch + int64_t time_last_used; + + // -1 for missing + int64_t times_used; +}; + +sync15_Service *sync15_service_create(const char* key_id, + const char* access_token, + const char* sync_key, + const char* tokenserver_base_url); + +void sync15_service_destroy(sync15_Service* svc); + +sync15_PasswordCollection* sync15_service_request_passwords(sync15_Service* svc); +void sync15_passwords_destroy(sync15_PasswordCollection *passwords); + +size_t sync15_passwords_record_count(const sync15_PasswordCollection* passwords); +size_t sync15_passwords_tombstone_count(const sync15_PasswordCollection* passwords); + +// Caller frees! Returns null if index > sync15_passwords_tombstone_count(passwords) +char *sync15_passwords_get_tombstone_at(const sync15_PasswordCollection* pws, size_t i); + +// Caller frees (via sync15_password_record_free) Returns null if index > sync15_passwords_record_count(pws) +sync15_PasswordRecord* sync15_passwords_get_record_at(const sync15_PasswordCollection* pws, size_t i); + +void sync15_password_record_destroy(sync15_PasswordRecord *record); + +#endif diff --git a/sync15-adapter/src/bso_record.rs b/sync15-adapter/src/bso_record.rs new file mode 100644 index 0000000000..73c25f5db6 --- /dev/null +++ b/sync15-adapter/src/bso_record.rs @@ -0,0 +1,287 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use serde::de::DeserializeOwned; +use serde::ser::Serialize; +use serde_json; +use error; +use base64; +use std::ops::{Deref, DerefMut}; +use std::convert::From; +use key_bundle::KeyBundle; +use util::ServerTimestamp; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct BsoRecord { + pub id: String, + + // It's not clear to me if this actually can be empty in practice. + // firefox-ios seems to think it can... + #[serde(default = "String::new")] + pub collection: String, + + #[serde(skip_serializing)] + // If we don't give it a default, we fail to deserialize + // items we wrote out during tests and such. + #[serde(default = "ServerTimestamp::default")] + pub modified: ServerTimestamp, + + #[serde(skip_serializing_if = "Option::is_none")] + pub sortindex: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, + + // We do some serde magic here with serde to parse the payload from JSON as we deserialize. + // This avoids having a separate intermediate type that only exists so that we can deserialize + // it's payload field as JSON (Especially since this one is going to exist more-or-less just so + // that we can decrypt the data...) + #[serde(with = "as_json", bound( + serialize = "T: Serialize", + deserialize = "T: DeserializeOwned"))] + pub payload: T, +} + +impl BsoRecord { + #[inline] + pub fn map_payload(self, mapper: F) -> BsoRecord

where F: FnOnce(T) -> P { + BsoRecord { + id: self.id, + collection: self.collection, + modified: self.modified, + sortindex: self.sortindex, + ttl: self.ttl, + payload: mapper(self.payload), + } + } + + #[inline] + pub fn with_payload

(self, payload: P) -> BsoRecord

{ + self.map_payload(|_| payload) + } +} + +/// Marker trait that indicates that something is a sync record type. By not implementing this +/// for EncryptedPayload, we can statically prevent double-encrypting. +pub trait Sync15Record: Clone + DeserializeOwned + Serialize { + fn collection_tag() -> &'static str; + fn record_id(&self) -> &str; + + // Max TTL is actually 31536000, weirdly. + #[inline] + fn ttl() -> Option { None } + + #[inline] + fn sortindex(&self) -> Option { None } +} + +impl From for BsoRecord where T: Sync15Record { + #[inline] + fn from(payload: T) -> BsoRecord { + let id = payload.record_id().into(); + let collection = T::collection_tag().into(); + let sortindex = payload.sortindex(); + BsoRecord { + id, collection, payload, sortindex, + modified: ServerTimestamp(0.0), + ttl: T::ttl(), + } + } +} + +impl BsoRecord> { + /// Helper to improve ergonomics for handling records that might be tombstones. + #[inline] + pub fn transpose(self) -> Option> { + let BsoRecord { id, collection, modified, sortindex, ttl, payload } = self; + match payload { + Some(p) => Some(BsoRecord { id, collection, modified, sortindex, ttl, payload: p }), + None => None + } + } +} + +impl Deref for BsoRecord { + type Target = T; + #[inline] + fn deref(&self) -> &T { + &self.payload + } +} + +impl DerefMut for BsoRecord { + #[inline] + fn deref_mut(&mut self) -> &mut T { + &mut self.payload + } +} + +impl BsoRecord { + /// If T is a Sync15Record, then you can/should just use record.into() instead! + #[inline] + pub fn new_non_record, C: Into>(id: I, coll: C, payload: T) -> BsoRecord { + BsoRecord { + id: id.into(), + collection: coll.into(), + ttl: None, + sortindex: None, + modified: ServerTimestamp::default(), + payload, + } + } +} + +// Contains the methods to automatically deserialize the payload to/from json. +mod as_json { + use serde_json; + use serde::de::{self, Deserialize, DeserializeOwned, Deserializer}; + use serde::ser::{self, Serialize, Serializer}; + + pub fn serialize(t: &T, serializer: S) -> Result + where T: Serialize, S: Serializer { + let j = serde_json::to_string(t).map_err(ser::Error::custom)?; + serializer.serialize_str(&j) + } + + pub fn deserialize<'de, T, D>(deserializer: D) -> Result + where T: DeserializeOwned, D: Deserializer<'de> { + let j = String::deserialize(deserializer)?; + serde_json::from_str(&j).map_err(de::Error::custom) + } +} + +#[derive(Deserialize, Serialize, Clone, Debug)] +pub struct EncryptedPayload { + #[serde(rename = "IV")] + pub iv: String, + pub hmac: String, + pub ciphertext: String, +} + +// This is a little cludgey but I couldn't think of another way to have easy deserialization +// without a bunch of wrapper types, while still only serializing a single time in the +// postqueue. +lazy_static! { + // The number of bytes taken up by padding in a EncryptedPayload. + static ref EMPTY_ENCRYPTED_PAYLOAD_SIZE: usize = serde_json::to_string( + &EncryptedPayload { iv: "".into(), hmac: "".into(), ciphertext: "".into() } + ).unwrap().len(); +} + +impl EncryptedPayload { + #[inline] + pub fn serialized_len(&self) -> usize { + (*EMPTY_ENCRYPTED_PAYLOAD_SIZE) + self.ciphertext.len() + self.hmac.len() + self.iv.len() + } +} + +impl BsoRecord { + pub fn decrypt(self, key: &KeyBundle) -> error::Result> where T: DeserializeOwned { + if !key.verify_hmac_string(&self.payload.hmac, &self.payload.ciphertext)? { + return Err(error::ErrorKind::HmacMismatch.into()); + } + + let iv = base64::decode(&self.payload.iv)?; + let ciphertext = base64::decode(&self.payload.ciphertext)?; + let cleartext = key.decrypt(&ciphertext, &iv)?; + + let new_payload = serde_json::from_str::(&cleartext)?; + + let result = self.with_payload(new_payload); + Ok(result) + } +} + +impl BsoRecord where T: Sync15Record { + pub fn encrypt(self, key: &KeyBundle) -> error::Result> { + let cleartext = serde_json::to_string(&self.payload)?; + let (enc_bytes, iv) = key.encrypt_bytes_rand_iv(&cleartext.as_bytes())?; + let iv_base64 = base64::encode(&iv); + let enc_base64 = base64::encode(&enc_bytes); + let hmac = key.hmac_string(enc_base64.as_bytes())?; + let result = self.with_payload(EncryptedPayload { + iv: iv_base64, + hmac: hmac, + ciphertext: enc_base64, + }); + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_deserialize_enc() { + let serialized = r#"{ + "id": "1234", + "collection": "passwords", + "modified": 12344321.0, + "payload": "{\"IV\": \"aaaaa\", \"hmac\": \"bbbbb\", \"ciphertext\": \"ccccc\"}" + }"#; + let record: BsoRecord = serde_json::from_str(serialized).unwrap(); + assert_eq!(&record.id, "1234"); + assert_eq!(&record.collection, "passwords"); + assert_eq!(record.modified.0, 12344321.0); + assert_eq!(&record.payload.iv, "aaaaa"); + assert_eq!(&record.payload.hmac, "bbbbb"); + assert_eq!(&record.payload.ciphertext, "ccccc"); + } + + #[test] + fn test_serialize_enc() { + let goal = r#"{"id":"1234","collection":"passwords","payload":"{\"IV\":\"aaaaa\",\"hmac\":\"bbbbb\",\"ciphertext\":\"ccccc\"}"}"#; + let record = BsoRecord { + id: "1234".into(), + modified: ServerTimestamp(999.0), // shouldn't be serialized by client no matter what it's value is + collection: "passwords".into(), + sortindex: None, + ttl: None, + payload: EncryptedPayload { + iv: "aaaaa".into(), + hmac: "bbbbb".into(), + ciphertext: "ccccc".into(), + } + }; + let actual = serde_json::to_string(&record).unwrap(); + assert_eq!(actual, goal); + + let val_str_payload: serde_json::Value = serde_json::from_str(goal).unwrap(); + assert_eq!(val_str_payload["payload"].as_str().unwrap().len(), + record.payload.serialized_len()) + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + struct MyRecordType { + id: String, + data: String, + idx: i32, + } + + impl Sync15Record for MyRecordType { + fn collection_tag() -> &'static str { "my_cool_records" } + fn record_id(&self) -> &str { &self.id } + // 3 years in seconds + fn ttl() -> Option { Some(3 * 365 * 24 * 60 * 60) } + fn sortindex(&self) -> Option { Some(self.idx) } + } + + #[test] + fn test_sync15record() { + let record: MyRecordType = MyRecordType { + id: "aaabbbcccddd".into(), + data: "this is extremely good and cool data".into(), + idx: 9001 + }; + let bso: BsoRecord = record.into(); + let s = serde_json::to_string(&bso).unwrap(); + let out: serde_json::Value = serde_json::from_str(&s).unwrap(); + let ttl = 3*365*24*60*60; + assert_eq!(out["ttl"], json!(ttl)); + assert_eq!(out["sortindex"], json!(9001)); + assert_eq!(out["id"], json!("aaabbbcccddd")); + assert_eq!(out["collection"], json!("my_cool_records")); + } + +} diff --git a/sync15-adapter/src/collection_keys.rs b/sync15-adapter/src/collection_keys.rs new file mode 100644 index 0000000000..13be35fdfc --- /dev/null +++ b/sync15-adapter/src/collection_keys.rs @@ -0,0 +1,61 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use bso_record::{BsoRecord, Sync15Record, EncryptedPayload}; +use key_bundle::KeyBundle; +use std::collections::HashMap; +use error::Result; + +#[derive(Deserialize, Serialize, Clone, Debug, Eq, PartialEq)] +struct CryptoKeysRecord { + pub id: String, + pub collection: String, + pub default: [String; 2], + pub collections: HashMap +} + +impl Sync15Record for CryptoKeysRecord { + fn collection_tag() -> &'static str { "crypto" } + fn record_id(&self) -> &str { "keys" } +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct CollectionKeys { + pub default: KeyBundle, + pub collections: HashMap +} + +impl CollectionKeys { + pub fn from_encrypted_bso(record: BsoRecord, root_key: &KeyBundle) -> Result { + let keys: BsoRecord = record.decrypt(root_key)?; + Ok(CollectionKeys { + default: KeyBundle::from_base64(&keys.payload.default[0], &keys.payload.default[1])?, + collections: + keys.payload.collections + .into_iter() + .map(|kv| Ok((kv.0, KeyBundle::from_base64(&kv.1[0], &kv.1[1])?))) + .collect::>>()? + }) + } + + fn to_bso(&self) -> BsoRecord { + CryptoKeysRecord { + id: "keys".into(), + collection: "crypto".into(), + default: self.default.to_b64_array(), + collections: self.collections.iter().map(|kv| + (kv.0.clone(), kv.1.to_b64_array())).collect() + }.into() + } + + #[inline] + pub fn to_encrypted_bso(&self, root_key: &KeyBundle) -> Result> { + self.to_bso().encrypt(root_key) + } + + #[inline] + pub fn key_for_collection<'a>(&'a self, collection: &str) -> &'a KeyBundle { + self.collections.get(collection).unwrap_or(&self.default) + } +} diff --git a/sync15-adapter/src/error.rs b/sync15-adapter/src/error.rs new file mode 100644 index 0000000000..61509d0744 --- /dev/null +++ b/sync15-adapter/src/error.rs @@ -0,0 +1,89 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +error_chain! { + foreign_links { + Base64Decode(::base64::DecodeError); + OpensslError(::openssl::error::ErrorStack); + BadCleartextUtf8(::std::string::FromUtf8Error); + JsonError(::serde_json::Error); + BadUrl(::reqwest::UrlError); + RequestError(::reqwest::Error); + HawkError(::hawk::Error); + } + errors { + BadKeyLength(which_key: &'static str, length: usize) { + description("Incorrect key length") + display("Incorrect key length for key {}: {}", which_key, length) + } + // Not including `expected` and `is`, since they don't seem useful and are inconvenient + // to include. If we decide we want them it's not too bad to include. + HmacMismatch { + description("SHA256 HMAC Mismatch error") + display("SHA256 HMAC Mismatch error") + } + + // Used when a BSO should be decrypted but is encrypted, or vice versa. + BsoWrongCryptState(is_decrypted: bool) { + description("BSO in wrong encryption state for operation") + display("Expected {} BSO, but got a(n) {} one", + if *is_decrypted { "encrypted" } else { "decrypted" }, + if *is_decrypted { "decrypted" } else { "encrypted" }) + } + + // Error from tokenserver. Ideally we should probably do a better job here... + TokenserverHttpError(code: ::reqwest::StatusCode) { + description("HTTP status when requesting a token from the tokenserver") + display("HTTP status {} when requesting a token from the tokenserver", code) + } + + // As above, but for storage requests + StorageHttpError(code: ::reqwest::StatusCode, route: String) { + description("HTTP error status when making a request to storage server") + display("HTTP status {} during a storage request to \"{}\"", code, route) + } + + BackoffError(retry_after_secs: f64) { + description("Server requested backoff") + display("Server requested backoff. Retry after {} seconds.", retry_after_secs) + } + + // This might just be a NYI, since IDK if we want to upload this. + NoMetaGlobal { + description("No meta global on server for user") + display("No meta global on server for user") + } + + // We should probably get rid of the ones of these that are actually possible, + // but I'd like to get this done rather than spend tons of time worrying about + // the right error types for now (but at the same time, I'd rather not unwrap) + UnexpectedError(message: String) { + description("Unexpected error") + display("Unexpected error: {}", message) + } + + RecordTooLargeError { + description("Record is larger than the maximum size allowed by the server") + display("Record is larger than the maximum size allowed by the server") + } + + BatchInterrupted { + description("Batch interrupted: server responded with 412") + display("Batch interrupted: server responded with 412") + } + + RecordUploadFailed(problems: ::std::collections::HashMap) { + description("Some records failed to upload, but success was required for the collection") + display("Several records failed to upload ({}), but success was required for the collection", + problems.len()) + } + } +} + +// Boilerplate helper... +pub fn unexpected(s: S) -> Error where S: Into { + ErrorKind::UnexpectedError(s.into()).into() +} + + diff --git a/sync15-adapter/src/key_bundle.rs b/sync15-adapter/src/key_bundle.rs new file mode 100644 index 0000000000..2f8590f703 --- /dev/null +++ b/sync15-adapter/src/key_bundle.rs @@ -0,0 +1,211 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use error::{Result, ErrorKind}; +use util::base16_encode; +use base64; +use openssl::{self, symm}; +use openssl::hash::MessageDigest; +use openssl::pkey::PKey; +use openssl::sign::Signer; + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct KeyBundle { + enc_key: Vec, + mac_key: Vec, +} + +impl KeyBundle { + + /// Construct a key bundle from the already-decoded encrypt and hmac keys. + /// Panics (asserts) if they aren't both 32 bytes. + pub fn new(enc: Vec, mac: Vec) -> Result { + if enc.len() != 32 { + // We probably should say which is bad... + return Err(ErrorKind::BadKeyLength("enc_key", enc.len()).into()); + } + if mac.len() != 32 { + return Err(ErrorKind::BadKeyLength("mac_key", mac.len()).into()); + } + Ok(KeyBundle { enc_key: enc, mac_key: mac }) + } + + pub fn new_random() -> Result { + let mut buffer = [0u8; 64]; + openssl::rand::rand_bytes(&mut buffer)?; + KeyBundle::from_ksync_bytes(&buffer) + } + + pub fn from_ksync_bytes(ksync: &[u8]) -> Result { + if ksync.len() != 64 { + return Err(ErrorKind::BadKeyLength("kSync", ksync.len()).into()); + } + Ok(KeyBundle { + enc_key: ksync[0..32].into(), + mac_key: ksync[32..64].into() + }) + } + + pub fn from_ksync_base64(ksync: &str) -> Result { + let bytes = base64::decode_config(&ksync, base64::URL_SAFE_NO_PAD)?; + KeyBundle::from_ksync_bytes(&bytes) + } + + pub fn from_base64(enc: &str, mac: &str) -> Result { + let enc_bytes = base64::decode(&enc)?; + let mac_bytes = base64::decode(&mac)?; + KeyBundle::new(enc_bytes.into(), mac_bytes.into()) + } + + #[inline] + pub fn encryption_key(&self) -> &[u8] { + &self.enc_key + } + + #[inline] + pub fn hmac_key(&self) -> &[u8] { + &self.mac_key + } + + #[inline] + pub fn to_b64_array(&self) -> [String; 2] { + [base64::encode(&self.enc_key), base64::encode(&self.mac_key)] + } + + /// Returns the 32 byte digest by value since it's small enough to be passed + /// around cheaply, and easily convertable into a slice or vec if you want. + fn hmac(&self, ciphertext: &[u8]) -> Result<[u8; 32]> { + let mut out = [0u8; 32]; + let key = PKey::hmac(self.hmac_key())?; + let mut signer = Signer::new(MessageDigest::sha256(), &key)?; + signer.update(ciphertext)?; + let size = signer.sign(&mut out)?; + // This isn't an Err since it really should not be possible. + assert!(size == 32, "Somehow the 256 bits from sha256 do not add up into 32 bytes..."); + Ok(out) + } + + pub fn hmac_string(&self, ciphertext: &[u8]) -> Result { + Ok(base16_encode(&self.hmac(ciphertext)?)) + } + + pub fn verify_hmac(&self, expected_hmac: &[u8], ciphertext_base64: &str) -> Result { + let computed_hmac = self.hmac(ciphertext_base64.as_bytes())?; + // I suspect this is unnecessary for our case, but the rust-openssl docs + // want us to use this over == to avoid sidechannels, and who am I to argue? + Ok(openssl::memcmp::eq(&expected_hmac, &computed_hmac)) + } + + pub fn verify_hmac_string(&self, expected_hmac: &str, ciphertext_base64: &str) -> Result { + let computed_hmac = self.hmac(ciphertext_base64.as_bytes())?; + let computed_hmac_string = base16_encode(&computed_hmac); + Ok(openssl::memcmp::eq(&expected_hmac.as_bytes(), &computed_hmac_string.as_bytes())) + } + + /// Decrypt the provided ciphertext with the given iv, and decodes the + /// result as a utf8 string. Important: Caller must check verify_hmac first! + pub fn decrypt(&self, ciphertext: &[u8], iv: &[u8]) -> Result { + let cleartext_bytes = symm::decrypt(symm::Cipher::aes_256_cbc(), + self.encryption_key(), + Some(iv), + ciphertext)?; + let cleartext = String::from_utf8(cleartext_bytes)?; + Ok(cleartext) + } + + /// Encrypt using the provided IV. + pub fn encrypt_bytes_with_iv(&self, cleartext_bytes: &[u8], iv: &[u8]) -> Result> { + let ciphertext = symm::encrypt(symm::Cipher::aes_256_cbc(), + self.encryption_key(), + Some(iv), + cleartext_bytes)?; + Ok(ciphertext) + } + + /// Generate a random iv and encrypt with it. Return both the encrypted bytes + /// and the generated iv. + pub fn encrypt_bytes_rand_iv(&self, cleartext_bytes: &[u8]) -> Result<(Vec, [u8; 16])> { + let mut iv = [0u8; 16]; + openssl::rand::rand_bytes(&mut iv)?; + let ciphertext = self.encrypt_bytes_with_iv(cleartext_bytes, &iv)?; + Ok((ciphertext, iv)) + } + + pub fn encrypt_with_iv(&self, cleartext: &str, iv: &[u8]) -> Result> { + self.encrypt_bytes_with_iv(cleartext.as_bytes(), iv) + } + + pub fn encrypt_rand_iv(&self, cleartext: &str) -> Result<(Vec, [u8; 16])> { + self.encrypt_bytes_rand_iv(cleartext.as_bytes()) + } +} + +#[cfg(test)] +mod test { + use super::*; + + static HMAC_B16: &'static str = "b1e6c18ac30deb70236bc0d65a46f7a4dce3b8b0e02cf92182b914e3afa5eebc"; + static IV_B64: &'static str = "GX8L37AAb2FZJMzIoXlX8w=="; + static HMAC_KEY_B64: &'static str = "MMntEfutgLTc8FlTLQFms8/xMPmCldqPlq/QQXEjx70="; + static ENC_KEY_B64: &'static str ="9K/wLdXdw+nrTtXo4ZpECyHFNr4d7aYHqeg3KW9+m6Q="; + + static CIPHERTEXT_B64_PIECES: &'static [&'static str] = &[ + "NMsdnRulLwQsVcwxKW9XwaUe7ouJk5Wn80QhbD80l0HEcZGCynh45qIbeYBik0lgcHbK", + "mlIxTJNwU+OeqipN+/j7MqhjKOGIlvbpiPQQLC6/ffF2vbzL0nzMUuSyvaQzyGGkSYM2", + "xUFt06aNivoQTvU2GgGmUK6MvadoY38hhW2LCMkoZcNfgCqJ26lO1O0sEO6zHsk3IVz6", + "vsKiJ2Hq6VCo7hu123wNegmujHWQSGyf8JeudZjKzfi0OFRRvvm4QAKyBWf0MgrW1F8S", + "FDnVfkq8amCB7NhdwhgLWbN+21NitNwWYknoEWe1m6hmGZDgDT32uxzWxCV8QqqrpH/Z", + "ggViEr9uMgoy4lYaWqP7G5WKvvechc62aqnsNEYhH26A5QgzmlNyvB+KPFvPsYzxDnSC", + "jOoRSLx7GG86wT59QZw=" + ]; + + static CLEARTEXT_B64_PIECES: &'static [&'static str] = &[ + "eyJpZCI6IjVxUnNnWFdSSlpYciIsImhpc3RVcmkiOiJmaWxlOi8vL1VzZXJzL2phc29u", + "L0xpYnJhcnkvQXBwbGljYXRpb24lMjBTdXBwb3J0L0ZpcmVmb3gvUHJvZmlsZXMva3Nn", + "ZDd3cGsuTG9jYWxTeW5jU2VydmVyL3dlYXZlL2xvZ3MvIiwidGl0bGUiOiJJbmRleCBv", + "ZiBmaWxlOi8vL1VzZXJzL2phc29uL0xpYnJhcnkvQXBwbGljYXRpb24gU3VwcG9ydC9G", + "aXJlZm94L1Byb2ZpbGVzL2tzZ2Q3d3BrLkxvY2FsU3luY1NlcnZlci93ZWF2ZS9sb2dz", + "LyIsInZpc2l0cyI6W3siZGF0ZSI6MTMxOTE0OTAxMjM3MjQyNSwidHlwZSI6MX1dfQ==" + ]; + + #[test] + fn test_hmac() { + let key_bundle = KeyBundle::from_base64(ENC_KEY_B64, HMAC_KEY_B64).unwrap(); + let ciphertext_base64 = CIPHERTEXT_B64_PIECES.join(""); + let hmac = key_bundle.hmac_string(ciphertext_base64.as_bytes()).unwrap(); + assert_eq!(hmac, HMAC_B16); + assert!(key_bundle.verify_hmac_string(HMAC_B16, &ciphertext_base64).unwrap()); + } + + #[test] + fn test_decrypt() { + let key_bundle = KeyBundle::from_base64(ENC_KEY_B64, HMAC_KEY_B64).unwrap(); + let ciphertext = base64::decode(&CIPHERTEXT_B64_PIECES.join("")).unwrap(); + let iv = base64::decode(IV_B64).unwrap(); + let s = key_bundle.decrypt(&ciphertext, &iv).unwrap(); + + let cleartext = String::from_utf8( + base64::decode(&CLEARTEXT_B64_PIECES.join("")).unwrap()).unwrap(); + assert_eq!(&cleartext, &s); + } + + #[test] + fn test_encrypt() { + let key_bundle = KeyBundle::from_base64(ENC_KEY_B64, HMAC_KEY_B64).unwrap(); + let iv = base64::decode(IV_B64).unwrap(); + + let cleartext_bytes = base64::decode(&CLEARTEXT_B64_PIECES.join("")).unwrap(); + let encrypted_bytes = key_bundle.encrypt_bytes_with_iv(&cleartext_bytes, &iv).unwrap(); + + let expect_ciphertext = base64::decode(&CIPHERTEXT_B64_PIECES.join("")).unwrap(); + + assert_eq!(&encrypted_bytes, &expect_ciphertext); + + let (enc_bytes2, iv2) = key_bundle.encrypt_bytes_rand_iv(&cleartext_bytes).unwrap(); + assert_ne!(&enc_bytes2, &expect_ciphertext); + + let s = key_bundle.decrypt(&enc_bytes2, &iv2).unwrap(); + assert_eq!(&cleartext_bytes, &s.as_bytes()); + } +} diff --git a/sync15-adapter/src/lib.rs b/sync15-adapter/src/lib.rs new file mode 100644 index 0000000000..47ef63a838 --- /dev/null +++ b/sync15-adapter/src/lib.rs @@ -0,0 +1,52 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +// `error_chain!` can recurse deeply and I guess we're just supposed to live with that... +#![recursion_limit = "1024"] + +extern crate serde; +extern crate base64; +extern crate openssl; +extern crate reqwest; +extern crate hawk; +#[macro_use] +extern crate hyper; + +#[macro_use] +extern crate lazy_static; + +#[macro_use] +extern crate serde_derive; + +#[macro_use] +extern crate log; + +// Right now we only need the `json!` macro in tests, and a raw `#[macro_use]` gives us a warning +#[cfg_attr(test, macro_use)] +extern crate serde_json; + +#[macro_use] +extern crate error_chain; + +extern crate url; + +// TODO: Some of these don't need to be pub... +pub mod key_bundle; +pub mod error; +pub mod bso_record; +pub mod record_types; +pub mod token; +pub mod collection_keys; +pub mod util; +pub mod request; +pub mod service; +pub mod tombstone; + +// Re-export some of the types callers are likely to want for convenience. +pub use bso_record::{BsoRecord, Sync15Record}; +pub use tombstone::{MaybeTombstone, Tombstone, NonTombstone}; +pub use service::{Sync15ServiceInit, Sync15Service, CollectionUpdate}; +pub use error::{Result, Error, ErrorKind}; + +pub use MaybeTombstone::*; diff --git a/sync15-adapter/src/record_types.rs b/sync15-adapter/src/record_types.rs new file mode 100644 index 0000000000..a0872c5ea3 --- /dev/null +++ b/sync15-adapter/src/record_types.rs @@ -0,0 +1,69 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use bso_record::Sync15Record; +use std::collections::HashMap; + +// Known record formats. + +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PasswordRecord { + pub id: String, + pub hostname: Option, + + // rename_all = "camelCase" by default will do formSubmitUrl, but we can just + // override this one field. + #[serde(rename = "formSubmitURL")] + pub form_submit_url: Option, + + pub http_realm: Option, + + #[serde(default = "String::new")] + pub username: String, + + pub password: String, + + #[serde(default = "String::new")] + pub username_field: String, + + #[serde(default = "String::new")] + pub password_field: String, + + pub time_created: i64, + pub time_password_changed: i64, + + #[serde(skip_serializing_if = "Option::is_none")] + pub time_last_used: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub times_used: Option, +} + +impl Sync15Record for PasswordRecord { + fn collection_tag() -> &'static str { "passwords" } + fn record_id(&self) -> &str { &self.id } +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MetaGlobalEngine { + pub version: usize, + #[serde(rename = "syncID")] + pub sync_id: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MetaGlobalRecord { + #[serde(rename = "syncID")] + pub sync_id: String, + #[serde(rename = "storageVersion")] + pub storage_version: usize, + pub engines: HashMap, + pub declined: Vec, +} + +impl Sync15Record for MetaGlobalRecord { + fn collection_tag() -> &'static str { "meta" } + fn record_id(&self) -> &str { "global" } +} diff --git a/sync15-adapter/src/request.rs b/sync15-adapter/src/request.rs new file mode 100644 index 0000000000..19b94398d1 --- /dev/null +++ b/sync15-adapter/src/request.rs @@ -0,0 +1,1202 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use util::ServerTimestamp; +use bso_record::{BsoRecord, EncryptedPayload}; + +use serde_json; +use std::fmt; +use std::collections::HashMap; +use std::default::Default; +use url::{Url, UrlQuery, form_urlencoded::Serializer}; +use error::{self, Result}; +use hyper::{StatusCode}; +use reqwest::Response; + +#[derive(Debug, Clone, Copy, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub enum RequestOrder { Oldest, Newest, Index } + +header! { (XIfUnmodifiedSince, "X-If-Unmodified-Since") => [ServerTimestamp] } +header! { (XLastModified, "X-Last-Modified") => [ServerTimestamp] } +header! { (XWeaveTimestamp, "X-Weave-Timestamp") => [ServerTimestamp] } + +impl fmt::Display for RequestOrder { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + &RequestOrder::Oldest => f.write_str("oldest"), + &RequestOrder::Newest => f.write_str("newest"), + &RequestOrder::Index => f.write_str("index") + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct CollectionRequest { + pub collection: String, + pub full: bool, + pub ids: Option>, + pub limit: usize, + pub older: Option, + pub newer: Option, + pub order: Option, + pub commit: bool, + pub batch: Option, +} + +impl CollectionRequest { + #[inline] + pub fn new(collection: S) -> CollectionRequest where S: Into { + CollectionRequest { + collection: collection.into(), + full: false, + ids: None, + limit: 0, + older: None, + newer: None, + order: None, + commit: false, + batch: None, + } + } + + #[inline] + pub fn ids(&mut self, v: V) -> &mut CollectionRequest where V: Into> { + self.ids = Some(v.into()); + self + } + + #[inline] + pub fn full(&mut self) -> &mut CollectionRequest { + self.full = true; + self + } + + #[inline] + pub fn older_than(&mut self, ts: ServerTimestamp) -> &mut CollectionRequest { + self.older = Some(ts); + self + } + + #[inline] + pub fn newer_than(&mut self, ts: ServerTimestamp) -> &mut CollectionRequest { + self.newer = Some(ts); + self + } + + #[inline] + pub fn sort_by(&mut self, order: RequestOrder) -> &mut CollectionRequest { + self.order = Some(order); + self + } + + #[inline] + pub fn limit(&mut self, num: usize) -> &mut CollectionRequest { + self.limit = num; + self + } + + #[inline] + pub fn batch(&mut self, batch: Option) -> &mut CollectionRequest { + self.batch = batch; + self + } + + #[inline] + pub fn commit(&mut self, v: bool) -> &mut CollectionRequest { + self.commit = v; + self + } + + fn build_query(&self, pairs: &mut Serializer) { + if self.full { + pairs.append_pair("full", "1"); + } + if self.limit > 0 { + pairs.append_pair("limit", &format!("{}", self.limit)); + } + if let &Some(ref ids) = &self.ids { + pairs.append_pair("ids", &ids.join(",")); + } + if let &Some(ref batch) = &self.batch { + pairs.append_pair("batch", &batch); + } + if self.commit { + pairs.append_pair("commit", "true"); + } + if let Some(ts) = self.older { + pairs.append_pair("older", &format!("{}", ts)); + } + if let Some(ts) = self.newer { + pairs.append_pair("newer", &format!("{}", ts)); + } + if let Some(o) = self.order { + pairs.append_pair("sort", &format!("{}", o)); + } + pairs.finish(); + } + + pub fn build_url(&self, mut base_url: Url) -> Result { + base_url.path_segments_mut() + .map_err(|_| error::unexpected("Not base URL??"))? + .extend(&["storage", &self.collection]); + self.build_query(&mut base_url.query_pairs_mut()); + // This is strange but just accessing query_pairs_mut makes you have + // a trailing question mark on your url. I don't think anything bad + // would happen here, but I don't know, and also, it looks dumb so + // I'd rather not have it. + if base_url.query() == Some("") { + base_url.set_query(None); + } + Ok(base_url) + } +} + +/// Manages a pair of (byte, count) limits for a PostQueue, such as +/// (max_post_bytes, max_post_records) or (max_total_bytes, max_total_records). +#[derive(Debug, Clone)] +struct LimitTracker { + max_bytes: usize, + max_records: usize, + cur_bytes: usize, + cur_records: usize, +} + +impl LimitTracker { + pub fn new(max_bytes: usize, max_records: usize) -> LimitTracker { + LimitTracker { + max_bytes, + max_records, + cur_bytes: 0, + cur_records: 0 + } + } + + pub fn clear(&mut self) { + self.cur_records = 0; + self.cur_bytes = 0; + } + + pub fn can_add_record(&self, payload_size: usize) -> bool { + // Desktop does the cur_bytes check as exclusive, but we shouldn't see any servers that + // don't have https://github.com/mozilla-services/server-syncstorage/issues/73 + self.cur_records + 1 <= self.max_records && + self.cur_bytes + payload_size <= self.max_bytes + } + + pub fn can_never_add(&self, record_size: usize) -> bool { + record_size >= self.max_bytes + } + + pub fn record_added(&mut self, record_size: usize) { + assert!(self.can_add_record(record_size), + "LimitTracker::record_added caller must check can_add_record"); + self.cur_records += 1; + self.cur_bytes += record_size; + } +} + +#[derive(Deserialize, Debug, Clone)] +pub struct InfoConfiguration { + /// The maximum size in bytes of the overall HTTP request body that will be accepted by the + /// server. + #[serde(default = "default_max_request_bytes")] + pub max_request_bytes: usize, + + /// The maximum number of records that can be uploaded to a collection in a single POST request. + #[serde(default = "usize::max_value")] + pub max_post_records: usize, + + /// The maximum combined size in bytes of the record payloads that can be uploaded to a + /// collection in a single POST request. + #[serde(default = "usize::max_value")] + pub max_post_bytes: usize, + + /// The maximum total number of records that can be uploaded to a collection as part of a + /// batched upload. + #[serde(default = "usize::max_value")] + pub max_total_records: usize, + + /// The maximum total combined size in bytes of the record payloads that can be uploaded to a + /// collection as part of a batched upload. + #[serde(default = "usize::max_value")] + pub max_total_bytes: usize, + + /// The maximum size of an individual BSO payload, in bytes. + #[serde(default = "default_max_record_payload_bytes")] + pub max_record_payload_bytes: usize, +} + +// This is annoying but seems to be the only way to do it... +fn default_max_request_bytes() -> usize { 260 * 1024 } +fn default_max_record_payload_bytes() -> usize { 256 * 1024 } + +impl Default for InfoConfiguration { + #[inline] + fn default() -> InfoConfiguration { + InfoConfiguration { + max_request_bytes: default_max_request_bytes(), + max_record_payload_bytes: default_max_record_payload_bytes(), + max_post_records: usize::max_value(), + max_post_bytes: usize::max_value(), + max_total_records: usize::max_value(), + max_total_bytes: usize::max_value(), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct UploadResult { + batch: Option, + /// Maps record id => why failed + #[serde(default = "HashMap::new")] + pub failed: HashMap, + /// Vec of ids + #[serde(default = "Vec::new")] + pub success: Vec +} + +// Easier to fake during tests +#[derive(Debug, Clone)] +pub struct PostResponse { + pub status: StatusCode, + pub result: UploadResult, // This is lazy... + pub last_modified: ServerTimestamp, +} + +impl PostResponse { + pub fn from_response(r: &mut Response) -> Result { + let result: UploadResult = r.json()?; + // TODO Can this happen in error cases? + let last_modified = r.headers().get::().map(|h| **h).ok_or_else(|| + error::unexpected("Server didn't send X-Last-Modified header"))?; + let status = r.status(); + Ok(PostResponse { status, result, last_modified }) + } +} + + +#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub enum BatchState { + Unsupported, + NoBatch, + InBatch(String), +} + +#[derive(Debug)] +pub struct PostQueue { + poster: Post, + on_response: OnResponse, + post_limits: LimitTracker, + batch_limits: LimitTracker, + max_payload_bytes: usize, + max_request_bytes: usize, + queued: Vec, + batch: BatchState, + last_modified: ServerTimestamp, +} + +pub trait BatchPoster { + /// Note: Last argument (reference to the batch poster) is provided for the purposes of testing + /// Important: Poster should not report non-success HTTP statuses as errors!! + fn post(&self, + body: &[u8], + xius: ServerTimestamp, + batch: Option, + commit: bool, + queue: &PostQueue) -> Result; +} + +// We don't just use a FnMut here since we want to override it in mocking for RefCell, +// which we can't do for FnMut since neither FnMut nor RefCell are defined here. Also, this +// is somewhat better for documentation. +pub trait PostResponseHandler { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()>; +} + + +#[derive(Debug, Clone)] +pub(crate) struct NormalResponseHandler { + pub failed_ids: Vec, + pub successful_ids: Vec, + pub allow_failed: bool, + pub pending_failed: Vec, + pub pending_success: Vec, +} + +impl NormalResponseHandler { + pub fn new(allow_failed: bool) -> NormalResponseHandler { + NormalResponseHandler { + failed_ids: vec![], + successful_ids: vec![], + pending_failed: vec![], + pending_success: vec![], + allow_failed, + } + } +} + +impl PostResponseHandler for NormalResponseHandler { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> error::Result<()> { + if !r.status.is_success() { + warn!("Got failure status from server while posting: {}", r.status); + if r.status == StatusCode::PreconditionFailed { + bail!(error::ErrorKind::BatchInterrupted); + } else { + bail!(error::ErrorKind::StorageHttpError(r.status, + "collection storage (TODO: record route somewhere)".into())); + } + } + if r.result.failed.len() > 0 && !self.allow_failed { + bail!(error::ErrorKind::RecordUploadFailed(r.result.failed.clone())); + } + for id in r.result.success.iter() { + self.pending_success.push(id.clone()); + } + for kv in r.result.failed.iter() { + self.pending_failed.push(kv.0.clone()); + } + if !mid_batch { + self.successful_ids.append(&mut self.pending_success); + self.failed_ids.append(&mut self.pending_failed); + } + Ok(()) + } +} + +impl PostQueue +where + Poster: BatchPoster, + OnResponse: PostResponseHandler +{ + pub fn new(config: &InfoConfiguration, + ts: ServerTimestamp, + poster: Poster, + on_response: OnResponse) -> PostQueue { + PostQueue { + poster, + on_response, + last_modified: ts, + post_limits: LimitTracker::new(config.max_post_bytes, config.max_post_records), + batch_limits: LimitTracker::new(config.max_total_bytes, config.max_total_records), + batch: BatchState::NoBatch, + max_payload_bytes: config.max_record_payload_bytes, + max_request_bytes: config.max_request_bytes, + queued: Vec::new(), + } + } + + #[inline] + fn in_batch(&self) -> bool { + match &self.batch { + &BatchState::Unsupported | + &BatchState::NoBatch => false, + _ => true + } + } + + pub fn enqueue(&mut self, record: &BsoRecord) -> Result { + let payload_length = record.payload.serialized_len(); + + if self.post_limits.can_never_add(payload_length) || + self.batch_limits.can_never_add(payload_length) || + payload_length >= self.max_payload_bytes { + warn!("Single record too large to submit to server ({} b)", payload_length); + return Ok(false); + } + + // Write directly into `queued` but undo if necessary (the vast majority of the time + // it won't be necessary). If we hit a problem we need to undo that, but the only error + // case we have to worry about right now is in flush() + let item_start = self.queued.len(); + + // This is conservative but can't hurt. + self.queued.reserve(payload_length + 2); + + // Either the first character in an array, or a comma separating + // it from the previous item. + let c = if self.queued.is_empty() { b'[' } else { b',' }; + self.queued.push(c); + + // This unwrap is fine, since serde_json's failure case is HashMaps that have non-object + // keys, which is impossible. If you decide to change this part, you *need* to call + // `self.queued.truncate(item_start)` here in the failure case! + serde_json::to_writer(&mut self.queued, &record).unwrap(); + + let item_end = self.queued.len(); + + debug_assert!(item_end >= payload_length, + "EncryptedPayload::serialized_len is bugged"); + + // The + 1 is only relevant for the final record, which will have a trailing ']'. + let item_len = item_end - item_start + 1; + + if item_len >= self.max_request_bytes { + self.queued.truncate(item_start); + warn!("Single record too large to submit to server ({} b)", item_len); + return Ok(false); + } + + let can_post_record = self.post_limits.can_add_record(payload_length); + let can_batch_record = self.batch_limits.can_add_record(payload_length); + let can_send_record = self.queued.len() < self.max_request_bytes; + + if !can_post_record || !can_send_record || !can_batch_record { + debug!("PostQueue flushing! (can_post = {}, can_send = {}, can_batch = {})", + can_post_record, can_send_record, can_batch_record); + // "unwrite" the record. + self.queued.truncate(item_start); + // Flush whatever we have queued. + self.flush(!can_batch_record)?; + // And write it again. + let c = if self.queued.is_empty() { b'[' } else { b',' }; + self.queued.push(c); + serde_json::to_writer(&mut self.queued, &record).unwrap(); + } + + self.post_limits.record_added(payload_length); + self.batch_limits.record_added(payload_length); + + Ok(true) + } + + pub fn flush(&mut self, want_commit: bool) -> Result<()> { + if self.queued.len() == 0 { + assert!(!self.in_batch(), + "Bug: Somehow we're in a batch but have no queued records"); + // Nothing to do! + return Ok(()); + } + + self.queued.push(b']'); + let batch_id = match &self.batch { + // Not the first post and we know we have no batch semantics. + &BatchState::Unsupported => None, + // First commit in possible batch + &BatchState::NoBatch => Some("true".into()), + // In a batch and we have a batch id. + &BatchState::InBatch(ref s) => Some(s.clone()) + }; + + info!("Posting {} records of {} bytes", self.post_limits.cur_records, self.queued.len()); + + let is_commit = want_commit && !batch_id.is_none(); + // Weird syntax for calling a function object that is a property. + let resp_or_error = self.poster.post(&self.queued, + self.last_modified, + batch_id, + is_commit, + self); + + self.queued.truncate(0); + + if want_commit || self.batch == BatchState::Unsupported { + self.batch_limits.clear(); + } + self.post_limits.clear(); + + let resp = resp_or_error?; + + if !resp.status.is_success() { + self.on_response.handle_response(resp, !want_commit)?; + bail!(error::unexpected("Expected OnResponse to have bailed out!")); + } + + if want_commit { + debug!("Committed batch {:?}", self.batch); + self.batch = BatchState::NoBatch; + self.last_modified = resp.last_modified; + self.on_response.handle_response(resp, false)?; + return Ok(()); + } + + if resp.status != StatusCode::Accepted { + if self.in_batch() { + bail!(error::unexpected( + "Server responded non-202 success code while a batch was in progress")); + } + self.last_modified = resp.last_modified; + self.batch = BatchState::Unsupported; + self.batch_limits.clear(); + self.on_response.handle_response(resp, false)?; + return Ok(()); + } + + let batch_id = resp.result.batch.as_ref().ok_or_else(|| + error::unexpected("Invalid server response: 202 without a batch ID"))?.clone(); + + match &self.batch { + &BatchState::Unsupported => { + warn!("Server changed it's mind about supporting batching mid-batch..."); + }, + + &BatchState::InBatch(ref cur_id) => { + if cur_id != &batch_id { + bail!(error::unexpected("Server changed batch id mid-batch!")); + } + }, + _ => {} + } + + // Can't change this in match arms without NLL + self.batch = BatchState::InBatch(batch_id); + self.last_modified = resp.last_modified; + + self.on_response.handle_response(resp, true)?; + + Ok(()) + } +} + +impl PostQueue { + pub(crate) fn successful_and_failed_ids(&mut self) -> (Vec, Vec) { + let mut good = Vec::with_capacity(self.on_response.successful_ids.len()); + // includes pending_success since they weren't committed! + let mut bad = Vec::with_capacity(self.on_response.failed_ids.len() + + self.on_response.pending_failed.len() + + self.on_response.pending_success.len()); + good.append(&mut self.on_response.successful_ids); + + bad.append(&mut self.on_response.failed_ids); + bad.append(&mut self.on_response.pending_failed); + bad.append(&mut self.on_response.pending_success); + + (good, bad) + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::collections::VecDeque; + use std::cell::RefCell; + use std::rc::Rc; + #[test] + fn test_url_building() { + let base = Url::parse("https://example.com/sync").unwrap(); + let empty = CollectionRequest::new("foo").build_url(base.clone()).unwrap(); + assert_eq!(empty.as_str(), "https://example.com/sync/storage/foo"); + let batch_start = CollectionRequest::new("bar").batch(Some("true".into())).commit(false) + .build_url(base.clone()).unwrap(); + assert_eq!(batch_start.as_str(), "https://example.com/sync/storage/bar?batch=true"); + let batch_commit = CollectionRequest::new("asdf").batch(Some("1234abc".into())).commit(true) + .build_url(base.clone()) + .unwrap(); + assert_eq!(batch_commit.as_str(), + "https://example.com/sync/storage/asdf?batch=1234abc&commit=true"); + + let idreq = CollectionRequest::new("wutang").full().ids(vec!["rza".into(), "gza".into()]) + .build_url(base.clone()).unwrap(); + assert_eq!(idreq.as_str(), "https://example.com/sync/storage/wutang?full=1&ids=rza%2Cgza"); + + let complex = CollectionRequest::new("specific").full().limit(10).sort_by(RequestOrder::Oldest) + .older_than(ServerTimestamp(9876.54)) + .newer_than(ServerTimestamp(1234.56)) + .build_url(base.clone()).unwrap(); + assert_eq!(complex.as_str(), + "https://example.com/sync/storage/specific?full=1&limit=10&older=9876.54&newer=1234.56&sort=oldest"); + + } + + #[derive(Debug, Clone)] + struct PostedData { + body: String, + xius: ServerTimestamp, + batch: Option, + commit: bool, + payload_bytes: usize, + records: usize + } + + impl PostedData { + fn records_as_json(&self) -> Vec { + let values = serde_json::from_str::(&self.body).expect("Posted invalid json"); + // Check that they actually deserialize as what we want + let records_or_err = serde_json::from_value::>>(values.clone()); + records_or_err.expect("Failed to deserialize data"); + serde_json::from_value(values).unwrap() + } + } + + + #[derive(Debug, Clone)] + struct BatchInfo { + id: Option, + posts: Vec, + bytes: usize, + records: usize, + } + + #[derive(Debug, Clone)] + struct TestPoster { + all_posts: Vec, + responses: VecDeque, + batches: Vec, + cur_batch: Option, + cfg: InfoConfiguration, + } + + type TestPosterRef = Rc>; + impl TestPoster { + pub fn new(cfg: &InfoConfiguration, responses: T) -> TestPosterRef + where T: Into> { + Rc::new(RefCell::new(TestPoster { + all_posts: vec![], + responses: responses.into(), + batches: vec![], + cur_batch: None, + cfg: cfg.clone(), + })) + } + // Adds &mut + fn do_post( + &mut self, + body: &[u8], + xius: ServerTimestamp, + batch: Option, + commit: bool, + queue: &PostQueue + ) -> Result { + + let mut post = PostedData { + body: String::from_utf8(body.into()).expect("Posted invalid utf8..."), + batch: batch.clone(), + xius, + commit, + payload_bytes: 0, + records: 0, + }; + + assert!(body.len() <= self.cfg.max_request_bytes); + + let (num_records, record_payload_bytes) = { + let recs = post.records_as_json(); + assert!(recs.len() <= self.cfg.max_post_records); + assert!(recs.len() <= self.cfg.max_total_records); + let payload_bytes: usize = recs.iter().map(|r| { + let len = r["payload"].as_str().expect("Non string payload property").len(); + assert!(len <= self.cfg.max_record_payload_bytes); + len + }).sum(); + assert!(payload_bytes <= self.cfg.max_post_bytes); + assert!(payload_bytes <= self.cfg.max_total_bytes); + + assert_eq!(queue.post_limits.cur_bytes, payload_bytes); + assert_eq!(queue.post_limits.cur_records, recs.len()); + (recs.len(), payload_bytes) + }; + post.payload_bytes = record_payload_bytes; + post.records = num_records; + + self.all_posts.push(post.clone()); + let response = self.responses.pop_front().unwrap(); + + if self.cur_batch.is_none() { + assert!(batch.is_none() || batch == Some("true".into()), + "We shouldn't be in a batch now"); + self.cur_batch = Some(BatchInfo { + id: response.result.batch.clone(), + posts: vec![], + records: 0, + bytes: 0, + }); + } else { + assert_eq!(batch, self.cur_batch.as_ref().unwrap().id, + "We're in a batch but got the wrong batch id"); + } + + { + let batch = self.cur_batch.as_mut().unwrap(); + batch.posts.push(post.clone()); + batch.records += num_records; + batch.bytes += record_payload_bytes; + + assert!(batch.bytes <= self.cfg.max_total_bytes); + assert!(batch.records <= self.cfg.max_total_records); + + assert_eq!(batch.records, queue.batch_limits.cur_records); + assert_eq!(batch.bytes, queue.batch_limits.cur_bytes); + } + + + if commit || response.result.batch.is_none() { + let batch = self.cur_batch.take().unwrap(); + self.batches.push(batch); + } + + Ok(response) + } + + fn do_handle_response(&mut self, _: PostResponse, mid_batch: bool) -> Result<()> { + assert_eq!(mid_batch, self.cur_batch.is_some()); + Ok(()) + } + } + impl BatchPoster for TestPosterRef { + fn post(&self, + body: &[u8], + xius: ServerTimestamp, + batch: Option, + commit: bool, + queue: &PostQueue) -> Result { + self.borrow_mut().do_post(body, xius, batch, commit, queue) + } + } + + impl PostResponseHandler for TestPosterRef { + fn handle_response(&mut self, r: PostResponse, mid_batch: bool) -> Result<()> { + self.borrow_mut().do_handle_response(r, mid_batch) + } + } + + type MockedPostQueue = PostQueue; + + fn pq_test_setup(cfg: InfoConfiguration, lm: f64, resps: Vec) -> (MockedPostQueue, TestPosterRef) { + let tester = TestPoster::new(&cfg, resps); + let pq = PostQueue::new(&cfg, ServerTimestamp(lm), tester.clone(), tester.clone()); + (pq, tester) + } + + fn fake_response<'a, T: Into>>(status: StatusCode, lm: f64, batch: T) -> PostResponse { + PostResponse { + status, + last_modified: ServerTimestamp(lm), + result: UploadResult { + batch: batch.into().map(|x| x.into()), + failed: HashMap::new(), + success: vec![], + } + } + } + + lazy_static! { + // ~40b + static ref PAYLOAD_OVERHEAD: usize = { + let payload = EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "".into() + }; + serde_json::to_string(&payload).unwrap().len() + }; + // ~80b + static ref TOTAL_RECORD_OVERHEAD: usize = { + let val = serde_json::to_value(BsoRecord { + id: "".into(), + collection: "".into(), + modified: ServerTimestamp(0.0), + sortindex: None, + ttl: None, + payload: EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "".into() + }, + }).unwrap(); + serde_json::to_string(&val).unwrap().len() + }; + // There's some subtlety in how we calulate this having to do with the fact that + // the quotes in the payload are escaped but the escape chars count to the request len + // and *not* to the payload len (the payload len check happens after json parsing the + // top level object). + static ref NON_PAYLOAD_OVERHEAD: usize = { + *TOTAL_RECORD_OVERHEAD - *PAYLOAD_OVERHEAD + }; + } + + // Actual record size (for max_request_len) will be larger by some amount + fn make_record(payload_size: usize) -> BsoRecord { + assert!(payload_size > *PAYLOAD_OVERHEAD); + let ciphertext_len = payload_size - *PAYLOAD_OVERHEAD; + BsoRecord { + id: "".into(), + collection: "".into(), + modified: ServerTimestamp(0.0), + sortindex: None, + ttl: None, + payload: EncryptedPayload { + iv: "".into(), + hmac: "".into(), + ciphertext: "x".repeat(ciphertext_len) + } + } + } + + fn request_bytes_for_payloads(payloads: &[usize]) -> usize { + 1 + payloads.iter().map(|&size| size + 1 + *NON_PAYLOAD_OVERHEAD).sum::() + } + + #[test] + fn test_pq_basic() { + let cfg = InfoConfiguration { + max_request_bytes: 1000, + max_record_payload_bytes: 1000, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Ok, time + 100.0, None), + ]); + + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 1); + assert_eq!(t.batches[0].bytes, 100); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100])); + } + + #[test] + fn test_pq_max_request_bytes_no_batch() { + let cfg = InfoConfiguration { + max_request_bytes: 250, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Ok, time + 100.0, None), + fake_response(StatusCode::Ok, time + 200.0, None), + ]); + + // Note that the total record overhead is around 85 bytes + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r] + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 203; [r,r] + pq.enqueue(&make_record(payload_size)).unwrap(); // too big, 2nd post. + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 2); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 2); + assert_eq!(t.batches[0].bytes, payload_size * 2); + assert_eq!(t.batches[0].posts[0].batch, Some("true".into())); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size])); + + assert_eq!(t.batches[1].posts.len(), 1); + assert_eq!(t.batches[1].records, 1); + assert_eq!(t.batches[1].bytes, payload_size); + // We know at this point that the server does not support batching. + assert_eq!(t.batches[1].posts[0].batch, None); + assert_eq!(t.batches[1].posts[0].commit, false); + assert_eq!(t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size])); + } + + #[test] + fn test_pq_max_record_payload_bytes_no_batch() { + let cfg = InfoConfiguration { + max_record_payload_bytes: 150, + max_request_bytes: 350, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Ok, time + 100.0, None), + fake_response(StatusCode::Ok, time + 200.0, None), + ]); + + // Note that the total record overhead is around 85 bytes + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); // total size == 102; [r] + let enqueued = pq.enqueue(&make_record(151)).unwrap(); // still 102 + assert!(!enqueued, "Should not have fit"); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 2); + assert_eq!(t.batches[0].bytes, payload_size * 2); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size])); + } + + #[test] + fn test_pq_single_batch() { + let cfg = InfoConfiguration::default(); + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Accepted, time, Some("1234")), + ]); + + let payload_size = 100 - *NON_PAYLOAD_OVERHEAD; + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.enqueue(&make_record(payload_size)).unwrap(); + pq.flush(true).unwrap(); + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 1); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].id.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts.len(), 1); + assert_eq!(t.batches[0].records, 3); + assert_eq!(t.batches[0].bytes, payload_size * 3); + assert_eq!(t.batches[0].posts[0].commit, true); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[payload_size, payload_size, payload_size])); + } + + #[test] + fn test_pq_multi_post_batch_bytes() { + let cfg = InfoConfiguration { + max_post_bytes: 200, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("1234")), + ]); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 2); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[0].records, 3); + assert_eq!(t.batches[0].bytes, 300); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 2); + assert_eq!(t.batches[0].posts[0].payload_bytes, 200); + assert_eq!(t.batches[0].posts[0].commit, false); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100])); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 1); + assert_eq!(t.batches[0].posts[1].payload_bytes, 100); + assert_eq!(t.batches[0].posts[1].commit, true); + assert_eq!(t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100])); + } + + + #[test] + fn test_pq_multi_post_batch_records() { + let cfg = InfoConfiguration { + max_post_records: 3, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("1234")), + ]); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 3); + assert_eq!(t.batches.len(), 1); + assert_eq!(t.batches[0].posts.len(), 3); + assert_eq!(t.batches[0].records, 7); + assert_eq!(t.batches[0].bytes, 700); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert_eq!(t.batches[0].posts[0].commit, false); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 3); + assert_eq!(t.batches[0].posts[1].payload_bytes, 300); + assert_eq!(t.batches[0].posts[1].commit, false); + assert_eq!(t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[0].posts[2].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[2].records, 1); + assert_eq!(t.batches[0].posts[2].payload_bytes, 100); + assert_eq!(t.batches[0].posts[2].commit, true); + assert_eq!(t.batches[0].posts[2].body.len(), + request_bytes_for_payloads(&[100])); + } + + #[test] + fn test_pq_multi_post_multi_batch_records() { + let cfg = InfoConfiguration { + max_post_records: 3, + max_total_records: 5, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("abcd")), + fake_response(StatusCode::Accepted, time, Some("abcd")), + ]); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + COMMIT + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 4); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[1].posts.len(), 2); + + assert_eq!(t.batches[0].records, 5); + assert_eq!(t.batches[1].records, 4); + + assert_eq!(t.batches[0].bytes, 500); + assert_eq!(t.batches[1].bytes, 400); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert_eq!(t.batches[0].posts[0].commit, false); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 2); + assert_eq!(t.batches[0].posts[1].payload_bytes, 200); + assert_eq!(t.batches[0].posts[1].commit, true); + assert_eq!(t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100])); + + + assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[1].posts[0].records, 3); + assert_eq!(t.batches[1].posts[0].payload_bytes, 300); + assert_eq!(t.batches[1].posts[0].commit, false); + assert_eq!(t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd"); + assert_eq!(t.batches[1].posts[1].records, 1); + assert_eq!(t.batches[1].posts[1].payload_bytes, 100); + assert_eq!(t.batches[1].posts[1].commit, true); + assert_eq!(t.batches[1].posts[1].body.len(), + request_bytes_for_payloads(&[100])); + } + + #[test] + fn test_pq_multi_post_multi_batch_bytes() { + let cfg = InfoConfiguration { + max_post_bytes: 300, + max_total_bytes: 500, + ..InfoConfiguration::default() + }; + let time = 11111111.0; + let (mut pq, tester) = pq_test_setup(cfg, time, vec![ + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("1234")), + fake_response(StatusCode::Accepted, time, Some("abcd")), + fake_response(StatusCode::Accepted, time, Some("abcd")), + ]); + + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + COMMIT + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + pq.enqueue(&make_record(100)).unwrap(); + // POST + pq.enqueue(&make_record(100)).unwrap(); + pq.flush(true).unwrap(); // COMMIT + + let t = tester.borrow(); + assert!(t.cur_batch.is_none()); + assert_eq!(t.all_posts.len(), 4); + assert_eq!(t.batches.len(), 2); + assert_eq!(t.batches[0].posts.len(), 2); + assert_eq!(t.batches[1].posts.len(), 2); + + assert_eq!(t.batches[0].records, 5); + assert_eq!(t.batches[1].records, 4); + + assert_eq!(t.batches[0].bytes, 500); + assert_eq!(t.batches[1].bytes, 400); + + assert_eq!(t.batches[0].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[0].posts[0].records, 3); + assert_eq!(t.batches[0].posts[0].payload_bytes, 300); + assert_eq!(t.batches[0].posts[0].commit, false); + assert_eq!(t.batches[0].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[0].posts[1].batch.as_ref().unwrap(), "1234"); + assert_eq!(t.batches[0].posts[1].records, 2); + assert_eq!(t.batches[0].posts[1].payload_bytes, 200); + assert_eq!(t.batches[0].posts[1].commit, true); + assert_eq!(t.batches[0].posts[1].body.len(), + request_bytes_for_payloads(&[100, 100])); + + + assert_eq!(t.batches[1].posts[0].batch.as_ref().unwrap(), "true"); + assert_eq!(t.batches[1].posts[0].records, 3); + assert_eq!(t.batches[1].posts[0].payload_bytes, 300); + assert_eq!(t.batches[1].posts[0].commit, false); + assert_eq!(t.batches[1].posts[0].body.len(), + request_bytes_for_payloads(&[100, 100, 100])); + + assert_eq!(t.batches[1].posts[1].batch.as_ref().unwrap(), "abcd"); + assert_eq!(t.batches[1].posts[1].records, 1); + assert_eq!(t.batches[1].posts[1].payload_bytes, 100); + assert_eq!(t.batches[1].posts[1].commit, true); + assert_eq!(t.batches[1].posts[1].body.len(), + request_bytes_for_payloads(&[100])); + } + + // TODO: Test + // + // - error cases!!! We don't test our handling of server errors at all! + // - mixed bytes/record limits + // + // A lot of these have good examples in test_postqueue.js on deskftop sync + +} diff --git a/sync15-adapter/src/service.rs b/sync15-adapter/src/service.rs new file mode 100644 index 0000000000..02ac7cd04b --- /dev/null +++ b/sync15-adapter/src/service.rs @@ -0,0 +1,306 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + + +use std::cell::Cell; +use std::time::{Duration}; +use std::collections::{HashMap, HashSet}; + +use reqwest::{ + Client, + Request, + Response, + Url, + header::{self, Accept} +}; +use hyper::{Method, StatusCode}; +use serde; + +use util::{ServerTimestamp, SERVER_EPOCH}; +use token; +use error; +use key_bundle::KeyBundle; +use bso_record::{BsoRecord, Sync15Record, EncryptedPayload}; +use tombstone::{MaybeTombstone, NonTombstone}; +use record_types::MetaGlobalRecord; +use collection_keys::CollectionKeys; +use request::{ + CollectionRequest, + InfoConfiguration, + XWeaveTimestamp, + XIfUnmodifiedSince, + PostResponse, + BatchPoster, + PostQueue, + PostResponseHandler, + NormalResponseHandler, +}; + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct Sync15ServiceInit { + pub key_id: String, + pub access_token: String, + pub sync_key: String, + pub tokenserver_base_url: String, +} + +#[derive(Debug)] +pub struct Sync15Service { + init_params: Sync15ServiceInit, + root_key: KeyBundle, + client: Client, + // We update this when we make requests + last_server_time: Cell, + tsc: token::TokenserverClient, + + keys: Option, + server_config: Option, + last_sync_remote: HashMap, +} + +impl Sync15Service { + pub fn new(init_params: Sync15ServiceInit) -> error::Result { + let root_key = KeyBundle::from_ksync_base64(&init_params.sync_key)?; + let client = Client::builder().timeout(Duration::from_secs(30)).build()?; + // TODO Should we be doing this here? What if we get backoff? Who handles that? + let tsc = token::TokenserverClient::new(&client, + &init_params.tokenserver_base_url, + init_params.access_token.clone(), + init_params.key_id.clone())?; + let timestamp = tsc.server_timestamp(); + Ok(Sync15Service { + client, + init_params, + root_key, + tsc, + last_server_time: Cell::new(timestamp), + keys: None, + server_config: None, + last_sync_remote: HashMap::new(), + }) + } + + #[inline] + fn authorized(&self, mut req: Request) -> error::Result { + let header = self.tsc.authorization(&req)?; + req.headers_mut().set(header); + Ok(req) + } + + // TODO: probably want a builder-like API to do collection requests (e.g. something + // that occupies roughly the same conceptual role as the Collection class in desktop) + fn build_request(&self, method: Method, url: Url) -> error::Result { + self.authorized(self.client.request(method, url).header(Accept::json()).build()?) + } + + fn relative_storage_request(&self, method: Method, relative_path: T) -> error::Result where T: AsRef { + let s = self.tsc.token().api_endpoint.clone() + "/"; + let url = Url::parse(&s)?.join(relative_path.as_ref())?; + Ok(self.make_storage_request(method, url)?) + } + + fn make_storage_request(&self, method: Method, url: Url) -> error::Result { + // I'm shocked that method isn't Copy... + Ok(self.exec_request(self.build_request(method.clone(), url)?, true)?) + } + + fn exec_request(&self, req: Request, require_success: bool) -> error::Result { + let resp = self.client.execute(req)?; + + self.update_timestamp(resp.headers()); + + if require_success && !resp.status().is_success() { + error!("HTTP error {} ({}) during storage request to {}", + resp.status().as_u16(), resp.status(), resp.url().path()); + bail!(error::ErrorKind::StorageHttpError( + resp.status(), resp.url().path().into())); + } + + // TODO: + // - handle backoff + // - x-weave-quota? + // - ... almost certainly other things too... + + Ok(resp) + } + + fn collection_request(&self, method: Method, r: &CollectionRequest) -> error::Result { + self.make_storage_request(method.clone(), + r.build_url(Url::parse(&self.tsc.token().api_endpoint)?)?) + } + + fn fetch_info(&self, path: &str) -> error::Result where for <'a> T: serde::de::Deserialize<'a> { + let mut resp = self.relative_storage_request(Method::Get, path)?; + let result: T = resp.json()?; + Ok(result) + } + + pub fn remote_setup(&mut self) -> error::Result<()> { + let server_config = self.fetch_info::("info/configuration")?; + self.server_config = Some(server_config); + let mut resp = match self.relative_storage_request(Method::Get, "storage/meta/global") { + Ok(r) => r, + // This is gross, but at least it works. Replace 404s on meta/global with NoMetaGlobal. + Err(error::Error(error::ErrorKind::StorageHttpError(StatusCode::NotFound, ..), _)) => + bail!(error::ErrorKind::NoMetaGlobal), + Err(e) => return Err(e), + }; + // Note: meta/global is not encrypted! + let meta_global: BsoRecord = resp.json()?; + info!("Meta global: {:?}", meta_global.payload); + let collections = self.fetch_info::>("info/collections")?; + self.update_keys(&collections)?; + self.last_sync_remote = collections; + Ok(()) + } + + fn update_keys(&mut self, _info_collections: &HashMap) -> error::Result<()> { + // TODO: if info/collections says we should, upload keys. + // TODO: This should be handled in collection_keys.rs, which should track modified time, etc. + let mut keys_resp = self.relative_storage_request(Method::Get, "storage/crypto/keys")?; + let keys: BsoRecord = keys_resp.json()?; + self.keys = Some(CollectionKeys::from_encrypted_bso(keys, &self.root_key)?); + // TODO: error handling... key upload? + Ok(()) + } + + pub fn key_for_collection(&self, collection: &str) -> error::Result<&KeyBundle> { + Ok(self.keys.as_ref() + .ok_or_else(|| error::unexpected("Don't have keys (yet?)"))? + .key_for_collection(collection)) + } + + pub fn all_records(&mut self, collection: &str) -> + error::Result>>> where T: Sync15Record { + let key = self.key_for_collection(collection)?; + let mut resp = self.collection_request(Method::Get, CollectionRequest::new(collection).full())?; + let records: Vec> = resp.json()?; + records.into_iter() + .map(|record| record.decrypt::>(key)) + .collect() + } + + fn update_timestamp(&self, hs: &header::Headers) { + if let Some(ts) = hs.get::().map(|h| **h) { + self.last_server_time.set(ts); + } else { + // Should we complain more here? + warn!("No X-Weave-Timestamp from storage server!"); + } + } + + pub fn last_modified(&self, coll: &str) -> Option { + self.last_sync_remote.get(coll).cloned() + } + + pub fn last_modified_or_zero(&self, coll: &str) -> ServerTimestamp { + self.last_modified(coll).unwrap_or(SERVER_EPOCH) + } + + fn new_post_queue<'a, F: PostResponseHandler>(&'a self, coll: &str, lm: Option, on_response: F) + -> error::Result, F>> { + let ts = lm.unwrap_or_else(|| self.last_modified_or_zero(&coll)); + let pw = PostWrapper { svc: self, coll: coll.into() }; + Ok(PostQueue::new(self.server_config.as_ref().unwrap(), ts, pw, on_response)) + } +} + +struct PostWrapper<'a> { + svc: &'a Sync15Service, + coll: String, +} + +impl<'a> BatchPoster for PostWrapper<'a> { + fn post(&self, + bytes: &[u8], + xius: ServerTimestamp, + batch: Option, + commit: bool, + _: &PostQueue) -> error::Result + { + let url = CollectionRequest::new(self.coll.clone()) + .batch(batch) + .commit(commit) + .build_url(Url::parse(&self.svc.tsc.token().api_endpoint)?)?; + + let mut req = self.svc.build_request(Method::Post, url)?; + req.headers_mut().set(header::ContentType::json()); + req.headers_mut().set(XIfUnmodifiedSince(xius)); + // It's very annoying that we need to copy the body here, the request + // shouldn't need to take ownership of it... + *req.body_mut() = Some(Vec::from(bytes).into()); + let mut resp = self.svc.exec_request(req, false)?; + Ok(PostResponse::from_response(&mut resp)?) + } +} + +#[derive(Debug, Clone)] +pub struct CollectionUpdate<'a, T> { + svc: &'a Sync15Service, + last_sync: ServerTimestamp, + to_update: Vec>, + allow_dropped_records: bool, + queued_ids: HashSet +} + +impl<'a, T> CollectionUpdate<'a, T> where T: Sync15Record { + pub fn new(svc: &'a Sync15Service, allow_dropped_records: bool) -> CollectionUpdate<'a, T> { + let coll = T::collection_tag(); + let ts = svc.last_modified_or_zero(coll); + CollectionUpdate { + svc, + last_sync: ts, + to_update: vec![], + allow_dropped_records, + queued_ids: HashSet::new(), + } + } + + pub fn add(&mut self, rec_or_tombstone: MaybeTombstone) { + // Block to limit scope of the `id` borrow. + { + let id = rec_or_tombstone.record_id(); + // Should this be an Err and not an assertion? + assert!(!self.queued_ids.contains(id), + "Attempt to update ID multiple times in the same batch {}", id); + self.queued_ids.insert(id.into()); + } + self.to_update.push(rec_or_tombstone); + } + + pub fn add_record(&mut self, record: T) { + self.add(NonTombstone(record)); + } + + pub fn add_tombstone(&mut self, id: String) { + self.add(MaybeTombstone::tombstone(id)); + } + + /// Returns a list of the IDs that failed if allowed_dropped_records is true, otherwise + /// returns an empty vec. + pub fn upload(self) -> error::Result<(Vec, Vec)> { + let mut failed = vec![]; + let key = self.svc.key_for_collection(T::collection_tag())?; + let mut q = self.svc.new_post_queue(T::collection_tag(), Some(self.last_sync), + NormalResponseHandler::new(self.allow_dropped_records))?; + + for record in self.to_update.into_iter() { + let record_cleartext: BsoRecord> = record.into(); + let encrypted = record_cleartext.encrypt(key)?; + let enqueued = q.enqueue(&encrypted)?; + if !enqueued && !self.allow_dropped_records { + bail!(error::ErrorKind::RecordTooLargeError); + } + } + + q.flush(true)?; + let (successful_ids, mut failed_ids) = q.successful_and_failed_ids(); + failed_ids.append(&mut failed); + if !self.allow_dropped_records { + assert_eq!(failed_ids.len(), 0, + "Bug: Should have failed by now if we aren't allowing dropped records"); + } + Ok((successful_ids, failed_ids)) + } +} diff --git a/sync15-adapter/src/token.rs b/sync15-adapter/src/token.rs new file mode 100644 index 0000000000..d368884d62 --- /dev/null +++ b/sync15-adapter/src/token.rs @@ -0,0 +1,133 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use hawk; + +use reqwest::{Client, Request, Url}; +use hyper::header::{Authorization, Bearer}; +use error::{self, Result}; +use std::fmt; +use std::borrow::{Borrow, Cow}; +use util::ServerTimestamp; + +/// Tokenserver's timestamp is X-Timestamp and not X-Weave-Timestamp. +header! { (RetryAfter, "Retry-After") => [f64] } + +/// Tokenserver's timestamp is X-Timestamp and not X-Weave-Timestamp. The value is in seconds. +header! { (XTimestamp, "X-Timestamp") => [ServerTimestamp] } + +/// OAuth tokenserver api uses this instead of X-Client-State. +header! { (XKeyID, "X-KeyID") => [String] } + +#[derive(Deserialize, Clone, Debug, PartialEq, Eq)] +pub struct TokenserverToken { + pub id: String, + pub key: String, + pub api_endpoint: String, + pub uid: u64, + pub duration: u64, + // This is treated as optional by at least the desktop client, + // but AFAICT it's always present. + pub hashed_fxa_uid: String, +} + +/// This is really more of a TokenAuthenticator. +pub struct TokenserverClient { + token: TokenserverToken, + server_timestamp: ServerTimestamp, + credentials: hawk::Credentials, +} + +// hawk::Credentials doesn't implement debug -_- +impl fmt::Debug for TokenserverClient { + fn fmt(&self, f: &mut fmt::Formatter) -> ::std::result::Result<(), fmt::Error> { + f.debug_struct("TokenserverClient") + .field("token", &self.token) + .field("server_timestamp", &self.server_timestamp) + .field("credentials", &"(omitted)") + .finish() + } +} + +fn token_url(base_url: &str) -> Result { + let mut url = Url::parse(base_url)?; + // kind of gross but avoids problems if base_url has a trailing slash. + url.path_segments_mut() + // We can't do anything anyway if this is the case. + .map_err(|_| error::unexpected("Bad tokenserver url (cannot be base)"))? + .extend(&["1.0", "sync", "1.5"]); + Ok(url) +} + +impl TokenserverClient { + #[inline] + pub fn server_timestamp(&self) -> ServerTimestamp { + self.server_timestamp + } + + #[inline] + pub fn token(&self) -> &TokenserverToken { + &self.token + } + + pub fn new(request_client: &Client, base_url: &str, access_token: String, key_id: String) -> Result { + let mut resp = request_client.get(token_url(base_url)?) + .header(Authorization(Bearer { token: access_token })) + .header(XKeyID(key_id)) + .send()?; + + if !resp.status().is_success() { + warn!("Non-success status when fetching token: {}", resp.status()); + // TODO: the body should be JSON and contain a status parameter we might need? + debug!(" Response body {}", resp.text().unwrap_or("???".into())); + if let Some(seconds) = resp.headers().get::().map(|h| **h) { + bail!(error::ErrorKind::BackoffError(seconds)); + } + bail!(error::ErrorKind::TokenserverHttpError(resp.status())); + } + + let token: TokenserverToken = resp.json()?; + + let timestamp = resp.headers() + .get::() + .map(|h| **h) + .ok_or_else(|| error::unexpected( + "Missing or corrupted X-Timestamp header from token server"))?; + let credentials = hawk::Credentials { + id: token.id.clone(), + key: hawk::Key::new(token.key.as_bytes(), hawk::Digest::sha256())?, + }; + Ok(TokenserverClient { + token, + credentials, + server_timestamp: timestamp + }) + } + + pub fn authorization(&self, req: &Request) -> Result> { + let url = req.url(); + + let path_and_query = match url.query() { + None => Cow::from(url.path()), + Some(qs) => Cow::from(format!("{}?{}", url.path(), qs)) + }; + + let host = url.host_str().ok_or_else(|| + error::unexpected("Tried to authorize bad URL using hawk (no host)"))?; + + // Known defaults exist for https? (among others), so this should be impossible + let port = url.port_or_known_default().ok_or_else(|| + error::unexpected( + "Tried to authorize bad URL using hawk (no port -- unknown protocol?)"))?; + + let header = hawk::RequestBuilder::new( + req.method().as_ref(), + host, + port, + path_and_query.borrow() + ).request().make_header(&self.credentials)?; + + Ok(Authorization(format!("Hawk {}", header))) + } +} diff --git a/sync15-adapter/src/tombstone.rs b/sync15-adapter/src/tombstone.rs new file mode 100644 index 0000000000..d15a027da5 --- /dev/null +++ b/sync15-adapter/src/tombstone.rs @@ -0,0 +1,179 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use bso_record::{BsoRecord, Sync15Record}; + +pub use MaybeTombstone::*; + +#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Eq)] +#[serde(untagged)] +pub enum MaybeTombstone { + Tombstone { id: String, deleted: bool }, + NonTombstone(T) +} + +impl MaybeTombstone { + #[inline] + pub fn tombstone>(id: R) -> MaybeTombstone { + Tombstone { id: id.into(), deleted: true } + } + + #[inline] + pub fn is_tombstone(&self) -> bool { + match self { + &NonTombstone(_) => false, + _ => true + } + } + + #[inline] + pub fn unwrap(self) -> T { + match self { + NonTombstone(record) => record, + _ => panic!("called `MaybeTombstone::unwrap()` on a Tombstone!"), + } + } + + #[inline] + pub fn expect(self, msg: &str) -> T { + match self { + NonTombstone(record) => record, + _ => panic!("{}", msg), + } + } + + #[inline] + pub fn ok_or(self, err: E) -> ::std::result::Result { + match self { + NonTombstone(record) => Ok(record), + _ => Err(err) + } + } + + #[inline] + pub fn record(self) -> Option { + match self { + NonTombstone(record) => Some(record), + _ => None + } + } +} + +impl Sync15Record for MaybeTombstone where T: Sync15Record { + fn collection_tag() -> &'static str { T::collection_tag() } + fn ttl() -> Option { T::ttl() } + fn record_id(&self) -> &str { + match self { + &Tombstone { ref id, .. } => id, + &NonTombstone(ref record) => record.record_id() + } + } + fn sortindex(&self) -> Option { + match self { + &Tombstone { .. } => None, + &NonTombstone(ref record) => record.sortindex() + } + } +} + +impl BsoRecord> { + #[inline] + pub fn is_tombstone(&self) -> bool { + self.payload.is_tombstone() + } + + #[inline] + pub fn record(self) -> Option> where T: Clone { + self.map_payload(|payload| payload.record()).transpose() + } +} + +pub type MaybeTombstoneRecord = BsoRecord>; + +#[cfg(test)] +mod tests { + + use super::*; + use key_bundle::KeyBundle; + use util::ServerTimestamp; + use serde_json; + + #[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug)] + struct DummyRecord { + id: String, + age: i64, + meta: String, + } + + impl Sync15Record for DummyRecord { + fn collection_tag() -> &'static str { "dummy" } + fn record_id(&self) -> &str { &self.id } + } + + #[test] + fn test_roundtrip_crypt_tombstone() { + let orig_record: MaybeTombstoneRecord = BsoRecord { + id: "aaaaaaaaaaaa".into(), + collection: "dummy".into(), + modified: ServerTimestamp(1234.0), + sortindex: None, + ttl: None, + payload: MaybeTombstone::tombstone("aaaaaaaaaaaa") + }; + + assert!(orig_record.is_tombstone()); + + let keybundle = KeyBundle::new_random().unwrap(); + + let encrypted = orig_record.clone().encrypt(&keybundle).unwrap(); + + assert!(keybundle.verify_hmac_string( + &encrypted.payload.hmac, &encrypted.payload.ciphertext).unwrap()); + + // While we're here, check on EncryptedPayload::serialized_len + let val_rec = serde_json::from_str::( + &serde_json::to_string(&encrypted).unwrap()).unwrap(); + assert_eq!(encrypted.payload.serialized_len(), + val_rec["payload"].as_str().unwrap().len()); + + let decrypted: MaybeTombstoneRecord = encrypted.decrypt(&keybundle).unwrap(); + assert!(decrypted.is_tombstone()); + assert_eq!(decrypted, orig_record); + } + + #[test] + fn test_roundtrip_crypt_record() { + let orig_record: MaybeTombstoneRecord = BsoRecord { + id: "aaaaaaaaaaaa".into(), + collection: "dummy".into(), + modified: ServerTimestamp(1234.0), + sortindex: None, + ttl: None, + payload: NonTombstone(DummyRecord { + id: "aaaaaaaaaaaa".into(), + age: 105, + meta: "data".into() + }) + }; + + assert!(!orig_record.is_tombstone()); + + let keybundle = KeyBundle::new_random().unwrap(); + + let encrypted = orig_record.clone().encrypt(&keybundle).unwrap(); + + assert!(keybundle.verify_hmac_string( + &encrypted.payload.hmac, &encrypted.payload.ciphertext).unwrap()); + + // While we're here, check on EncryptedPayload::serialized_len + let val_rec = serde_json::from_str::( + &serde_json::to_string(&encrypted).unwrap()).unwrap(); + assert_eq!(encrypted.payload.serialized_len(), + val_rec["payload"].as_str().unwrap().len()); + + let decrypted: MaybeTombstoneRecord = encrypted.decrypt(&keybundle).unwrap(); + assert!(!decrypted.is_tombstone()); + assert_eq!(decrypted, orig_record); + } +} diff --git a/sync15-adapter/src/util.rs b/sync15-adapter/src/util.rs new file mode 100644 index 0000000000..27dbc6a67f --- /dev/null +++ b/sync15-adapter/src/util.rs @@ -0,0 +1,135 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +use std::convert::From; +use std::time::Duration; +use std::{fmt, num}; +use std::str::FromStr; +use openssl; +use base64; + +pub fn base16_encode(bytes: &[u8]) -> String { + // This seems to be the fastest way of doing this without using a bunch of unsafe: + // https://gist.github.com/thomcc/c4860d68cf31f9b0283c692f83a239f3 + static HEX_CHARS: &'static [u8] = b"0123456789abcdef"; + let mut result = vec![0u8; bytes.len() * 2]; + let mut index = 0; + for &byte in bytes { + result[index + 0] = HEX_CHARS[(byte >> 4) as usize]; + result[index + 1] = HEX_CHARS[(byte & 15) as usize]; + index += 2; + } + // We know statically that this unwrap is safe, since we can only write ascii + String::from_utf8(result).unwrap() +} + +pub fn random_guid() -> Result { + let mut bytes = vec![0u8; 9]; + openssl::rand::rand_bytes(&mut bytes)?; + Ok(base64::encode_config(&bytes, base64::URL_SAFE_NO_PAD)) +} + +/// Typesafe way to manage server timestamps without accidentally mixing them up with +/// local ones. +/// +/// TODO: We should probably store this as milliseconds (or something) for stability and to get +/// Eq/Ord. The server guarantees that these are formatted to the hundreds place (not sure if this +/// is documented but the code does it intentionally...). This would also let us throw out negative +/// and NaN timestamps, which the server certainly won't send, but the guarantee would make me feel +/// better. +#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Deserialize, Serialize, Default)] +pub struct ServerTimestamp(pub f64); + +impl From for f64 { + #[inline] + fn from(ts: ServerTimestamp) -> Self { ts.0 } +} + +impl From for ServerTimestamp { + #[inline] + fn from(ts: f64) -> Self { + assert!(ts >= 0.0); + ServerTimestamp(ts) + } +} + +// This lets us use these in hyper header! blocks. +impl FromStr for ServerTimestamp { + type Err = num::ParseFloatError; + fn from_str(s: &str) -> Result { + Ok(ServerTimestamp(f64::from_str(s)?)) + } +} + +impl fmt::Display for ServerTimestamp { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +pub const SERVER_EPOCH: ServerTimestamp = ServerTimestamp(0.0); + +impl ServerTimestamp { + /// Returns None if `other` is later than `self` (Duration may not represent + /// negative timespans in rust). + #[inline] + pub fn duration_since(self, other: ServerTimestamp) -> Option { + let delta = self.0 - other.0; + if delta < 0.0 { + None + } else { + let secs = delta.floor(); + // We don't want to round here, since it could round up, and + // Duration::new will panic if it rounds up to 1e9 nanoseconds. + let nanos = ((delta - secs) * 1_000_000_000.0).floor() as u32; + Some(Duration::new(secs as u64, nanos)) + } + } + + /// Get the milliseconds for the timestamp. + #[inline] + pub fn as_millis(self) -> u64 { + (self.0 * 1000.0).floor() as u64 + } +} + +#[cfg(test)] +mod test { + use super::*; + use std::collections::HashSet; + + #[test] + fn test_server_timestamp() { + let t0 = ServerTimestamp(10300.15); + let t1 = ServerTimestamp(10100.05); + assert!(t1.duration_since(t0).is_none()); + assert!(t0.duration_since(t1).is_some()); + let dur = t0.duration_since(t1).unwrap(); + assert_eq!(dur.as_secs(), 200); + assert_eq!(dur.subsec_nanos(), 100_000_000); + } + + #[test] + fn test_base16_encode() { + assert_eq!(base16_encode(&[0x01, 0x10, 0x00, 0x00, 0xab, 0xbc, 0xde, 0xff]), + "01100000abbcdeff"); + assert_eq!(base16_encode(&[]), ""); + assert_eq!(base16_encode(&[0, 0, 0, 0]), "00000000"); + assert_eq!(base16_encode(&[0xff, 0xff, 0xff, 0xff]), "ffffffff"); + assert_eq!(base16_encode(&[0x00, 0x01, 0x02, 0x03, 0x0a]), "000102030a"); + assert_eq!(base16_encode(&[0x00, 0x10, 0x20, 0x30, 0xa0]), "00102030a0"); + } + + #[test] + fn test_gen_guid() { + let mut set = HashSet::new(); + for _ in 0..100 { + let res = random_guid().unwrap(); + assert_eq!(res.len(), 12); + assert!(!set.contains(&res)); + set.insert(res); + } + } +}