Compare commits

...

4 Commits

4 changed files with 142 additions and 126 deletions

View File

@@ -12,6 +12,8 @@ import (
"time" "time"
"muzi/db" "muzi/db"
"github.com/jackc/pgx/v5"
) )
type LastFMTrack struct { type LastFMTrack struct {
@@ -60,11 +62,46 @@ type Response struct {
} `json:"recenttracks"` } `json:"recenttracks"`
} }
func fetchPage(client *http.Client, page int, lfmUsername, apiKey string, userId int) pageResult {
resp, err := client.Get(
"https://ws.audioscrobbler.com/2.0/?method=user.getrecenttracks&user=" +
lfmUsername + "&api_key=" + apiKey + "&format=json&limit=100&page=" + strconv.Itoa(page),
)
if err != nil {
return pageResult{pageNum: page, err: err}
}
defer resp.Body.Close()
var data Response
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
return pageResult{pageNum: page, err: err}
}
var pageTracks []LastFMTrack
for j := range data.Recenttracks.Track {
if data.Recenttracks.Track[j].Attr.Nowplaying == "true" {
continue
}
unixTime, err := strconv.ParseInt(data.Recenttracks.Track[j].Date.Uts, 10, 64)
if err != nil {
continue
}
pageTracks = append(pageTracks, LastFMTrack{
UserId: userId,
Timestamp: time.Unix(unixTime, 0),
SongName: data.Recenttracks.Track[j].Name,
Artist: data.Recenttracks.Track[j].Artist.Text,
Album: data.Recenttracks.Track[j].Album.Text,
})
}
return pageResult{pageNum: page, tracks: pageTracks, err: nil}
}
func ImportLastFM( func ImportLastFM(
username string, lfmUsername string,
apiKey string, apiKey string,
userId int, userId int,
progressChan chan<- ProgressUpdate, progressChan chan<- ProgressUpdate,
username string,
) error { ) error {
totalImported := 0 totalImported := 0
@@ -74,7 +111,7 @@ func ImportLastFM(
resp, err := client.Get( resp, err := client.Get(
"https://ws.audioscrobbler.com/2.0/?method=user.getrecenttracks&user=" + "https://ws.audioscrobbler.com/2.0/?method=user.getrecenttracks&user=" +
username + "&api_key=" + apiKey + "&format=json&limit=100", lfmUsername + "&api_key=" + apiKey + "&format=json&limit=100",
) )
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error getting LastFM HTTP response: %v\n", err) fmt.Fprintf(os.Stderr, "Error getting LastFM HTTP response: %v\n", err)
@@ -83,10 +120,15 @@ func ImportLastFM(
} }
return err return err
} }
defer resp.Body.Close()
var initialData Response var initialData Response
json.NewDecoder(resp.Body).Decode(&initialData) err = json.NewDecoder(resp.Body).Decode(&initialData)
if err != nil {
fmt.Fprintf(os.Stderr,
"Error decoding initial LastFM response: %v\n", err)
return err
}
totalPages, err := strconv.Atoi(initialData.Recenttracks.Attr.TotalPages) totalPages, err := strconv.Atoi(initialData.Recenttracks.Attr.TotalPages)
resp.Body.Close()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing total pages: %v\n", err) fmt.Fprintf(os.Stderr, "Error parsing total pages: %v\n", err)
if progressChan != nil { if progressChan != nil {
@@ -94,7 +136,8 @@ func ImportLastFM(
} }
return err return err
} }
fmt.Printf("Total pages: %d\n", totalPages) fmt.Printf("%s started a LastFM import job of %d total pages\n", username,
totalPages)
// send initial progress update // send initial progress update
if progressChan != nil { if progressChan != nil {
@@ -115,40 +158,7 @@ func ImportLastFM(
go func(workerID int) { go func(workerID int) {
defer wg.Done() defer wg.Done()
for page := workerID + 1; page <= totalPages; page += 10 { for page := workerID + 1; page <= totalPages; page += 10 {
resp, err := client.Get( pageChan <- fetchPage(client, page, lfmUsername, apiKey, userId)
"https://ws.audioscrobbler.com/2.0/?method=user.getrecenttracks&user=" +
username + "&api_key=" + apiKey + "&format=json&limit=100&page=" + strconv.Itoa(page),
)
if err != nil {
pageChan <- pageResult{pageNum: page, err: err}
continue
}
var data Response
if err := json.NewDecoder(resp.Body).Decode(&data); err != nil {
resp.Body.Close()
pageChan <- pageResult{pageNum: page, err: err}
continue
}
resp.Body.Close()
var pageTracks []LastFMTrack
for j := range data.Recenttracks.Track {
if data.Recenttracks.Track[j].Attr.Nowplaying == "true" {
continue
}
unixTime, err := strconv.ParseInt(data.Recenttracks.Track[j].Date.Uts, 10, 64)
if err != nil {
continue
}
pageTracks = append(pageTracks, LastFMTrack{
UserId: userId,
Timestamp: time.Unix(unixTime, 0),
SongName: data.Recenttracks.Track[j].Name,
Artist: data.Recenttracks.Track[j].Artist.Text,
Album: data.Recenttracks.Track[j].Album.Text,
})
}
pageChan <- pageResult{pageNum: page, tracks: pageTracks, err: nil}
} }
}(worker) }(worker)
} }
@@ -171,13 +181,14 @@ func ImportLastFM(
for len(trackBatch) >= batchSize { for len(trackBatch) >= batchSize {
batch := trackBatch[:batchSize] batch := trackBatch[:batchSize]
trackBatch = trackBatch[batchSize:] trackBatch = trackBatch[batchSize:]
err := insertBatch(batch, &totalImported, batchSize) err := insertBatch(batch, &totalImported)
if err != nil { if err != nil {
// prevent logs being filled by duplicate warnings
if !strings.Contains(err.Error(), "duplicate") {
fmt.Fprintf(os.Stderr, "Batch insert failed: %v\n", err) fmt.Fprintf(os.Stderr, "Batch insert failed: %v\n", err)
} }
} }
fmt.Printf("Processed page %d/%d\n", result.pageNum, totalPages) }
// increment completed pages counter // increment completed pages counter
completedMu.Lock() completedMu.Lock()
completedPages++ completedPages++
@@ -197,13 +208,19 @@ func ImportLastFM(
} }
if len(trackBatch) > 0 { if len(trackBatch) > 0 {
err := insertBatch(trackBatch, &totalImported, batchSize) err := insertBatch(trackBatch, &totalImported)
if err != nil { if err != nil {
// prevent logs being filled by duplicate warnings
if !strings.Contains(err.Error(), "duplicate") {
fmt.Fprintf(os.Stderr, "Final batch insert failed: %v\n", err) fmt.Fprintf(os.Stderr, "Final batch insert failed: %v\n", err)
} }
} }
}
fmt.Printf("%d tracks imported from LastFM for user %s\n", totalImported, username) fmt.Printf("User %s imported %d tracks from LastFM account %s\n",
username,
totalImported,
lfmUsername)
// send completion update // send completion update
if progressChan != nil { if progressChan != nil {
@@ -218,65 +235,21 @@ func ImportLastFM(
return nil return nil
} }
func insertBatch(tracks []LastFMTrack, totalImported *int, batchSize int) error { func insertBatch(tracks []LastFMTrack, totalImported *int) error {
tx, err := db.Pool.Begin(context.Background()) copyCount, err := db.Pool.CopyFrom(context.Background(),
if err != nil { pgx.Identifier{"history"},
return err []string{
} "user_id", "timestamp", "song_name", "artist", "album_name",
"ms_played", "platform",
var batchValues []string },
var batchArgs []any pgx.CopyFromSlice(len(tracks), func(i int) ([]any, error) {
t := tracks[i]
for i, track := range tracks { return []any{
batchValues = append(batchValues, fmt.Sprintf( t.UserId, t.Timestamp, t.SongName, t.Artist,
"($%d, $%d, $%d, $%d, $%d, $%d, $%d)", t.Album, 0, "lastfm",
len(batchArgs)+1, }, nil
len(batchArgs)+2, }),
len(batchArgs)+3,
len(batchArgs)+4,
len(batchArgs)+5,
len(batchArgs)+6,
len(batchArgs)+7,
))
// lastfm doesn't store playtime for each track, so set to 0
batchArgs = append(
batchArgs,
track.UserId,
track.Timestamp,
track.SongName,
track.Artist,
track.Album,
0,
"lastfm",
) )
*totalImported += int(copyCount)
if len(batchValues) >= batchSize || i == len(tracks)-1 {
result, err := tx.Exec(
context.Background(),
`INSERT INTO history (user_id, timestamp, song_name, artist, album_name, ms_played, platform) VALUES `+
strings.Join(
batchValues,
", ",
)+` ON CONFLICT ON CONSTRAINT history_user_id_song_name_artist_timestamp_key DO NOTHING;`,
batchArgs...,
)
if err != nil {
tx.Rollback(context.Background())
return err return err
}
rowsAffected := result.RowsAffected()
if rowsAffected > 0 {
*totalImported += int(rowsAffected)
}
batchValues = batchValues[:0]
batchArgs = batchArgs[:0]
}
}
if err := tx.Commit(context.Background()); err != nil {
tx.Rollback(context.Background())
return err
}
return nil
} }

View File

@@ -207,6 +207,8 @@ func JsonToDB(jsonFile string, userId int) error {
continue continue
} }
// TODO: replace strings.Join with pgx copy
_, err = conn.Exec( _, err = conn.Exec(
context.Background(), context.Background(),
`INSERT INTO history (user_id, timestamp, song_name, artist, album_name, ms_played, platform) VALUES `+ `INSERT INTO history (user_id, timestamp, song_name, artist, album_name, ms_played, platform) VALUES `+

View File

@@ -24,6 +24,16 @@
Unable to create session. Please try again. Unable to create session. Please try again.
</div> </div>
{{end}} {{end}}
{{if eq .Error "userlength"}}
<div class="login-error">
Username length must be greater than 0.
</div>
{{end}}
{{if eq .Error "usertaken"}}
<div class="login-error">
Username must be unique. Please try again.
</div>
{{end}}
</form> </form>
</div> </div>
</body> </body>

View File

@@ -145,7 +145,7 @@ func getUserIdByUsername(ctx context.Context, username string) (int, error) {
} }
func hashPassword(pass []byte) (string, error) { func hashPassword(pass []byte) (string, error) {
if len(pass) < 8 || len(pass) > 64 { if len([]rune(string(pass))) < 8 || len(pass) > 64 {
return "", errors.New("Error: Password must be greater than 8 chars.") return "", errors.New("Error: Password must be greater than 8 chars.")
} }
hashedPassword, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) hashedPassword, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost)
@@ -167,20 +167,29 @@ func verifyPassword(hashedPassword string, enteredPassword []byte) bool {
func createAccount(w http.ResponseWriter, r *http.Request) { func createAccount(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" { if r.Method == "POST" {
r.ParseForm() err := r.ParseForm()
username := r.FormValue("uname")
hashedPassword, err := hashPassword([]byte(r.FormValue("pass")))
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error hashing password: %v\n", err) http.Error(w, err.Error(), http.StatusBadRequest)
http.Redirect(w, r, "/createaccount?error=length", http.StatusSeeOther)
return return
} }
err = db.CreateUsersTable() username := r.FormValue("uname")
if len([]rune(string(username))) == 0 {
http.Redirect(w, r, "/createaccount?error=userlength", http.StatusSeeOther)
return
}
var usertaken bool
err = db.Pool.QueryRow(r.Context(),
"SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)", username).
Scan(&usertaken)
if usertaken == true {
http.Redirect(w, r, "/createaccount?error=usertaken", http.StatusSeeOther)
return
}
hashedPassword, err := hashPassword([]byte(r.FormValue("pass")))
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Error ensuring users table exists: %v\n", err) fmt.Fprintf(os.Stderr, "Error hashing password: %v\n", err)
http.Redirect(w, r, "/createaccount", http.StatusSeeOther) http.Redirect(w, r, "/createaccount?error=passlength", http.StatusSeeOther)
return return
} }
@@ -204,6 +213,8 @@ func createAccount(w http.ResponseWriter, r *http.Request) {
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 86400 * 30, // 30 days MaxAge: 86400 * 30, // 30 days
}) })
http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther) http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther)
@@ -226,12 +237,20 @@ func createAccountPageHandler() http.HandlerFunc {
func loginSubmit(w http.ResponseWriter, r *http.Request) { func loginSubmit(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" { if r.Method == "POST" {
r.ParseForm() err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
username := r.FormValue("uname") username := r.FormValue("uname")
if username == "" {
http.Redirect(w, r, "/login?error=invalid-creds", http.StatusSeeOther)
return
}
password := r.FormValue("pass") password := r.FormValue("pass")
var storedPassword string var storedPassword string
err := db.Pool.QueryRow(r.Context(), "SELECT password FROM users WHERE username = $1;", username). err = db.Pool.QueryRow(r.Context(), "SELECT password FROM users WHERE username = $1;", username).
Scan(&storedPassword) Scan(&storedPassword)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Cannot get password for entered username: %v\n", err) fmt.Fprintf(os.Stderr, "Cannot get password for entered username: %v\n", err)
@@ -248,6 +267,8 @@ func loginSubmit(w http.ResponseWriter, r *http.Request) {
Value: sessionID, Value: sessionID,
Path: "/", Path: "/",
HttpOnly: true, HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
MaxAge: 86400 * 30, // 30 days MaxAge: 86400 * 30, // 30 days
}) })
http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther) http.Redirect(w, r, "/profile/"+username, http.StatusSeeOther)
@@ -351,11 +372,15 @@ func profilePageHandler() http.HandlerFunc {
func updateDuplicateEditsSetting(w http.ResponseWriter, r *http.Request) { func updateDuplicateEditsSetting(w http.ResponseWriter, r *http.Request) {
if r.Method == "POST" { if r.Method == "POST" {
r.ParseForm() err := r.ParseForm()
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
username := r.FormValue("username") username := r.FormValue("username")
allow := r.FormValue("allow") == "true" allow := r.FormValue("allow") == "true"
_, err := db.Pool.Exec( _, err = db.Pool.Exec(
r.Context(), r.Context(),
`UPDATE users SET allow_duplicate_edits = $1 WHERE username = $2;`, `UPDATE users SET allow_duplicate_edits = $1 WHERE username = $2;`,
allow, allow,
@@ -407,9 +432,13 @@ func importLastFMHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
r.ParseForm() err = r.ParseForm()
lastfmUsername := r.FormValue("lastfm_username") if err != nil {
lastfmAPIKey := r.FormValue("lastfm_api_key") http.Error(w, err.Error(), http.StatusBadRequest)
return
}
lastfmUsername := template.HTMLEscapeString(r.FormValue("lastfm_username"))
lastfmAPIKey := template.HTMLEscapeString(r.FormValue("lastfm_api_key"))
if lastfmUsername == "" || lastfmAPIKey == "" { if lastfmUsername == "" || lastfmAPIKey == "" {
http.Error(w, "Missing required fields", http.StatusBadRequest) http.Error(w, "Missing required fields", http.StatusBadRequest)
@@ -429,7 +458,8 @@ func importLastFMHandler(w http.ResponseWriter, r *http.Request) {
jobsMu.Unlock() jobsMu.Unlock()
go func() { go func() {
migrate.ImportLastFM(lastfmUsername, lastfmAPIKey, userId, progressChan) migrate.ImportLastFM(lastfmUsername, lastfmAPIKey, userId, progressChan,
username)
jobsMu.Lock() jobsMu.Lock()
delete(importJobs, jobID) delete(importJobs, jobID)
@@ -503,5 +533,6 @@ func Start() {
r.Post("/import/lastfm", importLastFMHandler) r.Post("/import/lastfm", importLastFMHandler)
r.Get("/import/lastfm/progress", importLastFMProgressHandler) r.Get("/import/lastfm/progress", importLastFMProgressHandler)
fmt.Printf("WebUI starting on %s\n", addr) fmt.Printf("WebUI starting on %s\n", addr)
http.ListenAndServe(addr, r) prot := http.NewCrossOriginProtection()
http.ListenAndServe(addr, prot.Handler(r))
} }