Add subcommand to add new user to the DB

This commit is contained in:
pjht 2024-10-18 19:26:45 -05:00
parent 2967284956
commit 2429fa269d
Signed by: pjht
GPG Key ID: 7B5F6AFBEC7EE78E
8 changed files with 78 additions and 56 deletions

View File

@ -1,7 +0,0 @@
[package]
name = "hash_pw"
version = "0.1.0"
edition = "2021"
[dependencies]
argon2 = "0.5.3"

View File

@ -1,22 +0,0 @@
use std::io::{stdin, stdout, Write};
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher,
};
fn main() {
print!("Enter password:");
stdout().flush().unwrap();
let mut pw_buf = String::new();
stdin().read_line(&mut pw_buf).unwrap();
let password = pw_buf.trim_end();
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let pw_hash = argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string();
println!("{pw_hash}");
}

View File

@ -6,4 +6,4 @@ pushd frontend
trunk build
popd
cargo run --bin server --release -- --port 8080 --addr :: --static-dir ./dist
cargo run --bin server --release -- --addr :: $@

View File

@ -38,7 +38,6 @@ impl AuthBackend {
#[async_trait]
impl AuthnBackend for AuthBackend {
type User = User;
type Credentials = UserCreds;
type Error = diesel::result::Error;
@ -75,9 +74,7 @@ impl AuthnBackend for AuthBackend {
{
Ok(user) => Ok(Some(user)),
Err(e) => match e {
NotFound => {
Ok(None)
},
NotFound => Ok(None),
_ => Err(e),
},
}

View File

@ -1,10 +1,17 @@
mod auth;
mod models;
mod schema;
mod auth;
mod server;
use clap::Parser;
use argon2::{
password_hash::{rand_core::OsRng, SaltString},
Argon2, PasswordHasher,
};
use clap::{Parser, Subcommand};
use diesel::prelude::*;
use models::NewUser;
use schema::users;
use std::io::{self, Write};
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::path::PathBuf;
use std::str::FromStr;
@ -28,6 +35,17 @@ struct Opt {
/// set the directory where static files are to be found
#[clap(long = "static-dir", default_value = "./dist")]
static_dir: PathBuf,
#[command(subcommand)]
command: Option<Command>,
}
#[derive(Subcommand, Debug)]
enum Command {
AddUser {
username: String,
display_name: Option<String>,
},
}
#[tokio::main]
@ -44,9 +62,47 @@ async fn main() {
tracing_subscriber::fmt::init();
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let db_conn = SqliteConnection::establish(&database_url)
let mut db_conn = SqliteConnection::establish(&database_url)
.unwrap_or_else(|_| panic!("Error connecting to database at {}", database_url));
if let Some(command) = opt.command {
match command {
Command::AddUser {
username,
display_name,
} => {
print!("Enter password for {username}: ");
io::stdout().flush().unwrap();
let mut pw_buf = String::new();
io::stdin().read_line(&mut pw_buf).unwrap();
let password = pw_buf.trim_end();
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let pw_hash = argon2.hash_password(password.as_bytes(), &salt).unwrap().to_string();
let user = NewUser {
username: &username,
pw_hash: &pw_hash,
display_name: display_name.as_deref().unwrap_or(&username),
};
diesel::insert_into(users::table)
.values(&user)
.execute(&mut db_conn)
.unwrap();
}
}
return;
}
if users::table
.count()
.get_result::<i64>(&mut db_conn)
.unwrap()
== 0
{
log::error!("No users are registered. The app is unusable in this state, refusing to start. Please add a user.");
return;
}
let sock_addr = SocketAddr::from((
IpAddr::from_str(opt.addr.as_str()).unwrap_or(IpAddr::V6(Ipv6Addr::LOCALHOST)),

View File

@ -12,6 +12,14 @@ pub struct User {
pub display_name: String,
}
#[derive(Insertable)]
#[diesel(table_name = users)]
pub struct NewUser<'a> {
pub username: &'a str,
pub pw_hash: &'a str,
pub display_name: &'a str,
}
#[derive(Insertable)]
#[diesel(table_name = messages)]
#[diesel(belongs_to(User))]

View File

@ -20,7 +20,4 @@ diesel::table! {
diesel::joinable!(messages -> users (user_id));
diesel::allow_tables_to_appear_in_same_query!(
messages,
users,
);
diesel::allow_tables_to_appear_in_same_query!(messages, users,);

View File

@ -1,4 +1,6 @@
use crate::auth::{AuthBackend, AuthSession, UserCreds};
use crate::models::{Message as DbMessage, NewMessage};
use crate::schema::messages;
use axum::body::Body;
use axum::extract::ws::{self, CloseFrame, Message as WsMessage, WebSocket};
use axum::extract::{Request, State, WebSocketUpgrade};
@ -10,16 +12,15 @@ use axum::Json;
use axum::{routing::get, Router};
use axum_login::AuthManagerLayerBuilder;
use common::{ChatMessage as ApiMessage, LoggedInResponse, LoginRequest, LoginResponse};
use diesel::prelude::*;
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use tokio::net::TcpListener;
use crate::models::{NewMessage, Message as DbMessage};
use crate::schema::messages;
use slab::Slab;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use time::Duration;
use tokio::fs;
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::services::ServeDir;
@ -27,7 +28,6 @@ use tower_http::trace::TraceLayer;
use tower_sessions::cookie::Key;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
use tracing::{debug, warn};
use diesel::prelude::*;
struct ServState {
client_sends: Slab<SplitSink<WebSocket, WsMessage>>,
@ -43,8 +43,6 @@ impl ServState {
}
}
async fn serve_index(static_dir: &Path) -> Response {
let index_path = PathBuf::from(&static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await {
@ -63,10 +61,7 @@ async fn serve_index(static_dir: &Path) -> Response {
.unwrap()
}
async fn login(
mut auth_session: AuthSession,
Json(req): Json<LoginRequest>,
) -> Response {
async fn login(mut auth_session: AuthSession, Json(req): Json<LoginRequest>) -> Response {
let Some(user) = auth_session
.authenticate(UserCreds {
username: req.username,
@ -100,7 +95,6 @@ async fn logged_in(auth_session: AuthSession) -> Response {
Json(resp).into_response()
}
async fn chat_ws(
State(state): State<Arc<Mutex<ServState>>>,
auth_session: AuthSession,
@ -118,7 +112,9 @@ async fn chat_ws(
return;
}
debug!("Client connected");
let messages: Vec<DbMessage> = messages::table.load(&mut *(state.lock().await.db_conn.lock().await)).unwrap();
let messages: Vec<DbMessage> = messages::table
.load(&mut *(state.lock().await.db_conn.lock().await))
.unwrap();
for db_msg in messages {
let api_msg = ApiMessage {
message: db_msg.message,
@ -176,12 +172,10 @@ async fn chat_ws(
pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: PathBuf) {
let db_conn = Arc::new(Mutex::new(db_conn));
let session_layer = SessionManagerLayer::new(MemoryStore::default())
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)));
let auth_backend = AuthBackend::new(db_conn.clone());
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
@ -224,7 +218,6 @@ pub async fn run(listener: TcpListener, db_conn: SqliteConnection, static_dir: P
.layer(auth_layer)
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
axum::serve(listener, app)
.await
.expect("Unable to start server");