diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 8a6ec4fb..aea5c59b 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -13,4 +13,5 @@ path = "src/db.rs" clap = { version = "4.5.23", features = ["derive"] } brass-config = { path = "../config" } async-std = { version = "1.13.0", features = ["attributes"] } -sqlx = { version = "0.8.2", features = ["runtime-async-std"] } +sqlx = { version = "0.8.2", features = ["runtime-async-std", "postgres"] } +anyhow = "1.0.94" diff --git a/cli/src/db.rs b/cli/src/db.rs index 1721518d..2423baf6 100644 --- a/cli/src/db.rs +++ b/cli/src/db.rs @@ -1,12 +1,18 @@ -use sqlx::Executor; -use std::str::FromStr; +use anyhow::Context; +use sqlx::migrate::Migrate; +use sqlx::{migrate::Migrator, Executor}; +use std::{ + collections::HashMap, + path::{Path, PathBuf}, + str::FromStr, +}; use brass_config::{load_config, parse_env, Environment}; -use clap::{Args, Parser, Subcommand}; +use clap::{Parser, Subcommand}; use sqlx::{postgres::PgConnectOptions, Connection, PgConnection}; #[derive(Parser)] -#[command(about, long_about = None)] +#[command(about = "A CLI tool for managing the projects database.", long_about = None)] struct Cli { #[command(subcommand)] command: Command, @@ -16,36 +22,129 @@ struct Cli { #[derive(Subcommand)] enum Command { + #[command(about = "Create the database and run all migrations")] Setup, + #[command(about = "Drop and recreate the database and run all migrations")] Reset, + #[command(about = "Run all pending migrations on database")] + Migrate, } -#[derive(Args)] -struct CommandArgs {} - #[async_std::main] async fn main() { let cli = Cli::parse(); let config = load_config(&cli.environment).expect("Could not load config!"); + let db_config = + PgConnectOptions::from_str(&config.database_url).expect("Invalid DATABASE_URL!"); match cli.command { - Command::Setup => {} - Command::Reset => { - let db_config = PgConnectOptions::from_str(&config.database_url).expect("Invalid DATABASE_URL!"); - let db_name = db_config - .get_database() - .expect("Failed to get database name!"); - - let root_db_config = db_config.clone().database("postgres"); - let mut root_connection: PgConnection = Connection::connect_with(&root_db_config).await.unwrap(); - - let query = format!("DROP DATABASE {}", db_name); - root_connection - .execute(query.as_str()) + Command::Setup => { + create_db(&db_config) .await - .expect("Failed to drop database!"); + .expect("Failed creating database."); - //Ok(String::from(db_name)) + migrate_db(&db_config) + .await + .expect("Failed migrating database."); + } + Command::Reset => { + drop_db(&db_config) + .await + .expect("Failed dropping database."); + + create_db(&db_config) + .await + .expect("Failed creating database."); + + migrate_db(&db_config) + .await + .expect("Failed migrating database."); + }, + Command::Migrate => { + migrate_db(&db_config) + .await + .expect("Failed migrating database."); } } } + +async fn drop_db(db_config: &PgConnectOptions) -> anyhow::Result<()> { + let db_name = db_config + .get_database() + .context("Failed to get database name!")?; + + let root_db_config = db_config.clone().database("postgres"); + let mut connection: PgConnection = Connection::connect_with(&root_db_config) + .await + .context("Connection to database failed!")?; + + let query_drop = format!("DROP DATABASE {}", db_name); + connection + .execute(query_drop.as_str()) + .await + .context("Failed to drop database!")?; + + Ok(()) +} + +async fn create_db(db_config: &PgConnectOptions) -> anyhow::Result<()> { + let db_name = db_config + .get_database() + .context("Failed to get database name!")?; + + let root_db_config = db_config.clone().database("postgres"); + let mut connection: PgConnection = Connection::connect_with(&root_db_config) + .await + .context("Connection to database failed!")?; + + let query_create = format!("CREATE DATABASE {}", db_name); + connection + .execute(query_create.as_str()) + .await + .context("Failed to create database!")?; + + Ok(()) +} + +async fn migrate_db(db_config: &PgConnectOptions) -> anyhow::Result<()> { + let mut connection: PgConnection = Connection::connect_with(db_config) + .await + .context("Connection to database failed!")?; + + let migrations_path = PathBuf::from( + std::env::var("CARGO_MANIFEST_DIR").expect("This command needs to be invoked using cargo"), + ) + .join("..") + .join("migrations") + .canonicalize() + .unwrap(); + + let migrator = Migrator::new(Path::new(&migrations_path)) + .await + .context("Failed to create migrator!")?; + + connection + .ensure_migrations_table() + .await + .context("Failed to ensure migrations table!")?; + + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await + .context("Failed to list applied migrations!")? + .into_iter() + .map(|m| (m.version, m)) + .collect(); + + for migration in migrator.iter() { + if !applied_migrations.contains_key(&migration.version) { + connection + .apply(migration) + .await + .context("Failed to apply migration {}!")?; + println!("Applied migration {}.", migration.version); + } + } + + Ok(()) +}