146 lines
4.9 KiB
Rust
146 lines
4.9 KiB
Rust
// took code from https://github.com/chriswk/actix-session-sqlx-postgres and adapted it to own usecase
|
|
use actix_session::storage::{LoadError, SaveError, SessionKey, SessionStore, UpdateError};
|
|
use actix_web::cookie::time::Duration;
|
|
use chrono::Utc;
|
|
use rand::{distributions::Alphanumeric, rngs::OsRng, Rng as _};
|
|
use serde_json::{self, Value};
|
|
use sqlx::{Pool, Postgres, Row};
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
#[derive(Clone)]
|
|
struct CacheConfiguration {
|
|
cache_keygen: Arc<dyn Fn(&str) -> String + Send + Sync>,
|
|
}
|
|
|
|
impl Default for CacheConfiguration {
|
|
fn default() -> Self {
|
|
Self {
|
|
cache_keygen: Arc::new(str::to_owned),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct SqlxPostgresqlSessionStore {
|
|
client_pool: Arc<Pool<Postgres>>,
|
|
configuration: CacheConfiguration,
|
|
}
|
|
|
|
fn generate_session_key() -> SessionKey {
|
|
let value = std::iter::repeat(())
|
|
.map(|()| OsRng.sample(Alphanumeric))
|
|
.take(64)
|
|
.collect::<Vec<_>>();
|
|
|
|
// These unwraps will never panic because pre-conditions are always verified
|
|
// (i.e. length and character set)
|
|
String::from_utf8(value).unwrap().try_into().unwrap()
|
|
}
|
|
|
|
impl SqlxPostgresqlSessionStore {
|
|
pub fn from_pool(pool: Arc<Pool<Postgres>>) -> SqlxPostgresqlSessionStore {
|
|
SqlxPostgresqlSessionStore {
|
|
client_pool: pool,
|
|
configuration: CacheConfiguration::default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) type SessionState = HashMap<String, String>;
|
|
|
|
impl SessionStore for SqlxPostgresqlSessionStore {
|
|
async fn load(&self, session_key: &SessionKey) -> Result<Option<SessionState>, LoadError> {
|
|
let key = (self.configuration.cache_keygen)(session_key.as_ref());
|
|
let row =
|
|
sqlx::query("SELECT sessionstate FROM session WHERE key = $1 AND expires > NOW()")
|
|
.bind(key)
|
|
.fetch_optional(self.client_pool.as_ref())
|
|
.await
|
|
.map_err(Into::into)
|
|
.map_err(LoadError::Other)?;
|
|
match row {
|
|
None => Ok(None),
|
|
Some(r) => {
|
|
let data: Value = r.get("sessionstate");
|
|
let state: SessionState = serde_json::from_value(data)
|
|
.map_err(Into::into)
|
|
.map_err(LoadError::Deserialization)?;
|
|
Ok(Some(state))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn save(
|
|
&self,
|
|
session_state: SessionState,
|
|
ttl: &Duration,
|
|
) -> Result<SessionKey, SaveError> {
|
|
let body = serde_json::to_value(&session_state)
|
|
.map_err(Into::into)
|
|
.map_err(SaveError::Serialization)?;
|
|
let key = generate_session_key();
|
|
let cache_key = (self.configuration.cache_keygen)(key.as_ref());
|
|
let expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
|
|
sqlx::query("INSERT INTO session(key, sessionstate, expires) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING")
|
|
.bind(cache_key)
|
|
.bind(body)
|
|
.bind(expires)
|
|
.execute(self.client_pool.as_ref())
|
|
.await
|
|
.map_err(Into::into)
|
|
.map_err(SaveError::Other)?;
|
|
Ok(key)
|
|
}
|
|
|
|
async fn update(
|
|
&self,
|
|
session_key: SessionKey,
|
|
session_state: SessionState,
|
|
ttl: &Duration,
|
|
) -> Result<SessionKey, UpdateError> {
|
|
let body = serde_json::to_value(&session_state)
|
|
.map_err(Into::into)
|
|
.map_err(UpdateError::Serialization)?;
|
|
let cache_key = (self.configuration.cache_keygen)(session_key.as_ref());
|
|
let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
|
|
sqlx::query("UPDATE session SET sessionstate = $1, expires = $2 WHERE key = $3")
|
|
.bind(body)
|
|
.bind(new_expires)
|
|
.bind(cache_key)
|
|
.execute(self.client_pool.as_ref())
|
|
.await
|
|
.map_err(Into::into)
|
|
.map_err(UpdateError::Other)?;
|
|
Ok(session_key)
|
|
}
|
|
|
|
async fn update_ttl(
|
|
&self,
|
|
session_key: &SessionKey,
|
|
ttl: &Duration,
|
|
) -> Result<(), anyhow::Error> {
|
|
let new_expires = Utc::now() + chrono::Duration::seconds(ttl.whole_seconds());
|
|
let key = (self.configuration.cache_keygen)(session_key.as_ref());
|
|
sqlx::query("UPDATE session SET expires = $1 WHERE key = $2")
|
|
.bind(new_expires)
|
|
.bind(key)
|
|
.execute(self.client_pool.as_ref())
|
|
.await
|
|
.map_err(Into::into)
|
|
.map_err(UpdateError::Other)?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn delete(&self, session_key: &SessionKey) -> Result<(), anyhow::Error> {
|
|
let key = (self.configuration.cache_keygen)(session_key.as_ref());
|
|
sqlx::query("DELETE FROM session WHERE key = $1")
|
|
.bind(key)
|
|
.execute(self.client_pool.as_ref())
|
|
.await
|
|
.map_err(Into::into)
|
|
.map_err(UpdateError::Other)?;
|
|
Ok(())
|
|
}
|
|
}
|