From eb06ddc35c9d83967837dc93f00325a42cceb302 Mon Sep 17 00:00:00 2001 From: riwiwa Date: Sat, 7 Feb 2026 23:57:43 -0800 Subject: [PATCH] added webui lastfm importing, account sessions, partial codebase cleanup --- db/db.go | 90 +++++--- go.mod | 4 +- go.sum | 1 - main.go | 85 ++----- migrate/lastfm.go | 63 +++++- static/style.css | 196 ++++++++++++++-- templates/login.gohtml | 7 +- templates/profile.gohtml | 17 +- web/web.go | 476 +++++++++++++++++++++++---------------- 9 files changed, 623 insertions(+), 316 deletions(-) diff --git a/db/db.go b/db/db.go index bb97843..e658bf4 100644 --- a/db/db.go +++ b/db/db.go @@ -6,39 +6,23 @@ import ( "os" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -func TableExists(name string, conn *pgx.Conn) bool { - var exists bool - err := conn.QueryRow( - context.Background(), - `SELECT EXISTS (SELECT 1 FROM pg_tables WHERE schemaname = 'public' AND - tablename = $1);`, - name, - ). - Scan(&exists) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT EXISTS failed: %v\n", err) - return false - } - return exists -} +var Pool *pgxpool.Pool -func DbExists() bool { - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - return false +func CreateAllTables() error { + if err := CreateHistoryTable(); err != nil { + return err } - defer conn.Close(context.Background()) - return true + if err := CreateUsersTable(); err != nil { + return err + } + return CreateSessionsTable() } func CreateDB() error { - conn, err := pgx.Connect( - context.Background(), + conn, err := pgx.Connect(context.Background(), "postgres://postgres:postgres@localhost:5432", ) if err != nil { @@ -46,16 +30,29 @@ func CreateDB() error { return err } defer conn.Close(context.Background()) + + var exists bool + err = conn.QueryRow(context.Background(), + "SELECT EXISTS(SELECT 1 FROM pg_database WHERE datname = 'muzi')").Scan(&exists) + if err != nil { + fmt.Fprintf(os.Stderr, "Error checking if database exists: %v\n", err) + return err + } + + if exists { + return nil + } + _, err = conn.Exec(context.Background(), "CREATE DATABASE muzi") if err != nil { - fmt.Fprintf(os.Stderr, "Cannot create muzi database: %v\n", err) + fmt.Fprintf(os.Stderr, "Error creating muzi database: %v\n", err) return err } return nil } -func CreateHistoryTable(conn *pgx.Conn) error { - _, err := conn.Exec(context.Background(), +func CreateHistoryTable() error { + _, err := Pool.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS history ( id SERIAL PRIMARY KEY, user_id INTEGER NOT NULL, @@ -64,7 +61,7 @@ func CreateHistoryTable(conn *pgx.Conn) error { artist TEXT NOT NULL, album_name TEXT, ms_played INTEGER, - platform TEXT DEFAULT 'spotify', + platform TEXT, UNIQUE (user_id, song_name, artist, timestamp) ); CREATE INDEX IF NOT EXISTS idx_history_user_timestamp ON history(user_id, timestamp DESC); @@ -77,10 +74,11 @@ func CreateHistoryTable(conn *pgx.Conn) error { return nil } -func CreateUsersTable(conn *pgx.Conn) error { - _, err := conn.Exec(context.Background(), +// TODO: move user settings to jsonb in db +func CreateUsersTable() error { + _, err := Pool.Exec(context.Background(), `CREATE TABLE IF NOT EXISTS users ( - username TEXT NOT NULL, + username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, bio TEXT DEFAULT 'This profile has no bio.', pfp TEXT DEFAULT '/files/assets/default.png', @@ -93,3 +91,29 @@ func CreateUsersTable(conn *pgx.Conn) error { } return nil } + +func CreateSessionsTable() error { + _, err := Pool.Exec(context.Background(), + `CREATE TABLE IF NOT EXISTS sessions ( + session_id TEXT PRIMARY KEY, + username TEXT NOT NULL REFERENCES users(username), + created_at TIMESTAMPTZ DEFAULT NOW(), + expires_at TIMESTAMPTZ DEFAULT NOW() + INTERVAL '30 days' + ); + CREATE INDEX IF NOT EXISTS idx_sessions_expires ON sessions(expires_at);`) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating sessions table: %v\n", err) + return err + } + return nil +} + +func CleanupExpiredSessions() error { + _, err := Pool.Exec(context.Background(), + "DELETE FROM sessions WHERE expires_at < NOW();") + if err != nil { + fmt.Fprintf(os.Stderr, "Error cleaning up expired sessions: %v\n", err) + return err + } + return nil +} diff --git a/go.mod b/go.mod index 5697fcb..8dc3218 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,14 @@ require ( github.com/go-chi/chi/v5 v5.2.3 github.com/jackc/pgtype v1.14.4 github.com/jackc/pgx/v5 v5.7.6 + golang.org/x/crypto v0.45.0 ) require ( github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect - golang.org/x/crypto v0.45.0 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect + golang.org/x/sync v0.18.0 // indirect golang.org/x/text v0.31.0 // indirect ) diff --git a/go.sum b/go.sum index d135778..69d57a3 100644 --- a/go.sum +++ b/go.sum @@ -66,7 +66,6 @@ github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8 github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= -github.com/jackc/puddle v1.3.0 h1:eHK/5clGOatcjX3oWGBO/MpxpbHzSwud5EWTSCI+MX0= github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= diff --git a/main.go b/main.go index e679890..9497a33 100644 --- a/main.go +++ b/main.go @@ -8,23 +8,11 @@ import ( "path/filepath" "muzi/db" - "muzi/migrate" "muzi/web" - "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" ) -func dbCheck() error { - if !db.DbExists() { - err := db.CreateDB() - if err != nil { - fmt.Fprintf(os.Stderr, "Error creating muzi DB: %v\n", err) - return err - } - } - return nil -} - func dirCheck(path string) error { _, err := os.Stat(path) if err != nil { @@ -35,45 +23,28 @@ func dirCheck(path string) error { return err } } - return nil } func main() { - dirImports := filepath.Join(".", "imports") + zipDir := filepath.Join(".", "imports", "spotify", "zip") + extDir := filepath.Join(".", "imports", "spotify", "extracted") - dirSpotify := filepath.Join(".", "imports", "spotify") - dirSpotifyZip := filepath.Join(".", "imports", "spotify", "zip") - dirSpotifyExt := filepath.Join(".", "imports", "spotify", "extracted") - - fmt.Printf("Checking if directory %s exists...\n", dirImports) - err := dirCheck(dirImports) - if err != nil { - return + dirs := []string{zipDir, extDir} + for _, dir := range dirs { + err := dirCheck(dir) + if err != nil { + fmt.Fprintf(os.Stderr, "Error checking dir: %s: %v\n", dir, err) + return + } } - fmt.Printf("Checking if directory %s exists...\n", dirSpotify) - err = dirCheck(dirSpotify) - if err != nil { - return - } - fmt.Printf("Checking if directory %s exists...\n", dirSpotifyZip) - err = dirCheck(dirSpotifyZip) - if err != nil { - return - } - fmt.Printf("Checking if directory %s exists...\n", dirSpotifyExt) - err = dirCheck(dirSpotifyExt) - if err != nil { - return - } - fmt.Println("Checking if muzi database exists...") - err = dbCheck() + err := db.CreateDB() if err != nil { + fmt.Fprintf(os.Stderr, "Error ensuring muzi DB exists: %v\n", err) return } - fmt.Println("Setting up database tables...") - conn, err := pgx.Connect( + db.Pool, err = pgxpool.New( context.Background(), "postgres://postgres:postgres@localhost:5432/muzi", ) @@ -81,33 +52,25 @@ func main() { fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) return } - defer conn.Close(context.Background()) + defer db.Pool.Close() - err = db.CreateHistoryTable(conn) + err = db.CreateAllTables() if err != nil { - fmt.Fprintf(os.Stderr, "Error creating history table: %v\n", err) + fmt.Fprintf(os.Stderr, "Error ensuring all tables exist: %v\n", err) return } - err = db.CreateUsersTable(conn) + err = db.CleanupExpiredSessions() if err != nil { - fmt.Fprintf(os.Stderr, "Error creating users table: %v\n", err) + fmt.Fprintf(os.Stderr, "Error cleaning expired sessions: %v\n", err) return } - username := "" - apiKey := "" - fmt.Printf("Importing LastFM data for %s\n", username) - // TODO: - // remove hardcoded userID by creating webUI import pages and getting - // userID from login session - err = migrate.ImportLastFM(username, apiKey, 1) - if err != nil { - return - } - err = migrate.ImportSpotify(1) - if err != nil { - return - } + /* + err = migrate.ImportSpotify(1) + if err != nil { + return + } + */ web.Start() } diff --git a/migrate/lastfm.go b/migrate/lastfm.go index c5cc8c9..f623e30 100644 --- a/migrate/lastfm.go +++ b/migrate/lastfm.go @@ -28,6 +28,15 @@ type pageResult struct { err error } +type ProgressUpdate struct { + CurrentPage int `json:"current_page"` + CompletedPages int `json:"completed_pages"` + TotalPages int `json:"total_pages"` + TracksImported int `json:"tracks_imported"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + type Response struct { Recenttracks struct { Track []struct { @@ -51,13 +60,21 @@ type Response struct { } `json:"recenttracks"` } -func ImportLastFM(username string, apiKey string, userId int) error { +func ImportLastFM( + username string, + apiKey string, + userId int, + progressChan chan<- ProgressUpdate, +) error { conn, err := pgx.Connect( context.Background(), "postgres://postgres:postgres@localhost:5432/muzi", ) if err != nil { fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) + if progressChan != nil { + progressChan <- ProgressUpdate{Status: "error", Error: err.Error()} + } return err } defer conn.Close(context.Background()) @@ -70,6 +87,9 @@ func ImportLastFM(username string, apiKey string, userId int) error { ) if err != nil { fmt.Fprintf(os.Stderr, "Error getting LastFM HTTP response: %v\n", err) + if progressChan != nil { + progressChan <- ProgressUpdate{Status: "error", Error: err.Error()} + } return err } var initialData Response @@ -78,10 +98,21 @@ func ImportLastFM(username string, apiKey string, userId int) error { resp.Body.Close() if err != nil { fmt.Fprintf(os.Stderr, "Error parsing total pages: %v\n", err) + if progressChan != nil { + progressChan <- ProgressUpdate{Status: "error", Error: err.Error()} + } return err } fmt.Printf("Total pages: %d\n", totalPages) + // send initial progress update + if progressChan != nil { + progressChan <- ProgressUpdate{ + TotalPages: totalPages, + Status: "running", + } + } + trackBatch := make([]LastFMTrack, 0, 1000) pageChan := make(chan pageResult, 20) @@ -137,6 +168,8 @@ func ImportLastFM(username string, apiKey string, userId int) error { }() batchSize := 500 + completedPages := 0 + var completedMu sync.Mutex for result := range pageChan { if result.err != nil { @@ -153,6 +186,23 @@ func ImportLastFM(username string, apiKey string, userId int) error { } } fmt.Printf("Processed page %d/%d\n", result.pageNum, totalPages) + + // increment completed pages counter + completedMu.Lock() + completedPages++ + currentCompleted := completedPages + completedMu.Unlock() + + // send progress update after each page + if progressChan != nil { + progressChan <- ProgressUpdate{ + CurrentPage: result.pageNum, + CompletedPages: currentCompleted, + TotalPages: totalPages, + TracksImported: totalImported, + Status: "running", + } + } } if len(trackBatch) > 0 { @@ -163,6 +213,17 @@ func ImportLastFM(username string, apiKey string, userId int) error { } fmt.Printf("%d tracks imported from LastFM for user %s\n", totalImported, username) + + // send completion update + if progressChan != nil { + progressChan <- ProgressUpdate{ + CurrentPage: totalPages, + TotalPages: totalPages, + TracksImported: totalImported, + Status: "completed", + } + } + return nil } diff --git a/static/style.css b/static/style.css index 10911ad..98a1c03 100644 --- a/static/style.css +++ b/static/style.css @@ -1,17 +1,17 @@ -body { - display: flex; - flex-direction: column; - background-color: #222; - color: #AFA; - align-content: center; - justify-content: center; - align-items: center; - text-align: center; - max-width: 70vw; - margin: 0; - width: 70vw; - font-family: sans-serif; -} + body { + display: flex; + flex-direction: column; + background-color: #222; + color: #AFA; + align-content: center; + justify-content: center; + align-items: center; + text-align: center; + max-width: 70vw; + margin: 0 auto; + width: 70vw; + font-family: sans-serif; + } .page_buttons { display: flex; @@ -27,7 +27,19 @@ body { } .user-stats-top { - display: inline-block; + display: flex; + flex-direction: column; + justify-content: center; + width: 20%; + h3 { + color: #FFF; + font-size: 25px; + margin: 0; + } + p { + margin: 0; + color: #EEE; + } } .username-bio { @@ -37,10 +49,15 @@ body { margin-left: 40px; } +.profile-top-blank { + width: 50%; + } + .profile-top { display: flex; flex-direction: row; align-content: center; + width: 100%; h1 { color: #FFFFFF; margin: 0; @@ -50,11 +67,6 @@ body { font-size: 15px; margin: 0; } - h3 { - color: #AAAAAA; - font-size: 25px; - margin: 0; - } img { object-fit: cover; width: 250px; @@ -63,16 +75,24 @@ body { } } +.login-form { + display: flex; + height: 100vh; + align-items: center; + justify-content: center; + } + .login-error { color: #AA0000; } .history { display: flex; + flex-direction: column; justify-content: center; - width: 100vw; + width: 100%; table { - width: 90%; + width: auto; } tr { display: flex; @@ -90,3 +110,135 @@ body { background-color: #111; } } + +.import-section { + margin: 20px 0; + padding: 20px; + background: #1a1a1a; + border-radius: 8px; +} + +.import-section form { + display: flex; + flex-direction: column; + gap: 10px; + margin-top: 15px; +} + +.import-section input { + padding: 8px; + border: 1px solid #333; + border-radius: 4px; + background: #222; + color: #AFA; +} + +.import-section button { + padding: 10px 20px; + background: #333; + color: #AFA; + border: 1px solid #444; + border-radius: 4px; + cursor: pointer; +} + +.import-section button:hover { + background: #444; +} + +.progress-container { + margin-top: 15px; + padding: 15px; + background: #1a1a1a; + border-radius: 8px; + border: 1px solid #333; +} + +.progress-bar-wrapper { + width: 100%; + height: 24px; + background: #2a2a2a; + border-radius: 12px; + overflow: hidden; + position: relative; + margin: 10px 0; + border: 2px solid #444; +} + +.progress-bar-fill { + height: 100%; + width: 0%; + background: linear-gradient(90deg, #5a5 0%, #7f7 50%, #5a5 100%); + background-size: 200% 100%; + border-radius: 10px; + transition: width 0.3s ease-out; + box-shadow: + inset 0 2px 4px rgba(255, 255, 255, 0.3), + inset 0 -2px 4px rgba(0, 0, 0, 0.3), + 0 0 15px rgba(0, 255, 0, 0.4); + position: absolute; + top: 0; + left: 0; + z-index: 1; +} + +.progress-bar-fill.animating { + animation: shimmer 2s linear infinite, pulse-glow 1.5s ease-in-out infinite; +} + +@keyframes pulse-glow { + 0%, 100% { + box-shadow: + inset 0 2px 4px rgba(255, 255, 255, 0.3), + inset 0 -2px 4px rgba(0, 0, 0, 0.3), + 0 0 15px rgba(0, 255, 0, 0.4); + } + 50% { + box-shadow: + inset 0 2px 4px rgba(255, 255, 255, 0.3), + inset 0 -2px 4px rgba(0, 0, 0, 0.3), + 0 0 25px rgba(0, 255, 0, 0.7); + } +} + +@keyframes shimmer { + 0% { background-position: 200% 0; } + 100% { background-position: -200% 0; } +} + +.progress-text { + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: #fff; + font-weight: bold; + font-size: 12px; + text-shadow: 0 1px 2px rgba(0, 0, 0, 0.8); + z-index: 2; + pointer-events: none; +} + +.progress-status { + color: #AFA; + font-size: 14px; + margin-bottom: 5px; +} + +.progress-tracks { + color: #888; + font-size: 12px; + margin-bottom: 5px; +} + +.progress-error { + color: #F88; + font-size: 14px; + margin-top: 10px; +} + +.progress-success { + color: #8F8; + font-size: 14px; + margin-top: 10px; +} diff --git a/templates/login.gohtml b/templates/login.gohtml index afdfcaf..d9719ce 100644 --- a/templates/login.gohtml +++ b/templates/login.gohtml @@ -14,11 +14,16 @@

- {{if .ShowError}} + {{if eq .Error "1"}}
Invalid credentials. Please try again.
{{end}} + {{if eq .Error "2"}} +
+ Unable to create session. Please try again. +
+ {{end}} diff --git a/templates/profile.gohtml b/templates/profile.gohtml index cda5bb3..eb8e00c 100644 --- a/templates/profile.gohtml +++ b/templates/profile.gohtml @@ -13,10 +13,15 @@

{{.Username}}

{{.Bio}}

-
-

{{.ScrobbleCount}} Listens

-

{{.ArtistCount}} Artists

+
+
+

{{formatInt .ScrobbleCount}}

Listens

+

{{formatInt .ArtistCount}}

Artists

+

+
+
+ Import Data

Listening History

@@ -38,8 +43,10 @@
- Prev Page - Next Page + {{if gt .Page 1 }} + Prev Page + {{end}} + Next Page
diff --git a/web/web.go b/web/web.go index 30e1aab..3fa1a60 100644 --- a/web/web.go +++ b/web/web.go @@ -2,22 +2,104 @@ package web import ( "context" + "crypto/rand" + "encoding/hex" + "encoding/json" "fmt" "html/template" "net/http" "os" "strconv" + "sync" "muzi/db" + "muzi/migrate" "golang.org/x/crypto/bcrypt" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/jackc/pgtype" - "github.com/jackc/pgx/v5" ) +// will add permissions later +type Session struct { + Username string +} + +var ( + importJobs = make(map[string]chan migrate.ProgressUpdate) + jobsMu sync.RWMutex + templates *template.Template +) + +func init() { + funcMap := template.FuncMap{ + "sub": sub, + "add": add, + "formatInt": formatInt, + } + templates = template.Must(template.New("").Funcs(funcMap).ParseGlob("./templates/*.gohtml")) +} + +func generateID() string { + b := make([]byte, 16) + rand.Read(b) + return hex.EncodeToString(b) +} + +func createSession(username string) string { + sessionID := generateID() + _, err := db.Pool.Exec( + context.Background(), + "INSERT INTO sessions (session_id, username, expires_at) VALUES ($1, $2, NOW() + INTERVAL '30 days');", + sessionID, + username, + ) + if err != nil { + fmt.Fprintf(os.Stderr, "Error creating session: %v\n", err) + return "" + } + return sessionID +} + +func getSession(ctx context.Context, sessionID string) *Session { + var username string + err := db.Pool.QueryRow( + ctx, + "SELECT username FROM sessions WHERE session_id = $1 AND expires_at > NOW();", + sessionID, + ).Scan(&username) + if err != nil { + return nil + } + return &Session{Username: username} +} + +// for account deletion later +func deleteSession(sessionID string) { + _, err := db.Pool.Exec( + context.Background(), + "DELETE FROM sessions WHERE session_id = $1;", + sessionID, + ) + if err != nil { + fmt.Fprintf(os.Stderr, "Error deleting session: %v\n", err) + } +} + +func getLoggedInUsername(r *http.Request) string { + cookie, err := r.Cookie("session") + if err != nil { + return "" + } + session := getSession(r.Context(), cookie.Value) + if session == nil { + return "" + } + return session.Username +} + type ProfileData struct { Username string Bio string @@ -31,118 +113,29 @@ type ProfileData struct { Page int } -func Sub(a int, b int) int { +func sub(a int, b int) int { return a - b } -func Add(a int, b int) int { +func add(a int, b int) int { return a + b } -func getUserIdByUsername(conn *pgx.Conn, username string) (int, error) { +func formatInt(n int) string { + if n < 1000 { + return fmt.Sprintf("%d", n) + } else { + return formatInt(n/1000) + "," + fmt.Sprintf("%03d", n%1000) + } +} + +func getUserIdByUsername(ctx context.Context, username string) (int, error) { var userId int - err := conn.QueryRow(context.Background(), "SELECT pk FROM users WHERE username = $1;", username). + err := db.Pool.QueryRow(ctx, "SELECT pk FROM users WHERE username = $1;", username). Scan(&userId) return userId, err } -func getTimes(conn *pgx.Conn, userId int, lim int, off int) []string { - var times []string - rows, err := conn.Query( - context.Background(), - "SELECT timestamp FROM history WHERE user_id = $1 ORDER BY timestamp DESC LIMIT $2 OFFSET $3;", - userId, - lim, - off, - ) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT timestamp failed: %v\n", err) - return nil - } - for rows.Next() { - var time pgtype.Timestamptz - err = rows.Scan(&time) - if err != nil { - fmt.Fprintf(os.Stderr, "Scanning time failed: %v\n", err) - return nil - } - times = append(times, time.Time.String()) - } - return times -} - -func getTitles(conn *pgx.Conn, userId int, lim int, off int) []string { - var titles []string - rows, err := conn.Query( - context.Background(), - "SELECT song_name FROM history WHERE user_id = $1 ORDER BY timestamp DESC LIMIT $2 OFFSET $3;", - userId, - lim, - off, - ) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT song_name failed: %v\n", err) - return nil - } - for rows.Next() { - var title string - err = rows.Scan(&title) - if err != nil { - fmt.Fprintf(os.Stderr, "Scanning title failed: %v\n", err) - return nil - } - titles = append(titles, title) - } - return titles -} - -func getArtists(conn *pgx.Conn, userId int, lim int, off int) []string { - var artists []string - rows, err := conn.Query( - context.Background(), - "SELECT artist FROM history WHERE user_id = $1 ORDER BY timestamp DESC LIMIT $2 OFFSET $3;", - userId, - lim, - off, - ) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT artist failed: %v\n", err) - return nil - } - for rows.Next() { - var artist string - err = rows.Scan(&artist) - if err != nil { - fmt.Fprintf(os.Stderr, "Scanning artist name failed: %v\n", err) - return nil - } - artists = append(artists, artist) - } - return artists -} - -func getScrobbles(conn *pgx.Conn, userId int) int { - var count int - err := conn.QueryRow(context.Background(), "SELECT COUNT(*) FROM history WHERE user_id = $1;", userId). - Scan(&count) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT COUNT failed: %v\n", err) - return 0 - } - return count -} - -func getArtistCount(conn *pgx.Conn, userId int) int { - var count int - err := conn.QueryRow(context.Background(), "SELECT COUNT(DISTINCT artist) FROM history WHERE user_id = $1;", userId). - Scan(&count) - if err != nil { - fmt.Fprintf(os.Stderr, "SELECT artist count failed: %v\n", err) - return 0 - } - return count -} - func hashPassword(pass []byte) string { hashedPassword, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) if err != nil { @@ -161,31 +154,21 @@ func verifyPassword(hashedPassword string, enteredPassword []byte) bool { } func createAccount(w http.ResponseWriter, r *http.Request) { - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) - return - } - defer conn.Close(context.Background()) - if r.Method == "POST" { r.ParseForm() username := r.FormValue("uname") hashedPassword := hashPassword([]byte(r.FormValue("pass"))) - err = db.CreateUsersTable(conn) + err := db.CreateUsersTable() if err != nil { fmt.Fprintf(os.Stderr, "Error ensuring users table exists: %v\n", err) http.Redirect(w, r, "/createaccount", http.StatusSeeOther) return } - _, err = conn.Exec( - context.Background(), + _, err = db.Pool.Exec( + r.Context(), `INSERT INTO users (username, password) VALUES ($1, $2);`, username, hashedPassword, @@ -194,6 +177,18 @@ func createAccount(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(os.Stderr, "Cannot add new user to users table: %v\n", err) http.Redirect(w, r, "/createaccount", http.StatusSeeOther) } else { + sessionID := createSession(username) + if sessionID == "" { + http.Redirect(w, r, "/login?error=2", http.StatusSeeOther) + return + } + http.SetCookie(w, &http.Cookie{ + Name: "session", + Value: sessionID, + Path: "/", + HttpOnly: true, + MaxAge: 86400 * 30, // 30 days + }) http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther) } } @@ -201,44 +196,39 @@ func createAccount(w http.ResponseWriter, r *http.Request) { func createAccountPageHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - tmp, err := template.New("create_account.gohtml"). - ParseFiles("./templates/create_account.gohtml") + err := templates.ExecuteTemplate(w, "create_account.gohtml", nil) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - err = tmp.Execute(w, nil) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return } } } func loginSubmit(w http.ResponseWriter, r *http.Request) { - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) - return - } - defer conn.Close(context.Background()) - if r.Method == "POST" { r.ParseForm() username := r.FormValue("uname") password := r.FormValue("pass") var storedPassword string - err := conn.QueryRow(context.Background(), "SELECT password FROM users WHERE username = $1;", username). + err := db.Pool.QueryRow(r.Context(), "SELECT password FROM users WHERE username = $1;", username). Scan(&storedPassword) if err != nil { fmt.Fprintf(os.Stderr, "Cannot get password for entered username: %v\n", err) } if verifyPassword(storedPassword, []byte(password)) { + sessionID := createSession(username) + if sessionID == "" { + http.Redirect(w, r, "/login?error=2", http.StatusSeeOther) + return + } + http.SetCookie(w, &http.Cookie{ + Name: "session", + Value: sessionID, + Path: "/", + HttpOnly: true, + MaxAge: 86400 * 30, // 30 days + }) http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther) } else { http.Redirect(w, r, "/login?error=1", http.StatusSeeOther) @@ -249,22 +239,12 @@ func loginSubmit(w http.ResponseWriter, r *http.Request) { func loginPageHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { type data struct { - ShowError bool + Error string } - d := data{ShowError: false} - if r.URL.Query().Get("error") != "" { - d.ShowError = true - } - tmp, err := template.New("login.gohtml").ParseFiles("./templates/login.gohtml") + d := data{Error: r.URL.Query().Get("error")} + err := templates.ExecuteTemplate(w, "login.gohtml", d) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - err = tmp.Execute(w, d) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return } } } @@ -273,18 +253,7 @@ func profilePageHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { username := chi.URLParam(r, "username") - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer conn.Close(context.Background()) - - userId, err := getUserIdByUsername(conn, username) + userId, err := getUserIdByUsername(r.Context(), username) if err != nil { fmt.Fprintf(os.Stderr, "Cannot find user %s: %v\n", username, err) http.Error(w, "User not found", http.StatusNotFound) @@ -307,59 +276,66 @@ func profilePageHandler() http.HandlerFunc { off := (pageInt - 1) * lim var profileData ProfileData + profileData.Username = username + profileData.Page = pageInt - err = conn.QueryRow( - context.Background(), - "SELECT bio, pfp, allow_duplicate_edits FROM users WHERE pk = $1;", + err = db.Pool.QueryRow( + r.Context(), + `SELECT bio, pfp, allow_duplicate_edits, + (SELECT COUNT(*) FROM history WHERE user_id = $1) as scrobble_count, + (SELECT COUNT(DISTINCT artist) FROM history WHERE user_id = $1) as artist_count + FROM users WHERE pk = $1;`, userId, - ).Scan(&profileData.Bio, &profileData.Pfp, &profileData.AllowDuplicateEdits) + ).Scan(&profileData.Bio, &profileData.Pfp, &profileData.AllowDuplicateEdits, &profileData.ScrobbleCount, &profileData.ArtistCount) if err != nil { fmt.Fprintf(os.Stderr, "Cannot get profile for %s: %v\n", username, err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - profileData.Username = username - profileData.ScrobbleCount = getScrobbles(conn, userId) - profileData.ArtistCount = getArtistCount(conn, userId) - profileData.Artists = getArtists(conn, userId, lim, off) - profileData.Titles = getTitles(conn, userId, lim, off) - profileData.Times = getTimes(conn, userId, lim, off) - profileData.Page = pageInt - funcMap := template.FuncMap{ - "Sub": Sub, - "Add": Add, - } - - tmp, err := template.New("profile.gohtml"). - Funcs(funcMap). - ParseFiles("./templates/profile.gohtml") + rows, err := db.Pool.Query( + r.Context(), + "SELECT artist, song_name, timestamp FROM history WHERE user_id = $1 ORDER BY timestamp DESC LIMIT $2 OFFSET $3;", + userId, + lim, + off, + ) if err != nil { + fmt.Fprintf(os.Stderr, "SELECT history failed: %v\n", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } - tmp.Execute(w, profileData) + defer rows.Close() + + for rows.Next() { + var artist, title string + var time pgtype.Timestamptz + err = rows.Scan(&artist, &title, &time) + if err != nil { + fmt.Fprintf(os.Stderr, "Scanning history row failed: %v\n", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + profileData.Artists = append(profileData.Artists, artist) + profileData.Titles = append(profileData.Titles, title) + profileData.Times = append(profileData.Times, time.Time.String()) + } + + err = templates.ExecuteTemplate(w, "profile.gohtml", profileData) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } } } func updateDuplicateEditsSetting(w http.ResponseWriter, r *http.Request) { - conn, err := pgx.Connect( - context.Background(), - "postgres://postgres:postgres@localhost:5432/muzi", - ) - if err != nil { - fmt.Fprintf(os.Stderr, "Cannot connect to muzi database: %v\n", err) - return - } - defer conn.Close(context.Background()) - if r.Method == "POST" { r.ParseForm() username := r.FormValue("username") allow := r.FormValue("allow") == "true" - _, err = conn.Exec( - context.Background(), + _, err := db.Pool.Exec( + r.Context(), `UPDATE users SET allow_duplicate_edits = $1 WHERE username = $2;`, allow, username, @@ -371,6 +347,121 @@ func updateDuplicateEditsSetting(w http.ResponseWriter, r *http.Request) { } } +func importPageHandler() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + username := getLoggedInUsername(r) + if username == "" { + http.Redirect(w, r, "/login", http.StatusSeeOther) + return + } + + type ImportData struct { + Username string + } + data := ImportData{Username: username} + + err := templates.ExecuteTemplate(w, "import.gohtml", data) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + } +} + +func importLastFMHandler(w http.ResponseWriter, r *http.Request) { + username := getLoggedInUsername(r) + if username == "" { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + userId, err := getUserIdByUsername(r.Context(), username) + if err != nil { + fmt.Fprintf(os.Stderr, "Cannot find user %s: %v\n", username, err) + http.Error(w, "User not found", http.StatusNotFound) + return + } + + r.ParseForm() + lastfmUsername := r.FormValue("lastfm_username") + lastfmAPIKey := r.FormValue("lastfm_api_key") + + if lastfmUsername == "" || lastfmAPIKey == "" { + http.Error(w, "Missing required fields", http.StatusBadRequest) + return + } + + jobID := generateID() + progressChan := make(chan migrate.ProgressUpdate, 100) + + jobsMu.Lock() + importJobs[jobID] = progressChan + jobsMu.Unlock() + + go func() { + migrate.ImportLastFM(lastfmUsername, lastfmAPIKey, userId, progressChan) + + jobsMu.Lock() + delete(importJobs, jobID) + jobsMu.Unlock() + close(progressChan) + }() + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]string{ + "job_id": jobID, + "status": "started", + }) +} + +func importLastFMProgressHandler(w http.ResponseWriter, r *http.Request) { + jobID := r.URL.Query().Get("job") + if jobID == "" { + http.Error(w, "Missing job ID", http.StatusBadRequest) + return + } + + jobsMu.RLock() + job, exists := importJobs[jobID] + jobsMu.RUnlock() + + if !exists { + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "data: %s\n\n", `{"status":"connected"}`) + flusher.Flush() + + for update := range job { + data, err := json.Marshal(update) + if err != nil { + continue + } + fmt.Fprintf(w, "data: %s\n\n", string(data)) + flusher.Flush() + + if update.Status == "completed" || update.Status == "error" { + return + } + } +} + func Start() { addr := ":1234" r := chi.NewRouter() @@ -379,9 +470,12 @@ func Start() { r.Get("/login", loginPageHandler()) r.Get("/createaccount", createAccountPageHandler()) r.Get("/profile/{username}", profilePageHandler()) + r.Get("/import", importPageHandler()) r.Post("/loginsubmit", loginSubmit) r.Post("/createaccountsubmit", createAccount) r.Post("/settings/duplicate-edits", updateDuplicateEditsSetting) + r.Post("/import/lastfm", importLastFMHandler) + r.Get("/import/lastfm/progress", importLastFMProgressHandler) fmt.Printf("WebUI starting on %s\n", addr) http.ListenAndServe(addr, r) }