290 lines
7.8 KiB
Go
290 lines
7.8 KiB
Go
// package auth implements the Yggdrasil authentication protocol.
|
|
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/config"
|
|
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
|
)
|
|
|
|
// Handler serves Yggdrasil endpoints.
|
|
type Handler struct {
|
|
db *database.DB
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewHandler creates a new auth handler.
|
|
func NewHandler(db *database.DB, cfg *config.Config) *Handler {
|
|
return &Handler{db: db, cfg: cfg}
|
|
}
|
|
|
|
// RegisterRoutes mounts the Yggdrasil endpoints on the given mux.
|
|
func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
|
mux.HandleFunc("POST /authserver/authenticate", h.authenticate)
|
|
mux.HandleFunc("POST /authserver/refresh", h.refresh)
|
|
mux.HandleFunc("POST /authserver/validate", h.validate)
|
|
}
|
|
|
|
// ── Request / Response types ──────────────────────────────────
|
|
|
|
type authenticateRequest struct {
|
|
Username string `json:"username"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
type authenticateResponse struct {
|
|
AccessToken string `json:"accessToken"`
|
|
ClientToken string `json:"clientToken"`
|
|
AvailableProfile []profile `json:"availableProfiles"`
|
|
SelectedProfile *profile `json:"selectedProfile,omitempty"`
|
|
User *userProperties `json:"user,omitempty"`
|
|
}
|
|
|
|
type profile struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
type userProperties struct {
|
|
ID string `json:"id"`
|
|
Properties []property `json:"properties"`
|
|
}
|
|
|
|
type property struct {
|
|
Name string `json:"name"`
|
|
Value string `json:"value"`
|
|
}
|
|
|
|
type refreshRequest struct {
|
|
AccessToken string `json:"accessToken"`
|
|
ClientToken string `json:"clientToken"`
|
|
}
|
|
|
|
type refreshResponse struct {
|
|
AccessToken string `json:"accessToken"`
|
|
ClientToken string `json:"clientToken"`
|
|
SelectedProfile *profile `json:"selectedProfile"`
|
|
}
|
|
|
|
type errorResponse struct {
|
|
Error string `json:"error"`
|
|
ErrorMessage string `json:"errorMessage"`
|
|
}
|
|
|
|
// ── Handlers ──────────────────────────────────────────────────
|
|
|
|
func (h *Handler) authenticate(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "Bad Request", "Cannot read body")
|
|
return
|
|
}
|
|
|
|
var req authenticateRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "Bad Request", "Invalid JSON")
|
|
return
|
|
}
|
|
|
|
// Look up user by username or email.
|
|
user, err := h.findUser(r.Context(), req.Username)
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "Forbidden", "Invalid credentials")
|
|
return
|
|
}
|
|
|
|
// Verify password (SHA-256 hex comparison).
|
|
if !verifyPassword(req.Password, user.PasswordHash) {
|
|
writeError(w, http.StatusUnauthorized, "Forbidden", "Invalid credentials")
|
|
return
|
|
}
|
|
|
|
// Generate tokens.
|
|
accessToken := GenerateToken()
|
|
clientToken := GenerateToken()
|
|
|
|
// Store session.
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
_, err = h.db.Pool().Exec(r.Context(),
|
|
`INSERT INTO yggdrasil_sessions (client_token, access_token, user_id, expires_at)
|
|
VALUES ($1, $2, $3, $4)`,
|
|
clientToken, accessToken, user.ID, expiresAt,
|
|
)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "Internal Error", "Failed to create session")
|
|
return
|
|
}
|
|
|
|
resp := authenticateResponse{
|
|
AccessToken: accessToken,
|
|
ClientToken: clientToken,
|
|
SelectedProfile: &profile{
|
|
ID: user.UUID,
|
|
Name: user.Username,
|
|
},
|
|
AvailableProfile: []profile{{
|
|
ID: user.UUID,
|
|
Name: user.Username,
|
|
}},
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, resp)
|
|
}
|
|
|
|
func (h *Handler) refresh(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
writeError(w, http.StatusBadRequest, "Bad Request", "Cannot read body")
|
|
return
|
|
}
|
|
|
|
var req refreshRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
writeError(w, http.StatusBadRequest, "Bad Request", "Invalid JSON")
|
|
return
|
|
}
|
|
|
|
// Look up session.
|
|
var userID int
|
|
var expiresAt time.Time
|
|
err = h.db.Pool().QueryRow(r.Context(),
|
|
`SELECT user_id, expires_at FROM yggdrasil_sessions
|
|
WHERE access_token = $1 AND client_token = $2`,
|
|
req.AccessToken, req.ClientToken,
|
|
).Scan(&userID, &expiresAt)
|
|
|
|
if err != nil {
|
|
writeError(w, http.StatusUnauthorized, "Forbidden", "Invalid token")
|
|
return
|
|
}
|
|
|
|
if time.Now().After(expiresAt) {
|
|
writeError(w, http.StatusUnauthorized, "Forbidden", "Token expired")
|
|
return
|
|
}
|
|
|
|
// Rotate access token.
|
|
newAccessToken := GenerateToken()
|
|
_, err = h.db.Pool().Exec(r.Context(),
|
|
`UPDATE yggdrasil_sessions SET access_token = $1, expires_at = $2
|
|
WHERE access_token = $3`,
|
|
newAccessToken, time.Now().Add(24*time.Hour), req.AccessToken,
|
|
)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "Internal Error", "Failed to refresh")
|
|
return
|
|
}
|
|
|
|
// Get user info.
|
|
var username, uuid string
|
|
err = h.db.Pool().QueryRow(r.Context(),
|
|
`SELECT username, uuid FROM users WHERE id = $1`, userID,
|
|
).Scan(&username, &uuid)
|
|
if err != nil {
|
|
writeError(w, http.StatusInternalServerError, "Internal Error", "User not found")
|
|
return
|
|
}
|
|
|
|
writeJSON(w, http.StatusOK, refreshResponse{
|
|
AccessToken: newAccessToken,
|
|
ClientToken: req.ClientToken,
|
|
SelectedProfile: &profile{
|
|
ID: uuid,
|
|
Name: username,
|
|
},
|
|
})
|
|
}
|
|
|
|
func (h *Handler) validate(w http.ResponseWriter, r *http.Request) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
var req refreshRequest
|
|
if err := json.Unmarshal(body, &req); err != nil {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
var expiresAt time.Time
|
|
err = h.db.Pool().QueryRow(r.Context(),
|
|
`SELECT expires_at FROM yggdrasil_sessions
|
|
WHERE access_token = $1 AND client_token = $2`,
|
|
req.AccessToken, req.ClientToken,
|
|
).Scan(&expiresAt)
|
|
|
|
if err != nil || time.Now().After(expiresAt) {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
// ── Helpers ───────────────────────────────────────────────────
|
|
|
|
func (h *Handler) findUser(ctx context.Context, login string) (*database.User, error) {
|
|
var user database.User
|
|
err := h.db.Pool().QueryRow(ctx,
|
|
`SELECT id, username, email, password_hash, uuid, role FROM users
|
|
WHERE username = $1 OR email = $1`, login,
|
|
).Scan(&user.ID, &user.Username, &user.Email, &user.PasswordHash, &user.UUID, &user.Role)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func verifyPassword(password, hash string) bool {
|
|
h := sha256.Sum256([]byte(password))
|
|
return subtle.ConstantTimeCompare([]byte(hex.EncodeToString(h[:])), []byte(hash)) == 1
|
|
}
|
|
|
|
// GenerateToken creates a random hex token (16 bytes → 32 hex chars).
|
|
func GenerateToken() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
_ = json.NewEncoder(w).Encode(v)
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, status int, err, msg string) {
|
|
writeJSON(w, status, errorResponse{
|
|
Error: err,
|
|
ErrorMessage: msg,
|
|
})
|
|
}
|
|
|
|
// HashPassword returns the SHA-256 hex of a password for storage.
|
|
func HashPassword(password string) string {
|
|
h := sha256.Sum256([]byte(password))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// GenerateUUID creates a random UUID v4-like string.
|
|
func GenerateUUID() string {
|
|
b := make([]byte, 16)
|
|
_, _ = rand.Read(b)
|
|
b[6] = (b[6] & 0x0f) | 0x40 // version 4
|
|
b[8] = (b[8] & 0x3f) | 0x80 // variant
|
|
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
|
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
|
}
|