mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2025-01-13 07:53:24 -05:00
Migrate old ws crate to tungstenite, which is async and also removes over 20 old dependencies
This commit is contained in:
parent
303eaabeea
commit
54c78cf06d
704
Cargo.lock
generated
704
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
12
Cargo.toml
12
Cargo.toml
@ -45,7 +45,7 @@ backtrace = "0.3.65" # Logging panics to logfile instead stderr only
|
|||||||
dotenvy = { version = "0.15.1", default-features = false }
|
dotenvy = { version = "0.15.1", default-features = false }
|
||||||
|
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
once_cell = "1.10.0"
|
once_cell = "1.11.0"
|
||||||
|
|
||||||
# Numerical libraries
|
# Numerical libraries
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
@ -55,9 +55,9 @@ num-derive = "0.3.3"
|
|||||||
rocket = { version = "0.5.0-rc.2", features = ["tls", "json"], default-features = false }
|
rocket = { version = "0.5.0-rc.2", features = ["tls", "json"], default-features = false }
|
||||||
|
|
||||||
# WebSockets libraries
|
# WebSockets libraries
|
||||||
ws = { version = "0.11.1", package = "parity-ws" }
|
tokio-tungstenite = "0.17.1"
|
||||||
rmpv = "1.0.0" # MessagePack library
|
rmpv = "1.0.0" # MessagePack library
|
||||||
chashmap = "2.2.2" # Concurrent hashmap implementation
|
dashmap = "5.3.3" # Concurrent hashmap implementation
|
||||||
|
|
||||||
# Async futures
|
# Async futures
|
||||||
futures = "0.3.21"
|
futures = "0.3.21"
|
||||||
@ -96,7 +96,7 @@ data-encoding = "2.3.2"
|
|||||||
jsonwebtoken = "8.1.0"
|
jsonwebtoken = "8.1.0"
|
||||||
|
|
||||||
# TOTP library
|
# TOTP library
|
||||||
totp-lite = "1.0.3"
|
totp-lite = "2.0.0"
|
||||||
|
|
||||||
# Yubico Library
|
# Yubico Library
|
||||||
yubico = { version = "0.11.0", features = ["online-tokio"], default-features = false }
|
yubico = { version = "0.11.0", features = ["online-tokio"], default-features = false }
|
||||||
@ -112,14 +112,14 @@ lettre = { version = "0.10.0-rc.7", features = ["smtp-transport", "builder", "se
|
|||||||
percent-encoding = "2.1.0" # URL encoding library used for URL's in the emails
|
percent-encoding = "2.1.0" # URL encoding library used for URL's in the emails
|
||||||
|
|
||||||
# Template library
|
# Template library
|
||||||
handlebars = { version = "4.2.2", features = ["dir_source"] }
|
handlebars = { version = "4.3.0", features = ["dir_source"] }
|
||||||
|
|
||||||
# HTTP client
|
# HTTP client
|
||||||
reqwest = { version = "0.11.10", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
reqwest = { version = "0.11.10", features = ["stream", "json", "gzip", "brotli", "socks", "cookies", "trust-dns"] }
|
||||||
|
|
||||||
# For favicon extraction from main website
|
# For favicon extraction from main website
|
||||||
html5gum = "0.4.0"
|
html5gum = "0.4.0"
|
||||||
regex = { version = "1.5.5", features = ["std", "perf", "unicode-perl"], default-features = false }
|
regex = { version = "1.5.6", features = ["std", "perf", "unicode-perl"], default-features = false }
|
||||||
data-url = "0.1.1"
|
data-url = "0.1.1"
|
||||||
bytes = "1.1.0"
|
bytes = "1.1.0"
|
||||||
cached = "0.34.0"
|
cached = "0.34.0"
|
||||||
|
@ -464,7 +464,7 @@ pub async fn update_cipher_from_data(
|
|||||||
cipher.set_favorite(data.Favorite, &headers.user.uuid, conn).await?;
|
cipher.set_favorite(data.Favorite, &headers.user.uuid, conn).await?;
|
||||||
|
|
||||||
if ut != UpdateType::None {
|
if ut != UpdateType::None {
|
||||||
nt.send_cipher_update(ut, cipher, &cipher.update_users_revision(conn).await);
|
nt.send_cipher_update(ut, cipher, &cipher.update_users_revision(conn).await).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -527,7 +527,7 @@ async fn post_ciphers_import(
|
|||||||
|
|
||||||
let mut user = headers.user;
|
let mut user = headers.user;
|
||||||
user.update_revision(&conn).await?;
|
user.update_revision(&conn).await?;
|
||||||
nt.send_user_update(UpdateType::Vault, &user);
|
nt.send_user_update(UpdateType::Vault, &user).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1000,7 +1000,7 @@ async fn save_attachment(
|
|||||||
|
|
||||||
data.data.persist_to(file_path).await?;
|
data.data.persist_to(file_path).await?;
|
||||||
|
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn).await);
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok((cipher, conn))
|
Ok((cipher, conn))
|
||||||
}
|
}
|
||||||
@ -1266,7 +1266,7 @@ async fn move_cipher_selected(
|
|||||||
// Move cipher
|
// Move cipher
|
||||||
cipher.move_to_folder(data.FolderId.clone(), &user_uuid, &conn).await?;
|
cipher.move_to_folder(data.FolderId.clone(), &user_uuid, &conn).await?;
|
||||||
|
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &[user_uuid.clone()]);
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &[user_uuid.clone()]).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1313,7 +1313,7 @@ async fn delete_all(
|
|||||||
Some(user_org) => {
|
Some(user_org) => {
|
||||||
if user_org.atype == UserOrgType::Owner {
|
if user_org.atype == UserOrgType::Owner {
|
||||||
Cipher::delete_all_by_organization(&org_data.org_id, &conn).await?;
|
Cipher::delete_all_by_organization(&org_data.org_id, &conn).await?;
|
||||||
nt.send_user_update(UpdateType::Vault, &user);
|
nt.send_user_update(UpdateType::Vault, &user).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
err!("You don't have permission to purge the organization vault");
|
err!("You don't have permission to purge the organization vault");
|
||||||
@ -1334,7 +1334,7 @@ async fn delete_all(
|
|||||||
}
|
}
|
||||||
|
|
||||||
user.update_revision(&conn).await?;
|
user.update_revision(&conn).await?;
|
||||||
nt.send_user_update(UpdateType::Vault, &user);
|
nt.send_user_update(UpdateType::Vault, &user).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1359,10 +1359,10 @@ async fn _delete_cipher_by_uuid(
|
|||||||
if soft_delete {
|
if soft_delete {
|
||||||
cipher.deleted_at = Some(Utc::now().naive_utc());
|
cipher.deleted_at = Some(Utc::now().naive_utc());
|
||||||
cipher.save(conn).await?;
|
cipher.save(conn).await?;
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await);
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await).await;
|
||||||
} else {
|
} else {
|
||||||
cipher.delete(conn).await?;
|
cipher.delete(conn).await?;
|
||||||
nt.send_cipher_update(UpdateType::CipherDelete, &cipher, &cipher.update_users_revision(conn).await);
|
nt.send_cipher_update(UpdateType::CipherDelete, &cipher, &cipher.update_users_revision(conn).await).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1407,7 +1407,7 @@ async fn _restore_cipher_by_uuid(uuid: &str, headers: &Headers, conn: &DbConn, n
|
|||||||
cipher.deleted_at = None;
|
cipher.deleted_at = None;
|
||||||
cipher.save(conn).await?;
|
cipher.save(conn).await?;
|
||||||
|
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await);
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await).await;
|
||||||
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, conn).await))
|
Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, conn).await))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1469,7 +1469,7 @@ async fn _delete_cipher_attachment_by_id(
|
|||||||
|
|
||||||
// Delete attachment
|
// Delete attachment
|
||||||
attachment.delete(conn).await?;
|
attachment.delete(conn).await?;
|
||||||
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await);
|
nt.send_cipher_update(UpdateType::CipherUpdate, &cipher, &cipher.update_users_revision(conn).await).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ async fn post_folders(data: JsonUpcase<FolderData>, headers: Headers, conn: DbCo
|
|||||||
let mut folder = Folder::new(headers.user.uuid, data.Name);
|
let mut folder = Folder::new(headers.user.uuid, data.Name);
|
||||||
|
|
||||||
folder.save(&conn).await?;
|
folder.save(&conn).await?;
|
||||||
nt.send_folder_update(UpdateType::FolderCreate, &folder);
|
nt.send_folder_update(UpdateType::FolderCreate, &folder).await;
|
||||||
|
|
||||||
Ok(Json(folder.to_json()))
|
Ok(Json(folder.to_json()))
|
||||||
}
|
}
|
||||||
@ -88,7 +88,7 @@ async fn put_folder(
|
|||||||
folder.name = data.Name;
|
folder.name = data.Name;
|
||||||
|
|
||||||
folder.save(&conn).await?;
|
folder.save(&conn).await?;
|
||||||
nt.send_folder_update(UpdateType::FolderUpdate, &folder);
|
nt.send_folder_update(UpdateType::FolderUpdate, &folder).await;
|
||||||
|
|
||||||
Ok(Json(folder.to_json()))
|
Ok(Json(folder.to_json()))
|
||||||
}
|
}
|
||||||
@ -112,6 +112,6 @@ async fn delete_folder(uuid: String, headers: Headers, conn: DbConn, nt: Notify<
|
|||||||
// Delete the actual folder entry
|
// Delete the actual folder entry
|
||||||
folder.delete(&conn).await?;
|
folder.delete(&conn).await?;
|
||||||
|
|
||||||
nt.send_folder_update(UpdateType::FolderDelete, &folder);
|
nt.send_folder_update(UpdateType::FolderDelete, &folder).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -173,7 +173,7 @@ async fn post_send(data: JsonUpcase<SendData>, headers: Headers, conn: DbConn, n
|
|||||||
|
|
||||||
let mut send = create_send(data, headers.user.uuid).await?;
|
let mut send = create_send(data, headers.user.uuid).await?;
|
||||||
send.save(&conn).await?;
|
send.save(&conn).await?;
|
||||||
nt.send_send_update(UpdateType::SyncSendCreate, &send, &send.update_users_revision(&conn).await);
|
nt.send_send_update(UpdateType::SyncSendCreate, &send, &send.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
@ -237,7 +237,7 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, conn: DbCo
|
|||||||
|
|
||||||
// Save the changes in the database
|
// Save the changes in the database
|
||||||
send.save(&conn).await?;
|
send.save(&conn).await?;
|
||||||
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await);
|
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
@ -418,7 +418,7 @@ async fn put_send(
|
|||||||
}
|
}
|
||||||
|
|
||||||
send.save(&conn).await?;
|
send.save(&conn).await?;
|
||||||
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await);
|
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
@ -435,7 +435,7 @@ async fn delete_send(id: String, headers: Headers, conn: DbConn, nt: Notify<'_>)
|
|||||||
}
|
}
|
||||||
|
|
||||||
send.delete(&conn).await?;
|
send.delete(&conn).await?;
|
||||||
nt.send_send_update(UpdateType::SyncSendDelete, &send, &send.update_users_revision(&conn).await);
|
nt.send_send_update(UpdateType::SyncSendDelete, &send, &send.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -455,7 +455,7 @@ async fn put_remove_password(id: String, headers: Headers, conn: DbConn, nt: Not
|
|||||||
|
|
||||||
send.set_password(None);
|
send.set_password(None);
|
||||||
send.save(&conn).await?;
|
send.save(&conn).await?;
|
||||||
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await);
|
nt.send_send_update(UpdateType::SyncSendUpdate, &send, &send.update_users_revision(&conn).await).await;
|
||||||
|
|
||||||
Ok(Json(send.to_json()))
|
Ok(Json(send.to_json()))
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,41 @@
|
|||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
use std::{
|
||||||
|
net::SocketAddr,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
use rocket::serde::json::Json;
|
use chrono::NaiveDateTime;
|
||||||
use rocket::Route;
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use rmpv::Value;
|
||||||
|
use rocket::{serde::json::Json, Route};
|
||||||
use serde_json::Value as JsonValue;
|
use serde_json::Value as JsonValue;
|
||||||
|
use tokio::{
|
||||||
|
net::{TcpListener, TcpStream},
|
||||||
|
sync::mpsc::Sender,
|
||||||
|
};
|
||||||
|
use tokio_tungstenite::{
|
||||||
|
accept_hdr_async,
|
||||||
|
tungstenite::{handshake, Message},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::{api::EmptyResult, auth::Headers, Error, CONFIG};
|
use crate::{
|
||||||
|
api::EmptyResult,
|
||||||
|
auth::Headers,
|
||||||
|
db::models::{Cipher, Folder, Send, User},
|
||||||
|
Error, CONFIG,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn routes() -> Vec<Route> {
|
pub fn routes() -> Vec<Route> {
|
||||||
routes![negotiate, websockets_err]
|
routes![negotiate, websockets_err]
|
||||||
}
|
}
|
||||||
|
|
||||||
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
|
|
||||||
|
|
||||||
#[get("/hub")]
|
#[get("/hub")]
|
||||||
fn websockets_err() -> EmptyResult {
|
fn websockets_err() -> EmptyResult {
|
||||||
|
static SHOW_WEBSOCKETS_MSG: AtomicBool = AtomicBool::new(true);
|
||||||
|
|
||||||
if CONFIG.websocket_enabled()
|
if CONFIG.websocket_enabled()
|
||||||
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
|
&& SHOW_WEBSOCKETS_MSG.compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed).is_ok()
|
||||||
{
|
{
|
||||||
@ -55,19 +77,6 @@ fn negotiate(_headers: Headers) -> Json<JsonValue> {
|
|||||||
//
|
//
|
||||||
// Websockets server
|
// Websockets server
|
||||||
//
|
//
|
||||||
use std::io;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::thread;
|
|
||||||
|
|
||||||
use ws::{self, util::Token, Factory, Handler, Handshake, Message, Sender};
|
|
||||||
|
|
||||||
use chashmap::CHashMap;
|
|
||||||
use chrono::NaiveDateTime;
|
|
||||||
use serde_json::from_str;
|
|
||||||
|
|
||||||
use crate::db::models::{Cipher, Folder, Send, User};
|
|
||||||
|
|
||||||
use rmpv::Value;
|
|
||||||
|
|
||||||
fn serialize(val: Value) -> Vec<u8> {
|
fn serialize(val: Value) -> Vec<u8> {
|
||||||
use rmpv::encode::write_value;
|
use rmpv::encode::write_value;
|
||||||
@ -118,192 +127,49 @@ fn convert_option<T: Into<Value>>(option: Option<T>) -> Value {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server WebSocket handler
|
|
||||||
pub struct WsHandler {
|
|
||||||
out: Sender,
|
|
||||||
user_uuid: Option<String>,
|
|
||||||
users: WebSocketUsers,
|
|
||||||
}
|
|
||||||
|
|
||||||
const RECORD_SEPARATOR: u8 = 0x1e;
|
const RECORD_SEPARATOR: u8 = 0x1e;
|
||||||
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
|
const INITIAL_RESPONSE: [u8; 3] = [0x7b, 0x7d, RECORD_SEPARATOR]; // {, }, <RS>
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize, Copy, Clone, Eq, PartialEq)]
|
||||||
struct InitialMessage {
|
struct InitialMessage<'a> {
|
||||||
protocol: String,
|
protocol: &'a str,
|
||||||
version: i32,
|
version: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
const PING_MS: u64 = 15_000;
|
static INITIAL_MESSAGE: InitialMessage<'static> = InitialMessage {
|
||||||
const PING: Token = Token(1);
|
protocol: "messagepack",
|
||||||
|
version: 1,
|
||||||
const ACCESS_TOKEN_KEY: &str = "access_token=";
|
};
|
||||||
|
|
||||||
impl WsHandler {
|
|
||||||
fn err(&self, msg: &'static str) -> ws::Result<()> {
|
|
||||||
self.out.close(ws::CloseCode::Invalid)?;
|
|
||||||
|
|
||||||
// We need to specifically return an IO error so ws closes the connection
|
|
||||||
let io_error = io::Error::from(io::ErrorKind::InvalidData);
|
|
||||||
Err(ws::Error::new(ws::ErrorKind::Io(io_error), msg))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_request_token(&self, hs: Handshake) -> Option<String> {
|
|
||||||
use std::str::from_utf8;
|
|
||||||
|
|
||||||
// Verify we have a token header
|
|
||||||
if let Some(header_value) = hs.request.header("Authorization") {
|
|
||||||
if let Ok(converted) = from_utf8(header_value) {
|
|
||||||
if let Some(token_part) = converted.split("Bearer ").nth(1) {
|
|
||||||
return Some(token_part.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Otherwise verify the query parameter value
|
|
||||||
let path = hs.request.resource();
|
|
||||||
if let Some(params) = path.split('?').nth(1) {
|
|
||||||
let params_iter = params.split('&').take(1);
|
|
||||||
for val in params_iter {
|
|
||||||
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
|
|
||||||
return Some(stripped.into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Handler for WsHandler {
|
|
||||||
fn on_open(&mut self, hs: Handshake) -> ws::Result<()> {
|
|
||||||
// Path == "/notifications/hub?id=<id>==&access_token=<access_token>"
|
|
||||||
//
|
|
||||||
// We don't use `id`, and as of around 2020-03-25, the official clients
|
|
||||||
// no longer seem to pass `id` (only `access_token`).
|
|
||||||
|
|
||||||
// Get user token from header or query parameter
|
|
||||||
let access_token = match self.get_request_token(hs) {
|
|
||||||
Some(token) => token,
|
|
||||||
_ => return self.err("Missing access token"),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Validate the user
|
|
||||||
use crate::auth;
|
|
||||||
let claims = match auth::decode_login(access_token.as_str()) {
|
|
||||||
Ok(claims) => claims,
|
|
||||||
Err(_) => return self.err("Invalid access token provided"),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Assign the user to the handler
|
|
||||||
let user_uuid = claims.sub;
|
|
||||||
self.user_uuid = Some(user_uuid.clone());
|
|
||||||
|
|
||||||
// Add the current Sender to the user list
|
|
||||||
let handler_insert = self.out.clone();
|
|
||||||
let handler_update = self.out.clone();
|
|
||||||
|
|
||||||
self.users.map.upsert(user_uuid, || vec![handler_insert], |ref mut v| v.push(handler_update));
|
|
||||||
|
|
||||||
// Schedule a ping to keep the connection alive
|
|
||||||
self.out.timeout(PING_MS, PING)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_message(&mut self, msg: Message) -> ws::Result<()> {
|
|
||||||
if let Message::Text(text) = msg.clone() {
|
|
||||||
let json = &text[..text.len() - 1]; // Remove last char
|
|
||||||
|
|
||||||
if let Ok(InitialMessage {
|
|
||||||
protocol,
|
|
||||||
version,
|
|
||||||
}) = from_str::<InitialMessage>(json)
|
|
||||||
{
|
|
||||||
if &protocol == "messagepack" && version == 1 {
|
|
||||||
return self.out.send(&INITIAL_RESPONSE[..]); // Respond to initial message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If it's not the initial message, just echo the message
|
|
||||||
self.out.send(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn on_timeout(&mut self, event: Token) -> ws::Result<()> {
|
|
||||||
if event == PING {
|
|
||||||
// send ping
|
|
||||||
self.out.send(create_ping())?;
|
|
||||||
|
|
||||||
// reschedule the timeout
|
|
||||||
self.out.timeout(PING_MS, PING)
|
|
||||||
} else {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct WsFactory {
|
|
||||||
pub users: WebSocketUsers,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WsFactory {
|
|
||||||
pub fn init() -> Self {
|
|
||||||
WsFactory {
|
|
||||||
users: WebSocketUsers {
|
|
||||||
map: Arc::new(CHashMap::new()),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Factory for WsFactory {
|
|
||||||
type Handler = WsHandler;
|
|
||||||
|
|
||||||
fn connection_made(&mut self, out: Sender) -> Self::Handler {
|
|
||||||
WsHandler {
|
|
||||||
out,
|
|
||||||
user_uuid: None,
|
|
||||||
users: self.users.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn connection_lost(&mut self, handler: Self::Handler) {
|
|
||||||
// Remove handler
|
|
||||||
if let Some(user_uuid) = &handler.user_uuid {
|
|
||||||
if let Some(mut user_conn) = self.users.map.get_mut(user_uuid) {
|
|
||||||
if let Some(pos) = user_conn.iter().position(|x| x == &handler.out) {
|
|
||||||
user_conn.remove(pos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// We attach the UUID to the sender so we can differentiate them when we need to remove them from the Vec
|
||||||
|
type UserSenders = (uuid::Uuid, Sender<Message>);
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct WebSocketUsers {
|
pub struct WebSocketUsers {
|
||||||
map: Arc<CHashMap<String, Vec<Sender>>>,
|
map: Arc<dashmap::DashMap<String, Vec<UserSenders>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WebSocketUsers {
|
impl WebSocketUsers {
|
||||||
fn send_update(&self, user_uuid: &str, data: &[u8]) -> ws::Result<()> {
|
async fn send_update(&self, user_uuid: &str, data: &[u8]) {
|
||||||
if let Some(user) = self.map.get(user_uuid) {
|
if let Some(user) = self.map.get(user_uuid).map(|v| v.clone()) {
|
||||||
for sender in user.iter() {
|
for (_, sender) in user.iter() {
|
||||||
sender.send(data)?;
|
if sender.send(Message::binary(data)).await.is_err() {
|
||||||
|
// TODO: Delete from map here too?
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: The last modified date needs to be updated before calling these methods
|
// NOTE: The last modified date needs to be updated before calling these methods
|
||||||
pub fn send_user_update(&self, ut: UpdateType, user: &User) {
|
pub async fn send_user_update(&self, ut: UpdateType, user: &User) {
|
||||||
let data = create_update(
|
let data = create_update(
|
||||||
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
|
vec![("UserId".into(), user.uuid.clone().into()), ("Date".into(), serialize_date(user.updated_at))],
|
||||||
ut,
|
ut,
|
||||||
);
|
);
|
||||||
|
|
||||||
self.send_update(&user.uuid, &data).ok();
|
self.send_update(&user.uuid, &data).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
|
pub async fn send_folder_update(&self, ut: UpdateType, folder: &Folder) {
|
||||||
let data = create_update(
|
let data = create_update(
|
||||||
vec![
|
vec![
|
||||||
("Id".into(), folder.uuid.clone().into()),
|
("Id".into(), folder.uuid.clone().into()),
|
||||||
@ -313,10 +179,10 @@ impl WebSocketUsers {
|
|||||||
ut,
|
ut,
|
||||||
);
|
);
|
||||||
|
|
||||||
self.send_update(&folder.user_uuid, &data).ok();
|
self.send_update(&folder.user_uuid, &data).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) {
|
pub async fn send_cipher_update(&self, ut: UpdateType, cipher: &Cipher, user_uuids: &[String]) {
|
||||||
let user_uuid = convert_option(cipher.user_uuid.clone());
|
let user_uuid = convert_option(cipher.user_uuid.clone());
|
||||||
let org_uuid = convert_option(cipher.organization_uuid.clone());
|
let org_uuid = convert_option(cipher.organization_uuid.clone());
|
||||||
|
|
||||||
@ -332,11 +198,11 @@ impl WebSocketUsers {
|
|||||||
);
|
);
|
||||||
|
|
||||||
for uuid in user_uuids {
|
for uuid in user_uuids {
|
||||||
self.send_update(uuid, &data).ok();
|
self.send_update(uuid, &data).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
|
pub async fn send_send_update(&self, ut: UpdateType, send: &Send, user_uuids: &[String]) {
|
||||||
let user_uuid = convert_option(send.user_uuid.clone());
|
let user_uuid = convert_option(send.user_uuid.clone());
|
||||||
|
|
||||||
let data = create_update(
|
let data = create_update(
|
||||||
@ -349,7 +215,7 @@ impl WebSocketUsers {
|
|||||||
);
|
);
|
||||||
|
|
||||||
for uuid in user_uuids {
|
for uuid in user_uuids {
|
||||||
self.send_update(uuid, &data).ok();
|
self.send_update(uuid, &data).await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -416,27 +282,129 @@ pub enum UpdateType {
|
|||||||
None = 100,
|
None = 100,
|
||||||
}
|
}
|
||||||
|
|
||||||
use rocket::State;
|
pub type Notify<'a> = &'a rocket::State<WebSocketUsers>;
|
||||||
pub type Notify<'a> = &'a State<WebSocketUsers>;
|
|
||||||
|
|
||||||
pub fn start_notification_server() -> WebSocketUsers {
|
pub fn start_notification_server() -> WebSocketUsers {
|
||||||
let factory = WsFactory::init();
|
let users = WebSocketUsers {
|
||||||
let users = factory.users.clone();
|
map: Arc::new(dashmap::DashMap::new()),
|
||||||
|
};
|
||||||
|
|
||||||
if CONFIG.websocket_enabled() {
|
if CONFIG.websocket_enabled() {
|
||||||
thread::spawn(move || {
|
let users2 = users.clone();
|
||||||
let mut settings = ws::Settings::default();
|
tokio::spawn(async move {
|
||||||
settings.max_connections = 500;
|
let addr = (CONFIG.websocket_address(), CONFIG.websocket_port());
|
||||||
settings.queue_size = 2;
|
let listener = TcpListener::bind(addr).await.expect("Can't listen on websocket port");
|
||||||
settings.panic_on_internal = false;
|
|
||||||
|
|
||||||
let ws = ws::Builder::new().with_settings(settings).build(factory).unwrap();
|
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||||
CONFIG.set_ws_shutdown_handle(ws.broadcaster());
|
CONFIG.set_ws_shutdown_handle(shutdown_tx);
|
||||||
ws.listen((CONFIG.websocket_address().as_str(), CONFIG.websocket_port())).unwrap();
|
|
||||||
|
|
||||||
warn!("WS Server stopped!");
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Ok((stream, addr)) = listener.accept() => {
|
||||||
|
tokio::spawn(handle_connection(stream, users2.clone(), addr));
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = &mut shutdown_rx => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Shutting down WebSockets server!")
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
users
|
users
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_connection(stream: TcpStream, users: WebSocketUsers, _addr: SocketAddr) -> Result<(), Error> {
|
||||||
|
let mut user_uuid: Option<String> = None;
|
||||||
|
|
||||||
|
// Accept connection, do initial handshake, validate auth token and get the user ID
|
||||||
|
use handshake::server::{Request, Response};
|
||||||
|
let mut stream = accept_hdr_async(stream, |req: &Request, res: Response| {
|
||||||
|
if let Some(token) = get_request_token(req) {
|
||||||
|
if let Ok(claims) = crate::auth::decode_login(&token) {
|
||||||
|
user_uuid = Some(claims.sub);
|
||||||
|
return Ok(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(Response::builder().status(401).body(None).unwrap())
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let user_uuid = user_uuid.expect("User UUID should be set after the handshake");
|
||||||
|
|
||||||
|
// Add a channel to send messages to this client to the map
|
||||||
|
let entry_uuid = uuid::Uuid::new_v4();
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(100);
|
||||||
|
users.map.entry(user_uuid.clone()).or_default().push((entry_uuid, tx));
|
||||||
|
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(15));
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
res = stream.next() => {
|
||||||
|
match res {
|
||||||
|
Some(Ok(message)) => {
|
||||||
|
// We should receive an initial message with the protocol and version, and we will reply to it
|
||||||
|
if let Message::Text(ref message) = message {
|
||||||
|
let msg = message.strip_suffix('\u{1e}').unwrap_or(message);
|
||||||
|
|
||||||
|
if serde_json::from_str(msg).ok() == Some(INITIAL_MESSAGE) {
|
||||||
|
stream.send(Message::binary(INITIAL_RESPONSE)).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Just echo anything else the client sends
|
||||||
|
if stream.send(message).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res = rx.recv() => {
|
||||||
|
match res {
|
||||||
|
Some(res) => {
|
||||||
|
if stream.send(res).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_= interval.tick() => {
|
||||||
|
if stream.send(Message::Binary(create_ping())).await.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete from map
|
||||||
|
users.map.entry(user_uuid).or_default().retain(|(uuid, _)| uuid != &entry_uuid);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_request_token(req: &handshake::server::Request) -> Option<String> {
|
||||||
|
const ACCESS_TOKEN_KEY: &str = "access_token=";
|
||||||
|
|
||||||
|
if let Some(Ok(auth)) = req.headers().get("Authorization").map(|a| a.to_str()) {
|
||||||
|
if let Some(token_part) = auth.strip_prefix("Bearer ") {
|
||||||
|
return Some(token_part.to_owned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(params) = req.uri().query() {
|
||||||
|
let params_iter = params.split('&').take(1);
|
||||||
|
for val in params_iter {
|
||||||
|
if let Some(stripped) = val.strip_prefix(ACCESS_TOKEN_KEY) {
|
||||||
|
return Some(stripped.to_owned());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
@ -37,7 +37,7 @@ macro_rules! make_config {
|
|||||||
|
|
||||||
struct Inner {
|
struct Inner {
|
||||||
rocket_shutdown_handle: Option<rocket::Shutdown>,
|
rocket_shutdown_handle: Option<rocket::Shutdown>,
|
||||||
ws_shutdown_handle: Option<ws::Sender>,
|
ws_shutdown_handle: Option<tokio::sync::oneshot::Sender<()>>,
|
||||||
|
|
||||||
templates: Handlebars<'static>,
|
templates: Handlebars<'static>,
|
||||||
config: ConfigItems,
|
config: ConfigItems,
|
||||||
@ -948,19 +948,17 @@ impl Config {
|
|||||||
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
|
self.inner.write().unwrap().rocket_shutdown_handle = Some(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_ws_shutdown_handle(&self, handle: ws::Sender) {
|
pub fn set_ws_shutdown_handle(&self, handle: tokio::sync::oneshot::Sender<()>) {
|
||||||
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
|
self.inner.write().unwrap().ws_shutdown_handle = Some(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shutdown(&self) {
|
pub fn shutdown(&self) {
|
||||||
if let Ok(c) = self.inner.read() {
|
if let Ok(mut c) = self.inner.write() {
|
||||||
if let Some(handle) = c.ws_shutdown_handle.clone() {
|
if let Some(handle) = c.ws_shutdown_handle.take() {
|
||||||
handle.shutdown().ok();
|
handle.send(()).ok();
|
||||||
}
|
}
|
||||||
// Wait a bit before stopping the web server
|
|
||||||
tokio::runtime::Handle::current()
|
if let Some(handle) = c.rocket_shutdown_handle.take() {
|
||||||
.block_on(async move { tokio::time::sleep(tokio::time::Duration::from_secs(1)).await });
|
|
||||||
if let Some(handle) = c.rocket_shutdown_handle.clone() {
|
|
||||||
handle.notify();
|
handle.notify();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -49,6 +49,7 @@ use rocket::error::Error as RocketErr;
|
|||||||
use serde_json::{Error as SerdeErr, Value};
|
use serde_json::{Error as SerdeErr, Value};
|
||||||
use std::io::Error as IoErr;
|
use std::io::Error as IoErr;
|
||||||
use std::time::SystemTimeError as TimeErr;
|
use std::time::SystemTimeError as TimeErr;
|
||||||
|
use tokio_tungstenite::tungstenite::Error as TungstError;
|
||||||
use webauthn_rs::error::WebauthnError as WebauthnErr;
|
use webauthn_rs::error::WebauthnError as WebauthnErr;
|
||||||
use yubico::yubicoerror::YubicoError as YubiErr;
|
use yubico::yubicoerror::YubicoError as YubiErr;
|
||||||
|
|
||||||
@ -88,6 +89,7 @@ make_error! {
|
|||||||
DieselCon(DieselConErr): _has_source, _api_error,
|
DieselCon(DieselConErr): _has_source, _api_error,
|
||||||
DieselMig(DieselMigErr): _has_source, _api_error,
|
DieselMig(DieselMigErr): _has_source, _api_error,
|
||||||
Webauthn(WebauthnErr): _has_source, _api_error,
|
Webauthn(WebauthnErr): _has_source, _api_error,
|
||||||
|
WebSocket(TungstError): _has_source, _api_error,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Error {
|
impl std::fmt::Debug for Error {
|
||||||
|
Loading…
Reference in New Issue
Block a user