diff --git a/db/Schema.cpp b/db/Schema.cpp index 9b06de0..eaf75cb 100644 --- a/db/Schema.cpp +++ b/db/Schema.cpp @@ -3,7 +3,6 @@ #include "migrations/1624829187_create_events.h" - namespace db { namespace db_builder = db::migrations::builder; @@ -23,37 +22,64 @@ namespace db { void Schema::set_user_version(uint32_t version) { try { m_db->exec("PRAGMA user_version = " + std::to_string(version) + ";"); + m_user_version = version; } catch (std::exception& e) { std::cout << "Exception: " << e.what() << std::endl; } } - void Schema::run_migrations() { + void Schema::set_savepoint() { + try { + m_db->exec("BEGIN;"); + m_db->exec("SAVEPOINT " + m_migration_savepoint + ";"); + } catch (std::exception& e) { + std::cout << "Exception: " << e.what() << std::endl; + } + } + + void Schema::rollback_migrations() { + try { + m_db->exec("ROLLBACK TO " + m_migration_savepoint + ";"); + } catch (std::exception& e) { + std::cout << "Exception: " << e.what() << std::endl; + } + } + + void Schema::commit_migrations() { + try { + m_db->exec("COMMIT;"); + } catch (std::exception& e) { + std::cout << "Exception: " << e.what() << std::endl; + } + + } + + uint32_t Schema::run_migrations() { assemble_migrations(); if(migrations.begin() == migrations.end()) { DEBUG << "No migrations found..."; - return; + return ERROR_NO_MIGRATIONS; } - // TODO try and catch block - // migration should stop on error and restore old schema - + set_savepoint(); uint32_t user_version = get_user_version(); for(auto const& [epoch, migration] : migrations) { if(epoch <= user_version) { DEBUG << "Skipping: " << migration->get_migration_name(); continue; } - DEBUG << "migrating: " << migration->get_migration_name(); + std::cout << "migrating: " << migration->get_migration_name() << std::endl; try { m_db->exec(migration->get_statement()); } catch (std::exception& e) { std::cout << "Exception: " << e.what() << std::endl; + std::cout << "Rolling back migrations..." << std::endl; + rollback_migrations(); + return ERROR_FAULTY_MIGRATION; } } - DEBUG << migrations.rbegin()->first; + commit_migrations(); set_user_version(migrations.rbegin()->first); - user_version = get_user_version(); - DEBUG << user_version; + return 0; } void Schema::assemble_migrations() { diff --git a/db/Schema.h b/db/Schema.h index d545fdc..c5b8983 100644 --- a/db/Schema.h +++ b/db/Schema.h @@ -8,14 +8,19 @@ namespace db { class Schema { public: Schema(sql::Database* db); - void run_migrations(); + uint32_t run_migrations(); private: uint32_t get_user_version() const; void set_user_version(uint32_t user_version); + void set_savepoint(); + void rollback_migrations(); + void commit_migrations(); void assemble_migrations(); sql::Database* m_db; + uint32_t m_user_version; + const std::string m_migration_savepoint = "migrations_savepoint"; std::map migrations; }; } diff --git a/db/db.h b/db/db.h index cf1ab3a..4196cde 100644 --- a/db/db.h +++ b/db/db.h @@ -3,6 +3,9 @@ #include #include +#define ERROR_NO_MIGRATIONS -2 +#define ERROR_FAULTY_MIGRATION -1 + namespace sql = SQLite; #include