diff --git a/migrate/spotify.go b/migrate/spotify.go index e103478..0c0531c 100644 --- a/migrate/spotify.go +++ b/migrate/spotify.go @@ -207,6 +207,8 @@ func JsonToDB(jsonFile string, userId int) error { continue } + // TODO: replace strings.Join with pgx copy + _, err = conn.Exec( context.Background(), `INSERT INTO history (user_id, timestamp, song_name, artist, album_name, ms_played, platform) VALUES `+ diff --git a/templates/create_account.gohtml b/templates/create_account.gohtml index d62fcee..6368c06 100644 --- a/templates/create_account.gohtml +++ b/templates/create_account.gohtml @@ -24,6 +24,16 @@ Unable to create session. Please try again. {{end}} + {{if eq .Error "userlength"}} +
+ Username length must be greater than 0. +
+ {{end}} + {{if eq .Error "usertaken"}} +
+ Username must be unique. Please try again. +
+ {{end}} diff --git a/web/web.go b/web/web.go index e35c066..beae2e0 100644 --- a/web/web.go +++ b/web/web.go @@ -145,7 +145,7 @@ func getUserIdByUsername(ctx context.Context, username string) (int, error) { } func hashPassword(pass []byte) (string, error) { - if len(pass) < 8 || len(pass) > 64 { + if len([]rune(string(pass))) < 8 || len(pass) > 64 { return "", errors.New("Error: Password must be greater than 8 chars.") } hashedPassword, err := bcrypt.GenerateFromPassword(pass, bcrypt.DefaultCost) @@ -167,20 +167,29 @@ func verifyPassword(hashedPassword string, enteredPassword []byte) bool { func createAccount(w http.ResponseWriter, r *http.Request) { if r.Method == "POST" { - r.ParseForm() - - username := r.FormValue("uname") - hashedPassword, err := hashPassword([]byte(r.FormValue("pass"))) + err := r.ParseForm() if err != nil { - fmt.Fprintf(os.Stderr, "Error hashing password: %v\n", err) - http.Redirect(w, r, "/createaccount?error=length", http.StatusSeeOther) + http.Error(w, err.Error(), http.StatusBadRequest) return } - err = db.CreateUsersTable() + username := r.FormValue("uname") + if len([]rune(string(username))) == 0 { + http.Redirect(w, r, "/createaccount?error=userlength", http.StatusSeeOther) + return + } + var usertaken bool + err = db.Pool.QueryRow(r.Context(), + "SELECT EXISTS(SELECT 1 FROM users WHERE username = $1)", username). + Scan(&usertaken) + if usertaken == true { + http.Redirect(w, r, "/createaccount?error=usertaken", http.StatusSeeOther) + return + } + hashedPassword, err := hashPassword([]byte(r.FormValue("pass"))) if err != nil { - fmt.Fprintf(os.Stderr, "Error ensuring users table exists: %v\n", err) - http.Redirect(w, r, "/createaccount", http.StatusSeeOther) + fmt.Fprintf(os.Stderr, "Error hashing password: %v\n", err) + http.Redirect(w, r, "/createaccount?error=passlength", http.StatusSeeOther) return } @@ -228,12 +237,20 @@ func createAccountPageHandler() http.HandlerFunc { func loginSubmit(w http.ResponseWriter, r *http.Request) { if r.Method == "POST" { - r.ParseForm() + err := r.ParseForm() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } username := r.FormValue("uname") + if username == "" { + http.Redirect(w, r, "/login?error=invalid-creds", http.StatusSeeOther) + return + } password := r.FormValue("pass") var storedPassword string - err := db.Pool.QueryRow(r.Context(), "SELECT password FROM users WHERE username = $1;", username). + err = db.Pool.QueryRow(r.Context(), "SELECT password FROM users WHERE username = $1;", username). Scan(&storedPassword) if err != nil { fmt.Fprintf(os.Stderr, "Cannot get password for entered username: %v\n", err) @@ -355,11 +372,15 @@ func profilePageHandler() http.HandlerFunc { func updateDuplicateEditsSetting(w http.ResponseWriter, r *http.Request) { if r.Method == "POST" { - r.ParseForm() + err := r.ParseForm() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } username := r.FormValue("username") allow := r.FormValue("allow") == "true" - _, err := db.Pool.Exec( + _, err = db.Pool.Exec( r.Context(), `UPDATE users SET allow_duplicate_edits = $1 WHERE username = $2;`, allow, @@ -411,9 +432,13 @@ func importLastFMHandler(w http.ResponseWriter, r *http.Request) { return } - r.ParseForm() - lastfmUsername := r.FormValue("lastfm_username") - lastfmAPIKey := r.FormValue("lastfm_api_key") + err = r.ParseForm() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + lastfmUsername := template.HTMLEscapeString(r.FormValue("lastfm_username")) + lastfmAPIKey := template.HTMLEscapeString(r.FormValue("lastfm_api_key")) if lastfmUsername == "" || lastfmAPIKey == "" { http.Error(w, "Missing required fields", http.StatusBadRequest)