diff --git a/Cargo.lock b/Cargo.lock index 90b5c28..68fbec6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1148,7 +1148,7 @@ dependencies = [ [[package]] name = "rmmm" -version = "0.4.1" +version = "0.5.0" dependencies = [ "anyhow", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 70a84f5..e9b259c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rmmm" -version = "0.4.2" +version = "0.5.0" description = "Rust MySQL Migration Manager" repository = "https://github.com/EasyPost/rmmm" authors = ["James Brown "] diff --git a/src/go_database_dsn.rs b/src/go_database_dsn.rs index d618582..d6c42dc 100644 --- a/src/go_database_dsn.rs +++ b/src/go_database_dsn.rs @@ -103,6 +103,7 @@ static DSN_REGEX: Lazy = Lazy::new(|| { pub(crate) struct GoDatabaseDsn { username: Option, password: Option, + protocol: String, address: Address, database: String, } @@ -118,9 +119,15 @@ impl FromStr for GoDatabaseDsn { let password = caps.name("password").map(|s| s.as_str().to_owned()); match caps.name("protocol").map(|s| s.as_str()) { Some("tcp") => {} + Some("unix") => {} Some(other) => anyhow::bail!("unhandled DSN protocol {}", other), None => {} } + let protocol = caps + .name("protocol") + .ok_or_else(|| anyhow::anyhow!("no protocol in DSN {}", s))? + .as_str() + .to_owned(); let address = caps .name("address") .ok_or_else(|| anyhow::anyhow!("no address in DSN {}", s))? @@ -134,6 +141,7 @@ impl FromStr for GoDatabaseDsn { Ok(GoDatabaseDsn { username, password, + protocol, address, database, }) @@ -144,19 +152,28 @@ impl TryInto for GoDatabaseDsn { type Error = anyhow::Error; fn try_into(self) -> Result { - Ok(mysql::OptsBuilder::new() - .user(self.username) - .pass(self.password) - .db_name(Some(self.database)) - .tcp_port(self.address.port) - .ip_or_hostname(Some(self.address.name.into_mysql_string())) - .into()) + if self.protocol == "unix" { + Ok(mysql::OptsBuilder::new() + .user(self.username) + .pass(self.password) + .db_name(Some(self.database)) + .socket(Some(self.address.name.into_mysql_string())) + .into()) + } else { + Ok(mysql::OptsBuilder::new() + .user(self.username) + .pass(self.password) + .db_name(Some(self.database)) + .tcp_port(self.address.port) + .ip_or_hostname(Some(self.address.name.into_mysql_string())) + .into()) + } } } #[cfg(test)] mod tests { - use super::{Address, AddressName, GoDatabaseDsn}; + use super::{Address, AddressName, GoDatabaseDsn, DEFAULT_PORT}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use anyhow::Context; @@ -183,7 +200,7 @@ mod tests { "127.0.0.1".parse::
().unwrap(), Address { name: AddressName::Address(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))), - port: 3306, + port: DEFAULT_PORT, } ); assert_eq!( @@ -197,7 +214,7 @@ mod tests { "[::2]".parse::
().unwrap(), Address { name: AddressName::Address(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2))), - port: 3306, + port: DEFAULT_PORT, } ); assert_eq!( @@ -207,6 +224,13 @@ mod tests { port: 3307, } ); + assert_eq!( + "/var/lib/mysql.sock".parse::
().unwrap(), + Address { + name: AddressName::Name("/var/lib/mysql.sock".to_string()), + port: DEFAULT_PORT, + }, + ); } #[test] @@ -221,6 +245,20 @@ mod tests { port: 33606 } ); + let parsed: GoDatabaseDsn = "foo:bar@unix(/var/lib/mysql.sock)/foodb?ignored=true" + .parse() + .expect("should parse"); + assert_eq!( + parsed.address, + Address { + name: AddressName::Name("/var/lib/mysql.sock".parse().unwrap()), + port: DEFAULT_PORT, + } + ); + assert_eq!( + parsed.protocol, + "unix", + ); assert_eq!(parsed.username.as_deref(), Some("foo")); assert_eq!(parsed.password.as_deref(), Some("bar")); assert_eq!(parsed.database, "foodb".to_string()); @@ -229,7 +267,9 @@ mod tests { "foo:bar@tcp([::1]:3300)/foo", "foo@tcp([::1])/foo", "tcp(127.0.0.1)/baz", - "usps:sekret@tcp(dblb.local.easypo.net:36060)/usps", + "user:sekret@tcp(hostname:36060)/dbname", + "user@unix(/var/lib/mysql/mysql.sock)/dbname?parseTime=true&loc=UTC&sql_mode='STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION'", + "user:pass@unix(/var/lib/mysql/mysql.sock)/dbname?parseTime=true&sql_mode='STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION'", ] { s.parse::() .context(format!("attempting to parse {}", s)) diff --git a/src/main.rs b/src/main.rs index da7e54a..957483c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,7 +60,7 @@ enum MigrationStatus { #[derive(Tabled, Debug)] struct MigrationStatusRow { - id: usize, + id: u32, label: String, status: MigrationStatus, executed_at: String, @@ -71,9 +71,9 @@ fn command_status(state: MigrationState, runner: MigrationRunner) -> anyhow::Res let run_so_far = runner.list_run_migrations()?; let all_ids = state .all_ids() - .union(&run_so_far.iter().map(|m| m.id).collect::>()) + .union(&run_so_far.iter().map(|m| m.id).collect::>()) .cloned() - .collect::>(); + .collect::>(); let migrations_by_id = state.migrations_by_id(); let run_so_far_by_id = run_so_far .into_iter() @@ -111,15 +111,15 @@ fn command_status(state: MigrationState, runner: MigrationRunner) -> anyhow::Res #[derive(Tabled, Debug)] struct MigrationPlanRow { - id: usize, - prev_id: String, + id: u32, sql_text: String, } -fn command_upgrade( +fn command_apply_migrations( matches: &clap::ArgMatches, state: MigrationState, runner: MigrationRunner, + is_upgrade: bool, ) -> anyhow::Result<()> { debug!("Starting command_upgrade"); let target_revision = { @@ -132,7 +132,7 @@ fn command_upgrade( .context("revision must be an integer or 'latest'")? } }; - let plan = runner.plan(&state, target_revision)?; + let plan = runner.plan(&state, target_revision, is_upgrade)?; if plan.is_empty() { info!("Nothing to do!"); return Ok(()); @@ -142,16 +142,13 @@ fn command_upgrade( .iter() .map(|ps| MigrationPlanRow { id: ps.id, - prev_id: ps - .prev_id - .map(|u| u.to_string()) - .unwrap_or_else(|| "(none)".to_string()), sql_text: ps.sql.clone(), }) .collect::>(); let table = tabled::Table::new(&plan_data) .with(tabled::Style::modern().horizontal_off()) - .with(tabled::Modify::new(tabled::Column(2..=2)).with(tabled::Alignment::left())); + .with(tabled::Modify::new(tabled::Column(1..=1)).with(tabled::Alignment::left())) + ; println!("Migration plan:"); println!("{}", table); if matches.is_present("execute") { @@ -159,9 +156,11 @@ fn command_upgrade( runner.execute(plan)?; info!("done!"); println!("New version: {}", target_revision); - let schema = runner.dump_schema()?; if !matches.is_present("no-dump") { + let schema = runner.dump_schema()?; state.write_schema(&schema)?; + } else { + println!("not writing schema file"); } } else { error!("rerun with --execute to execute this plan"); @@ -296,7 +295,6 @@ fn cli() -> clap::Command<'static> { Arg::new("no-dump") .long("--no-write-schema") .env("NO_WRITE_SCHEMA") - .action(clap::ArgAction::SetTrue) .help("Do not write updated db/structure.sql when done"), ), ) @@ -323,6 +321,12 @@ fn cli() -> clap::Command<'static> { .short('x') .long("execute") .help("Actually upgrade (otherwise will just print what will be done"), + ) + .arg( + Arg::new("no-dump") + .long("--no-write-schema") + .env("NO_WRITE_SCHEMA") + .help("Do not write updated db/structure.sql when done"), ), ) .subcommand( @@ -354,10 +358,10 @@ fn main() -> anyhow::Result<()> { command_status(current_state, runner)?; } Some(("upgrade", smatches)) => { - command_upgrade(smatches, current_state, runner)?; + command_apply_migrations(smatches, current_state, runner, true)?; } Some(("downgrade", smatches)) => { - command_upgrade(smatches, current_state, runner)?; + command_apply_migrations(smatches, current_state, runner, false)?; } Some(("apply-snapshot", smatches)) => { command_apply_snapshot( diff --git a/src/migration_runner.rs b/src/migration_runner.rs index beb9fb8..115e058 100644 --- a/src/migration_runner.rs +++ b/src/migration_runner.rs @@ -16,14 +16,13 @@ pub(crate) struct MigrationRunner { #[derive(Debug)] pub struct ExecutedMigration { - pub id: usize, + pub id: u32, pub executed_at: Option>, } #[derive(Debug)] pub struct MigrationStep { - pub prev_id: Option, - pub id: usize, + pub id: u32, pub label: Option, pub sql: String, } @@ -31,6 +30,8 @@ pub struct MigrationStep { #[derive(Debug)] pub struct MigrationPlan { steps: Vec, + + // determines if INSERTs or DELETEs are done on the migrations tracking table is_upgrade: bool, } @@ -86,51 +87,73 @@ impl MigrationRunner { pub fn plan( &self, state: &MigrationState, - target_revision: usize, + target_revision: u32, + is_upgrade: bool, + ) -> anyhow::Result { + if is_upgrade { + return self.plan_upgrade(state, target_revision) + } else { + return self.plan_downgrade(state, target_revision) + } + } + + pub fn plan_upgrade( + &self, + state: &MigrationState, + target_revision: u32, ) -> anyhow::Result { let highest_id = state.highest_id(); - if target_revision > highest_id { + if target_revision == 0 || target_revision > highest_id { anyhow::bail!("Invalid target revision {}", target_revision); } let run_ids = self .list_run_migrations()? .into_iter() .map(|m| m.id) - .collect::>(); - let is_upgrade = if let Some(highest_run_id) = run_ids.iter().max() { - *highest_run_id <= target_revision - } else { - target_revision != 0 - }; + .collect::>(); + let state_by_id = state.migrations_by_id(); - let to_run = if is_upgrade { - state + let to_run = state .all_ids() .difference(&run_ids) + .filter(|&&i| i <= target_revision) .cloned() .sorted() - .collect::>() - } else { - run_ids - .iter() - .filter(|&&i| i > target_revision) - .cloned() - .collect::>() - }; - let steps = if is_upgrade { + .collect::>(); + let steps = to_run .into_iter() .map(|id| { let step = state_by_id.get(&id).unwrap(); MigrationStep { - prev_id: if id == 1 { None } else { Some(id - 1) }, id, label: step.label.clone(), sql: step.upgrade_text.clone(), } }) - .collect::>() - } else { + .collect::>(); + Ok(MigrationPlan { steps: steps, is_upgrade: true }) + } + + pub fn plan_downgrade( + &self, + state: &MigrationState, + target_revision: u32 + ) -> anyhow::Result { + let state_by_id = state.migrations_by_id(); + let run_ids = self + .list_run_migrations()? + .into_iter() + .map(|m| m.id) + .collect::>(); + + let to_run = run_ids + .iter() + .filter(|&&i| i > target_revision) + .cloned() + .collect::>(); + + let steps = to_run .into_iter() .rev() @@ -138,7 +161,6 @@ impl MigrationRunner { let step = state_by_id.get(&id).unwrap(); if let Some(sql) = step.downgrade_text.as_ref() { Ok(MigrationStep { - prev_id: if id == highest_id { None } else { Some(id + 1) }, id, label: step.label.clone(), sql: sql.clone(), @@ -147,9 +169,8 @@ impl MigrationRunner { anyhow::bail!("step {:?} is irreversible", id); } }) - .collect::>>()? - }; - Ok(MigrationPlan { steps, is_upgrade }) + .collect::>>()?; + Ok(MigrationPlan { steps: steps, is_upgrade: false }) } fn now(&self) -> u64 { @@ -237,6 +258,7 @@ impl MigrationRunner { }, )?; lines.extend(schema); + lines.extend(vec!["".to_string()]); } if tables.contains(&"rmmm_migrations".to_owned()) { lines.extend(vec!["".to_string()]); @@ -251,6 +273,7 @@ impl MigrationRunner { }, )?); } + lines.extend(vec!["\n".to_string()]); // make sure the output ends in a newline and a blank line Ok(lines.join("\n")) } } diff --git a/src/migration_state.rs b/src/migration_state.rs index f891463..d699531 100644 --- a/src/migration_state.rs +++ b/src/migration_state.rs @@ -7,11 +7,11 @@ use anyhow::Context; use itertools::Itertools; use log::debug; -const DEFAULT_EDITOR: &str = "nano"; +const DEFAULT_EDITOR: &str = "vim"; #[derive(Debug)] pub(crate) struct Migration { - pub id: usize, + pub id: u32, pub label: Option, pub upgrade_text: String, pub downgrade_text: Option, @@ -34,7 +34,7 @@ impl Migration { Ok(s.to_string()) } - fn from_path(id: usize, p: &Path) -> anyhow::Result { + fn from_path(id: u32, p: &Path) -> anyhow::Result { let upgrade_file = std::fs::read_to_string(p)?; lazy_static::lazy_static! { static ref LABEL_RE: regex::Regex = @@ -66,7 +66,7 @@ impl Migration { pub(crate) struct MigrationState { root_path: PathBuf, pub migrations: Vec, - next_id: usize, + next_id: u32, } impl MigrationState { @@ -137,15 +137,15 @@ impl MigrationState { Ok(()) } - pub fn migrations_by_id(&self) -> BTreeMap { + pub fn migrations_by_id(&self) -> BTreeMap { self.migrations.iter().map(|m| (m.id, m)).collect() } - pub fn all_ids(&self) -> BTreeSet { + pub fn all_ids(&self) -> BTreeSet { self.migrations.iter().map(|m| m.id).collect() } - pub fn highest_id(&self) -> usize { + pub fn highest_id(&self) -> u32 { self.next_id - 1 }