diff --git a/scrobble/lastfm.go b/scrobble/lastfm.go index 70d3cd1..2b2f79a 100644 --- a/scrobble/lastfm.go +++ b/scrobble/lastfm.go @@ -23,6 +23,15 @@ func NewLastFMHandler() *LastFMHandler { } func (h *LastFMHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + if r.URL.Query().Get("hs") == "true" { + h.handleHandshake(w, r) + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) return @@ -34,31 +43,44 @@ func (h *LastFMHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - method := r.FormValue("method") - apiKey := r.FormValue("api_key") + method := r.PostForm.Get("method") + apiKey := r.PostForm.Get("api_key") + sk := r.PostForm.Get("s") + track := r.PostForm.Get("t") - switch method { - case "auth.gettoken": - h.handleGetToken(w, apiKey) - case "auth.getsession": - h.handleGetSession(w, r) - case "track.updateNowPlaying": - h.handleNowPlaying(w, r) - case "track.scrobble": - h.handleScrobble(w, r) - default: - h.respond(w, "failed", 400, fmt.Sprintf("Invalid method: %s", method)) + if method != "" { + switch method { + case "auth.gettoken": + h.handleGetToken(w, apiKey) + case "auth.getsession": + h.handleGetSession(w, r) + case "track.updateNowPlaying": + h.handleNowPlaying(w, r) + case "track.scrobble": + h.handleScrobble(w, r) + default: + h.respond(w, "failed", 400, fmt.Sprintf("Invalid method: %s", method)) + } + return } + + if sk != "" { + if r.PostForm.Get("a[0]") != "" && (r.PostForm.Get("t[0]") != "" || r.PostForm.Get("i[0]") != "") { + h.handleScrobble(w, r) + return + } + if track != "" { + h.handleNowPlaying(w, r) + return + } + } + + h.respond(w, "failed", 400, "Missing required parameters") } func (h *LastFMHandler) respond(w http.ResponseWriter, status string, code int, message string) { - w.Header().Set("Content-Type", "application/xml; charset=utf-8") - fmt.Fprintf(w, ` - - - %s - -`, status, code, message) + w.Header().Set("Content-Type", "text/plain") + w.Write([]byte(fmt.Sprintf("FAILED %s", message))) } func (h *LastFMHandler) respondOK(w http.ResponseWriter, content string) { @@ -66,6 +88,40 @@ func (h *LastFMHandler) respondOK(w http.ResponseWriter, content string) { w.Write([]byte(content)) } +func (h *LastFMHandler) handleHandshake(w http.ResponseWriter, r *http.Request) { + username := r.URL.Query().Get("u") + token := r.URL.Query().Get("t") + authToken := r.URL.Query().Get("a") + + if username == "" || token == "" || authToken == "" { + w.Write([]byte("BADAUTH")) + return + } + + userId, err := GetUserByUsername(username) + if err != nil { + w.Write([]byte("BADAUTH")) + return + } + + sessionKey, err := GenerateSessionKey() + if err != nil { + w.Write([]byte("FAILED Could not generate session")) + return + } + + _, err = db.Pool.Exec(context.Background(), + `UPDATE users SET api_secret = $1 WHERE pk = $2`, + sessionKey, userId) + if err != nil { + fmt.Fprintf(os.Stderr, "Error updating session key: %v\n", err) + w.Write([]byte("FAILED Database error")) + return + } + + w.Write([]byte(fmt.Sprintf("OK\n%s\nhttp://127.0.0.1:1234/2.0/\nhttp://127.0.0.1:1234/2.0/\n", sessionKey))) +} + func (h *LastFMHandler) handleGetToken(w http.ResponseWriter, apiKey string) { userId, _, err := GetUserByAPIKey(apiKey) if err != nil { @@ -121,7 +177,7 @@ func (h *LastFMHandler) handleGetSession(w http.ResponseWriter, r *http.Request) } func (h *LastFMHandler) handleNowPlaying(w http.ResponseWriter, r *http.Request) { - sessionKey := r.FormValue("sk") + sessionKey := r.PostForm.Get("s") if sessionKey == "" { h.respond(w, "failed", 9, "Invalid session") return @@ -133,11 +189,16 @@ func (h *LastFMHandler) handleNowPlaying(w http.ResponseWriter, r *http.Request) return } - artist := r.FormValue("artist") - track := r.FormValue("track") - album := r.FormValue("album") + artist := r.PostForm.Get("a") + track := r.PostForm.Get("t") + album := r.PostForm.Get("b") - duration := r.FormValue("duration") + if track == "" { + h.respondOK(w, "OK") + return + } + + duration := r.PostForm.Get("l") msPlayed := 0 if duration != "" { if d, err := strconv.Atoi(duration); err == nil { @@ -145,24 +206,21 @@ func (h *LastFMHandler) handleNowPlaying(w http.ResponseWriter, r *http.Request) } } - if track != "" { - UpdateNowPlaying(NowPlaying{ - UserId: userId, - SongName: track, - Artist: artist, - Album: album, - MsPlayed: msPlayed, - Platform: "lastfm_api", - UpdatedAt: time.Now(), - }) - } + UpdateNowPlaying(NowPlaying{ + UserId: userId, + SongName: track, + Artist: artist, + Album: album, + MsPlayed: msPlayed, + Platform: "lastfm_api", + UpdatedAt: time.Now(), + }) - h.respondOK(w, ` -`) + h.respondOK(w, "OK") } func (h *LastFMHandler) handleScrobble(w http.ResponseWriter, r *http.Request) { - sessionKey := r.FormValue("sk") + sessionKey := r.PostForm.Get("s") if sessionKey == "" { h.respond(w, "failed", 9, "Invalid session") return @@ -174,7 +232,7 @@ func (h *LastFMHandler) handleScrobble(w http.ResponseWriter, r *http.Request) { return } - scrobbles := h.parseScrobbles(r.Form, userId) + scrobbles := h.parseScrobbles(r.PostForm, userId) if len(scrobbles) == 0 { h.respond(w, "failed", 1, "No scrobbles to submit") return @@ -194,10 +252,7 @@ func (h *LastFMHandler) handleScrobble(w http.ResponseWriter, r *http.Request) { ClearNowPlaying(userId) - h.respondOK(w, fmt.Sprintf(` - - -`, accepted, ignored)) + h.respondOK(w, fmt.Sprintf("OK\n%d\n%d\n", accepted, ignored)) } func (h *LastFMHandler) parseScrobbles(form url.Values, userId int) []Scrobble { @@ -207,15 +262,15 @@ func (h *LastFMHandler) parseScrobbles(form url.Values, userId int) []Scrobble { var artist, track, album, timestampStr string if i == 0 { - artist = form.Get("artist") - track = form.Get("track") - album = form.Get("album") - timestampStr = form.Get("timestamp") + artist = form.Get("a[0]") + track = form.Get("t[0]") + album = form.Get("b[0]") + timestampStr = form.Get("i[0]") } else { - artist = form.Get(fmt.Sprintf("artist[%d]", i-1)) - track = form.Get(fmt.Sprintf("track[%d]", i-1)) - album = form.Get(fmt.Sprintf("album[%d]", i-1)) - timestampStr = form.Get(fmt.Sprintf("timestamp[%d]", i-1)) + artist = form.Get(fmt.Sprintf("a[%d]", i)) + track = form.Get(fmt.Sprintf("t[%d]", i)) + album = form.Get(fmt.Sprintf("b[%d]", i)) + timestampStr = form.Get(fmt.Sprintf("i[%d]", i)) } if artist == "" || track == "" || timestampStr == "" { @@ -227,7 +282,7 @@ func (h *LastFMHandler) parseScrobbles(form url.Values, userId int) []Scrobble { continue } - duration := form.Get(fmt.Sprintf("duration[%d]", i-1)) + duration := form.Get(fmt.Sprintf("l[%d]", i)) msPlayed := 0 if duration != "" { if d, err := strconv.Atoi(duration); err == nil { diff --git a/scrobble/scrobble.go b/scrobble/scrobble.go index aa5e34e..366235d 100644 --- a/scrobble/scrobble.go +++ b/scrobble/scrobble.go @@ -36,7 +36,7 @@ type NowPlaying struct { UpdatedAt time.Time } -var CurrentNowPlaying = make(map[int]NowPlaying) +var CurrentNowPlaying = make(map[int]map[string]NowPlaying) func GenerateAPIKey() (string, error) { bytes := make([]byte, 16) @@ -80,6 +80,20 @@ func GetUserByAPIKey(apiKey string) (int, string, error) { return userId, username, nil } +func GetUserByUsername(username string) (int, error) { + if username == "" { + return 0, fmt.Errorf("empty username") + } + + var userId int + err := db.Pool.QueryRow(context.Background(), + "SELECT pk FROM users WHERE username = $1", username).Scan(&userId) + if err != nil { + return 0, err + } + return userId, nil +} + func GetUserBySessionKey(sessionKey string) (int, string, error) { if sessionKey == "" { return 0, "", fmt.Errorf("empty session key") @@ -167,18 +181,38 @@ func checkDuplicate(userId int, artist, songName string, timestamp time.Time) (b } func UpdateNowPlaying(np NowPlaying) { - CurrentNowPlaying[np.UserId] = np + if CurrentNowPlaying[np.UserId] == nil { + CurrentNowPlaying[np.UserId] = make(map[string]NowPlaying) + } + CurrentNowPlaying[np.UserId][np.Platform] = np } func GetNowPlaying(userId int) (NowPlaying, bool) { - np, ok := CurrentNowPlaying[userId] - return np, ok + platforms := CurrentNowPlaying[userId] + if platforms == nil { + return NowPlaying{}, false + } + np, ok := platforms["lastfm_api"] + if ok && np.SongName != "" { + return np, true + } + np, ok = platforms["spotify"] + if ok && np.SongName != "" { + return np, true + } + return NowPlaying{}, false } func ClearNowPlaying(userId int) { delete(CurrentNowPlaying, userId) } +func ClearNowPlayingPlatform(userId int, platform string) { + if CurrentNowPlaying[userId] != nil { + delete(CurrentNowPlaying[userId], platform) + } +} + func GetUserSpotifyCredentials(userId int) (clientId, clientSecret, accessToken, refreshToken string, expiresAt time.Time, err error) { var clientIdPg, clientSecretPg, accessTokenPg, refreshTokenPg pgtype.Text var expiresAtPg pgtype.Timestamptz diff --git a/scrobble/spotify.go b/scrobble/spotify.go index 9cb5464..363f1b5 100644 --- a/scrobble/spotify.go +++ b/scrobble/spotify.go @@ -297,7 +297,7 @@ func checkCurrentlyPlaying(userId int, accessToken string) error { defer resp.Body.Close() if resp.StatusCode == 204 { - ClearNowPlaying(userId) + ClearNowPlayingPlatform(userId, "spotify") return nil } @@ -311,7 +311,7 @@ func checkCurrentlyPlaying(userId int, accessToken string) error { } if !playing.IsPlaying || playing.Item.Name == "" { - ClearNowPlaying(userId) + ClearNowPlayingPlatform(userId, "spotify") return nil } diff --git a/web/web.go b/web/web.go index 5d65eae..09a1414 100644 --- a/web/web.go +++ b/web/web.go @@ -16,9 +16,15 @@ import ( "github.com/go-chi/chi/v5/middleware" ) +const serverAddr = "127.0.0.1:1234" + // 50 MiB const maxHeaderSize int64 = 50 * 1024 * 1024 +func serverAddrStr() string { + return serverAddr +} + // Holds all the parsed HTML templates var templates *template.Template @@ -68,7 +74,7 @@ func rootHandler() http.HandlerFunc { // Serves all pages at the specified address. func Start() { - addr := ":1234" + addr := serverAddr r := chi.NewRouter() r.Use(middleware.Logger) r.Handle("/files/*", http.StripPrefix("/files", http.FileServer(http.Dir("./static")))) @@ -84,7 +90,8 @@ func Start() { r.Get("/import/lastfm/progress", importLastFMProgressHandler) r.Get("/import/spotify/progress", importSpotifyProgressHandler) - r.Post("/2.0/", http.HandlerFunc(scrobble.NewLastFMHandler().ServeHTTP)) + r.Handle("/2.0", scrobble.NewLastFMHandler()) + r.Handle("/2.0/", scrobble.NewLastFMHandler()) r.Post("/1/submit-listens", http.HandlerFunc(scrobble.NewListenbrainzHandler().ServeHTTP)) r.Route("/scrobble/spotify", func(r chi.Router) { r.Get("/authorize", http.HandlerFunc(scrobble.NewSpotifyHandler().ServeHTTP))