diff --git a/migrate/lastfm.go b/migrate/lastfm.go index 314c60e..8d9e9fd 100644 --- a/migrate/lastfm.go +++ b/migrate/lastfm.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/jackc/pgx/v5" + "muzi/db" ) type LastFMTrack struct { @@ -66,19 +66,6 @@ func ImportLastFM( userId int, progressChan chan<- ProgressUpdate, ) error { - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) - if progressChan != nil { - progressChan <- ProgressUpdate{Status: "error", Error: err.Error()} - } - return err - } - defer conn.Close(context.Background()) - totalImported := 0 client := &http.Client{ @@ -184,7 +171,7 @@ func ImportLastFM( for len(trackBatch) >= batchSize { batch := trackBatch[:batchSize] trackBatch = trackBatch[batchSize:] - err := insertBatch(conn, batch, &totalImported, batchSize) + err := insertBatch(batch, &totalImported, batchSize) if err != nil { fmt.Fprintf(os.Stderr, "Batch insert failed: %v\n", err) } @@ -210,7 +197,7 @@ func ImportLastFM( } if len(trackBatch) > 0 { - err := insertBatch(conn, trackBatch, &totalImported, batchSize) + err := insertBatch(trackBatch, &totalImported, batchSize) if err != nil { fmt.Fprintf(os.Stderr, "Final batch insert failed: %v\n", err) } @@ -231,8 +218,8 @@ func ImportLastFM( return nil } -func insertBatch(conn *pgx.Conn, tracks []LastFMTrack, totalImported *int, batchSize int) error { - tx, err := conn.Begin(context.Background()) +func insertBatch(tracks []LastFMTrack, totalImported *int, batchSize int) error { + tx, err := db.Pool.Begin(context.Background()) if err != nil { return err } @@ -287,6 +274,7 @@ func insertBatch(conn *pgx.Conn, tracks []LastFMTrack, totalImported *int, batch } if err := tx.Commit(context.Background()); err != nil { + tx.Rollback(context.Background()) return err }