Start work on authentication

This commit is contained in:
pjht 2024-10-09 11:58:49 -05:00
parent 0e4d29ae52
commit 74e7a9c2d5
Signed by: pjht
GPG Key ID: CA239FC6934E6F3A
5 changed files with 518 additions and 86 deletions

250
Cargo.lock generated
View File

@ -210,6 +210,26 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-login"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5260ed0ecc8ace8e7e61a7406672faba598c8a86b8f4742fcdde0ddc979a318f"
dependencies = [
"async-trait",
"axum",
"form_urlencoded",
"serde",
"subtle",
"thiserror",
"tower-cookies",
"tower-layer",
"tower-service",
"tower-sessions",
"tracing",
"urlencoding",
]
[[package]]
name = "axum_static"
version = "1.7.1"
@ -433,6 +453,22 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"base64",
"hmac",
"percent-encoding",
"rand",
"sha2",
"subtle",
"time",
"version_check",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.7"
@ -479,6 +515,16 @@ version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2"
[[package]]
name = "deranged"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
dependencies = [
"powerfmt",
"serde",
]
[[package]]
name = "digest"
version = "0.10.7"
@ -487,6 +533,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"crypto-common",
"subtle",
]
[[package]]
@ -544,11 +591,13 @@ dependencies = [
"ciborium",
"common",
"console_error_panic_hook",
"gloo 0.11.0",
"gloo-console 0.3.0",
"gloo-net 0.6.0",
"gloo-timers 0.3.0",
"itertools",
"log",
"wasm-bindgen",
"wasm-logger",
"web-sys",
"yew",
@ -712,6 +761,25 @@ dependencies = [
"gloo-worker 0.4.0",
]
[[package]]
name = "gloo"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d15282ece24eaf4bd338d73ef580c6714c8615155c4190c781290ee3fa0fd372"
dependencies = [
"gloo-console 0.3.0",
"gloo-dialogs 0.2.0",
"gloo-events 0.2.0",
"gloo-file 0.3.0",
"gloo-history 0.2.2",
"gloo-net 0.5.0",
"gloo-render 0.2.0",
"gloo-storage 0.3.0",
"gloo-timers 0.3.0",
"gloo-utils 0.2.0",
"gloo-worker 0.5.0",
]
[[package]]
name = "gloo-console"
version = "0.2.3"
@ -898,6 +966,27 @@ dependencies = [
"web-sys",
]
[[package]]
name = "gloo-net"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43aaa242d1239a8822c15c645f02166398da4f8b5c4bae795c1f5b44e9eee173"
dependencies = [
"futures-channel",
"futures-core",
"futures-sink",
"gloo-utils 0.2.0",
"http 0.2.12",
"js-sys",
"pin-project",
"serde",
"serde_json",
"thiserror",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "gloo-net"
version = "0.6.0"
@ -1053,6 +1142,25 @@ dependencies = [
"web-sys",
]
[[package]]
name = "gloo-worker"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "085f262d7604911c8150162529cefab3782e91adb20202e8658f7275d2aefe5d"
dependencies = [
"bincode",
"futures",
"gloo-utils 0.2.0",
"gloo-worker-macros",
"js-sys",
"pinned",
"serde",
"thiserror",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
]
[[package]]
name = "gloo-worker-macros"
version = "0.1.0"
@ -1093,6 +1201,15 @@ version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024"
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "http"
version = "0.2.12"
@ -1313,6 +1430,7 @@ checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17"
dependencies = [
"autocfg",
"scopeguard",
"serde",
]
[[package]]
@ -1386,6 +1504,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-traits"
version = "0.2.19"
@ -1504,6 +1628,12 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2"
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.20"
@ -1741,8 +1871,10 @@ dependencies = [
name = "server"
version = "0.1.0"
dependencies = [
"async-trait",
"axum",
"axum-client-ip",
"axum-login",
"axum_static",
"ciborium",
"clap",
@ -1750,9 +1882,11 @@ dependencies = [
"futures",
"log",
"slab",
"time",
"tokio",
"tower",
"tower-http 0.6.1",
"tower-sessions",
"tracing",
"tracing-subscriber",
]
@ -1768,6 +1902,17 @@ dependencies = [
"digest",
]
[[package]]
name = "sha2"
version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
@ -1823,6 +1968,12 @@ version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "1.0.109"
@ -1886,6 +2037,37 @@ dependencies = [
"once_cell",
]
[[package]]
name = "time"
version = "0.3.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
dependencies = [
"num-conv",
"time-core",
]
[[package]]
name = "tokio"
version = "1.40.0"
@ -1984,6 +2166,23 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-cookies"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fd0118512cf0b3768f7fcccf0bef1ae41d68f2b45edc1e77432b36c97c56c6d"
dependencies = [
"async-trait",
"axum-core",
"cookie",
"futures-util",
"http 1.1.0",
"parking_lot",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-http"
version = "0.5.2"
@ -2052,6 +2251,57 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tower-sessions"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65856c81ee244e0f8a55ab0f7b769b72fbde387c235f0a73cd97c579818d05eb"
dependencies = [
"async-trait",
"http 1.1.0",
"time",
"tokio",
"tower-cookies",
"tower-layer",
"tower-service",
"tower-sessions-core",
"tower-sessions-memory-store",
"tracing",
]
[[package]]
name = "tower-sessions-core"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb6abbfcaf6436ec5a772cd9f965401da12db793e404ae6134eac066fa5a04f3"
dependencies = [
"async-trait",
"axum-core",
"base64",
"futures",
"http 1.1.0",
"parking_lot",
"rand",
"serde",
"serde_json",
"thiserror",
"time",
"tokio",
"tracing",
]
[[package]]
name = "tower-sessions-memory-store"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fad75660c8afbe74f4e7cbbe8e9090171a056b57370ea4d7d5e9eb3e4af3092"
dependencies = [
"async-trait",
"time",
"tokio",
"tower-sessions-core",
]
[[package]]
name = "tracing"
version = "0.1.40"

View File

@ -8,13 +8,15 @@ chrono = { version = "0.4.38", features = ["wasmbind"] }
ciborium = "0.2.2"
common = { version = "0.1.0", path = "../common" }
console_error_panic_hook = "0.1.7"
gloo = "0.11.0"
gloo-console = "0.3.0"
gloo-net = "0.6.0"
gloo-timers = "0.3.0"
itertools = "0.13.0"
log = "0.4.22"
wasm-bindgen = "0.2.93"
wasm-logger = "0.2.0"
web-sys = { version = "0.3.70", features = ["Navigator"] }
web-sys = { version = "0.3.70", features = ["Navigator", "WebSocket", "EventListener"] }
yew = { version = "0.21.0", features = ["csr"] }
yew-router = "0.18.0"
yew-websocket = "1.21.0"

View File

@ -1,15 +1,14 @@
use std::time::Duration;
use std::{sync::Arc, time::Duration};
use chrono::Local;
use common::ChatMessage;
use gloo::{events::EventListener, net::http::Request};
use gloo_console::log;
use gloo_timers::future::sleep;
use yew::prelude::*;
use wasm_bindgen::JsCast;
use web_sys::{js_sys::Uint8Array, CloseEvent, Event, MessageEvent, WebSocket};
use yew::{platform::spawn_local, prelude::*};
use yew_router::prelude::*;
use yew_websocket::{
format::Binary,
websocket::{WebSocketService, WebSocketStatus, WebSocketTask},
};
#[derive(Clone, Routable, PartialEq)]
enum Route {
@ -32,65 +31,137 @@ fn switch(routes: Route) -> Html {
}
}
#[derive(Clone, Debug)]
enum WsEvent {
#[allow(dead_code)]
Close(CloseEvent),
#[allow(dead_code)]
Open(Event),
}
impl WsEvent {
fn as_close(&self) -> Option<&CloseEvent> {
if let Self::Close(v) = self {
Some(v)
} else {
None
}
}
}
enum HomeMessage {
SubmittedMessage(String),
RecievedMessage(ChatMessage),
WsStateChange(WebSocketStatus),
WsEvent(WsEvent),
WsReconnect,
Authenticated,
}
#[derive(PartialEq, Eq)]
enum WsState {
Closed,
Open,
}
struct Home {
authenticated: bool,
messages: Vec<ChatMessage>,
chat_ws: WebSocketTask,
ws_state: WebSocketStatus,
chat_ws: Option<WebSocketConn>,
ws_state: WsState,
ws_reconnecting: bool,
message_container_ref: NodeRef,
}
struct WebSocketConn {
ws: WebSocket,
#[allow(dead_code)]
listeners: [EventListener; 3],
}
impl Drop for WebSocketConn {
fn drop(&mut self) {
let _ = self.ws.close();
}
}
impl Home {
fn connect_ws(ctx: &Context<Self>) -> WebSocketTask {
fn connect_ws(&mut self, ctx: &Context<Self>) {
let location = web_sys::window().unwrap().location();
let ws_proto = if location.protocol().unwrap() == "https:" {
"wss"
} else {
"ws"
};
let api_url = format!("{}://{}/api/chat_ws", ws_proto, location.host().unwrap());
log!("Connecting to ", &api_url);
WebSocketService::connect_binary(
&api_url,
ctx.link().callback(|msg: Binary| {
let msg = msg.unwrap();
let msg = ciborium::from_reader(msg.as_slice()).unwrap();
HomeMessage::RecievedMessage(msg)
}),
ctx.link().callback(HomeMessage::WsStateChange),
)
.unwrap()
let ws_url = format!("{}://{}/api/chat_ws", ws_proto, location.host().unwrap());
log!("Connecting to ", &ws_url);
let ws = WebSocket::new(&ws_url).unwrap();
ws.set_binary_type(web_sys::BinaryType::Arraybuffer);
let msg_callback = ctx.link().callback(|msg: Vec<u8>| {
let msg = ciborium::from_reader(msg.as_slice()).unwrap();
HomeMessage::RecievedMessage(msg)
});
let state_callback = ctx
.link()
.callback(|state: WsEvent| HomeMessage::WsEvent(state));
let cb = state_callback.clone();
let open_ev = EventListener::new(&ws, "open", move |event: &Event| {
cb.emit(WsEvent::Open(event.clone()))
});
let cb = state_callback.clone();
let close_ev = EventListener::new(&ws, "close", move |event: &Event| {
let event = event.dyn_ref::<CloseEvent>().unwrap();
cb.emit(WsEvent::Close(event.clone()))
});
let msg_ev = EventListener::new(&ws, "message", move |event: &Event| {
let event = event.dyn_ref::<MessageEvent>().unwrap();
let bytes = event.data();
let bytes = Uint8Array::new(&bytes).to_vec();
msg_callback.emit(bytes);
});
self.chat_ws = Some(WebSocketConn {
ws,
listeners: [open_ev, close_ev, msg_ev],
});
}
fn authenticate(ctx: &Context<Self>) {
log!("Authenticating to backend");
let auth_cb = ctx.link().callback(|_: ()| HomeMessage::Authenticated);
spawn_local(async move {
let location = web_sys::window().unwrap().location();
let login_url = format!(
"{}//{}/api/login",
location.protocol().unwrap(),
location.host().unwrap()
);
Request::post(&login_url)
.body("")
.unwrap()
.send()
.await
.unwrap();
auth_cb.emit(());
});
}
}
//fn on_mobile() -> bool {
// let window = web_sys::window().unwrap();
// let navigator = window.navigator();
//
// navigator.max_touch_points() > 0 || window.inner_width().unwrap().as_f64().unwrap() < 768.0
//}
impl Component for Home {
type Message = HomeMessage;
type Properties = ();
fn create(ctx: &Context<Self>) -> Self {
let chat_ws = Self::connect_ws(ctx);
Self {
Self::authenticate(ctx);
let mut slf = Self {
authenticated: false,
messages: Vec::new(),
chat_ws,
ws_state: WebSocketStatus::Closed,
chat_ws: None,
ws_state: WsState::Closed,
ws_reconnecting: false,
message_container_ref: NodeRef::default(),
}
};
slf.connect_ws(ctx);
slf
}
fn view(&self, ctx: &Context<Self>) -> Html {
@ -111,16 +182,21 @@ impl Component for Home {
}
})
};
let disable_input = self.ws_state != WebSocketStatus::Opened;
let disable_input = self.ws_state != WsState::Open || !self.authenticated;
html! {
<div class="myvh-100 d-flex flex-column">
<Nav />
<div class="container-fluid d-flex flex-column flex-grow-1 mt-3">
if disable_input {
if self.ws_state != WsState::Open {
<div class="alert alert-warning" role="alert">
{ "Connection to backend lost, trying to reconnect" }
</div>
}
if !self.authenticated {
<div class="alert alert-warning" role="alert">
{ "Authenticating to backend" }
</div>
}
<div
ref={self.message_container_ref.clone()}
class="d-flex border rounded flex-grow-1 flex-column-reverse overflow-auto mb-3"
@ -157,54 +233,68 @@ impl Component for Home {
self.messages.push(msg.clone());
let mut buf = Vec::new();
ciborium::into_writer(&msg, &mut buf).unwrap();
self.chat_ws.send_binary(buf);
self.chat_ws
.as_ref()
.unwrap()
.ws
.send_with_u8_array(&buf)
.unwrap();
true
}
HomeMessage::RecievedMessage(msg) => {
self.messages.push(msg);
true
}
HomeMessage::WsStateChange(state) => {
if state != self.ws_state {
if state != WebSocketStatus::Opened {
log!("WS connection closed");
if self.ws_state != WebSocketStatus::Opened {
log!("Already closed");
if !self.ws_reconnecting {
log!("Reconnecting in 5s");
self.ws_reconnecting = true;
ctx.link().send_future(async move {
sleep(Duration::from_secs(5)).await;
HomeMessage::WsReconnect
});
}
} else {
log!("Reconnecting");
self.chat_ws = Self::connect_ws(ctx);
}
} else {
log!("WS connection opened");
self.messages.clear();
HomeMessage::WsEvent(event) => {
if let Some(close_event) = event.as_close() {
let code = close_event.code();
log!("WS connection closed with code", code);
if code == 4000 {
log!("Unauthorized, reauthenticating");
self.ws_state = WsState::Closed;
self.authenticated = false;
Self::authenticate(ctx);
return true;
}
self.ws_state = state.clone();
if self.ws_state == WsState::Closed {
log!("Already closed");
if !self.ws_reconnecting {
log!("Reconnecting in 5s");
self.ws_reconnecting = true;
ctx.link().send_future(async move {
sleep(Duration::from_secs(5)).await;
HomeMessage::WsReconnect
});
}
false
} else {
log!("Reconnecting");
self.connect_ws(ctx);
self.ws_state = WsState::Closed;
true
}
} else if self.ws_state != WsState::Open {
log!("WS connection opened");
self.messages.clear();
self.ws_state = WsState::Open;
true
} else {
if state != WebSocketStatus::Opened {
log!("Close/error state while closed/errored, marking closed and reconnecting");
self.ws_state = WebSocketStatus::Closed;
self.chat_ws = Self::connect_ws(ctx);
}
false
}
}
HomeMessage::WsReconnect => {
if self.ws_state != WebSocketStatus::Opened {
if self.ws_state != WsState::Open {
log!("Reconnecting");
self.ws_reconnecting = false;
self.chat_ws = Self::connect_ws(ctx);
self.connect_ws(ctx);
}
false
}
HomeMessage::Authenticated => {
self.connect_ws(ctx);
self.authenticated = true;
true
}
}
}

View File

@ -4,8 +4,10 @@ version = "0.1.0"
edition = "2021"
[dependencies]
async-trait = "0.1.83"
axum = { version = "0.7.7", features = ["ws"] }
axum-client-ip = "0.6.1"
axum-login = "0.16.0"
axum_static = "1.7.1"
ciborium = "0.2.2"
clap = { version = "4.5.19", features = ["derive"] }
@ -13,8 +15,10 @@ common = { version = "0.1.0", path = "../common" }
futures = "0.3.31"
log = "0.4.22"
slab = "0.4.9"
time = "0.3.36"
tokio = { version = "1.40.0", features = ["full"] }
tower = "0.5.1"
tower-http = { version = "0.6.1", features = ["full"] }
tower-sessions = { version = "0.13.0", features = ["memory-store", "signed"] }
tracing = "0.1.40"
tracing-subscriber = "0.3.18"

View File

@ -1,25 +1,33 @@
use axum::async_trait;
use axum::body::Body;
use axum::extract::ws::{self, CloseFrame, Message, WebSocket};
use axum::extract::{Request, State, WebSocketUpgrade};
use axum::http::header::CONTENT_TYPE;
use axum::http::StatusCode;
use axum::response::Response;
use axum::response::{IntoResponse, Response};
use axum::routing::post;
use axum::{routing::get, Router};
use axum_login::{
login_required, AuthManagerLayerBuilder, AuthSession, AuthUser, AuthnBackend, UserId,
};
use clap::Parser;
use common::ChatMessage;
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use slab::Slab;
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use time::Duration;
use tokio::fs::{self, OpenOptions};
use tokio::io::AsyncWriteExt;
use tokio::sync::Mutex;
use tower::{ServiceBuilder, ServiceExt};
use tower_http::services::ServeDir;
use tower_http::trace::TraceLayer;
use tower_sessions::cookie::Key;
use tower_sessions::{Expiry, MemoryStore, SessionManagerLayer};
use tracing::{debug, warn};
// Setup the command line interface with clap.
@ -60,6 +68,71 @@ impl ServState {
}
}
#[derive(Clone, Debug)]
struct DummyAuthUser;
impl AuthUser for DummyAuthUser {
type Id = String;
fn id(&self) -> Self::Id {
"pjht".to_string()
}
fn session_auth_hash(&self) -> &[u8] {
&[0]
}
}
#[derive(Clone, Debug)]
struct DummyAuthBackend;
#[async_trait]
impl AuthnBackend for DummyAuthBackend {
type User = DummyAuthUser;
type Credentials = ();
type Error = std::convert::Infallible;
async fn authenticate(
&self,
_creds: Self::Credentials,
) -> Result<Option<Self::User>, Self::Error> {
Ok(Some(DummyAuthUser))
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
if user_id == "pjht" {
Ok(Some(DummyAuthUser))
} else {
Ok(None)
}
}
}
async fn serve_index(static_dir: &Path) -> Response {
let index_path = PathBuf::from(&static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await {
Err(_) => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("index file not found"))
.unwrap()
}
Ok(index_content) => index_content,
};
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(index_content))
.unwrap()
}
async fn login(mut auth_session: AuthSession<DummyAuthBackend>) -> Response {
if auth_session.login(&DummyAuthUser).await.is_err() {
return StatusCode::INTERNAL_SERVER_ERROR.into_response();
}
StatusCode::OK.into_response()
}
#[tokio::main]
async fn main() {
let opt = Opt::parse();
@ -73,7 +146,20 @@ async fn main() {
// enable console logging
tracing_subscriber::fmt::init();
let api = Router::new().route("/chat_ws", get(chat_ws));
let signing_key = Key::generate();
let session_layer = SessionManagerLayer::new(MemoryStore::default())
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::days(1)))
.with_signed(signing_key);
let auth_backend = DummyAuthBackend;
let auth_layer = AuthManagerLayerBuilder::new(auth_backend, session_layer).build();
let api = Router::new()
//.route_layer(login_required!(DummyAuthBackend))
.route("/chat_ws", get(chat_ws))
.route("/login", post(login));
let app = Router::new()
.nest("/api", api)
@ -91,21 +177,7 @@ async fn main() {
{
Ok(res) => {
if res.status() == StatusCode::NOT_FOUND {
let index_path = PathBuf::from(&opt.static_dir).join("index.html");
let index_content = match fs::read_to_string(index_path).await {
Err(_) => {
return Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::from("index file not found"))
.unwrap()
}
Ok(index_content) => index_content,
};
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "text/html")
.body(Body::from(index_content))
.unwrap()
serve_index(&opt.static_dir).await
} else {
res.map(Body::new)
}
@ -117,6 +189,7 @@ async fn main() {
}
})
.with_state(Arc::new(Mutex::new(ServState::new())))
.layer(auth_layer)
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()));
let sock_addr = SocketAddr::from((
@ -133,9 +206,22 @@ async fn main() {
.expect("Unable to start server");
}
async fn chat_ws(State(state): State<Arc<Mutex<ServState>>>, ws: WebSocketUpgrade) -> Response {
async fn chat_ws(
State(state): State<Arc<Mutex<ServState>>>,
auth_session: AuthSession<DummyAuthBackend>,
ws: WebSocketUpgrade,
) -> Response {
ws.on_upgrade(move |socket| async move {
let (mut tx, mut rx) = socket.split();
if auth_session.user.is_none() {
let _ = tx
.send(Message::Close(Some(CloseFrame {
code: 4000,
reason: "Unauthorized".into(),
})))
.await;
return;
}
debug!("Client connected");
for msg in &state.lock().await.chat_log {
let mut buf = Vec::new();