466 lines
12 KiB
Go
Raw Normal View History

2020-11-09 10:05:29 -03:00
package gormigrate
import (
"errors"
"fmt"
"github.com/jinzhu/gorm"
)
const (
initSchemaMigrationID = "SCHEMA_INIT"
)
// MigrateFunc is the func signature for migrating.
type MigrateFunc func(*gorm.DB) error
// RollbackFunc is the func signature for rollbacking.
type RollbackFunc func(*gorm.DB) error
// InitSchemaFunc is the func signature for initializing the schema.
type InitSchemaFunc func(*gorm.DB) error
// Options define options for all migrations.
type Options struct {
// TableName is the migration table.
TableName string
// IDColumnName is the name of column where the migration id will be stored.
IDColumnName string
// IDColumnSize is the length of the migration id column
IDColumnSize int
// UseTransaction makes Gormigrate execute migrations inside a single transaction.
// Keep in mind that not all databases support DDL commands inside transactions.
UseTransaction bool
// ValidateUnknownMigrations will cause migrate to fail if there's unknown migration
// IDs in the database
ValidateUnknownMigrations bool
}
// Migration represents a database migration (a modification to be made on the database).
type Migration struct {
// ID is the migration identifier. Usually a timestamp like "201601021504".
ID string
// Migrate is a function that will br executed while running this migration.
Migrate MigrateFunc
// Rollback will be executed on rollback. Can be nil.
Rollback RollbackFunc
}
// Gormigrate represents a collection of all migrations of a database schema.
type Gormigrate struct {
db *gorm.DB
tx *gorm.DB
options *Options
migrations []*Migration
initSchema InitSchemaFunc
}
// ReservedIDError is returned when a migration is using a reserved ID
type ReservedIDError struct {
ID string
}
func (e *ReservedIDError) Error() string {
return fmt.Sprintf(`gormigrate: Reserved migration ID: "%s"`, e.ID)
}
// DuplicatedIDError is returned when more than one migration have the same ID
type DuplicatedIDError struct {
ID string
}
func (e *DuplicatedIDError) Error() string {
return fmt.Sprintf(`gormigrate: Duplicated migration ID: "%s"`, e.ID)
}
var (
// DefaultOptions can be used if you don't want to think about options.
DefaultOptions = &Options{
TableName: "migrations",
IDColumnName: "id",
IDColumnSize: 255,
UseTransaction: false,
ValidateUnknownMigrations: false,
}
// ErrRollbackImpossible is returned when trying to rollback a migration
// that has no rollback function.
ErrRollbackImpossible = errors.New("gormigrate: It's impossible to rollback this migration")
// ErrNoMigrationDefined is returned when no migration is defined.
ErrNoMigrationDefined = errors.New("gormigrate: No migration defined")
// ErrMissingID is returned when the ID od migration is equal to ""
ErrMissingID = errors.New("gormigrate: Missing ID in migration")
// ErrNoRunMigration is returned when any run migration was found while
// running RollbackLast
ErrNoRunMigration = errors.New("gormigrate: Could not find last run migration")
// ErrMigrationIDDoesNotExist is returned when migrating or rolling back to a migration ID that
// does not exist in the list of migrations
ErrMigrationIDDoesNotExist = errors.New("gormigrate: Tried to migrate to an ID that doesn't exist")
// ErrUnknownPastMigration is returned if a migration exists in the DB that doesn't exist in the code
ErrUnknownPastMigration = errors.New("gormigrate: Found migration in DB that does not exist in code")
)
// New returns a new Gormigrate.
func New(db *gorm.DB, options *Options, migrations []*Migration) *Gormigrate {
if options.TableName == "" {
options.TableName = DefaultOptions.TableName
}
if options.IDColumnName == "" {
options.IDColumnName = DefaultOptions.IDColumnName
}
if options.IDColumnSize == 0 {
options.IDColumnSize = DefaultOptions.IDColumnSize
}
return &Gormigrate{
db: db,
options: options,
migrations: migrations,
}
}
// InitSchema sets a function that is run if no migration is found.
// The idea is preventing to run all migrations when a new clean database
// is being migrating. In this function you should create all tables and
// foreign key necessary to your application.
func (g *Gormigrate) InitSchema(initSchema InitSchemaFunc) {
g.initSchema = initSchema
}
// Migrate executes all migrations that did not run yet.
func (g *Gormigrate) Migrate() error {
if !g.hasMigrations() {
return ErrNoMigrationDefined
}
var targetMigrationID string
if len(g.migrations) > 0 {
targetMigrationID = g.migrations[len(g.migrations)-1].ID
}
return g.migrate(targetMigrationID)
}
// MigrateTo executes all migrations that did not run yet up to the migration that matches `migrationID`.
func (g *Gormigrate) MigrateTo(migrationID string) error {
if err := g.checkIDExist(migrationID); err != nil {
return err
}
return g.migrate(migrationID)
}
func (g *Gormigrate) migrate(migrationID string) error {
if !g.hasMigrations() {
return ErrNoMigrationDefined
}
if err := g.checkReservedID(); err != nil {
return err
}
if err := g.checkDuplicatedID(); err != nil {
return err
}
g.begin()
defer g.rollback()
if err := g.createMigrationTableIfNotExists(); err != nil {
return err
}
if g.options.ValidateUnknownMigrations {
unknownMigrations, err := g.unknownMigrationsHaveHappened()
if err != nil {
return err
}
if unknownMigrations {
return ErrUnknownPastMigration
}
}
if g.initSchema != nil {
canInitializeSchema, err := g.canInitializeSchema()
if err != nil {
return err
}
if canInitializeSchema {
if err := g.runInitSchema(); err != nil {
return err
}
return g.commit()
}
}
for _, migration := range g.migrations {
if err := g.runMigration(migration); err != nil {
return err
}
if migrationID != "" && migration.ID == migrationID {
break
}
}
return g.commit()
}
// There are migrations to apply if either there's a defined
// initSchema function or if the list of migrations is not empty.
func (g *Gormigrate) hasMigrations() bool {
return g.initSchema != nil || len(g.migrations) > 0
}
// Check whether any migration is using a reserved ID.
// For now there's only have one reserved ID, but there may be more in the future.
func (g *Gormigrate) checkReservedID() error {
for _, m := range g.migrations {
if m.ID == initSchemaMigrationID {
return &ReservedIDError{ID: m.ID}
}
}
return nil
}
func (g *Gormigrate) checkDuplicatedID() error {
lookup := make(map[string]struct{}, len(g.migrations))
for _, m := range g.migrations {
if _, ok := lookup[m.ID]; ok {
return &DuplicatedIDError{ID: m.ID}
}
lookup[m.ID] = struct{}{}
}
return nil
}
func (g *Gormigrate) checkIDExist(migrationID string) error {
for _, migrate := range g.migrations {
if migrate.ID == migrationID {
return nil
}
}
return ErrMigrationIDDoesNotExist
}
// RollbackLast undo the last migration
func (g *Gormigrate) RollbackLast() error {
if len(g.migrations) == 0 {
return ErrNoMigrationDefined
}
g.begin()
defer g.rollback()
lastRunMigration, err := g.getLastRunMigration()
if err != nil {
return err
}
if err := g.rollbackMigration(lastRunMigration); err != nil {
return err
}
return g.commit()
}
// RollbackTo undoes migrations up to the given migration that matches the `migrationID`.
// Migration with the matching `migrationID` is not rolled back.
func (g *Gormigrate) RollbackTo(migrationID string) error {
if len(g.migrations) == 0 {
return ErrNoMigrationDefined
}
if err := g.checkIDExist(migrationID); err != nil {
return err
}
g.begin()
defer g.rollback()
for i := len(g.migrations) - 1; i >= 0; i-- {
migration := g.migrations[i]
if migration.ID == migrationID {
break
}
migrationRan, err := g.migrationRan(migration)
if err != nil {
return err
}
if migrationRan {
if err := g.rollbackMigration(migration); err != nil {
return err
}
}
}
return g.commit()
}
func (g *Gormigrate) getLastRunMigration() (*Migration, error) {
for i := len(g.migrations) - 1; i >= 0; i-- {
migration := g.migrations[i]
migrationRan, err := g.migrationRan(migration)
if err != nil {
return nil, err
}
if migrationRan {
return migration, nil
}
}
return nil, ErrNoRunMigration
}
// RollbackMigration undo a migration.
func (g *Gormigrate) RollbackMigration(m *Migration) error {
g.begin()
defer g.rollback()
if err := g.rollbackMigration(m); err != nil {
return err
}
return g.commit()
}
func (g *Gormigrate) rollbackMigration(m *Migration) error {
if m.Rollback == nil {
return ErrRollbackImpossible
}
if err := m.Rollback(g.tx); err != nil {
return err
}
sql := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", g.options.TableName, g.options.IDColumnName)
return g.tx.Exec(sql, m.ID).Error
}
func (g *Gormigrate) runInitSchema() error {
if err := g.initSchema(g.tx); err != nil {
return err
}
if err := g.insertMigration(initSchemaMigrationID); err != nil {
return err
}
for _, migration := range g.migrations {
if err := g.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
func (g *Gormigrate) runMigration(migration *Migration) error {
if len(migration.ID) == 0 {
return ErrMissingID
}
migrationRan, err := g.migrationRan(migration)
if err != nil {
return err
}
if !migrationRan {
if err := migration.Migrate(g.tx); err != nil {
return err
}
if err := g.insertMigration(migration.ID); err != nil {
return err
}
}
return nil
}
func (g *Gormigrate) createMigrationTableIfNotExists() error {
if g.tx.HasTable(g.options.TableName) {
return nil
}
sql := fmt.Sprintf("CREATE TABLE %s (%s VARCHAR(%d) PRIMARY KEY)", g.options.TableName, g.options.IDColumnName, g.options.IDColumnSize)
return g.tx.Exec(sql).Error
}
func (g *Gormigrate) migrationRan(m *Migration) (bool, error) {
var count int
err := g.tx.
Table(g.options.TableName).
Where(fmt.Sprintf("%s = ?", g.options.IDColumnName), m.ID).
Count(&count).
Error
return count > 0, err
}
// The schema can be initialised only if it hasn't been initialised yet
// and no other migration has been applied already.
func (g *Gormigrate) canInitializeSchema() (bool, error) {
migrationRan, err := g.migrationRan(&Migration{ID: initSchemaMigrationID})
if err != nil {
return false, err
}
if migrationRan {
return false, nil
}
// If the ID doesn't exist, we also want the list of migrations to be empty
var count int
err = g.tx.
Table(g.options.TableName).
Count(&count).
Error
return count == 0, err
}
func (g *Gormigrate) unknownMigrationsHaveHappened() (bool, error) {
sql := fmt.Sprintf("SELECT %s FROM %s", g.options.IDColumnName, g.options.TableName)
rows, err := g.tx.Raw(sql).Rows()
if err != nil {
return false, err
}
defer rows.Close()
validIDSet := make(map[string]struct{}, len(g.migrations)+1)
validIDSet[initSchemaMigrationID] = struct{}{}
for _, migration := range g.migrations {
validIDSet[migration.ID] = struct{}{}
}
for rows.Next() {
var pastMigrationID string
if err := rows.Scan(&pastMigrationID); err != nil {
return false, err
}
if _, ok := validIDSet[pastMigrationID]; !ok {
return true, nil
}
}
return false, nil
}
func (g *Gormigrate) insertMigration(id string) error {
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?)", g.options.TableName, g.options.IDColumnName)
return g.tx.Exec(sql, id).Error
}
func (g *Gormigrate) begin() {
if g.options.UseTransaction {
g.tx = g.db.Begin()
} else {
g.tx = g.db
}
}
func (g *Gormigrate) commit() error {
if g.options.UseTransaction {
return g.tx.Commit().Error
}
return nil
}
func (g *Gormigrate) rollback() {
if g.options.UseTransaction {
g.tx.Rollback()
}
}