use std::sync::atomic::{AtomicBool, Ordering}; use rocket::Route; use rocket_contrib::json::Json; use serde_json::Value as JsonValue; use crate::{api::EmptyResult, auth::Headers, db::DbConn, Error, CONFIG}; pub fn routes() -> Vec { routes![negotiate, websockets_err] } static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true); #[get("/hub")] fn websockets_err() -> EmptyResult { if CONFIG.websocket_enabled() && SHOW_WEBSOCKETS_MSG .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { err!( " ########################################################### '/notifications/hub' should be proxied to the websocket server or notifications won't work. Go to the Wiki for more info, or disable WebSockets setting WEBSOCKET_ENABLED=false. ###########################################################################################\n" ) } else { Err(Error::empty()) } } #[post("/hub/negotiate")] fn negotiate(_headers: Headers, _conn: DbConn) -> Json { use crate::crypto; use data_encoding::BASE64URL; let conn_id = BASE64URL.encode(&crypto::get_random(vec![0u8; 16])); let mut available_transports: Vec = Vec::new(); if CONFIG.websocket_enabled() { available_transports.push(json!({"transport":"WebSockets", "transferFormats":["Text","Binary"]})); } // TODO: Implement transports // Rocket WS support: https://github.com/SergioBenitez/Rocket/issues/90 // Rocket SSE support: https://github.com/SergioBenitez/Rocket/issues/33 // {"transport":"ServerSentEvents", "transferFormats":["Text"]}, // {"transport":"LongPolling", "transferFormats":["Text","Binary"]} Json(json!({ "connectionId": conn_id, "availableTransports": available_transports })) } // // Websockets server // use std::io; use std::sync::Arc; use std::thread; use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender}; use chashmap::CHashMap; use chrono::NaiveDateTime; use serde_json::from_str; use crate::db::models::{Cipher, Folder, User}; use rmpv::Value; fn serialize(val: Value) -> Vec { use rmpv::encode::write_value; let mut buf = Vec::new(); write_value(&mut buf, &val).expect("Error encoding MsgPack"); // Add size bytes at the start // Extracted from BinaryMessageFormat.js let mut size: usize = buf.len(); let mut len_buf: Vec = Vec::new(); loop { let mut size_part = size & 0x7f; size >>= 7; if size > 0 { size_part |= 0x80; } len_buf.push(size_part as u8); if size == 0 { break; } } len_buf.append(&mut buf); len_buf } fn serialize_date(date: NaiveDateTime) -> Value { let seconds: i64 = date.timestamp(); let nanos: i64 = date.timestamp_subsec_nanos().into(); let timestamp = nanos << 34 | seconds; let bs = timestamp.to_be_bytes(); // -1 is Timestamp // https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type Value::Ext(-1, bs.to_vec()) } fn convert_option>(option: Option) -> Value { match option { Some(a) => a.into(), None => Value::Nil, } } // Server WebSocket handler pub struct WsHandler { out: Sender, user_uuid: Option, users: WebSocketUsers, } const RECORD_SEPARATOR: u8 = 0x1e; const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, #[derive(Deserialize)] struct InitialMessage { protocol: String, version: i32, } const PING_MS: u64 = 15_000; const PING: Token = Token(1); const ACCESS_TOKEN_KEY: &str = "access_token="; impl WsHandler { fn err(&self, msg: &'static str) -> ws::Result<()> { self.out.close(ws::CloseCode::Invalid)?; // We need to specifically return an IO error so ws closes the connection let io_error = io::Error::from(io::ErrorKind::InvalidData); Err(ws::Error::new(ws::ErrorKind::Io(io_error), msg)) } fn get_request_token(&self, hs: Handshake) -> Option { use std::str::from_utf8; // Verify we have a token header if let Some(header_value) = hs.request.header("Authorization") { if let Ok(converted) = from_utf8(header_value) { if let Some(token_part) = converted.split("Bearer ").nth(1) { return Some(token_part.into()); } } }; // Otherwise verify the query parameter value let path = hs.request.resource(); if let Some(params) = path.split('?').nth(1) { let params_iter = params.split('&').take(1); for val in params_iter { if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) { return Some(stripped.into()); } } }; None } } impl Handler for WsHandler { fn on_open(&mut self, hs: Handshake) -> ws::Result<()> { // Path == "/notifications/hub?id===&access_token=" // // We don't use `id`, and as of around 2020-03-25, the official clients // no longer seem to pass `id` (only `access_token`). // Get user token from header or query parameter let access_token = match self.get_request_token(hs) { Some(token) => token, _ => return self.err("Missing access token"), }; // Validate the user use crate::auth; let claims = match auth::decode_login(access_token.as_str()) { Ok(claims) => claims, Err(_) => return self.err("Invalid access token provided"), }; // Assign the user to the handler let user_uuid = claims.sub; self.user_uuid = Some(user_uuid.clone()); // Add the current Sender to the user list let handler_insert = self.out.clone(); let handler_update = self.out.clone(); self.users .map .upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update)); // Schedule a ping to keep the connection alive self.out.timeout(PING_MS, PING) } fn on_message(&mut self, msg: Message) -> ws::Result<()> { if let Message::Text(text) = msg.clone() { let json = &text[..text.len() - 1]; // Remove last char if let Ok(InitialMessage { protocol, version }) = from_str::(json) { if &protocol == "messagepack" && version == 1 { return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message } } } // If it's not the initial message, just echo the message self.out.send(msg) } fn on_timeout(&mut self, event: Token) -> ws::Result<()> { if event == PING { // send ping self.out.send(create_ping())?; // reschedule the timeout self.out.timeout(PING_MS, PING) } else { Ok(()) } } } struct WsFactory { pub users: WebSocketUsers, } impl WsFactory { pub fn init() -> Self { WsFactory { users: WebSocketUsers { map: Arc::new(CHashMap::new()), }, } } } impl Factory for WsFactory { type Handler = WsHandler; fn connection_made(&mut self, out: Sender) -> Self::Handler { WsHandler { out, user_uuid: None, users: self.users.clone(), } } fn connection_lost(&mut self, handler: Self::Handler) { // Remove handler if let Some(user_uuid) = &handler.user_uuid { if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) { if let Some(pos) = user_conn.iter().position(|x| x == &handler.out) { user_conn.remove(pos); } } } } } #[derive(Clone)] pub struct WebSocketUsers { map: Arc>>, } impl WebSocketUsers { fn send_update(&self, user_uuid: &str, data: &[u8]) -> ws::Result<()> { if let Some(user) = self.map.get(user_uuid) { for sender in user.iter() { sender.send(data)?; } } Ok(()) } // NOTE: The last modified date needs to be updated before calling these methods pub fn send_user_update(&self, ut: UpdateType, user: &User) { let data = create_update( vec![ ("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at)), ], ut, ); self.send_update(&user.uuid, &data).ok(); } pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) { let data = create_update( vec![ ("Id".into(), folder.uuid.clone().into()), ("UserId".into(), folder.user_uuid.clone().into()), ("RevisionDate".into(), serialize_date(folder.updated_at)), ], ut, ); self.send_update(&folder.user_uuid, &data).ok(); } pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) { let user_uuid = convert_option(cipher.user_uuid.clone()); let org_uuid = convert_option(cipher.organization_uuid.clone()); let data = create_update( vec![ ("Id".into(), cipher.uuid.clone().into()), ("UserId".into(), user_uuid), ("OrganizationId".into(), org_uuid), ("CollectionIds".into(), Value::Nil), ("RevisionDate".into(), serialize_date(cipher.updated_at)), ], ut, ); for uuid in user_uuids { self.send_update(&uuid, &data).ok(); } } } /* Message Structure [ 1, // MessageType.Invocation {}, // Headers (map) null, // InvocationId "ReceiveMessage", // Target [ // Arguments { "ContextId": "app_id", "Type": ut as i32, "Payload": {} } ] ] */ fn create_update(payload: Vec<(Value, Value)>, ut: UpdateType) -> Vec { use rmpv::Value as V; let value = V::Array(vec![ 1.into(), V::Map(vec![]), V::Nil, "ReceiveMessage".into(), V::Array(vec![V::Map(vec![ ("ContextId".into(), "app_id".into()), ("Type".into(), (ut as i32).into()), ("Payload".into(), payload.into()), ])]), ]); serialize(value) } fn create_ping() -> Vec { serialize(Value::Array(vec![6.into()])) } #[allow(dead_code)] #[derive(PartialEq)] pub enum UpdateType { CipherUpdate = 0, CipherCreate = 1, LoginDelete = 2, FolderDelete = 3, Ciphers = 4, Vault = 5, OrgKeys = 6, FolderCreate = 7, FolderUpdate = 8, CipherDelete = 9, SyncSettings = 10, LogOut = 11, SyncSendCreate = 12, SyncSendUpdate = 13, SyncSendDelete = 14, None = 100, } use rocket::State; pub type Notify<'a> = State<'a, WebSocketUsers>; pub fn start_notification_server() -> WebSocketUsers { let factory = WsFactory::init(); let users = factory.users.clone(); if CONFIG.websocket_enabled() { thread::spawn(move || { let settings = ws::Settings { max_connections: 500, queue_size: 2, panic_on_internal: false, ..Default::default() }; ws::Builder::new() .with_settings(settings) .build(factory) .unwrap() .listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())) .unwrap(); }); } users }