brass/web/src/middleware/load_current_user_from_db.rs

73 lines
1.9 KiB
Rust

use std::{
future::{ready, Ready},
rc::Rc,
};
use actix_identity::IdentityExt;
use actix_web::{
dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
web, Error, HttpMessage,
};
use futures_util::{future::LocalBoxFuture, FutureExt};
use sqlx::PgPool;
use crate::models::User;
pub struct LoadCurrentUser;
impl<S, B> Transform<S, ServiceRequest> for LoadCurrentUser
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = LoadCurrentUserMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(LoadCurrentUserMiddleware {
service: Rc::new(service),
}))
}
}
pub struct LoadCurrentUserMiddleware<S> {
service: Rc<S>,
}
impl<S, B> Service<ServiceRequest> for LoadCurrentUserMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let srv = self.service.clone();
async move {
if let Ok(identity) = req.get_identity() {
if let Ok(id) = identity.id() {
let pool = req.app_data::<web::Data<PgPool>>().unwrap();
let user = User::read_by_id(pool.get_ref(), id.parse().unwrap())
.await
.unwrap()
.unwrap();
req.extensions_mut().insert::<User>(user);
}
}
let res = srv.call(req).await?;
Ok(res)
}
.boxed_local()
}
}