// took code from https://github.com/chriswk/actix-session-sqlx-postgres and adapted it to own usecase use actix_session::storage::{LoadError, SaveError, SessionKey, SessionStore, UpdateError}; use actix_web::cookie::time::Duration; use chrono::Utc; use rand::{distributions::Alphanumeric, rngs::OsRng, Rng as _}; use serde_json::{self, Value}; use sqlx::{Pool, Postgres, Row}; use std::collections::HashMap; use std::sync::Arc; #[derive(Clone)] struct CacheConfiguration { cache_keygen: Arc String + Send + Sync>, } impl Default for CacheConfiguration { fn default() -> Self { Self { cache_keygen: Arc::new(str::to_owned), } } } #[derive(Clone)] pub struct SqlxPostgresqlSessionStore { client_pool: Arc>, configuration: CacheConfiguration, } fn generate_session_key() -> SessionKey { let value = std::iter::repeat(()) .map(|()| OsRng.sample(Alphanumeric)) .take(64) .collect::>(); // These unwraps will never panic because pre-conditions are always verified // (i.e. length and character set) String::from_utf8(value).unwrap().try_into().unwrap() } impl SqlxPostgresqlSessionStore { pub fn from_pool(pool: Arc>) -> SqlxPostgresqlSessionStore { SqlxPostgresqlSessionStore { client_pool: pool, configuration: CacheConfiguration::default(), } } } pub(crate) type SessionState = HashMap; #[async_trait::async_trait(?Send)] impl SessionStore for SqlxPostgresqlSessionStore { async fn load(&self, session_key: &SessionKey) -> Result, LoadError> { let key = (self.configuration.cache_keygen)(session_key.as_ref()); let row = sqlx::query("SELECT sessionstate FROM session WHERE key = $1 AND expires > NOW()") .bind(key) .fetch_optional(self.client_pool.as_ref()) .await .map_err(Into::into) .map_err(LoadError::Other)?; match row { None => Ok(None), Some(r) => { let data: Value = r.get("sessionstate"); let state: SessionState = serde_json::from_value(data) .map_err(Into::into) .map_err(LoadError::Deserialization)?; Ok(Some(state)) } } } async fn save( &self, session_state: SessionState, ttl: &Duration, ) -> Result { let body = serde_json::to_value(&session_state) .map_err(Into::into) .map_err(SaveError::Serialization)?; let key = generate_session_key(); let cache_key = (self.configuration.cache_keygen)(key.as_ref()); let expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds()); sqlx::query("INSERT INTO session(key, sessionstate, expires) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING") .bind(cache_key) .bind(body) .bind(expires) .execute(self.client_pool.as_ref()) .await .map_err(Into::into) .map_err(SaveError::Other)?; Ok(key) } async fn update( &self, session_key: SessionKey, session_state: SessionState, ttl: &Duration, ) -> Result { let body = serde_json::to_value(&session_state) .map_err(Into::into) .map_err(UpdateError::Serialization)?; let cache_key = (self.configuration.cache_keygen)(session_key.as_ref()); let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds()); sqlx::query("UPDATE session SET sessionstate = $1, expires = $2 WHERE key = $3") .bind(body) .bind(new_expires) .bind(cache_key) .execute(self.client_pool.as_ref()) .await .map_err(Into::into) .map_err(UpdateError::Other)?; Ok(session_key) } async fn update_ttl( &self, session_key: &SessionKey, ttl: &Duration, ) -> Result<(), anyhow::Error> { let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds()); let key = (self.configuration.cache_keygen)(session_key.as_ref()); sqlx::query("UPDATE session SET expires = $1 WHERE key = $2") .bind(new_expires) .bind(key) .execute(self.client_pool.as_ref()) .await .map_err(Into::into) .map_err(UpdateError::Other)?; Ok(()) } async fn delete(&self, session_key: &SessionKey) -> Result<(), anyhow::Error> { let key = (self.configuration.cache_keygen)(session_key.as_ref()); sqlx::query("DELETE FROM session WHERE key = $1") .bind(key) .execute(self.client_pool.as_ref()) .await .map_err(Into::into) .map_err(UpdateError::Other)?; Ok(()) } }