Add subcommand to add new user to the DB

This commit is contained in:
pjht 2024-10-18 19:26:45 -05:00
parent 2967284956
commit a1856aaeaa
Signed by: pjht
GPG Key ID: 7B5F6AFBEC7EE78E
9 changed files with 79 additions and 57 deletions

View File

@ -2,6 +2,6 @@
resolver = "2" resolver = "2"
members = [ "common", members = [ "common",
"frontend", "hash_pw", "frontend",
"server", "server",
] ]

View File

@ -1,7 +0,0 @@
[package]
name = "hash_pw"
version = "0.1.0"
edition = "2021"
[dependencies]
argon2 = "0.5.3"

View File

@ -1,22 +0,0 @@
use std::io::{stdin, stdout, Write};
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher,
};
fn main() {
print!("Enter password:");
stdout().flush().unwrap();
let mut pw_buf = String::new();
stdin().read_line(&mut pw_buf).unwrap();
let password = pw_buf.trim_end();
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let pw_hash = argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string();
println!("{pw_hash}");
}

View File

@ -6,4 +6,4 @@ pushd frontend
trunk build trunk build
popd popd
cargo run --bin server --release -- --port 8080 --addr :: --static-dir ./dist cargo run --bin server --release -- --addr :: $@

View File

@ -38,7 +38,6 @@ impl AuthBackend {
#[async_trait] #[async_trait]
impl AuthnBackend for AuthBackend { impl AuthnBackend for AuthBackend {
type User = User; type User = User;
type Credentials = UserCreds; type Credentials = UserCreds;
type Error = diesel::result::Error; type Error = diesel::result::Error;
@ -75,9 +74,7 @@ impl AuthnBackend for AuthBackend {
{ {
Ok(user) => Ok(Some(user)), Ok(user) => Ok(Some(user)),
Err(e) => match e { Err(e) => match e {
NotFound => { NotFound => Ok(None),
Ok(None)
},
_ => Err(e), _ => Err(e),
}, },
} }

View File

@ -1,10 +1,17 @@
mod auth;
mod models; mod models;
mod schema; mod schema;
mod auth;
mod server; mod server;
use clap::Parser; use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher,
};
use clap::{Parser, Subcommand};
use diesel::prelude::*; use diesel::prelude::*;
use models::NewUser;
use schema::users;
use std::io::{self, Write};
use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr; use std::str::FromStr;
@ -28,6 +35,17 @@ struct Opt {
/// set the directory where static files are to be found /// set the directory where static files are to be found
#[clap(long = "static-dir", default_value = "./dist")] #[clap(long = "static-dir", default_value = "./dist")]
static_dir: PathBuf, static_dir: PathBuf,
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Subcommand, Debug)]
enum Command {
AddUser {
username: String,
display_name: Option<String>,
},
} }
#[tokio::main] #[tokio::main]
@ -44,9 +62,47 @@ async fn main() {
tracing_subscriber::fmt::init(); tracing_subscriber::fmt::init();
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let db_conn = SqliteConnection::establish(&database_url) let mut db_conn = SqliteConnection::establish(&database_url)
.unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url)); .unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url));
if let Some(command) = opt.command {
match command {
Command::AddUser {
username,
display_name,
} => {
print!("Enter password for {username}: ");
io::stdout().flush().unwrap();
let mut pw_buf = String::new();
io::stdin().read_line(&mut pw_buf).unwrap();
let password = pw_buf.trim_end();
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string();
let user = NewUser {
username: &username,
pw_hash: &pw_hash,
display_name: display_name.as_deref().unwrap_or(&username),
};
diesel::insert_into(users::table)
.values(&user)
.execute(&mut db_conn)
.unwrap();
}
}
return;
}
if users::table
.count()
.get_result::<i64>(&mut db_conn)
.unwrap()
== 0
{
log::error!("No users are registered. The app is unusable in this state, refusing to start. Please add a user with add-user.");
return;
}
let sock_addr = SocketAddr::from(( let sock_addr = SocketAddr::from((
IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)), IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)),

View File

@ -12,6 +12,14 @@ pub struct User {
pub display_name: String, pub display_name: String,
} }
#[derive(Insertable)]
#[diesel(table_name = users)]
pub struct NewUser<'a> {
pub username: &'a str,
pub pw_hash: &'a str,
pub display_name: &'a str,
}
#[derive(Insertable)] #[derive(Insertable)]
#[diesel(table_name = messages)] #[diesel(table_name = messages)]
#[diesel(belongs_to(User))] #[diesel(belongs_to(User))]

View File

@ -20,7 +20,4 @@ diesel::table! {
diesel::joinable!(messages -> users (user_id)); diesel::joinable!(messages -> users (user_id));
diesel::allow_tables_to_appear_in_same_query!( diesel::allow_tables_to_appear_in_same_query!(messages, users,);
messages,
users,
);

View File

@ -1,4 +1,6 @@
use crate::auth::{AuthBackend, AuthSession, UserCreds}; use crate::auth::{AuthBackend, AuthSession, UserCreds};
use crate::models::{Message as DbMessage, NewMessage};
use crate::schema::messages;
use axum::body::Body; use axum::body::Body;
use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket}; use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket};
use axum::extract::{Request, State, WebSocketUpgrade}; use axum::extract::{Request, State, WebSocketUpgrade};
@ -10,16 +12,15 @@ use axum::Json;
use axum::{routing::get, Router}; use axum::{routing::get, Router};
use axum_login::AuthManagerLayerBuilder; use axum_login::AuthManagerLayerBuilder;
use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse}; use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse};
use diesel::prelude::*;
use futures::stream::SplitSink; use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt}; use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use crate::models::{NewMessage, Message as DbMessage};
use crate::schema::messages;
use slab::Slab; use slab::Slab;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use time::Duration; use time::Duration;
use tokio::fs; use tokio::fs;
use tokio::net::TcpListener;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tower::{ServiceBuilder, ServiceExt}; use tower::{ServiceBuilder, ServiceExt};
use tower_http::services::ServeDir; use tower_http::services::ServeDir;
@ -27,7 +28,6 @@ use tower_http::trace::TraceLayer;
use tower_sessions::cookie::Key; use tower_sessions::cookie::Key;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer}; use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
use tracing::{debug, warn}; use tracing::{debug, warn};
use diesel::prelude::*;
struct ServState { struct ServState {
client_sends: Slab<SplitSink<WebSocket, WsMessage>>, client_sends: Slab<SplitSink<WebSocket, WsMessage>>,
@ -43,8 +43,6 @@ impl ServState {
} }
} }
async fn serve_index(static_dir: &Path) -> Response { async fn serve_index(static_dir: &Path) -> Response {
let index_path = PathBuf::from(&static_dir).join("index.html"); let index_path = PathBuf::from(&static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await { let index_content = match fs::read_to_string(index_path).await {
@ -63,10 +61,7 @@ async fn serve_index(static_dir: &Path) -> Response {
.unwrap() .unwrap()
} }
async fn login( async fn login(mut auth_session: AuthSession, Json(req): Json<LoginRequest>) -> Response {
mut auth_session: AuthSession,
Json(req): Json<LoginRequest>,
) -> Response {
let Some(user) = auth_session let Some(user) = auth_session
.authenticate(UserCreds { .authenticate(UserCreds {
username: req.username, username: req.username,
@ -100,7 +95,6 @@ async fn logged_in(auth_session: AuthSession) -> Response {
Json(resp).into_response() Json(resp).into_response()
} }
async fn chat_ws( async fn chat_ws(
State(state): State<Arc<Mutex<ServState>>>, State(state): State<Arc<Mutex<ServState>>>,
auth_session: AuthSession, auth_session: AuthSession,
@ -118,7 +112,9 @@ async fn chat_ws(
return; return;
} }
debug!("Client connected"); debug!("Client connected");
let messages: Vec<DbMessage> = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap(); let messages: Vec<DbMessage> = messages::table
.load(&mut *(state.lock().await.db_conn.lock().await))
.unwrap();
for db_msg in messages { for db_msg in messages {
let api_msg = ApiMessage { let api_msg = ApiMessage {
message: db_msg.message, message: db_msg.message,
@ -176,12 +172,10 @@ async fn chat_ws(
pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) { pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) {
let db_conn = Arc::new(Mutex::new(db_conn)); let db_conn = Arc::new(Mutex::new(db_conn));
let session_layer = SessionManagerLayer::new(MemoryStore::default()) let session_layer = SessionManagerLayer::new(MemoryStore::default())
.with_secure(false) .with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1))); .with_expiry(Expiry::OnInactivity(Duration::days(1)));
let auth_backend = AuthBackend::new(db_conn.clone()); let auth_backend = AuthBackend::new(db_conn.clone());
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build(); let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
@ -224,7 +218,6 @@ pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: P
.layer(auth_layer) .layer(auth_layer)
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http())); .layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
axum::serve(listener, app) axum::serve(listener, app)
.await .await
.expect("Unable to start server"); .expect("Unable to start server");