feat: add Docker infrastructure, migrations, CI/CD client, session cleanup, tests
Docker & Deployment: - Add Dockerfile (multi-stage, alpine, non-root) - Add docker-compose.yml (caddy, backend, postgres, watchtower) - Add Caddyfile (TLS, file_server, reverse proxy) - Add .env.example Database: - Add migrations/001_init.sql (all tables + indexes) CI/CD: - Add cmd/ci-release/main.go (launcher binary upload tool) Session management: - Add internal/session/cleanup.go (background expired session cleanup) - Integrate cleanup worker into main.go Bug fixes: - Fix launcherLatest download URL to include version segment - Fix serveLauncherAsset path to match route pattern - Add Content-Type detection from file extension in CAS serveFile - Add empty-field validation in webLogin - Format string fix in ci-release (%d → %s for resp.Status) Tests: - Add internal/auth/auth_test.go (8 tests) - Add internal/cas/cas_test.go (7 tests) - Add internal/session/cleanup_test.go (1 test) - Add internal/api/api_test.go (5 tests) - All tests passing, go vet clean Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -64,7 +64,7 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
mux.HandleFunc("GET /api/web/profile/{uuid}", h.getProfile)
|
||||
|
||||
// Skin serving.
|
||||
mux.HandleFunc("GET /skins/{hash}.png", h.serveSkin)
|
||||
mux.HandleFunc("GET /skins/{hash}", h.serveSkin)
|
||||
}
|
||||
|
||||
// ── Request / Response types ──────────────────────────────────
|
||||
@@ -211,6 +211,11 @@ func (h *Handler) webLogin(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Username == "" || req.Password == "" {
|
||||
writeError(w, http.StatusBadRequest, "Username and password are required")
|
||||
return
|
||||
}
|
||||
|
||||
var user database.User
|
||||
err = h.db.Pool().QueryRow(r.Context(),
|
||||
`SELECT id, username, password_hash, uuid FROM users
|
||||
@@ -442,8 +447,8 @@ func (h *Handler) launcherLatest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
downloadURL := fmt.Sprintf("%s/files/launcher/%s/%s/%s",
|
||||
h.cfg.BaseURL, osParam, archParam, filepath.Base(release.FilePath))
|
||||
downloadURL := fmt.Sprintf("%s/files/launcher/%s/%s/%s/%s",
|
||||
h.cfg.BaseURL, release.Version, osParam, archParam, filepath.Base(release.FilePath))
|
||||
|
||||
writeJSON(w, http.StatusOK, launcherLatestResponse{
|
||||
Version: release.Version,
|
||||
|
||||
216
internal/api/api_test.go
Normal file
216
internal/api/api_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/config"
|
||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
||||
)
|
||||
|
||||
// newTestHandler creates an API handler with a nil DB for testing validation
|
||||
// and routing only. Handers that touch the database will panic with nil DB —
|
||||
// integration tests with a real database cover those paths.
|
||||
func newTestHandler(t *testing.T) *Handler {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
CASDir: dir,
|
||||
SkinsDir: dir,
|
||||
BaseURL: "https://test.example.com",
|
||||
JWTSecret: "test-secret",
|
||||
}
|
||||
return &Handler{db: &database.DB{}, cfg: cfg}
|
||||
}
|
||||
|
||||
// TestRegisterValidation tests input validation in the register handler
|
||||
// without requiring a database connection.
|
||||
func TestRegisterValidation(t *testing.T) {
|
||||
h := newTestHandler(t)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStatus int
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty body",
|
||||
body: "{}",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "Username, email and password are required",
|
||||
},
|
||||
{
|
||||
name: "invalid email",
|
||||
body: `{"username":"test","email":"notanemail","password":"pass"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "Invalid email address",
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
body: "not json",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantErr: "Invalid JSON",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/web/register",
|
||||
bytes.NewReader([]byte(tt.body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", w.Code, tt.wantStatus)
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if got := resp["error"]; got != tt.wantErr {
|
||||
t.Errorf("error = %q, want %q", got, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebLoginValidation tests input validation in the web login handler.
|
||||
func TestWebLoginValidation(t *testing.T) {
|
||||
h := newTestHandler(t)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "empty credentials",
|
||||
body: `{"username":"","password":""}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing username",
|
||||
body: `{"password":"secret"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing password",
|
||||
body: `{"username":"test"}`,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid json",
|
||||
body: "not json",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/api/web/login",
|
||||
bytes.NewReader([]byte(tt.body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", w.Code, tt.wantStatus)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLauncherLatest_MissingParams tests that missing query parameters
|
||||
// return 400 without hitting the database.
|
||||
func TestLauncherLatest_MissingParams(t *testing.T) {
|
||||
h := newTestHandler(t)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
queries := []string{
|
||||
"/api/launcher/latest",
|
||||
"/api/launcher/latest?os=windows",
|
||||
"/api/launcher/latest?arch=amd64",
|
||||
}
|
||||
|
||||
for _, q := range queries {
|
||||
req := httptest.NewRequest("GET", q, nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("%s: expected 400, got %d", q, w.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_NoToken tests that protected endpoints reject
|
||||
// requests without a Bearer token.
|
||||
func TestAuthMiddleware_NoToken(t *testing.T) {
|
||||
h := newTestHandler(t)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
protected := []struct {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"POST", "/api/web/profile/skin"},
|
||||
{"POST", "/api/web/profile/cape"},
|
||||
{"DELETE", "/api/web/profile/skin"},
|
||||
{"DELETE", "/api/web/profile/cape"},
|
||||
}
|
||||
|
||||
for _, ep := range protected {
|
||||
t.Run(ep.method+" "+ep.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(ep.method, ep.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s: expected 401, got %d", ep.method, ep.path, w.Code)
|
||||
}
|
||||
var resp map[string]string
|
||||
json.Unmarshal(w.Body.Bytes(), &resp)
|
||||
if !strings.Contains(resp["error"], "Missing authorization") {
|
||||
t.Errorf("unexpected error: %s", resp["error"])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRoutesRegistered verifies all expected API routes are mounted
|
||||
// and return proper HTTP status codes (not 404 for known paths).
|
||||
func TestRoutesRegistered(t *testing.T) {
|
||||
h := newTestHandler(t)
|
||||
mux := http.NewServeMux()
|
||||
h.RegisterRoutes(mux)
|
||||
|
||||
// Public routes that should respond without a database.
|
||||
// Only routes with early validation (before DB access) are listed.
|
||||
knownRoutes := []struct {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"POST", "/api/web/register"},
|
||||
{"POST", "/api/web/login"},
|
||||
{"GET", "/api/launcher/latest"},
|
||||
}
|
||||
|
||||
for _, r := range knownRoutes {
|
||||
t.Run(r.method+" "+r.path, func(t *testing.T) {
|
||||
req := httptest.NewRequest(r.method, r.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
mux.ServeHTTP(w, req)
|
||||
// Should not be 404 (route exists).
|
||||
if w.Code == http.StatusNotFound {
|
||||
t.Errorf("%s %s: route not found (404)", r.method, r.path)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
117
internal/auth/auth_test.go
Normal file
117
internal/auth/auth_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateToken(t *testing.T) {
|
||||
tok := GenerateToken()
|
||||
if len(tok) != 32 {
|
||||
t.Errorf("expected 32-char token, got %d chars: %s", len(tok), tok)
|
||||
}
|
||||
// Must be hex.
|
||||
for _, c := range tok {
|
||||
if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f')) {
|
||||
t.Errorf("token contains non-hex char: %c", c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateToken_Uniqueness(t *testing.T) {
|
||||
// Two tokens should never collide.
|
||||
t1 := GenerateToken()
|
||||
t2 := GenerateToken()
|
||||
if t1 == t2 {
|
||||
t.Error("two generated tokens are identical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUUID(t *testing.T) {
|
||||
uuid := GenerateUUID()
|
||||
// Format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx (36 chars).
|
||||
if len(uuid) != 36 {
|
||||
t.Errorf("expected 36-char UUID, got %d: %s", len(uuid), uuid)
|
||||
}
|
||||
// Check dashes at correct positions.
|
||||
for _, pos := range []int{8, 13, 18, 23} {
|
||||
if uuid[pos] != '-' {
|
||||
t.Errorf("expected dash at position %d, got %c", pos, uuid[pos])
|
||||
}
|
||||
}
|
||||
// Version 4: char at position 14 should be '4'.
|
||||
if uuid[14] != '4' {
|
||||
t.Errorf("expected version 4 at position 14, got %c", uuid[14])
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateUUID_Uniqueness(t *testing.T) {
|
||||
u1 := GenerateUUID()
|
||||
u2 := GenerateUUID()
|
||||
if u1 == u2 {
|
||||
t.Error("two generated UUIDs are identical")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHashPassword(t *testing.T) {
|
||||
hash, err := HashPassword("testpassword")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword failed: %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(hash, "$2a$") {
|
||||
t.Errorf("expected bcrypt hash starting with $2a$, got: %s", hash[:4])
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyPassword(t *testing.T) {
|
||||
hash, err := HashPassword("minecraft123")
|
||||
if err != nil {
|
||||
t.Fatalf("HashPassword failed: %v", err)
|
||||
}
|
||||
|
||||
if !VerifyPassword("minecraft123", hash) {
|
||||
t.Error("VerifyPassword returned false for correct password")
|
||||
}
|
||||
if VerifyPassword("wrongpassword", hash) {
|
||||
t.Error("VerifyPassword returned true for wrong password")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsBcryptHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
hash string
|
||||
want bool
|
||||
}{
|
||||
{"$2a$10$abcdefghijklmnopqrstuuxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", true},
|
||||
{"$2b$10$abcdefghijklmnopqrstuuxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", true},
|
||||
{"$2y$10$abcdefghijklmnopqrstuuxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", true},
|
||||
{"5e884898da28047151d0e56f8dc6292773603d0d6aabbdd62a11ef721d1542d8", false},
|
||||
{"", false},
|
||||
{"plaintext", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := IsBcryptHash(tt.hash)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsBcryptHash(%q) = %v, want %v", tt.hash, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractBearer(t *testing.T) {
|
||||
tests := []struct {
|
||||
header string
|
||||
want string
|
||||
}{
|
||||
{"Bearer abc123", "abc123"},
|
||||
{"Bearer ", ""},
|
||||
{"abc123", ""},
|
||||
{"", ""},
|
||||
{"Basic abc123", ""},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := ExtractBearer(tt.header)
|
||||
if got != tt.want {
|
||||
t.Errorf("ExtractBearer(%q) = %q, want %q", tt.header, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,33 @@ import (
|
||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
||||
)
|
||||
|
||||
// mimeByExtension maps common file extensions to MIME types for CAS serving.
|
||||
var mimeByExtension = map[string]string{
|
||||
".jar": "application/java-archive",
|
||||
".json": "application/json",
|
||||
".png": "image/png",
|
||||
".zip": "application/zip",
|
||||
".toml": "application/toml",
|
||||
".cfg": "text/plain",
|
||||
".conf": "text/plain",
|
||||
".txt": "text/plain",
|
||||
".log": "text/plain",
|
||||
".xml": "application/xml",
|
||||
".yml": "application/x-yaml",
|
||||
".yaml": "application/x-yaml",
|
||||
".properties": "text/plain",
|
||||
}
|
||||
|
||||
// detectContentType returns a MIME type based on the file's extension.
|
||||
// Falls back to application/octet-stream for unknown types.
|
||||
func detectContentType(fileName string) string {
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
if mime, ok := mimeByExtension[ext]; ok {
|
||||
return mime
|
||||
}
|
||||
return "application/octet-stream"
|
||||
}
|
||||
|
||||
// Handler serves CAS endpoints.
|
||||
type Handler struct {
|
||||
db *database.DB
|
||||
@@ -34,12 +61,13 @@ func (h *Handler) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Public file serving — immutable, long cache.
|
||||
mux.HandleFunc("GET /files/{hash}", h.serveFile)
|
||||
|
||||
// Launcher binary downloads — also served from CAS-like paths.
|
||||
// Launcher binary downloads — served from /files/launcher/{version}/{os}/{arch}/{filename}.
|
||||
mux.HandleFunc("GET /files/launcher/{version}/{os}/{arch}/{filename}", h.serveLauncherAsset)
|
||||
}
|
||||
|
||||
// serveFile serves a file from CAS by its SHA-1 hash.
|
||||
// Files are immutable, so we set Cache-Control: public, max-age=31536000 (1 year).
|
||||
// Content-Type is detected from the original file name stored in global_files.
|
||||
func (h *Handler) serveFile(w http.ResponseWriter, r *http.Request) {
|
||||
hash := r.PathValue("hash")
|
||||
if !isValidHash(hash) {
|
||||
@@ -54,6 +82,16 @@ func (h *Handler) serveFile(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the original file name for Content-Type detection.
|
||||
var fileName string
|
||||
err = h.db.Pool().QueryRow(r.Context(),
|
||||
`SELECT file_name FROM global_files WHERE sha1 = $1`, hash,
|
||||
).Scan(&fileName)
|
||||
if err != nil {
|
||||
fileName = hash // fallback: no extension info
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", detectContentType(fileName))
|
||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
w.Write(data)
|
||||
}
|
||||
@@ -89,6 +127,7 @@ func (h *Handler) serveLauncherAsset(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", detectContentType(filename))
|
||||
w.Header().Set("Cache-Control", "public, max-age=31536000, immutable")
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
131
internal/cas/cas_test.go
Normal file
131
internal/cas/cas_test.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package cas
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
hash string
|
||||
want bool
|
||||
}{
|
||||
{"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", true},
|
||||
{"0000000000000000000000000000000000000000", true},
|
||||
{"ffffffffffffffffffffffffffffffffffffffff", true},
|
||||
{"A1B2C3D4E5F6A1B2C3D4E5F6A1B2C3D4E5F6A1B2", false}, // uppercase
|
||||
{"g1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2", false}, // non-hex
|
||||
{"a1b2c3d4e5f6", false}, // too short
|
||||
{"", false}, // empty
|
||||
{"a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3", false}, // too long (41)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := isValidHash(tt.hash)
|
||||
if got != tt.want {
|
||||
t.Errorf("isValidHash(%q) = %v, want %v", tt.hash, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
data := []byte("hello minecraft world")
|
||||
|
||||
hash, err := StoreFile(dir, data)
|
||||
if err != nil {
|
||||
t.Fatalf("StoreFile failed: %v", err)
|
||||
}
|
||||
if len(hash) != 40 {
|
||||
t.Errorf("expected 40-char hash, got %d", len(hash))
|
||||
}
|
||||
|
||||
// File should exist at dir/<prefix>/<hash>.
|
||||
path := filepath.Join(dir, hash[:2], hash)
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("stored file not found: %v", err)
|
||||
}
|
||||
if info.Size() != int64(len(data)) {
|
||||
t.Errorf("stored file size = %d, want %d", info.Size(), len(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFile_Duplicate(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
data := []byte("same content")
|
||||
|
||||
h1, err := StoreFile(dir, data)
|
||||
if err != nil {
|
||||
t.Fatalf("first StoreFile failed: %v", err)
|
||||
}
|
||||
h2, err := StoreFile(dir, data)
|
||||
if err != nil {
|
||||
t.Fatalf("second StoreFile failed: %v", err)
|
||||
}
|
||||
if h1 != h2 {
|
||||
t.Errorf("same data produced different hashes: %s vs %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileExists(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
data := []byte("test data")
|
||||
|
||||
hash, _ := StoreFile(dir, data)
|
||||
if !FileExists(dir, hash) {
|
||||
t.Error("FileExists returned false for stored file")
|
||||
}
|
||||
if FileExists(dir, "0000000000000000000000000000000000000000") {
|
||||
t.Error("FileExists returned true for non-existent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyAndStore(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
data := []byte("verify me")
|
||||
hash, _ := StoreFile(dir, data)
|
||||
|
||||
// Correct hash → should succeed (idempotent).
|
||||
got, err := VerifyAndStore(dir, data, hash)
|
||||
if err != nil {
|
||||
t.Errorf("VerifyAndStore with correct hash failed: %v", err)
|
||||
}
|
||||
if got != hash {
|
||||
t.Errorf("hash mismatch: got %s, want %s", got, hash)
|
||||
}
|
||||
|
||||
// Wrong hash → should fail.
|
||||
_, err = VerifyAndStore(dir, data, "0000000000000000000000000000000000000000")
|
||||
if err == nil {
|
||||
t.Error("VerifyAndStore with wrong hash should have failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDetectContentType(t *testing.T) {
|
||||
tests := []struct {
|
||||
fileName string
|
||||
want string
|
||||
}{
|
||||
{"mod.jar", "application/java-archive"},
|
||||
{"config.json", "application/json"},
|
||||
{"skin.png", "image/png"},
|
||||
{"pack.zip", "application/zip"},
|
||||
{"options.toml", "application/toml"},
|
||||
{"server.cfg", "text/plain"},
|
||||
{"notes.txt", "text/plain"},
|
||||
{"data.xml", "application/xml"},
|
||||
{"config.yml", "application/x-yaml"},
|
||||
{"config.yaml", "application/x-yaml"},
|
||||
{"game.properties", "text/plain"},
|
||||
{"unknown.dat", "application/octet-stream"},
|
||||
{"noext", "application/octet-stream"},
|
||||
{"UPPER.JAR", "application/java-archive"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := detectContentType(tt.fileName)
|
||||
if got != tt.want {
|
||||
t.Errorf("detectContentType(%q) = %q, want %q", tt.fileName, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
68
internal/session/cleanup.go
Normal file
68
internal/session/cleanup.go
Normal file
@@ -0,0 +1,68 @@
|
||||
// package session manages Yggdrasil session lifecycle.
|
||||
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
||||
)
|
||||
|
||||
// StartCleanupWorker launches a background goroutine that deletes expired
|
||||
// yggdrasil_sessions every interval. It stops when the context is cancelled.
|
||||
func StartCleanupWorker(db *database.DB, interval time.Duration) context.CancelFunc {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
log.Printf("Session cleanup worker started (interval: %v)", interval)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run once on start.
|
||||
cleanup(db)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
cleanup(db)
|
||||
case <-ctx.Done():
|
||||
log.Println("Session cleanup worker stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return cancel
|
||||
}
|
||||
|
||||
func cleanup(db *database.DB) {
|
||||
if db == nil {
|
||||
return
|
||||
}
|
||||
pool := db.Pool()
|
||||
if pool == nil {
|
||||
return
|
||||
}
|
||||
tag, err := pool.Exec(context.Background(),
|
||||
`DELETE FROM yggdrasil_sessions WHERE expires_at < NOW()`)
|
||||
if err != nil {
|
||||
log.Printf("Session cleanup error: %v", err)
|
||||
return
|
||||
}
|
||||
if tag.RowsAffected() > 0 {
|
||||
log.Printf("Session cleanup: removed %d expired sessions", tag.RowsAffected())
|
||||
}
|
||||
}
|
||||
|
||||
// CountActive returns the number of non-expired sessions.
|
||||
func CountActive(pool *pgxpool.Pool) (int, error) {
|
||||
var count int
|
||||
err := pool.QueryRow(context.Background(),
|
||||
`SELECT COUNT(*) FROM yggdrasil_sessions WHERE expires_at > NOW()`,
|
||||
).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
14
internal/session/cleanup_test.go
Normal file
14
internal/session/cleanup_test.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestStartCleanupWorker(t *testing.T) {
|
||||
// Verify the worker starts and can be cancelled without panic.
|
||||
cancel := StartCleanupWorker(nil, 1*time.Millisecond)
|
||||
defer cancel()
|
||||
// Give it a moment to attempt one cleanup cycle.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
Reference in New Issue
Block a user