From a1856aaeaa0759a77408b02f82246f6fe699123b Mon Sep 17 00:00:00 2001 From: pjht Date: Fri, 18 Oct 2024 19:26:45 -0500 Subject: [PATCH] Add subcommand to add new user to the DB --- Cargo.toml | 2 +- hash_pw/Cargo.toml | 7 ----- hash_pw/src/main.rs | 22 ---------------- prod.sh | 2 +- server/src/auth.rs | 5 +--- server/src/main.rs | 62 +++++++++++++++++++++++++++++++++++++++++--- server/src/models.rs | 8 ++++++ server/src/schema.rs | 5 +--- server/src/server.rs | 23 ++++++---------- 9 files changed, 79 insertions(+), 57 deletions(-) delete mode 100644 hash_pw/Cargo.toml delete mode 100644 hash_pw/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index bcfe235..c1436d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,6 @@ resolver = "2" members = [ "common", - "frontend", "hash_pw", + "frontend", "server", ] diff --git a/hash_pw/Cargo.toml b/hash_pw/Cargo.toml deleted file mode 100644 index 3c6b728..0000000 --- a/hash_pw/Cargo.toml +++ /dev/null @@ -1,7 +0,0 @@ -[package] -name = "hash_pw" -version = "0.1.0" -edition = "2021" - -[dependencies] -argon2 = "0.5.3" diff --git a/hash_pw/src/main.rs b/hash_pw/src/main.rs deleted file mode 100644 index 6566809..0000000 --- a/hash_pw/src/main.rs +++ /dev/null @@ -1,22 +0,0 @@ -use std::io::{stdin, stdout, Write}; - -use argon2::{ - password_hash::{rand_core::OsRng, SaltString}, - Argon2, PasswordHasher, -}; - -fn main() { - print!("Enter password:"); - stdout().flush().unwrap(); - let mut pw_buf = String::new(); - stdin().read_line(&mut pw_buf).unwrap(); - let password = pw_buf.trim_end(); - - let salt = SaltString::generate(&mut OsRng); - let argon2 = Argon2::default(); - let pw_hash = argon2 - .hash_password(password.as_bytes(), &salt) - .unwrap() - .to_string(); - println!("{pw_hash}"); -} diff --git a/prod.sh b/prod.sh index 334aa19..4c37896 100755 --- a/prod.sh +++ b/prod.sh @@ -6,4 +6,4 @@ pushd frontend trunk build popd -cargo run --bin server --release -- --port 8080 --addr :: --static-dir ./dist +cargo run --bin server --release -- --addr :: $@ diff --git a/server/src/auth.rs b/server/src/auth.rs index 714a4ca..3c3901a 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -38,7 +38,6 @@ impl AuthBackend { #[async_trait] impl AuthnBackend for AuthBackend { - type User = User; type Credentials = UserCreds; type Error = diesel::result::Error; @@ -75,9 +74,7 @@ impl AuthnBackend for AuthBackend { { Ok(user) => Ok(Some(user)), Err(e) => match e { - NotFound => { - Ok(None) - }, + NotFound => Ok(None), _ => Err(e), }, } diff --git a/server/src/main.rs b/server/src/main.rs index fbc00db..3ba7f5f 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,10 +1,17 @@ +mod auth; mod models; mod schema; -mod auth; mod server; -use clap::Parser; +use argon2::{ + password_hash::{rand_core::OsRng, SaltString}, + Argon2, PasswordHasher, +}; +use clap::{Parser, Subcommand}; use diesel::prelude::*; +use models::NewUser; +use schema::users; +use std::io::{self, Write}; use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::path::PathBuf; use std::str::FromStr; @@ -28,6 +35,17 @@ struct Opt { /// set the directory where static files are to be found #[clap(long = "static-dir", default_value = "./dist")] static_dir: PathBuf, + + #[command(subcommand)] + command: Option, +} + +#[derive(Subcommand, Debug)] +enum Command { + AddUser { + username: String, + display_name: Option, + }, } #[tokio::main] @@ -44,9 +62,47 @@ async fn main() { tracing_subscriber::fmt::init(); let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - let db_conn = SqliteConnection::establish(&database_url) + let mut db_conn = SqliteConnection::establish(&database_url) .unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url)); + if let Some(command) = opt.command { + match command { + Command::AddUser { + username, + display_name, + } => { + print!("Enter password for {username}: "); + io::stdout().flush().unwrap(); + let mut pw_buf = String::new(); + io::stdin().read_line(&mut pw_buf).unwrap(); + let password = pw_buf.trim_end(); + + let salt = SaltString::generate(&mut OsRng); + let argon2 = Argon2::default(); + let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string(); + let user = NewUser { + username: &username, + pw_hash: &pw_hash, + display_name: display_name.as_deref().unwrap_or(&username), + }; + diesel::insert_into(users::table) + .values(&user) + .execute(&mut db_conn) + .unwrap(); + } + } + return; + } + + if users::table + .count() + .get_result::(&mut db_conn) + .unwrap() + == 0 + { + log::error!("No users are registered. The app is unusable in this state, refusing to start. Please add a user with add-user."); + return; + } let sock_addr = SocketAddr::from(( IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)), diff --git a/server/src/models.rs b/server/src/models.rs index aeb02a0..c2c557b 100644 --- a/server/src/models.rs +++ b/server/src/models.rs @@ -12,6 +12,14 @@ pub struct User { pub display_name: String, } +#[derive(Insertable)] +#[diesel(table_name = users)] +pub struct NewUser<'a> { + pub username: &'a str, + pub pw_hash: &'a str, + pub display_name: &'a str, +} + #[derive(Insertable)] #[diesel(table_name = messages)] #[diesel(belongs_to(User))] diff --git a/server/src/schema.rs b/server/src/schema.rs index 9fa3cc6..96af492 100644 --- a/server/src/schema.rs +++ b/server/src/schema.rs @@ -20,7 +20,4 @@ diesel::table! { diesel::joinable!(messages -> users (user_id)); -diesel::allow_tables_to_appear_in_same_query!( - messages, - users, -); +diesel::allow_tables_to_appear_in_same_query!(messages, users,); diff --git a/server/src/server.rs b/server/src/server.rs index 2c96e40..dd441f0 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,4 +1,6 @@ use crate::auth::{AuthBackend, AuthSession, UserCreds}; +use crate::models::{Message as DbMessage, NewMessage}; +use crate::schema::messages; use axum::body::Body; use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket}; use axum::extract::{Request, State, WebSocketUpgrade}; @@ -10,16 +12,15 @@ use axum::Json; use axum::{routing::get, Router}; use axum_login::AuthManagerLayerBuilder; use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse}; +use diesel::prelude::*; use futures::stream::SplitSink; use futures::{SinkExt, StreamExt}; -use tokio::net::TcpListener; -use crate::models::{NewMessage, Message as DbMessage}; -use crate::schema::messages; use slab::Slab; use std::path::{Path, PathBuf}; use std::sync::Arc; use time::Duration; use tokio::fs; +use tokio::net::TcpListener; use tokio::sync::Mutex; use tower::{ServiceBuilder, ServiceExt}; use tower_http::services::ServeDir; @@ -27,7 +28,6 @@ use tower_http::trace::TraceLayer; use tower_sessions::cookie::Key; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; use tracing::{debug, warn}; -use diesel::prelude::*; struct ServState { client_sends: Slab>, @@ -43,8 +43,6 @@ impl ServState { } } - - async fn serve_index(static_dir: &Path) -> Response { let index_path = PathBuf::from(&static_dir).join("index.html"); let index_content = match fs::read_to_string(index_path).await { @@ -63,10 +61,7 @@ async fn serve_index(static_dir: &Path) -> Response { .unwrap() } -async fn login( - mut auth_session: AuthSession, - Json(req): Json, -) -> Response { +async fn login(mut auth_session: AuthSession, Json(req): Json) -> Response { let Some(user) = auth_session .authenticate(UserCreds { username: req.username, @@ -100,7 +95,6 @@ async fn logged_in(auth_session: AuthSession) -> Response { Json(resp).into_response() } - async fn chat_ws( State(state): State>>, auth_session: AuthSession, @@ -118,7 +112,9 @@ async fn chat_ws( return; } debug!("Client connected"); - let messages: Vec = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap(); + let messages: Vec = messages::table + .load(&mut *(state.lock().await.db_conn.lock().await)) + .unwrap(); for db_msg in messages { let api_msg = ApiMessage { message: db_msg.message, @@ -176,12 +172,10 @@ async fn chat_ws( pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) { let db_conn = Arc::new(Mutex::new(db_conn)); - let session_layer = SessionManagerLayer::new(MemoryStore::default()) .with_secure(false) .with_expiry(Expiry::OnInactivity(Duration::days(1))); - let auth_backend = AuthBackend::new(db_conn.clone()); let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build(); @@ -224,7 +218,6 @@ pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: P .layer(auth_layer) .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); - axum::serve(listener, app) .await .expect("Unable to start server");