feat: simplify loading of current user

This commit is contained in:
Max Hohlfeld 2024-06-06 22:04:48 +02:00
parent 9c75c10b11
commit 6434ddd5fb
3 changed files with 79 additions and 9 deletions

View File

@ -5,7 +5,7 @@ use chrono::{NaiveDate, Utc};
use serde::Deserialize; use serde::Deserialize;
use sqlx::PgPool; use sqlx::PgPool;
use crate::models::{Area, Availabillity, Event, Function, User, Role}; use crate::models::{Area, Availabillity, Event, Function, Role, User};
use super::{ use super::{
delete_availabillity::delete_availabillity, delete_availabillity::delete_availabillity,
@ -40,28 +40,25 @@ struct CalendarTemplate {
#[actix_web::get("/")] #[actix_web::get("/")]
async fn get_index( async fn get_index(
user: Identity, current_user: web::ReqData<User>,
pool: web::Data<PgPool>, pool: web::Data<PgPool>,
query: web::Query<CalendarQuery>, query: web::Query<CalendarQuery>,
) -> impl Responder { ) -> impl Responder {
let current_user = User::read_by_id(pool.get_ref(), user.id().unwrap().parse().unwrap())
.await
.unwrap();
let date = match query.date { let date = match query.date {
Some(given_date) => given_date, Some(given_date) => given_date,
None => Utc::now().date_naive(), None => Utc::now().date_naive(),
}; };
let areas = Area::read_all(pool.get_ref()) let areas = Area::read_all(pool.get_ref()).await.unwrap();
let events = Event::read_by_date_including_location(pool.get_ref(), date)
.await .await
.unwrap(); .unwrap();
let events = Event::read_by_date_including_location(pool.get_ref(), date).await.unwrap();
let availabillities = Availabillity::read_by_date_including_user(pool.get_ref(), date) let availabillities = Availabillity::read_by_date_including_user(pool.get_ref(), date)
.await .await
.unwrap(); .unwrap();
let template = CalendarTemplate { let template = CalendarTemplate {
user: current_user, user: current_user.into_inner(),
date, date,
areas, areas,
events, events,

View File

@ -0,0 +1,71 @@
use std::{
future::{ready, Ready},
rc::Rc,
};
use actix_identity::IdentityExt;
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
web, Error, HttpMessage,
};
use futures_util::{future::LocalBoxFuture, FutureExt};
use sqlx::PgPool;
use crate::models::User;
pub struct LoadUser;
impl<S, B> Transform<S, ServiceRequest> for LoadUser
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = LoadUserMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(LoadUserMiddleware {
service: Rc::new(service),
}))
}
}
pub struct LoadUserMiddleware<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for LoadUserMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let srv = self.service.clone();
async move {
if let Ok(identity) = req.get_identity() {
if let Ok(id) = identity.id() {
let pool = req.app_data::<web::Data<PgPool>>().unwrap();
let user = User::read_by_id(pool.get_ref(), id.parse().unwrap())
.await
.unwrap();
req.extensions_mut().insert::<User>(user);
}
}
let res = srv.call(req).await?;
Ok(res)
}
.boxed_local()
}
}

View File

@ -21,6 +21,7 @@ mod endpoints;
mod models; mod models;
mod postgres_session_store; mod postgres_session_store;
mod load_current_user_from_db;
pub enum Command { pub enum Command {
Migrate, Migrate,
@ -120,6 +121,7 @@ async fn main() -> anyhow::Result<()> {
.configure(calendar::init) .configure(calendar::init)
.configure(endpoints::init) .configure(endpoints::init)
.wrap(redirect::CheckLogin) .wrap(redirect::CheckLogin)
.wrap(load_current_user_from_db::LoadUser)
.wrap( .wrap(
IdentityMiddleware::builder() IdentityMiddleware::builder()
.visit_deadline(Some(Duration::from_secs(300))) .visit_deadline(Some(Duration::from_secs(300)))