init.go 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. package migration
  2. import (
  3. "log"
  4. "path/filepath"
  5. "sort"
  6. "sync"
  7. "gorm.io/gorm"
  8. )
  9. var Migrate = &Migration{
  10. version: make(map[string]func(db *gorm.DB, version string) error),
  11. }
  12. type Migration struct {
  13. db *gorm.DB
  14. version map[string]func(db *gorm.DB, version string) error
  15. mutex sync.Mutex
  16. }
  17. func (e *Migration) GetDb() *gorm.DB {
  18. return e.db
  19. }
  20. func (e *Migration) SetDb(db *gorm.DB) {
  21. e.db = db
  22. }
  23. func (e *Migration) SetVersion(k string, f func(db *gorm.DB, version string) error) {
  24. e.mutex.Lock()
  25. defer e.mutex.Unlock()
  26. e.version[k] = f
  27. }
  28. func (e *Migration) Migrate() {
  29. versions := make([]string, 0)
  30. for k := range e.version {
  31. versions = append(versions, k)
  32. }
  33. if !sort.StringsAreSorted(versions) {
  34. sort.Strings(versions)
  35. }
  36. var err error
  37. var count int64
  38. for _, v := range versions {
  39. err = e.db.Table("sys_migration").Where("version = ?", v).Count(&count).Error
  40. if err != nil {
  41. log.Fatalln(err)
  42. }
  43. if count > 0 {
  44. log.Println(count)
  45. count = 0
  46. continue
  47. }
  48. err = (e.version[v])(e.db.Debug(), v)
  49. if err != nil {
  50. log.Fatalln(err)
  51. }
  52. }
  53. }
  54. func GetFilename(s string) string {
  55. s = filepath.Base(s)
  56. return s[:13]
  57. }