Skip to content

Commit

Permalink
Created worker pool and migrate states as well
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Geens committed Jan 10, 2025
1 parent f452341 commit c78d0be
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 50 deletions.
3 changes: 2 additions & 1 deletion cmd/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ func main() {
name := flag.String("name", "cernboxngcopy", "Database name")
gatewaysvc := flag.String("gatewaysvc", "localhost:9142", "Gateway service location")
token := flag.String("token", "", "JWT token for gateway svc")
dryRun := flag.Bool("dryrun", true, "Use dry run?")

flag.Parse()

fmt.Printf("Connecting to %s@%s:%d\n", *username, *host, *port)
sql.RunMigration(*username, *password, *host, *name, *gatewaysvc, *token, *port)
sql.RunMigration(*username, *password, *host, *name, *gatewaysvc, *token, *port, *dryRun)
}
195 changes: 146 additions & 49 deletions share/sql/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,42 @@ type ShareOrLink struct {
Link *model.PublicLink
}

func RunMigration(username, password, host, name, gatewaysvc, token string, port int) {
type OldShareEntry struct {
ID int
UIDOwner string
UIDInitiator string
Prefix string
ItemSource string
ItemType string
ShareWith string
Token string
Expiration string
Permissions int
ShareType int
ShareName string
STime int
FileTarget string
State int
Quicklink bool
Description string
NotifyUploads bool
NotifyUploadsExtraRecipients sql.NullString
Orphan bool
}

type OldShareState struct {
id int
recipient string
state int
}

const (
bufferSize = 10
numWorkers = 10
)

func RunMigration(username, password, host, name, gatewaysvc, token string, port int, dryRun bool) {
// Config
config := map[string]interface{}{
"engine": "mysql",
"db_username": username,
Expand All @@ -38,13 +73,15 @@ func RunMigration(username, password, host, name, gatewaysvc, token string, port
"db_port": port,
"db_name": name,
"gatewaysvc": gatewaysvc,
"dry_run": false,
"dry_run": dryRun,
}
// Authenticate to gateway service
tokenlessCtx, cancel := context.WithCancel(context.Background())
ctx := appctx.ContextSetToken(tokenlessCtx, token)
ctx = metadata.AppendToOutgoingContext(ctx, appctx.TokenHeader, token)
defer cancel()

// Set up migrator
shareManager, err := New(ctx, config)
if err != nil {
fmt.Println("Failed to create shareManager: " + err.Error())
Expand All @@ -62,80 +99,140 @@ func RunMigration(username, password, host, name, gatewaysvc, token string, port
ShareMgr: sharemgr,
}

ch := make(chan *ShareOrLink, 100)
go getAllShares(ctx, migrator, ch)
for share := range ch {
// TODO error handling
if share.IsShare {
fmt.Printf("Creating share %d\n", share.Share.ID)
migrator.NewDb.Create(&share.Share)
} else {
fmt.Printf("Creating share %d\n", share.Link.ID)
migrator.NewDb.Create(&share.Link)
}
if dryRun {
migrator.NewDb = migrator.NewDb.Debug()
}

migrateShares(ctx, migrator)
fmt.Println("---------------------------------")
migrateShareStatuses(ctx, migrator)

}

func getAllShares(ctx context.Context, migrator Migrator, ch chan *ShareOrLink) {
// First we find out what the highest ID is
count, err := getCount(migrator)
func migrateShares(ctx context.Context, migrator Migrator) {
// Check how many shares are to be migrated
count, err := getCount(migrator, "oc_share")
if err != nil {
fmt.Println("Error getting highest id: " + err.Error())
close(ch)
fmt.Println("Error getting count: " + err.Error())
return
}
fmt.Printf("Migrating %d shares\n", count)

// Get all old shares
query := "select id, coalesce(uid_owner, '') as uid_owner, coalesce(uid_initiator, '') as uid_initiator, lower(coalesce(share_with, '')) as share_with, coalesce(fileid_prefix, '') as fileid_prefix, coalesce(item_source, '') as item_source, coalesce(item_type, '') as item_type, stime, permissions, share_type, orphan FROM oc_share order by id desc" // AND id=?"
params := []interface{}{}

res, err := migrator.OldDb.Query(query, params...)

if err != nil {
fmt.Printf("Fatal error: %s", err.Error())
close(ch)
return
os.Exit(1)
}

// Create channel for workers
ch := make(chan *OldShareEntry, bufferSize)
defer close(ch)

// Start all workers
for range numWorkers {
go workerShare(ctx, migrator, ch)
}

for res.Next() {
var s OldShareEntry
res.Scan(&s.ID, &s.UIDOwner, &s.UIDInitiator, &s.ShareWith, &s.Prefix, &s.ItemSource, &s.ItemType, &s.STime, &s.Permissions, &s.ShareType, &s.Orphan)
newShare, err := oldShareToNewShare(ctx, migrator, s)
if err == nil {
ch <- newShare
ch <- &s
} else {
fmt.Printf("Error occured for share %s: %s\n", s.ID, err.Error())
fmt.Printf("Error occured for share %d: %s\n", s.ID, err.Error())
}
}
}

func migrateShareStatuses(ctx context.Context, migrator Migrator) {
// Check how many shares are to be migrated
count, err := getCount(migrator, "oc_share")
if err != nil {
fmt.Println("Error getting count: " + err.Error())
return
}
fmt.Printf("Migrating %d share statuses\n", count)

// Get all old shares
query := "select id, coalesce(recipient, '') as recipient, state FROM oc_share_status order by id desc"
params := []interface{}{}

close(ch)
res, err := migrator.OldDb.Query(query, params...)

if err != nil {
fmt.Printf("Fatal error: %s", err.Error())
os.Exit(1)
}

// Create channel for workers
ch := make(chan *OldShareState, bufferSize)
defer close(ch)

// Start all workers
for range numWorkers {
go workerState(ctx, migrator, ch)
}

for res.Next() {
var s OldShareState
res.Scan(&s.id, &s.recipient, &s.state)
if err == nil {
ch <- &s
} else {
fmt.Printf("Error occured for share status%d: %s\n", s.id, err.Error())
}
}
}

type OldShareEntry struct {
ID int
UIDOwner string
UIDInitiator string
Prefix string
ItemSource string
ItemType string
ShareWith string
Token string
Expiration string
Permissions int
ShareType int
ShareName string
STime int
FileTarget string
State int
Quicklink bool
Description string
NotifyUploads bool
NotifyUploadsExtraRecipients sql.NullString
Orphan bool
func workerShare(ctx context.Context, migrator Migrator, ch chan *OldShareEntry) {
for share := range ch {
handleSingleShare(ctx, migrator, share)
}
}

func workerState(ctx context.Context, migrator Migrator, ch chan *OldShareState) {
for state := range ch {
handleSingleState(ctx, migrator, state)
}
}

func handleSingleShare(ctx context.Context, migrator Migrator, s *OldShareEntry) {
share, err := oldShareToNewShare(ctx, migrator, s)
if err != nil {
return
}
// TODO error handling
if share.IsShare {
migrator.NewDb.Create(&share.Share)
} else {
migrator.NewDb.Create(&share.Link)
}
}

func handleSingleState(ctx context.Context, migrator Migrator, s *OldShareState) {
// case collaboration.ShareState_SHARE_STATE_REJECTED:
// state = -1
// case collaboration.ShareState_SHARE_STATE_ACCEPTED:
// state = 1

newShareState := &model.ShareState{
ShareID: uint(s.id),
Model: gorm.Model{
ID: uint(s.id),
},
User: s.recipient,
Hidden: s.state == -1, // Hidden if REJECTED
Synced: true, // for now, we always sync? or not? TODO
}
migrator.NewDb.Create(&newShareState)
}

func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry) (*ShareOrLink, error) {
func oldShareToNewShare(ctx context.Context, migrator Migrator, s *OldShareEntry) (*ShareOrLink, error) {
expirationDate, expirationError := time.Parse("2006-01-02 15:04:05", s.Expiration)

protoShare := model.ProtoShare{
Expand Down Expand Up @@ -171,7 +268,7 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry)
protoShare.Orphan = true
} else {
// We do not set, because of a general error
fmt.Printf("An error occured while statting (%s, %s): %s\n", protoShare.Instance, protoShare.Inode, err.Error())
fmt.Printf("An error occured for share %d while statting (%s, %s): %s\n", s.ID, protoShare.Instance, protoShare.Inode, err.Error())
}
}

Expand Down Expand Up @@ -211,9 +308,9 @@ func oldShareToNewShare(ctx context.Context, migrator Migrator, s OldShareEntry)
}
}

func getCount(migrator Migrator) (int, error) {
func getCount(migrator Migrator, table string) (int, error) {
res := 0
query := "select count(*) from oc_share"
query := "select count(*) from " + table
params := []interface{}{}

if err := migrator.OldDb.QueryRow(query, params...).Scan(&res); err != nil {
Expand Down

0 comments on commit c78d0be

Please sign in to comment.