From 6434ddd5fb578cc3f344d3927519671556c4f4c8 Mon Sep 17 00:00:00 2001 From: Max Hohlfeld Date: Thu, 6 Jun 2024 22:04:48 +0200 Subject: [PATCH] feat: simplify loading of current user --- src/calendar/routes.rs | 15 +++---- src/load_current_user_from_db.rs | 71 ++++++++++++++++++++++++++++++++ src/main.rs | 2 + 3 files changed, 79 insertions(+), 9 deletions(-) create mode 100644 src/load_current_user_from_db.rs diff --git a/src/calendar/routes.rs b/src/calendar/routes.rs index b4756def..49ebaa39 100644 --- a/src/calendar/routes.rs +++ b/src/calendar/routes.rs @@ -5,7 +5,7 @@ use chrono::{NaiveDate, Utc}; use serde::Deserialize; use sqlx::PgPool; -use crate::models::{Area, Availabillity, Event, Function, User, Role}; +use crate::models::{Area, Availabillity, Event, Function, Role, User}; use super::{ delete_availabillity::delete_availabillity, @@ -40,28 +40,25 @@ struct CalendarTemplate { #[actix_web::get("/")] async fn get_index( - user: Identity, + current_user: web::ReqData, pool: web::Data, query: web::Query, ) -> impl Responder { - let current_user = User::read_by_id(pool.get_ref(), user.id().unwrap().parse().unwrap()) - .await - .unwrap(); let date = match query.date { Some(given_date) => given_date, None => Utc::now().date_naive(), }; - let areas = Area::read_all(pool.get_ref()) + let areas = Area::read_all(pool.get_ref()).await.unwrap(); + + let events = Event::read_by_date_including_location(pool.get_ref(), date) .await .unwrap(); - - let events = Event::read_by_date_including_location(pool.get_ref(), date).await.unwrap(); let availabillities = Availabillity::read_by_date_including_user(pool.get_ref(), date) .await .unwrap(); let template = CalendarTemplate { - user: current_user, + user: current_user.into_inner(), date, areas, events, diff --git a/src/load_current_user_from_db.rs b/src/load_current_user_from_db.rs new file mode 100644 index 00000000..d11fb7be --- /dev/null +++ b/src/load_current_user_from_db.rs @@ -0,0 +1,71 @@ +use std::{ + future::{ready, Ready}, + rc::Rc, +}; + +use actix_identity::IdentityExt; +use actix_web::{ + dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform}, + web, Error, HttpMessage, +}; +use futures_util::{future::LocalBoxFuture, FutureExt}; +use sqlx::PgPool; + +use crate::models::User; + +pub struct LoadUser; + +impl Transform for LoadUser +where + S: Service, Error = Error> + 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type InitError = (); + type Transform = LoadUserMiddleware; + type Future = Ready>; + + fn new_transform(&self, service: S) -> Self::Future { + ready(Ok(LoadUserMiddleware { + service: Rc::new(service), + })) + } +} + +pub struct LoadUserMiddleware { + service: Rc, +} + +impl Service for LoadUserMiddleware +where + S: Service, Error = Error> + 'static, +{ + type Response = ServiceResponse; + type Error = Error; + type Future = LocalBoxFuture<'static, Result>; + + forward_ready!(service); + + fn call(&self, req: ServiceRequest) -> Self::Future { + let srv = self.service.clone(); + + async move { + if let Ok(identity) = req.get_identity() { + if let Ok(id) = identity.id() { + let pool = req.app_data::>().unwrap(); + + let user = User::read_by_id(pool.get_ref(), id.parse().unwrap()) + .await + .unwrap(); + + req.extensions_mut().insert::(user); + } + } + + let res = srv.call(req).await?; + + Ok(res) + } + .boxed_local() + } +} diff --git a/src/main.rs b/src/main.rs index 46620f34..74fd6733 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,6 +21,7 @@ mod endpoints; mod models; mod postgres_session_store; +mod load_current_user_from_db; pub enum Command { Migrate, @@ -120,6 +121,7 @@ async fn main() -> anyhow::Result<()> { .configure(calendar::init) .configure(endpoints::init) .wrap(redirect::CheckLogin) + .wrap(load_current_user_from_db::LoadUser) .wrap( IdentityMiddleware::builder() .visit_deadline(Some(Duration::from_secs(300)))