Switch to using a Sqlite DB intstead of a text file and some code cleanup

This commit is contained in:
pjht 2024-10-18 16:07:06 -05:00
parent cbd2792ad3
commit 2967284956
Signed by: pjht
GPG Key ID: CA239FC6934E6F3A
18 changed files with 546 additions and 285 deletions

1
.envrc
View File

@ -1 +1,2 @@
export RUSTFMT=yew-fmt export RUSTFMT=yew-fmt
export DATABASE_URL=database.db

2
.gitignore vendored
View File

@ -1,3 +1,3 @@
/target /target
/dist /dist
chat_log database.db

113
Cargo.lock generated
View File

@ -509,6 +509,41 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "darling"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn 2.0.79",
]
[[package]]
name = "darling_macro"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn 2.0.79",
]
[[package]] [[package]]
name = "data-encoding" name = "data-encoding"
version = "2.6.0" version = "2.6.0"
@ -525,6 +560,40 @@ dependencies = [
"serde", "serde",
] ]
[[package]]
name = "diesel"
version = "2.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "158fe8e2e68695bd615d7e4f3227c0727b151330d3e253b525086c348d055d5e"
dependencies = [
"chrono",
"diesel_derives",
"libsqlite3-sys",
"time",
]
[[package]]
name = "diesel_derives"
version = "2.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7f2c3de51e2ba6bf2a648285696137aaf0f5f487bcbea93972fe8a364e131a4"
dependencies = [
"diesel_table_macro_syntax",
"dsl_auto_type",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "diesel_table_macro_syntax"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25"
dependencies = [
"syn 2.0.79",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@ -536,6 +605,26 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "dsl_auto_type"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607"
dependencies = [
"darling",
"either",
"heck",
"proc-macro2",
"quote",
"syn 2.0.79",
]
[[package]]
name = "either"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
@ -1278,6 +1367,12 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]] [[package]]
name = "implicit-clone" name = "implicit-clone"
version = "0.4.9" version = "0.4.9"
@ -1360,6 +1455,16 @@ version = "0.2.159"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
[[package]]
name = "libsqlite3-sys"
version = "0.30.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
dependencies = [
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.12" version = "0.4.12"
@ -1817,9 +1922,11 @@ dependencies = [
"argon2", "argon2",
"axum", "axum",
"axum-login", "axum-login",
"chrono",
"ciborium", "ciborium",
"clap", "clap",
"common", "common",
"diesel",
"futures", "futures",
"log", "log",
"slab", "slab",
@ -2348,6 +2455,12 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.5" version = "0.9.5"

2
dev.sh
View File

@ -4,4 +4,4 @@ IFS=$'\n\t'
(trap 'kill 0' SIGINT; \ (trap 'kill 0' SIGINT; \
bash -c 'cd frontend; trunk serve' & \ bash -c 'cd frontend; trunk serve' & \
bash -c 'cargo watch -i chat_log -- cargo run --bin server -- --port 8081') bash -c 'cargo watch -i database.db -- cargo run --bin server -- --port 8081')

9
diesel.toml Normal file
View File

@ -0,0 +1,9 @@
# For documentation on how to configure this file,
# see https://diesel.rs/guides/configuring-diesel-cli
[print_schema]
file = "server/src/schema.rs"
custom_type_derives = ["diesel::query_builder::QueryId", "Clone"]
[migrations_directory]
dir = "/home/pterpstra/projects/local_chat/migrations"

View File

@ -85,7 +85,10 @@ pub fn Login() -> Html {
Ok(_) => BrowserHistory::new().push(&redirect), Ok(_) => BrowserHistory::new().push(&redirect),
Err(err) => { Err(err) => {
login_error.set(Some(err)); login_error.set(Some(err));
password_ref.cast::<HtmlInputElement>().unwrap().set_value(""); password_ref
.cast::<HtmlInputElement>()
.unwrap()
.set_value("");
disable_input.set(false); disable_input.set(false);
} }
}; };

View File

@ -1,6 +1,9 @@
use std::io::{stdin, stdout, Write}; use std::io::{stdin, stdout, Write};
use argon2::{password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher}; use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher,
};
fn main() { fn main() {
print!("Enter password:"); print!("Enter password:");
@ -11,6 +14,9 @@ fn main() {
let salt = SaltString::generate(&mut OsRng); let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default(); let argon2 = Argon2::default();
let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string(); let pw_hash = argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string();
println!("{pw_hash}"); println!("{pw_hash}");
} }

0
migrations/.keep Normal file
View File

View File

@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE messages;

View File

@ -0,0 +1,8 @@
-- Your SQL goes here
CREATE TABLE messages (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
message VARCHAR NOT NULL,
time VARCHAR NOT NULL,
user_id INTEGER NOT NULL,
FOREIGN KEY (user_id) REFERENCES users (id)
)

View File

@ -0,0 +1,2 @@
-- This file should undo anything in `up.sql`
DROP TABLE users;

View File

@ -0,0 +1,7 @@
-- Your SQL goes here
CREATE TABLE users (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
username VARCHAR UNIQUE NOT NULL,
pw_hash VARCHAR NOT NULL,
display_name VARCHAR NOT NULL
)

View File

@ -7,9 +7,11 @@ edition = "2021"
argon2 = "0.5.3" argon2 = "0.5.3"
axum = { version = "0.7.7", features = ["ws"] } axum = { version = "0.7.7", features = ["ws"] }
axum-login = "0.16.0" axum-login = "0.16.0"
chrono = "0.4.38"
ciborium = "0.2.2" ciborium = "0.2.2"
clap = { version = "4.5.19", features = ["derive"] } clap = { version = "4.5.19", features = ["derive"] }
common = { version = "0.1.0", path = "../common" } common = { version = "0.1.0", path = "../common" }
diesel = { version = "2.2.4", features = ["chrono", "sqlite"] }
futures = "0.3.31" futures = "0.3.31"
log = "0.4.22" log = "0.4.22"
slab = "0.4.9" slab = "0.4.9"

87
server/src/auth.rs Normal file
View File

@ -0,0 +1,87 @@
use std::sync::Arc;
use argon2::{Argon2, PasswordHash, PasswordVerifier};
use axum::async_trait;
use axum_login::{AuthUser, AuthnBackend, UserId};
use diesel::{prelude::*, result::Error::NotFound};
use tokio::sync::Mutex;
use crate::{models::User, schema::users};
pub struct UserCreds {
pub username: String,
pub password: String,
}
impl AuthUser for User {
type Id = i32;
fn id(&self) -> Self::Id {
self.id
}
fn session_auth_hash(&self) -> &[u8] {
self.pw_hash.as_bytes()
}
}
#[derive(Clone)]
pub struct AuthBackend {
db_conn: Arc<Mutex<SqliteConnection>>,
}
impl AuthBackend {
pub fn new(db_conn: Arc<Mutex<SqliteConnection>>) -> Self {
Self { db_conn }
}
}
#[async_trait]
impl AuthnBackend for AuthBackend {
type User = User;
type Credentials = UserCreds;
type Error = diesel::result::Error;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
let user: User = match users::table
.filter(users::username.eq(&creds.username))
.first(&mut *self.db_conn.lock().await)
{
Ok(user) => user,
Err(e) => match e {
NotFound => return Ok(None),
_ => return Err(e),
},
};
let pw_hash = PasswordHash::new(&user.pw_hash).unwrap();
if Argon2::default()
.verify_password(creds.password.as_bytes(), &pw_hash)
.is_ok()
{
Ok(Some(user.clone()))
} else {
Ok(None)
}
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
match users::table
.find(user_id)
.first(&mut *self.db_conn.lock().await)
{
Ok(user) => Ok(Some(user)),
Err(e) => match e {
NotFound => {
Ok(None)
},
_ => Err(e),
},
}
}
}
pub type AuthSession = axum_login::AuthSession<AuthBackend>;

View File

@ -1,34 +1,13 @@
use argon2::{Argon2, PasswordHash, PasswordVerifier}; mod models;
use axum::body::Body; mod schema;
use axum::extract::ws::{self, CloseFrame, Message, WebSocket}; mod auth;
use axum::extract::{Request, State, WebSocketUpgrade}; mod server;
use axum::http::header::CONTENT_TYPE;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{async_trait, Json};
use axum::{routing::get, Router};
use axum_login::{AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, UserId};
use clap::Parser; use clap::Parser;
use common::{ChatMessage, LoggedInResponse, LoginRequest, LoginResponse}; use diesel::prelude::*;
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use slab::Slab;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::path::{Path, PathBuf}; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc;
use time::Duration;
use tokio::fs::{self, OpenOptions};
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tower_sessions::cookie::Key;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
use tracing::{debug, warn};
// Setup the command line interface with clap. // Setup the command line interface with clap.
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -51,132 +30,6 @@ struct Opt {
static_dir: PathBuf, static_dir: PathBuf,
} }
struct ServState {
client_sends: Slab<SplitSink<WebSocket, Message>>,
chat_log: Vec<ChatMessage>,
}
impl ServState {
fn new() -> Self {
let chat_log = std::fs::read("chat_log").map_or(Vec::new(), |chat_log| {
ciborium::from_reader(chat_log.as_slice()).unwrap()
});
Self {
client_sends: Slab::new(),
chat_log,
}
}
}
//const USERS: HashMap<String, User> = {
// let mut map = HashMap::new();
// map.insert("pjht".to_string(), User {
// username: "pjht".to_string(),
// pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$miwLkAOUyJa7NxWX9ueBJQ$w43pjFVBqnqRvWWiPaW2cbhlxk0Dq5sdYjmy4I+Yh+U".to_string(),
// id: 0,
// }).unwrap();
//map
//};
struct UserCreds {
username: String,
password: String,
}
#[derive(Clone, Debug)]
struct User {
username: String,
pw_hash: String,
}
impl AuthUser for User {
type Id = String;
fn id(&self) -> Self::Id {
self.username.clone()
}
fn session_auth_hash(&self) -> &[u8] {
self.pw_hash.as_bytes()
}
}
#[derive(Clone, Debug)]
struct AuthBackend {
users: HashMap<String, User>
}
#[async_trait]
impl AuthnBackend for AuthBackend {
type User = User;
type Credentials = UserCreds;
type Error = std::convert::Infallible;
async fn authenticate(
&self,
creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
Ok((|| {
let user = self.users.get(&creds.username)?;
let pw_hash = PasswordHash::new(&user.pw_hash).unwrap();
if Argon2::default().verify_password(creds.password.as_bytes(), &pw_hash).is_ok() {
Some(user.clone())
} else {
None
}
})())
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
Ok(self.users.get(user_id).cloned())
}
}
async fn serve_index(static_dir: &Path) -> Response {
let index_path = PathBuf::from(&static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await {
Err(_) => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("index file not found"))
.unwrap()
}
Ok(index_content) => index_content,
};
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(index_content))
.unwrap()
}
async fn login(
mut auth_session: AuthSession<AuthBackend>,
Json(req): Json<LoginRequest>,
) -> Response {
let Some(user) = auth_session.authenticate(UserCreds {username: req.username, password: req.password}).await.unwrap() else {
return Json(LoginResponse(Err("Invalid username or password".to_string()))).into_response();
};
if auth_session.login(&user).await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
Json(LoginResponse(Ok(()))).into_response()
}
async fn logout(mut auth_session: AuthSession<AuthBackend>) -> Response {
if auth_session.logout().await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
StatusCode::OK.into_response()
}
async fn logged_in(auth_session: AuthSession<AuthBackend>) -> Response {
let resp = LoggedInResponse {
logged_in: auth_session.user.is_some(),
};
Json(resp).into_response()
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let opt = Opt::parse(); let opt = Opt::parse();
@ -190,60 +43,10 @@ async fn main() {
// enable console logging // enable console logging
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let signing_key = Key::generate(); let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let db_conn = SqliteConnection::establish(&database_url)
.unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url));
let session_layer = SessionManagerLayer::new(MemoryStore::default())
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
.with_signed(signing_key);
let mut users = HashMap::new();
users.insert("pjht".to_string(), User {
username: "pjht".to_string(),
pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$aI7n6SgiTcVhSBjk7pXo+Q$wcJriSwcIj5Al/oNlZJdxMVOA/15e13t2AaWs4VQmVM".to_string(),
});
let auth_backend = AuthBackend {
users
};
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
let api = Router::new()
//.route_layer(login_required!(DummyAuthBackend))
.route("/chat_ws", get(chat_ws))
.route("/login", post(login))
.route("/logout", post(logout))
.route("/logged_in", get(logged_in));
let app = Router::new()
.nest("/api", api)
.fallback(|req: Request| async move {
if req.uri().path().starts_with("/api") {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from(""))
.unwrap();
}
match ServeDir::new(&opt.static_dir)
.append_index_html_on_directories(false)
.oneshot(req)
.await
{
Ok(res) => {
if res.status() == StatusCode::NOT_FOUND {
serve_index(&opt.static_dir).await
} else {
res.map(Body::new)
}
}
Err(err) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("error: {err}")))
.expect("error response"),
}
})
.with_state(Arc::new(Mutex::new(ServState::new())))
.layer(auth_layer)
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
let sock_addr = SocketAddr::from(( let sock_addr = SocketAddr::from((
IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)), IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)),
@ -254,76 +57,5 @@ async fn main() {
log::info!("listening on http://{}", sock_addr); log::info!("listening on http://{}", sock_addr);
axum::serve(listener, app) server::run(listener, db_conn, opt.static_dir).await;
.await
.expect("Unable to start server");
}
async fn chat_ws(
State(state): State<Arc<Mutex<ServState>>>,
auth_session: AuthSession<AuthBackend>,
ws: WebSocketUpgrade,
) -> Response {
ws.on_upgrade(move |socket| async move {
let (mut tx, mut rx) = socket.split();
if auth_session.user.is_none() {
let _ = tx
.send(Message::Close(Some(CloseFrame {
code: 4000,
reason: "Unauthorized".into(),
})))
.await;
return;
}
debug!("Client connected");
for msg in &state.lock().await.chat_log {
let mut buf = Vec::new();
ciborium::into_writer(&msg, &mut buf).unwrap();
tx.send(Message::Binary(buf)).await.unwrap();
}
let tx_idx = state.lock().await.client_sends.insert(tx);
let mut close_code = ws::close_code::NORMAL;
while let Some(msg) = rx.next().await {
if let Ok(msg) = msg {
let msg_bytes = match msg {
Message::Binary(msg) => msg,
_ => {
close_code = ws::close_code::UNSUPPORTED;
warn!("Got unsupported message");
break;
}
};
let msg: ChatMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap();
state.lock().await.chat_log.push(msg);
let mut log_file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open("chat_log")
.await
.unwrap();
let mut buf = Vec::new();
ciborium::into_writer(&state.lock().await.chat_log, &mut buf).unwrap();
log_file.write_all(&buf).await.unwrap();
for (i, client_tx) in state.lock().await.client_sends.iter_mut() {
if i == tx_idx {
continue;
}
let _ = client_tx.send(Message::Binary(msg_bytes.clone())).await;
}
} else {
close_code = ws::close_code::PROTOCOL;
warn!("Websocket protocol error");
break;
};
}
debug!("Client disconnected");
let mut tx = state.lock().await.client_sends.remove(tx_idx);
let _ = tx
.send(Message::Close(Some(CloseFrame {
code: close_code,
reason: "".into(),
})))
.await;
})
} }

32
server/src/models.rs Normal file
View File

@ -0,0 +1,32 @@
use crate::schema::*;
use chrono::{DateTime, Local};
use diesel::prelude::*;
#[derive(Queryable, Selectable, Debug, Clone)]
#[diesel(table_name = users)]
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
pub struct User {
pub id: i32,
pub username: String,
pub pw_hash: String,
pub display_name: String,
}
#[derive(Insertable)]
#[diesel(table_name = messages)]
#[diesel(belongs_to(User))]
pub struct NewMessage<'a> {
pub message: &'a str,
pub time: DateTime<Local>,
pub user_id: i32,
}
#[derive(Queryable, Selectable, Debug, Clone)]
#[diesel(table_name = messages)]
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
pub struct Message {
pub id: i32,
pub message: String,
pub time: DateTime<Local>,
pub user_id: i32,
}

26
server/src/schema.rs Normal file
View File

@ -0,0 +1,26 @@
// @generated automatically by Diesel CLI.
diesel::table! {
messages (id) {
id -> Integer,
message -> Text,
time -> TimestamptzSqlite,
user_id -> Integer,
}
}
diesel::table! {
users (id) {
id -> Integer,
username -> Text,
pw_hash -> Text,
display_name -> Text,
}
}
diesel::joinable!(messages -> users (user_id));
diesel::allow_tables_to_appear_in_same_query!(
messages,
users,
);

231
server/src/server.rs Normal file
View File

@ -0,0 +1,231 @@
use crate::auth::{AuthBackend, AuthSession, UserCreds};
use axum::body::Body;
use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket};
use axum::extract::{Request, State, WebSocketUpgrade};
use axum::http::header::CONTENT_TYPE;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::Json;
use axum::{routing::get, Router};
use axum_login::AuthManagerLayerBuilder;
use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use crate::models::{NewMessage, Message as DbMessage};
use crate::schema::messages;
use slab::Slab;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use time::Duration;
use tokio::fs;
use tokio::sync::Mutex;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tower_sessions::cookie::Key;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
use tracing::{debug, warn};
use diesel::prelude::*;
struct ServState {
client_sends: Slab<SplitSink<WebSocket, WsMessage>>,
db_conn: Arc<Mutex<SqliteConnection>>,
}
impl ServState {
fn new(db_conn: Arc<Mutex<SqliteConnection>>) -> Self {
Self {
client_sends: Slab::new(),
db_conn,
}
}
}
async fn serve_index(static_dir: &Path) -> Response {
let index_path = PathBuf::from(&static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await {
Err(_) => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("index file not found"))
.unwrap()
}
Ok(index_content) => index_content,
};
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(index_content))
.unwrap()
}
async fn login(
mut auth_session: AuthSession,
Json(req): Json<LoginRequest>,
) -> Response {
let Some(user) = auth_session
.authenticate(UserCreds {
username: req.username,
password: req.password,
})
.await
.unwrap()
else {
return Json(LoginResponse(Err(
"Invalid username or password".to_string()
)))
.into_response();
};
if auth_session.login(&user).await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
Json(LoginResponse(Ok(()))).into_response()
}
async fn logout(mut auth_session: AuthSession) -> impl IntoResponse {
if auth_session.logout().await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
StatusCode::OK.into_response()
}
async fn logged_in(auth_session: AuthSession) -> Response {
let resp = LoggedInResponse {
logged_in: auth_session.user.is_some(),
};
Json(resp).into_response()
}
async fn chat_ws(
State(state): State<Arc<Mutex<ServState>>>,
auth_session: AuthSession,
ws: WebSocketUpgrade,
) -> Response {
ws.on_upgrade(move |socket| async move {
let (mut tx, mut rx) = socket.split();
if auth_session.user.is_none() {
let _ = tx
.send(WsMessage::Close(Some(CloseFrame {
code: 4000,
reason: "Unauthorized".into(),
})))
.await;
return;
}
debug!("Client connected");
let messages: Vec<DbMessage> = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap();
for db_msg in messages {
let api_msg = ApiMessage {
message: db_msg.message,
time: db_msg.time,
};
let mut buf = Vec::new();
ciborium::into_writer(&api_msg, &mut buf).unwrap();
tx.send(WsMessage::Binary(buf)).await.unwrap();
}
let tx_idx = state.lock().await.client_sends.insert(tx);
let mut close_code = ws::close_code::NORMAL;
while let Some(ws_msg) = rx.next().await {
if let Ok(ws_msg) = ws_msg {
let msg_bytes = match ws_msg {
WsMessage::Binary(msg) => msg,
_ => {
close_code = ws::close_code::UNSUPPORTED;
warn!("Got unsupported message");
break;
}
};
let api_msg: ApiMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap();
let db_msg = NewMessage {
message: &api_msg.message,
time: api_msg.time,
user_id: auth_session.user.as_ref().unwrap().id,
};
diesel::insert_into(messages::table)
.values(&db_msg)
.execute(&mut *(state.lock().await.db_conn.lock().await))
.unwrap();
for (i, client_tx) in state.lock().await.client_sends.iter_mut() {
if i == tx_idx {
continue;
}
let _ = client_tx.send(WsMessage::Binary(msg_bytes.clone())).await;
}
} else {
close_code = ws::close_code::PROTOCOL;
warn!("Websocket protocol error");
break;
};
}
debug!("Client disconnected");
let mut tx = state.lock().await.client_sends.remove(tx_idx);
let _ = tx
.send(WsMessage::Close(Some(CloseFrame {
code: close_code,
reason: "".into(),
})))
.await;
})
}
pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) {
let db_conn = Arc::new(Mutex::new(db_conn));
let session_layer = SessionManagerLayer::new(MemoryStore::default())
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)));
let auth_backend = AuthBackend::new(db_conn.clone());
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
let api = Router::new()
//.route_layer(login_required!(DummyAuthBackend))
.route("/chat_ws", get(chat_ws))
.route("/login", post(login))
.route("/logout", post(logout))
.route("/logged_in", get(logged_in));
let app = Router::new()
.nest("/api", api)
.fallback(|req: Request| async move {
if req.uri().path().starts_with("/api") {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from(""))
.unwrap();
}
match ServeDir::new(&static_dir)
.append_index_html_on_directories(false)
.oneshot(req)
.await
{
Ok(res) => {
if res.status() == StatusCode::NOT_FOUND {
serve_index(&static_dir).await
} else {
res.map(Body::new)
}
}
Err(err) => Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::from(format!("error: {err}")))
.expect("error response"),
}
})
.with_state(Arc::new(Mutex::new(ServState::new(db_conn))))
.layer(auth_layer)
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
axum::serve(listener, app)
.await
.expect("Unable to start server");
}