init.go 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package migration
  2. import (
  3. "log"
  4. "path/filepath"
  5. "sort"
  6. "strconv"
  7. "sync"
  8. "github.com/spf13/cast"
  9. "gorm.io/gorm"
  10. )
  11. var Migrate = &Migration{
  12. version: make(map[int]func(db *gorm.DB, version string) error),
  13. }
  14. type Migration struct {
  15. db *gorm.DB
  16. version map[int]func(db *gorm.DB, version string) error
  17. mutex sync.Mutex
  18. }
  19. func (e *Migration) GetDb() *gorm.DB {
  20. return e.db
  21. }
  22. func (e *Migration) SetDb(db *gorm.DB) {
  23. e.db = db
  24. }
  25. func (e *Migration) SetVersion(k int, f func(db *gorm.DB, version string) error) {
  26. e.mutex.Lock()
  27. defer e.mutex.Unlock()
  28. e.version[k] = f
  29. }
  30. func (e *Migration) Migrate() {
  31. versions := make([]int, 0)
  32. for k := range e.version {
  33. versions = append(versions, k)
  34. }
  35. if !sort.IntsAreSorted(versions) {
  36. sort.Ints(versions)
  37. }
  38. var err error
  39. var count int64
  40. for _, v := range versions {
  41. err = e.db.Debug().Table("sys_migration").Where("version = ?", v).Count(&count).Error
  42. if err != nil {
  43. log.Fatalln(err)
  44. }
  45. if count > 0 {
  46. log.Println(count)
  47. count = 0
  48. continue
  49. }
  50. err = (e.version[v])(e.db.Debug(), strconv.Itoa(v))
  51. if err != nil {
  52. log.Fatalln(err)
  53. }
  54. }
  55. }
  56. func GetFilename(s string) int {
  57. s = filepath.Base(s)
  58. return cast.ToInt(s[:13])
  59. }