diff --git a/.envrc b/.envrc index b07fdf3..e33eddb 100644 --- a/.envrc +++ b/.envrc @@ -1 +1,2 @@ export RUSTFMT=yew-fmt +export DATABASE_URL=database.db diff --git a/.gitignore b/.gitignore index 951782f..26cb74f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ /target /dist -chat_log +database.db diff --git a/Cargo.lock b/Cargo.lock index 7cf5bf1..1051120 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -509,6 +509,41 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.79", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn 2.0.79", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -525,6 +560,40 @@ dependencies = [ "serde", ] +[[package]] +name = "diesel" +version = "2.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "158fe8e2e68695bd615d7e4f3227c0727b151330d3e253b525086c348d055d5e" +dependencies = [ + "chrono", + "diesel_derives", + "libsqlite3-sys", + "time", +] + +[[package]] +name = "diesel_derives" +version = "2.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7f2c3de51e2ba6bf2a648285696137aaf0f5f487bcbea93972fe8a364e131a4" +dependencies = [ + "diesel_table_macro_syntax", + "dsl_auto_type", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "diesel_table_macro_syntax" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" +dependencies = [ + "syn 2.0.79", +] + [[package]] name = "digest" version = "0.10.7" @@ -536,6 +605,26 @@ dependencies = [ "subtle", ] +[[package]] +name = "dsl_auto_type" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607" +dependencies = [ + "darling", + "either", + "heck", + "proc-macro2", + "quote", + "syn 2.0.79", +] + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + [[package]] name = "equivalent" version = "1.0.1" @@ -1278,6 +1367,12 @@ dependencies = [ "cc", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "implicit-clone" version = "0.4.9" @@ -1360,6 +1455,16 @@ version = "0.2.159" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "lock_api" version = "0.4.12" @@ -1817,9 +1922,11 @@ dependencies = [ "argon2", "axum", "axum-login", + "chrono", "ciborium", "clap", "common", + "diesel", "futures", "log", "slab", @@ -2348,6 +2455,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version_check" version = "0.9.5" diff --git a/dev.sh b/dev.sh index 880e393..d6d23ad 100755 --- a/dev.sh +++ b/dev.sh @@ -4,4 +4,4 @@ IFS=$'\n\t' (trap 'kill 0' SIGINT; \ bash -c 'cd frontend; trunk serve' & \ - bash -c 'cargo watch -i chat_log -- cargo run --bin server -- --port 8081') + bash -c 'cargo watch -i database.db -- cargo run --bin server -- --port 8081') diff --git a/diesel.toml b/diesel.toml new file mode 100644 index 0000000..9c675a2 --- /dev/null +++ b/diesel.toml @@ -0,0 +1,9 @@ +# For documentation on how to configure this file, +# see https://diesel.rs/guides/configuring-diesel-cli + +[print_schema] +file = "server/src/schema.rs" +custom_type_derives = ["diesel::query_builder::QueryId", "Clone"] + +[migrations_directory] +dir = "/home/pterpstra/projects/local_chat/migrations" diff --git a/frontend/src/pages/login.rs b/frontend/src/pages/login.rs index 6108564..45266d9 100644 --- a/frontend/src/pages/login.rs +++ b/frontend/src/pages/login.rs @@ -85,7 +85,10 @@ pub fn Login() -> Html { Ok(_) => BrowserHistory::new().push(&redirect), Err(err) => { login_error.set(Some(err)); - password_ref.cast::().unwrap().set_value(""); + password_ref + .cast::() + .unwrap() + .set_value(""); disable_input.set(false); } }; diff --git a/hash_pw/src/main.rs b/hash_pw/src/main.rs index 90585a9..6566809 100644 --- a/hash_pw/src/main.rs +++ b/hash_pw/src/main.rs @@ -1,6 +1,9 @@ use std::io::{stdin, stdout, Write}; -use argon2::{password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher}; +use argon2::{ + password_hash::{rand_core::OsRng, SaltString}, + Argon2, PasswordHasher, +}; fn main() { print!("Enter password:"); @@ -8,9 +11,12 @@ fn main() { let mut pw_buf = String::new(); stdin().read_line(&mut pw_buf).unwrap(); let password = pw_buf.trim_end(); - + let salt = SaltString::generate(&mut OsRng); let argon2 = Argon2::default(); - let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string(); + let pw_hash = argon2 + .hash_password(password.as_bytes(), &salt) + .unwrap() + .to_string(); println!("{pw_hash}"); } diff --git a/migrations/.keep b/migrations/.keep new file mode 100644 index 0000000..e69de29 diff --git a/migrations/2024-10-18-181813_create_messages/down.sql b/migrations/2024-10-18-181813_create_messages/down.sql new file mode 100644 index 0000000..bb9ce09 --- /dev/null +++ b/migrations/2024-10-18-181813_create_messages/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE messages; diff --git a/migrations/2024-10-18-181813_create_messages/up.sql b/migrations/2024-10-18-181813_create_messages/up.sql new file mode 100644 index 0000000..2c9a614 --- /dev/null +++ b/migrations/2024-10-18-181813_create_messages/up.sql @@ -0,0 +1,8 @@ +-- Your SQL goes here +CREATE TABLE messages ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + message VARCHAR NOT NULL, + time VARCHAR NOT NULL, + user_id INTEGER NOT NULL, + FOREIGN KEY (user_id) REFERENCES users (id) +) diff --git a/migrations/2024-10-18-182242_create_users/down.sql b/migrations/2024-10-18-182242_create_users/down.sql new file mode 100644 index 0000000..dc3714b --- /dev/null +++ b/migrations/2024-10-18-182242_create_users/down.sql @@ -0,0 +1,2 @@ +-- This file should undo anything in `up.sql` +DROP TABLE users; diff --git a/migrations/2024-10-18-182242_create_users/up.sql b/migrations/2024-10-18-182242_create_users/up.sql new file mode 100644 index 0000000..48332d6 --- /dev/null +++ b/migrations/2024-10-18-182242_create_users/up.sql @@ -0,0 +1,7 @@ +-- Your SQL goes here +CREATE TABLE users ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + username VARCHAR UNIQUE NOT NULL, + pw_hash VARCHAR NOT NULL, + display_name VARCHAR NOT NULL +) diff --git a/server/Cargo.toml b/server/Cargo.toml index e670715..e82d5fb 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -7,9 +7,11 @@ edition = "2021" argon2 = "0.5.3" axum = { version = "0.7.7", features = ["ws"] } axum-login = "0.16.0" +chrono = "0.4.38" ciborium = "0.2.2" clap = { version = "4.5.19", features = ["derive"] } common = { version = "0.1.0", path = "../common" } +diesel = { version = "2.2.4", features = ["chrono", "sqlite"] } futures = "0.3.31" log = "0.4.22" slab = "0.4.9" diff --git a/server/src/auth.rs b/server/src/auth.rs new file mode 100644 index 0000000..714a4ca --- /dev/null +++ b/server/src/auth.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use argon2::{Argon2, PasswordHash, PasswordVerifier}; +use axum::async_trait; +use axum_login::{AuthUser, AuthnBackend, UserId}; +use diesel::{prelude::*, result::Error::NotFound}; +use tokio::sync::Mutex; + +use crate::{models::User, schema::users}; + +pub struct UserCreds { + pub username: String, + pub password: String, +} + +impl AuthUser for User { + type Id = i32; + + fn id(&self) -> Self::Id { + self.id + } + + fn session_auth_hash(&self) -> &[u8] { + self.pw_hash.as_bytes() + } +} + +#[derive(Clone)] +pub struct AuthBackend { + db_conn: Arc>, +} + +impl AuthBackend { + pub fn new(db_conn: Arc>) -> Self { + Self { db_conn } + } +} + +#[async_trait] +impl AuthnBackend for AuthBackend { + + type User = User; + type Credentials = UserCreds; + type Error = diesel::result::Error; + + async fn authenticate( + &self, + creds: Self::Credentials, + ) -> Result, Self::Error> { + let user: User = match users::table + .filter(users::username.eq(&creds.username)) + .first(&mut *self.db_conn.lock().await) + { + Ok(user) => user, + Err(e) => match e { + NotFound => return Ok(None), + _ => return Err(e), + }, + }; + let pw_hash = PasswordHash::new(&user.pw_hash).unwrap(); + if Argon2::default() + .verify_password(creds.password.as_bytes(), &pw_hash) + .is_ok() + { + Ok(Some(user.clone())) + } else { + Ok(None) + } + } + + async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { + match users::table + .find(user_id) + .first(&mut *self.db_conn.lock().await) + { + Ok(user) => Ok(Some(user)), + Err(e) => match e { + NotFound => { + Ok(None) + }, + _ => Err(e), + }, + } + } +} + +pub type AuthSession = axum_login::AuthSession; diff --git a/server/src/main.rs b/server/src/main.rs index f0b3e64..fbc00db 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,34 +1,13 @@ -use argon2::{Argon2, PasswordHash, PasswordVerifier}; -use axum::body::Body; -use axum::extract::ws::{self, CloseFrame, Message, WebSocket}; -use axum::extract::{Request, State, WebSocketUpgrade}; -use axum::http::header::CONTENT_TYPE; -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; -use axum::routing::post; -use axum::{async_trait, Json}; -use axum::{routing::get, Router}; -use axum_login::{AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, UserId}; +mod models; +mod schema; +mod auth; +mod server; + use clap::Parser; -use common::{ChatMessage, LoggedInResponse, LoginRequest, LoginResponse}; -use futures::stream::SplitSink; -use futures::{SinkExt, StreamExt}; -use slab::Slab; -use std::collections::HashMap; +use diesel::prelude::*; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::str::FromStr; -use std::sync::Arc; -use time::Duration; -use tokio::fs::{self, OpenOptions}; -use tokio::io::AsyncWriteExt; -use tokio::sync::Mutex; -use tower::{ServiceBuilder, ServiceExt}; -use tower_http::services::ServeDir; -use tower_http::trace::TraceLayer; -use tower_sessions::cookie::Key; -use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; -use tracing::{debug, warn}; // Setup the command line interface with clap. #[derive(Parser, Debug)] @@ -51,132 +30,6 @@ struct Opt { static_dir: PathBuf, } -struct ServState { - client_sends: Slab>, - chat_log: Vec, -} - -impl ServState { - fn new() -> Self { - let chat_log = std::fs::read("chat_log").map_or(Vec::new(), |chat_log| { - ciborium::from_reader(chat_log.as_slice()).unwrap() - }); - Self { - client_sends: Slab::new(), - chat_log, - } - } -} - -//const USERS: HashMap = { -// let mut map = HashMap::new(); -// map.insert("pjht".to_string(), User { -// username: "pjht".to_string(), -// pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$miwLkAOUyJa7NxWX9ueBJQ$w43pjFVBqnqRvWWiPaW2cbhlxk0Dq5sdYjmy4I+Yh+U".to_string(), -// id: 0, -// }).unwrap(); -//map -//}; - -struct UserCreds { - username: String, - password: String, -} - -#[derive(Clone, Debug)] -struct User { - username: String, - pw_hash: String, -} - -impl AuthUser for User { - type Id = String; - - fn id(&self) -> Self::Id { - self.username.clone() - } - - fn session_auth_hash(&self) -> &[u8] { - self.pw_hash.as_bytes() - } -} - -#[derive(Clone, Debug)] -struct AuthBackend { - users: HashMap -} - -#[async_trait] -impl AuthnBackend for AuthBackend { - type User = User; - type Credentials = UserCreds; - type Error = std::convert::Infallible; - - async fn authenticate( - &self, - creds: Self::Credentials, - ) -> Result, Self::Error> { - Ok((|| { - let user = self.users.get(&creds.username)?; - let pw_hash = PasswordHash::new(&user.pw_hash).unwrap(); - if Argon2::default().verify_password(creds.password.as_bytes(), &pw_hash).is_ok() { - Some(user.clone()) - } else { - None - } - })()) - } - - async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { - Ok(self.users.get(user_id).cloned()) - } -} - -async fn serve_index(static_dir: &Path) -> Response { - let index_path = PathBuf::from(&static_dir).join("index.html"); - let index_content = match fs::read_to_string(index_path).await { - Err(_) => { - return Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from("index file not found")) - .unwrap() - } - Ok(index_content) => index_content, - }; - Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, "text/html") - .body(Body::from(index_content)) - .unwrap() -} - -async fn login( - mut auth_session: AuthSession, - Json(req): Json, -) -> Response { - let Some(user) = auth_session.authenticate(UserCreds {username: req.username, password: req.password}).await.unwrap() else { - return Json(LoginResponse(Err("Invalid username or password".to_string()))).into_response(); - }; - if auth_session.login(&user).await.is_err() { - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - Json(LoginResponse(Ok(()))).into_response() -} - -async fn logout(mut auth_session: AuthSession) -> Response { - if auth_session.logout().await.is_err() { - return StatusCode::INTERNAL_SERVER_ERROR.into_response(); - } - StatusCode::OK.into_response() -} - -async fn logged_in(auth_session: AuthSession) -> Response { - let resp = LoggedInResponse { - logged_in: auth_session.user.is_some(), - }; - Json(resp).into_response() -} - #[tokio::main] async fn main() { let opt = Opt::parse(); @@ -190,60 +43,10 @@ async fn main() { // enable console logging tracing_subscriber::fmt::init(); - let signing_key = Key::generate(); + let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let db_conn = SqliteConnection::establish(&database_url) + .unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url)); - let session_layer = SessionManagerLayer::new(MemoryStore::default()) - .with_secure(false) - .with_expiry(Expiry::OnInactivity(Duration::days(1))) - .with_signed(signing_key); - - let mut users = HashMap::new(); - users.insert("pjht".to_string(), User { - username: "pjht".to_string(), - pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$aI7n6SgiTcVhSBjk7pXo+Q$wcJriSwcIj5Al/oNlZJdxMVOA/15e13t2AaWs4VQmVM".to_string(), - }); - let auth_backend = AuthBackend { - users - }; - let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build(); - - let api = Router::new() - //.route_layer(login_required!(DummyAuthBackend)) - .route("/chat_ws", get(chat_ws)) - .route("/login", post(login)) - .route("/logout", post(logout)) - .route("/logged_in", get(logged_in)); - - let app = Router::new() - .nest("/api", api) - .fallback(|req: Request| async move { - if req.uri().path().starts_with("/api") { - return Response::builder() - .status(StatusCode::NOT_FOUND) - .body(Body::from("")) - .unwrap(); - } - match ServeDir::new(&opt.static_dir) - .append_index_html_on_directories(false) - .oneshot(req) - .await - { - Ok(res) => { - if res.status() == StatusCode::NOT_FOUND { - serve_index(&opt.static_dir).await - } else { - res.map(Body::new) - } - } - Err(err) => Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .body(Body::from(format!("error: {err}"))) - .expect("error response"), - } - }) - .with_state(Arc::new(Mutex::new(ServState::new()))) - .layer(auth_layer) - .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); let sock_addr = SocketAddr::from(( IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)), @@ -254,76 +57,5 @@ async fn main() { log::info!("listening on http://{}", sock_addr); - axum::serve(listener, app) - .await - .expect("Unable to start server"); -} - -async fn chat_ws( - State(state): State>>, - auth_session: AuthSession, - ws: WebSocketUpgrade, -) -> Response { - ws.on_upgrade(move |socket| async move { - let (mut tx, mut rx) = socket.split(); - if auth_session.user.is_none() { - let _ = tx - .send(Message::Close(Some(CloseFrame { - code: 4000, - reason: "Unauthorized".into(), - }))) - .await; - return; - } - debug!("Client connected"); - for msg in &state.lock().await.chat_log { - let mut buf = Vec::new(); - ciborium::into_writer(&msg, &mut buf).unwrap(); - tx.send(Message::Binary(buf)).await.unwrap(); - } - let tx_idx = state.lock().await.client_sends.insert(tx); - let mut close_code = ws::close_code::NORMAL; - while let Some(msg) = rx.next().await { - if let Ok(msg) = msg { - let msg_bytes = match msg { - Message::Binary(msg) => msg, - _ => { - close_code = ws::close_code::UNSUPPORTED; - warn!("Got unsupported message"); - break; - } - }; - let msg: ChatMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap(); - state.lock().await.chat_log.push(msg); - let mut log_file = OpenOptions::new() - .write(true) - .create(true) - .truncate(true) - .open("chat_log") - .await - .unwrap(); - let mut buf = Vec::new(); - ciborium::into_writer(&state.lock().await.chat_log, &mut buf).unwrap(); - log_file.write_all(&buf).await.unwrap(); - for (i, client_tx) in state.lock().await.client_sends.iter_mut() { - if i == tx_idx { - continue; - } - let _ = client_tx.send(Message::Binary(msg_bytes.clone())).await; - } - } else { - close_code = ws::close_code::PROTOCOL; - warn!("Websocket protocol error"); - break; - }; - } - debug!("Client disconnected"); - let mut tx = state.lock().await.client_sends.remove(tx_idx); - let _ = tx - .send(Message::Close(Some(CloseFrame { - code: close_code, - reason: "".into(), - }))) - .await; - }) + server::run(listener, db_conn, opt.static_dir).await; } diff --git a/server/src/models.rs b/server/src/models.rs new file mode 100644 index 0000000..aeb02a0 --- /dev/null +++ b/server/src/models.rs @@ -0,0 +1,32 @@ +use crate::schema::*; +use chrono::{DateTime, Local}; +use diesel::prelude::*; + +#[derive(Queryable, Selectable, Debug, Clone)] +#[diesel(table_name = users)] +#[diesel(check_for_backend(diesel::sqlite::Sqlite))] +pub struct User { + pub id: i32, + pub username: String, + pub pw_hash: String, + pub display_name: String, +} + +#[derive(Insertable)] +#[diesel(table_name = messages)] +#[diesel(belongs_to(User))] +pub struct NewMessage<'a> { + pub message: &'a str, + pub time: DateTime, + pub user_id: i32, +} + +#[derive(Queryable, Selectable, Debug, Clone)] +#[diesel(table_name = messages)] +#[diesel(check_for_backend(diesel::sqlite::Sqlite))] +pub struct Message { + pub id: i32, + pub message: String, + pub time: DateTime, + pub user_id: i32, +} diff --git a/server/src/schema.rs b/server/src/schema.rs new file mode 100644 index 0000000..9fa3cc6 --- /dev/null +++ b/server/src/schema.rs @@ -0,0 +1,26 @@ +// @generated automatically by Diesel CLI. + +diesel::table! { + messages (id) { + id -> Integer, + message -> Text, + time -> TimestamptzSqlite, + user_id -> Integer, + } +} + +diesel::table! { + users (id) { + id -> Integer, + username -> Text, + pw_hash -> Text, + display_name -> Text, + } +} + +diesel::joinable!(messages -> users (user_id)); + +diesel::allow_tables_to_appear_in_same_query!( + messages, + users, +); diff --git a/server/src/server.rs b/server/src/server.rs new file mode 100644 index 0000000..2c96e40 --- /dev/null +++ b/server/src/server.rs @@ -0,0 +1,231 @@ +use crate::auth::{AuthBackend, AuthSession, UserCreds}; +use axum::body::Body; +use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket}; +use axum::extract::{Request, State, WebSocketUpgrade}; +use axum::http::header::CONTENT_TYPE; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::routing::post; +use axum::Json; +use axum::{routing::get, Router}; +use axum_login::AuthManagerLayerBuilder; +use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse}; +use futures::stream::SplitSink; +use futures::{SinkExt, StreamExt}; +use tokio::net::TcpListener; +use crate::models::{NewMessage, Message as DbMessage}; +use crate::schema::messages; +use slab::Slab; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use time::Duration; +use tokio::fs; +use tokio::sync::Mutex; +use tower::{ServiceBuilder, ServiceExt}; +use tower_http::services::ServeDir; +use tower_http::trace::TraceLayer; +use tower_sessions::cookie::Key; +use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; +use tracing::{debug, warn}; +use diesel::prelude::*; + +struct ServState { + client_sends: Slab>, + db_conn: Arc>, +} + +impl ServState { + fn new(db_conn: Arc>) -> Self { + Self { + client_sends: Slab::new(), + db_conn, + } + } +} + + + +async fn serve_index(static_dir: &Path) -> Response { + let index_path = PathBuf::from(&static_dir).join("index.html"); + let index_content = match fs::read_to_string(index_path).await { + Err(_) => { + return Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::from("index file not found")) + .unwrap() + } + Ok(index_content) => index_content, + }; + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "text/html") + .body(Body::from(index_content)) + .unwrap() +} + +async fn login( + mut auth_session: AuthSession, + Json(req): Json, +) -> Response { + let Some(user) = auth_session + .authenticate(UserCreds { + username: req.username, + password: req.password, + }) + .await + .unwrap() + else { + return Json(LoginResponse(Err( + "Invalid username or password".to_string() + ))) + .into_response(); + }; + if auth_session.login(&user).await.is_err() { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + Json(LoginResponse(Ok(()))).into_response() +} + +async fn logout(mut auth_session: AuthSession) -> impl IntoResponse { + if auth_session.logout().await.is_err() { + return StatusCode::INTERNAL_SERVER_ERROR.into_response(); + } + StatusCode::OK.into_response() +} + +async fn logged_in(auth_session: AuthSession) -> Response { + let resp = LoggedInResponse { + logged_in: auth_session.user.is_some(), + }; + Json(resp).into_response() +} + + +async fn chat_ws( + State(state): State>>, + auth_session: AuthSession, + ws: WebSocketUpgrade, +) -> Response { + ws.on_upgrade(move |socket| async move { + let (mut tx, mut rx) = socket.split(); + if auth_session.user.is_none() { + let _ = tx + .send(WsMessage::Close(Some(CloseFrame { + code: 4000, + reason: "Unauthorized".into(), + }))) + .await; + return; + } + debug!("Client connected"); + let messages: Vec = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap(); + for db_msg in messages { + let api_msg = ApiMessage { + message: db_msg.message, + time: db_msg.time, + }; + let mut buf = Vec::new(); + ciborium::into_writer(&api_msg, &mut buf).unwrap(); + tx.send(WsMessage::Binary(buf)).await.unwrap(); + } + let tx_idx = state.lock().await.client_sends.insert(tx); + let mut close_code = ws::close_code::NORMAL; + while let Some(ws_msg) = rx.next().await { + if let Ok(ws_msg) = ws_msg { + let msg_bytes = match ws_msg { + WsMessage::Binary(msg) => msg, + _ => { + close_code = ws::close_code::UNSUPPORTED; + warn!("Got unsupported message"); + break; + } + }; + let api_msg: ApiMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap(); + let db_msg = NewMessage { + message: &api_msg.message, + time: api_msg.time, + user_id: auth_session.user.as_ref().unwrap().id, + }; + diesel::insert_into(messages::table) + .values(&db_msg) + .execute(&mut *(state.lock().await.db_conn.lock().await)) + .unwrap(); + for (i, client_tx) in state.lock().await.client_sends.iter_mut() { + if i == tx_idx { + continue; + } + let _ = client_tx.send(WsMessage::Binary(msg_bytes.clone())).await; + } + } else { + close_code = ws::close_code::PROTOCOL; + warn!("Websocket protocol error"); + break; + }; + } + debug!("Client disconnected"); + let mut tx = state.lock().await.client_sends.remove(tx_idx); + let _ = tx + .send(WsMessage::Close(Some(CloseFrame { + code: close_code, + reason: "".into(), + }))) + .await; + }) +} + +pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) { + let db_conn = Arc::new(Mutex::new(db_conn)); + + + let session_layer = SessionManagerLayer::new(MemoryStore::default()) + .with_secure(false) + .with_expiry(Expiry::OnInactivity(Duration::days(1))); + + + let auth_backend = AuthBackend::new(db_conn.clone()); + + let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build(); + + let api = Router::new() + //.route_layer(login_required!(DummyAuthBackend)) + .route("/chat_ws", get(chat_ws)) + .route("/login", post(login)) + .route("/logout", post(logout)) + .route("/logged_in", get(logged_in)); + + let app = Router::new() + .nest("/api", api) + .fallback(|req: Request| async move { + if req.uri().path().starts_with("/api") { + return Response::builder() + .status(StatusCode::NOT_FOUND) + .body(Body::from("")) + .unwrap(); + } + match ServeDir::new(&static_dir) + .append_index_html_on_directories(false) + .oneshot(req) + .await + { + Ok(res) => { + if res.status() == StatusCode::NOT_FOUND { + serve_index(&static_dir).await + } else { + res.map(Body::new) + } + } + Err(err) => Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Body::from(format!("error: {err}"))) + .expect("error response"), + } + }) + .with_state(Arc::new(Mutex::new(ServState::new(db_conn)))) + .layer(auth_layer) + .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); + + + axum::serve(listener, app) + .await + .expect("Unable to start server"); +}