Switch to using a Sqlite DB intstead of a text file and some code cleanup
This commit is contained in:
parent
cbd2792ad3
commit
2967284956
2
.gitignore
vendored
2
.gitignore
vendored
@ -1,3 +1,3 @@
|
||||
/target
|
||||
/dist
|
||||
chat_log
|
||||
database.db
|
||||
|
113
Cargo.lock
generated
113
Cargo.lock
generated
@ -509,6 +509,41 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"darling_macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_core"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
|
||||
dependencies = [
|
||||
"fnv",
|
||||
"ident_case",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"strsim",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling_macro"
|
||||
version = "0.20.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
|
||||
dependencies = [
|
||||
"darling_core",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.6.0"
|
||||
@ -525,6 +560,40 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diesel"
|
||||
version = "2.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "158fe8e2e68695bd615d7e4f3227c0727b151330d3e253b525086c348d055d5e"
|
||||
dependencies = [
|
||||
"chrono",
|
||||
"diesel_derives",
|
||||
"libsqlite3-sys",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diesel_derives"
|
||||
version = "2.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7f2c3de51e2ba6bf2a648285696137aaf0f5f487bcbea93972fe8a364e131a4"
|
||||
dependencies = [
|
||||
"diesel_table_macro_syntax",
|
||||
"dsl_auto_type",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "diesel_table_macro_syntax"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25"
|
||||
dependencies = [
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.7"
|
||||
@ -536,6 +605,26 @@ dependencies = [
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dsl_auto_type"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c5d9abe6314103864cc2d8901b7ae224e0ab1a103a0a416661b4097b0779b607"
|
||||
dependencies = [
|
||||
"darling",
|
||||
"either",
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.79",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.1"
|
||||
@ -1278,6 +1367,12 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ident_case"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
|
||||
|
||||
[[package]]
|
||||
name = "implicit-clone"
|
||||
version = "0.4.9"
|
||||
@ -1360,6 +1455,16 @@ version = "0.2.159"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5"
|
||||
|
||||
[[package]]
|
||||
name = "libsqlite3-sys"
|
||||
version = "0.30.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149"
|
||||
dependencies = [
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.12"
|
||||
@ -1817,9 +1922,11 @@ dependencies = [
|
||||
"argon2",
|
||||
"axum",
|
||||
"axum-login",
|
||||
"chrono",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"common",
|
||||
"diesel",
|
||||
"futures",
|
||||
"log",
|
||||
"slab",
|
||||
@ -2348,6 +2455,12 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d"
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.5"
|
||||
|
2
dev.sh
2
dev.sh
@ -4,4 +4,4 @@ IFS=$'\n\t'
|
||||
|
||||
(trap 'kill 0' SIGINT; \
|
||||
bash -c 'cd frontend; trunk serve' & \
|
||||
bash -c 'cargo watch -i chat_log -- cargo run --bin server -- --port 8081')
|
||||
bash -c 'cargo watch -i database.db -- cargo run --bin server -- --port 8081')
|
||||
|
9
diesel.toml
Normal file
9
diesel.toml
Normal file
@ -0,0 +1,9 @@
|
||||
# For documentation on how to configure this file,
|
||||
# see https://diesel.rs/guides/configuring-diesel-cli
|
||||
|
||||
[print_schema]
|
||||
file = "server/src/schema.rs"
|
||||
custom_type_derives = ["diesel::query_builder::QueryId", "Clone"]
|
||||
|
||||
[migrations_directory]
|
||||
dir = "/home/pterpstra/projects/local_chat/migrations"
|
@ -85,7 +85,10 @@ pub fn Login() -> Html {
|
||||
Ok(_) => BrowserHistory::new().push(&redirect),
|
||||
Err(err) => {
|
||||
login_error.set(Some(err));
|
||||
password_ref.cast::<HtmlInputElement>().unwrap().set_value("");
|
||||
password_ref
|
||||
.cast::<HtmlInputElement>()
|
||||
.unwrap()
|
||||
.set_value("");
|
||||
disable_input.set(false);
|
||||
}
|
||||
};
|
||||
|
@ -1,6 +1,9 @@
|
||||
use std::io::{stdin, stdout, Write};
|
||||
|
||||
use argon2::{password_hash::{rand_core::OsRng, SaltString}, Argon2, PasswordHasher};
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, SaltString},
|
||||
Argon2, PasswordHasher,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
print!("Enter password:");
|
||||
@ -11,6 +14,9 @@ fn main() {
|
||||
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string();
|
||||
let pw_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.unwrap()
|
||||
.to_string();
|
||||
println!("{pw_hash}");
|
||||
}
|
||||
|
0
migrations/.keep
Normal file
0
migrations/.keep
Normal file
2
migrations/2024-10-18-181813_create_messages/down.sql
Normal file
2
migrations/2024-10-18-181813_create_messages/down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- This file should undo anything in `up.sql`
|
||||
DROP TABLE messages;
|
8
migrations/2024-10-18-181813_create_messages/up.sql
Normal file
8
migrations/2024-10-18-181813_create_messages/up.sql
Normal file
@ -0,0 +1,8 @@
|
||||
-- Your SQL goes here
|
||||
CREATE TABLE messages (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
message VARCHAR NOT NULL,
|
||||
time VARCHAR NOT NULL,
|
||||
user_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id)
|
||||
)
|
2
migrations/2024-10-18-182242_create_users/down.sql
Normal file
2
migrations/2024-10-18-182242_create_users/down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
-- This file should undo anything in `up.sql`
|
||||
DROP TABLE users;
|
7
migrations/2024-10-18-182242_create_users/up.sql
Normal file
7
migrations/2024-10-18-182242_create_users/up.sql
Normal file
@ -0,0 +1,7 @@
|
||||
-- Your SQL goes here
|
||||
CREATE TABLE users (
|
||||
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
|
||||
username VARCHAR UNIQUE NOT NULL,
|
||||
pw_hash VARCHAR NOT NULL,
|
||||
display_name VARCHAR NOT NULL
|
||||
)
|
@ -7,9 +7,11 @@ edition = "2021"
|
||||
argon2 = "0.5.3"
|
||||
axum = { version = "0.7.7", features = ["ws"] }
|
||||
axum-login = "0.16.0"
|
||||
chrono = "0.4.38"
|
||||
ciborium = "0.2.2"
|
||||
clap = { version = "4.5.19", features = ["derive"] }
|
||||
common = { version = "0.1.0", path = "../common" }
|
||||
diesel = { version = "2.2.4", features = ["chrono", "sqlite"] }
|
||||
futures = "0.3.31"
|
||||
log = "0.4.22"
|
||||
slab = "0.4.9"
|
||||
|
87
server/src/auth.rs
Normal file
87
server/src/auth.rs
Normal file
@ -0,0 +1,87 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use argon2::{Argon2, PasswordHash, PasswordVerifier};
|
||||
use axum::async_trait;
|
||||
use axum_login::{AuthUser, AuthnBackend, UserId};
|
||||
use diesel::{prelude::*, result::Error::NotFound};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{models::User, schema::users};
|
||||
|
||||
pub struct UserCreds {
|
||||
pub username: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
impl AuthUser for User {
|
||||
type Id = i32;
|
||||
|
||||
fn id(&self) -> Self::Id {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn session_auth_hash(&self) -> &[u8] {
|
||||
self.pw_hash.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthBackend {
|
||||
db_conn: Arc<Mutex<SqliteConnection>>,
|
||||
}
|
||||
|
||||
impl AuthBackend {
|
||||
pub fn new(db_conn: Arc<Mutex<SqliteConnection>>) -> Self {
|
||||
Self { db_conn }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthnBackend for AuthBackend {
|
||||
|
||||
type User = User;
|
||||
type Credentials = UserCreds;
|
||||
type Error = diesel::result::Error;
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
creds: Self::Credentials,
|
||||
) -> Result<Option<Self::User>, Self::Error> {
|
||||
let user: User = match users::table
|
||||
.filter(users::username.eq(&creds.username))
|
||||
.first(&mut *self.db_conn.lock().await)
|
||||
{
|
||||
Ok(user) => user,
|
||||
Err(e) => match e {
|
||||
NotFound => return Ok(None),
|
||||
_ => return Err(e),
|
||||
},
|
||||
};
|
||||
let pw_hash = PasswordHash::new(&user.pw_hash).unwrap();
|
||||
if Argon2::default()
|
||||
.verify_password(creds.password.as_bytes(), &pw_hash)
|
||||
.is_ok()
|
||||
{
|
||||
Ok(Some(user.clone()))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
||||
match users::table
|
||||
.find(user_id)
|
||||
.first(&mut *self.db_conn.lock().await)
|
||||
{
|
||||
Ok(user) => Ok(Some(user)),
|
||||
Err(e) => match e {
|
||||
NotFound => {
|
||||
Ok(None)
|
||||
},
|
||||
_ => Err(e),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type AuthSession = axum_login::AuthSession<AuthBackend>;
|
@ -1,34 +1,13 @@
|
||||
use argon2::{Argon2, PasswordHash, PasswordVerifier};
|
||||
use axum::body::Body;
|
||||
use axum::extract::ws::{self, CloseFrame, Message, WebSocket};
|
||||
use axum::extract::{Request, State, WebSocketUpgrade};
|
||||
use axum::http::header::CONTENT_TYPE;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::post;
|
||||
use axum::{async_trait, Json};
|
||||
use axum::{routing::get, Router};
|
||||
use axum_login::{AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, UserId};
|
||||
mod models;
|
||||
mod schema;
|
||||
mod auth;
|
||||
mod server;
|
||||
|
||||
use clap::Parser;
|
||||
use common::{ChatMessage, LoggedInResponse, LoginRequest, LoginResponse};
|
||||
use futures::stream::SplitSink;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use slab::Slab;
|
||||
use std::collections::HashMap;
|
||||
use diesel::prelude::*;
|
||||
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use time::Duration;
|
||||
use tokio::fs::{self, OpenOptions};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{ServiceBuilder, ServiceExt};
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_sessions::cookie::Key;
|
||||
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
// Setup the command line interface with clap.
|
||||
#[derive(Parser, Debug)]
|
||||
@ -51,132 +30,6 @@ struct Opt {
|
||||
static_dir: PathBuf,
|
||||
}
|
||||
|
||||
struct ServState {
|
||||
client_sends: Slab<SplitSink<WebSocket, Message>>,
|
||||
chat_log: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
impl ServState {
|
||||
fn new() -> Self {
|
||||
let chat_log = std::fs::read("chat_log").map_or(Vec::new(), |chat_log| {
|
||||
ciborium::from_reader(chat_log.as_slice()).unwrap()
|
||||
});
|
||||
Self {
|
||||
client_sends: Slab::new(),
|
||||
chat_log,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//const USERS: HashMap<String, User> = {
|
||||
// let mut map = HashMap::new();
|
||||
// map.insert("pjht".to_string(), User {
|
||||
// username: "pjht".to_string(),
|
||||
// pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$miwLkAOUyJa7NxWX9ueBJQ$w43pjFVBqnqRvWWiPaW2cbhlxk0Dq5sdYjmy4I+Yh+U".to_string(),
|
||||
// id: 0,
|
||||
// }).unwrap();
|
||||
//map
|
||||
//};
|
||||
|
||||
struct UserCreds {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct User {
|
||||
username: String,
|
||||
pw_hash: String,
|
||||
}
|
||||
|
||||
impl AuthUser for User {
|
||||
type Id = String;
|
||||
|
||||
fn id(&self) -> Self::Id {
|
||||
self.username.clone()
|
||||
}
|
||||
|
||||
fn session_auth_hash(&self) -> &[u8] {
|
||||
self.pw_hash.as_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct AuthBackend {
|
||||
users: HashMap<String, User>
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AuthnBackend for AuthBackend {
|
||||
type User = User;
|
||||
type Credentials = UserCreds;
|
||||
type Error = std::convert::Infallible;
|
||||
|
||||
async fn authenticate(
|
||||
&self,
|
||||
creds: Self::Credentials,
|
||||
) -> Result<Option<Self::User>, Self::Error> {
|
||||
Ok((|| {
|
||||
let user = self.users.get(&creds.username)?;
|
||||
let pw_hash = PasswordHash::new(&user.pw_hash).unwrap();
|
||||
if Argon2::default().verify_password(creds.password.as_bytes(), &pw_hash).is_ok() {
|
||||
Some(user.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})())
|
||||
}
|
||||
|
||||
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
||||
Ok(self.users.get(user_id).cloned())
|
||||
}
|
||||
}
|
||||
|
||||
async fn serve_index(static_dir: &Path) -> Response {
|
||||
let index_path = PathBuf::from(&static_dir).join("index.html");
|
||||
let index_content = match fs::read_to_string(index_path).await {
|
||||
Err(_) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from("index file not found"))
|
||||
.unwrap()
|
||||
}
|
||||
Ok(index_content) => index_content,
|
||||
};
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "text/html")
|
||||
.body(Body::from(index_content))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn login(
|
||||
mut auth_session: AuthSession<AuthBackend>,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Response {
|
||||
let Some(user) = auth_session.authenticate(UserCreds {username: req.username, password: req.password}).await.unwrap() else {
|
||||
return Json(LoginResponse(Err("Invalid username or password".to_string()))).into_response();
|
||||
};
|
||||
if auth_session.login(&user).await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
Json(LoginResponse(Ok(()))).into_response()
|
||||
}
|
||||
|
||||
async fn logout(mut auth_session: AuthSession<AuthBackend>) -> Response {
|
||||
if auth_session.logout().await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
|
||||
async fn logged_in(auth_session: AuthSession<AuthBackend>) -> Response {
|
||||
let resp = LoggedInResponse {
|
||||
logged_in: auth_session.user.is_some(),
|
||||
};
|
||||
Json(resp).into_response()
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
let opt = Opt::parse();
|
||||
@ -190,60 +43,10 @@ async fn main() {
|
||||
// enable console logging
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let signing_key = Key::generate();
|
||||
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||
let db_conn = SqliteConnection::establish(&database_url)
|
||||
.unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url));
|
||||
|
||||
let session_layer = SessionManagerLayer::new(MemoryStore::default())
|
||||
.with_secure(false)
|
||||
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
|
||||
.with_signed(signing_key);
|
||||
|
||||
let mut users = HashMap::new();
|
||||
users.insert("pjht".to_string(), User {
|
||||
username: "pjht".to_string(),
|
||||
pw_hash: "$argon2id$v=19$m=19456,t=2,p=1$aI7n6SgiTcVhSBjk7pXo+Q$wcJriSwcIj5Al/oNlZJdxMVOA/15e13t2AaWs4VQmVM".to_string(),
|
||||
});
|
||||
let auth_backend = AuthBackend {
|
||||
users
|
||||
};
|
||||
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
|
||||
|
||||
let api = Router::new()
|
||||
//.route_layer(login_required!(DummyAuthBackend))
|
||||
.route("/chat_ws", get(chat_ws))
|
||||
.route("/login", post(login))
|
||||
.route("/logout", post(logout))
|
||||
.route("/logged_in", get(logged_in));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", api)
|
||||
.fallback(|req: Request| async move {
|
||||
if req.uri().path().starts_with("/api") {
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from(""))
|
||||
.unwrap();
|
||||
}
|
||||
match ServeDir::new(&opt.static_dir)
|
||||
.append_index_html_on_directories(false)
|
||||
.oneshot(req)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if res.status() == StatusCode::NOT_FOUND {
|
||||
serve_index(&opt.static_dir).await
|
||||
} else {
|
||||
res.map(Body::new)
|
||||
}
|
||||
}
|
||||
Err(err) => Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from(format!("error: {err}")))
|
||||
.expect("error response"),
|
||||
}
|
||||
})
|
||||
.with_state(Arc::new(Mutex::new(ServState::new())))
|
||||
.layer(auth_layer)
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
|
||||
|
||||
let sock_addr = SocketAddr::from((
|
||||
IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)),
|
||||
@ -254,76 +57,5 @@ async fn main() {
|
||||
|
||||
log::info!("listening on http://{}", sock_addr);
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Unable to start server");
|
||||
}
|
||||
|
||||
async fn chat_ws(
|
||||
State(state): State<Arc<Mutex<ServState>>>,
|
||||
auth_session: AuthSession<AuthBackend>,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
let (mut tx, mut rx) = socket.split();
|
||||
if auth_session.user.is_none() {
|
||||
let _ = tx
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: 4000,
|
||||
reason: "Unauthorized".into(),
|
||||
})))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
debug!("Client connected");
|
||||
for msg in &state.lock().await.chat_log {
|
||||
let mut buf = Vec::new();
|
||||
ciborium::into_writer(&msg, &mut buf).unwrap();
|
||||
tx.send(Message::Binary(buf)).await.unwrap();
|
||||
}
|
||||
let tx_idx = state.lock().await.client_sends.insert(tx);
|
||||
let mut close_code = ws::close_code::NORMAL;
|
||||
while let Some(msg) = rx.next().await {
|
||||
if let Ok(msg) = msg {
|
||||
let msg_bytes = match msg {
|
||||
Message::Binary(msg) => msg,
|
||||
_ => {
|
||||
close_code = ws::close_code::UNSUPPORTED;
|
||||
warn!("Got unsupported message");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let msg: ChatMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap();
|
||||
state.lock().await.chat_log.push(msg);
|
||||
let mut log_file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open("chat_log")
|
||||
.await
|
||||
.unwrap();
|
||||
let mut buf = Vec::new();
|
||||
ciborium::into_writer(&state.lock().await.chat_log, &mut buf).unwrap();
|
||||
log_file.write_all(&buf).await.unwrap();
|
||||
for (i, client_tx) in state.lock().await.client_sends.iter_mut() {
|
||||
if i == tx_idx {
|
||||
continue;
|
||||
}
|
||||
let _ = client_tx.send(Message::Binary(msg_bytes.clone())).await;
|
||||
}
|
||||
} else {
|
||||
close_code = ws::close_code::PROTOCOL;
|
||||
warn!("Websocket protocol error");
|
||||
break;
|
||||
};
|
||||
}
|
||||
debug!("Client disconnected");
|
||||
let mut tx = state.lock().await.client_sends.remove(tx_idx);
|
||||
let _ = tx
|
||||
.send(Message::Close(Some(CloseFrame {
|
||||
code: close_code,
|
||||
reason: "".into(),
|
||||
})))
|
||||
.await;
|
||||
})
|
||||
server::run(listener, db_conn, opt.static_dir).await;
|
||||
}
|
||||
|
32
server/src/models.rs
Normal file
32
server/src/models.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use crate::schema::*;
|
||||
use chrono::{DateTime, Local};
|
||||
use diesel::prelude::*;
|
||||
|
||||
#[derive(Queryable, Selectable, Debug, Clone)]
|
||||
#[diesel(table_name = users)]
|
||||
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||
pub struct User {
|
||||
pub id: i32,
|
||||
pub username: String,
|
||||
pub pw_hash: String,
|
||||
pub display_name: String,
|
||||
}
|
||||
|
||||
#[derive(Insertable)]
|
||||
#[diesel(table_name = messages)]
|
||||
#[diesel(belongs_to(User))]
|
||||
pub struct NewMessage<'a> {
|
||||
pub message: &'a str,
|
||||
pub time: DateTime<Local>,
|
||||
pub user_id: i32,
|
||||
}
|
||||
|
||||
#[derive(Queryable, Selectable, Debug, Clone)]
|
||||
#[diesel(table_name = messages)]
|
||||
#[diesel(check_for_backend(diesel::sqlite::Sqlite))]
|
||||
pub struct Message {
|
||||
pub id: i32,
|
||||
pub message: String,
|
||||
pub time: DateTime<Local>,
|
||||
pub user_id: i32,
|
||||
}
|
26
server/src/schema.rs
Normal file
26
server/src/schema.rs
Normal file
@ -0,0 +1,26 @@
|
||||
// @generated automatically by Diesel CLI.
|
||||
|
||||
diesel::table! {
|
||||
messages (id) {
|
||||
id -> Integer,
|
||||
message -> Text,
|
||||
time -> TimestamptzSqlite,
|
||||
user_id -> Integer,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
users (id) {
|
||||
id -> Integer,
|
||||
username -> Text,
|
||||
pw_hash -> Text,
|
||||
display_name -> Text,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::joinable!(messages -> users (user_id));
|
||||
|
||||
diesel::allow_tables_to_appear_in_same_query!(
|
||||
messages,
|
||||
users,
|
||||
);
|
231
server/src/server.rs
Normal file
231
server/src/server.rs
Normal file
@ -0,0 +1,231 @@
|
||||
use crate::auth::{AuthBackend, AuthSession, UserCreds};
|
||||
use axum::body::Body;
|
||||
use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket};
|
||||
use axum::extract::{Request, State, WebSocketUpgrade};
|
||||
use axum::http::header::CONTENT_TYPE;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::{IntoResponse, Response};
|
||||
use axum::routing::post;
|
||||
use axum::Json;
|
||||
use axum::{routing::get, Router};
|
||||
use axum_login::AuthManagerLayerBuilder;
|
||||
use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse};
|
||||
use futures::stream::SplitSink;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::net::TcpListener;
|
||||
use crate::models::{NewMessage, Message as DbMessage};
|
||||
use crate::schema::messages;
|
||||
use slab::Slab;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use time::Duration;
|
||||
use tokio::fs;
|
||||
use tokio::sync::Mutex;
|
||||
use tower::{ServiceBuilder, ServiceExt};
|
||||
use tower_http::services::ServeDir;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tower_sessions::cookie::Key;
|
||||
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
|
||||
use tracing::{debug, warn};
|
||||
use diesel::prelude::*;
|
||||
|
||||
struct ServState {
|
||||
client_sends: Slab<SplitSink<WebSocket, WsMessage>>,
|
||||
db_conn: Arc<Mutex<SqliteConnection>>,
|
||||
}
|
||||
|
||||
impl ServState {
|
||||
fn new(db_conn: Arc<Mutex<SqliteConnection>>) -> Self {
|
||||
Self {
|
||||
client_sends: Slab::new(),
|
||||
db_conn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
async fn serve_index(static_dir: &Path) -> Response {
|
||||
let index_path = PathBuf::from(&static_dir).join("index.html");
|
||||
let index_content = match fs::read_to_string(index_path).await {
|
||||
Err(_) => {
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from("index file not found"))
|
||||
.unwrap()
|
||||
}
|
||||
Ok(index_content) => index_content,
|
||||
};
|
||||
Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header(CONTENT_TYPE, "text/html")
|
||||
.body(Body::from(index_content))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
async fn login(
|
||||
mut auth_session: AuthSession,
|
||||
Json(req): Json<LoginRequest>,
|
||||
) -> Response {
|
||||
let Some(user) = auth_session
|
||||
.authenticate(UserCreds {
|
||||
username: req.username,
|
||||
password: req.password,
|
||||
})
|
||||
.await
|
||||
.unwrap()
|
||||
else {
|
||||
return Json(LoginResponse(Err(
|
||||
"Invalid username or password".to_string()
|
||||
)))
|
||||
.into_response();
|
||||
};
|
||||
if auth_session.login(&user).await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
Json(LoginResponse(Ok(()))).into_response()
|
||||
}
|
||||
|
||||
async fn logout(mut auth_session: AuthSession) -> impl IntoResponse {
|
||||
if auth_session.logout().await.is_err() {
|
||||
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||
}
|
||||
StatusCode::OK.into_response()
|
||||
}
|
||||
|
||||
async fn logged_in(auth_session: AuthSession) -> Response {
|
||||
let resp = LoggedInResponse {
|
||||
logged_in: auth_session.user.is_some(),
|
||||
};
|
||||
Json(resp).into_response()
|
||||
}
|
||||
|
||||
|
||||
async fn chat_ws(
|
||||
State(state): State<Arc<Mutex<ServState>>>,
|
||||
auth_session: AuthSession,
|
||||
ws: WebSocketUpgrade,
|
||||
) -> Response {
|
||||
ws.on_upgrade(move |socket| async move {
|
||||
let (mut tx, mut rx) = socket.split();
|
||||
if auth_session.user.is_none() {
|
||||
let _ = tx
|
||||
.send(WsMessage::Close(Some(CloseFrame {
|
||||
code: 4000,
|
||||
reason: "Unauthorized".into(),
|
||||
})))
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
debug!("Client connected");
|
||||
let messages: Vec<DbMessage> = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap();
|
||||
for db_msg in messages {
|
||||
let api_msg = ApiMessage {
|
||||
message: db_msg.message,
|
||||
time: db_msg.time,
|
||||
};
|
||||
let mut buf = Vec::new();
|
||||
ciborium::into_writer(&api_msg, &mut buf).unwrap();
|
||||
tx.send(WsMessage::Binary(buf)).await.unwrap();
|
||||
}
|
||||
let tx_idx = state.lock().await.client_sends.insert(tx);
|
||||
let mut close_code = ws::close_code::NORMAL;
|
||||
while let Some(ws_msg) = rx.next().await {
|
||||
if let Ok(ws_msg) = ws_msg {
|
||||
let msg_bytes = match ws_msg {
|
||||
WsMessage::Binary(msg) => msg,
|
||||
_ => {
|
||||
close_code = ws::close_code::UNSUPPORTED;
|
||||
warn!("Got unsupported message");
|
||||
break;
|
||||
}
|
||||
};
|
||||
let api_msg: ApiMessage = ciborium::from_reader(msg_bytes.as_slice()).unwrap();
|
||||
let db_msg = NewMessage {
|
||||
message: &api_msg.message,
|
||||
time: api_msg.time,
|
||||
user_id: auth_session.user.as_ref().unwrap().id,
|
||||
};
|
||||
diesel::insert_into(messages::table)
|
||||
.values(&db_msg)
|
||||
.execute(&mut *(state.lock().await.db_conn.lock().await))
|
||||
.unwrap();
|
||||
for (i, client_tx) in state.lock().await.client_sends.iter_mut() {
|
||||
if i == tx_idx {
|
||||
continue;
|
||||
}
|
||||
let _ = client_tx.send(WsMessage::Binary(msg_bytes.clone())).await;
|
||||
}
|
||||
} else {
|
||||
close_code = ws::close_code::PROTOCOL;
|
||||
warn!("Websocket protocol error");
|
||||
break;
|
||||
};
|
||||
}
|
||||
debug!("Client disconnected");
|
||||
let mut tx = state.lock().await.client_sends.remove(tx_idx);
|
||||
let _ = tx
|
||||
.send(WsMessage::Close(Some(CloseFrame {
|
||||
code: close_code,
|
||||
reason: "".into(),
|
||||
})))
|
||||
.await;
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) {
|
||||
let db_conn = Arc::new(Mutex::new(db_conn));
|
||||
|
||||
|
||||
let session_layer = SessionManagerLayer::new(MemoryStore::default())
|
||||
.with_secure(false)
|
||||
.with_expiry(Expiry::OnInactivity(Duration::days(1)));
|
||||
|
||||
|
||||
let auth_backend = AuthBackend::new(db_conn.clone());
|
||||
|
||||
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
|
||||
|
||||
let api = Router::new()
|
||||
//.route_layer(login_required!(DummyAuthBackend))
|
||||
.route("/chat_ws", get(chat_ws))
|
||||
.route("/login", post(login))
|
||||
.route("/logout", post(logout))
|
||||
.route("/logged_in", get(logged_in));
|
||||
|
||||
let app = Router::new()
|
||||
.nest("/api", api)
|
||||
.fallback(|req: Request| async move {
|
||||
if req.uri().path().starts_with("/api") {
|
||||
return Response::builder()
|
||||
.status(StatusCode::NOT_FOUND)
|
||||
.body(Body::from(""))
|
||||
.unwrap();
|
||||
}
|
||||
match ServeDir::new(&static_dir)
|
||||
.append_index_html_on_directories(false)
|
||||
.oneshot(req)
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if res.status() == StatusCode::NOT_FOUND {
|
||||
serve_index(&static_dir).await
|
||||
} else {
|
||||
res.map(Body::new)
|
||||
}
|
||||
}
|
||||
Err(err) => Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::from(format!("error: {err}")))
|
||||
.expect("error response"),
|
||||
}
|
||||
})
|
||||
.with_state(Arc::new(Mutex::new(ServState::new(db_conn))))
|
||||
.layer(auth_layer)
|
||||
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
|
||||
|
||||
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Unable to start server");
|
||||
}
|
Loading…
Reference in New Issue
Block a user