fix: add per-hash mutex to prevent concurrent CAS writes
StoreFile now uses a per-hash sync.Mutex to prevent race conditions when multiple workers (launcher fetcher or parallel uploads) write the same file simultaneously. Duplicate writes are idempotent — if another goroutine stored the file while we waited, return the existing hash without re-writing.
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/pkg/utils"
|
"gitea.mrixs.me/Mrixs/MrixsCraft-server/pkg/utils"
|
||||||
|
|
||||||
@@ -18,6 +19,37 @@ import (
|
|||||||
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
"gitea.mrixs.me/Mrixs/MrixsCraft-server/internal/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// hashLocks provides per-hash mutexes to prevent concurrent writes
|
||||||
|
// to the same CAS entry. Protected by mu.
|
||||||
|
var (
|
||||||
|
hashLocks = make(map[string]*sync.Mutex)
|
||||||
|
hashLocksMu sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// acquireLock returns (and creates if needed) the mutex for a given hash
|
||||||
|
// and locks it. Caller MUST call releaseLock for the same hash.
|
||||||
|
func acquireLock(hash string) {
|
||||||
|
hashLocksMu.Lock()
|
||||||
|
mu, ok := hashLocks[hash]
|
||||||
|
if !ok {
|
||||||
|
mu = &sync.Mutex{}
|
||||||
|
hashLocks[hash] = mu
|
||||||
|
}
|
||||||
|
hashLocksMu.Unlock()
|
||||||
|
mu.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// releaseLock unlocks the per-hash mutex. Must be called after acquireLock
|
||||||
|
// to avoid deadlocks.
|
||||||
|
func releaseLock(hash string) {
|
||||||
|
hashLocksMu.Lock()
|
||||||
|
mu, ok := hashLocks[hash]
|
||||||
|
hashLocksMu.Unlock()
|
||||||
|
if ok {
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// mimeByExtension maps common file extensions to MIME types for CAS serving.
|
// mimeByExtension maps common file extensions to MIME types for CAS serving.
|
||||||
var mimeByExtension = map[string]string{
|
var mimeByExtension = map[string]string{
|
||||||
".jar": "application/java-archive",
|
".jar": "application/java-archive",
|
||||||
@@ -147,8 +179,16 @@ func isValidHash(hash string) bool {
|
|||||||
|
|
||||||
// StoreFile writes data to the CAS directory structure.
|
// StoreFile writes data to the CAS directory structure.
|
||||||
// Returns the SHA-1 hash of the stored data.
|
// Returns the SHA-1 hash of the stored data.
|
||||||
|
// Uses a per-hash mutex to prevent concurrent writes of the same entry.
|
||||||
func StoreFile(casDir string, data []byte) (string, error) {
|
func StoreFile(casDir string, data []byte) (string, error) {
|
||||||
hash := utils.SHA1Bytes(data)
|
hash := utils.SHA1Bytes(data)
|
||||||
|
acquireLock(hash)
|
||||||
|
defer releaseLock(hash)
|
||||||
|
|
||||||
|
if FileExists(casDir, hash) {
|
||||||
|
return hash, nil // Already stored by a concurrent caller.
|
||||||
|
}
|
||||||
|
|
||||||
destDir := filepath.Join(casDir, hash[:2])
|
destDir := filepath.Join(casDir, hash[:2])
|
||||||
if err := os.MkdirAll(destDir, 0o755); err != nil {
|
if err := os.MkdirAll(destDir, 0o755); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
|||||||
@@ -1,8 +1,12 @@
|
|||||||
package cas
|
package cas
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/hex"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,6 +72,45 @@ func TestStoreFile_Duplicate(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStoreFile_ConcurrentSameHash(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
data := []byte("concurrent write test")
|
||||||
|
|
||||||
|
const workers = 10
|
||||||
|
var success int64
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(workers)
|
||||||
|
|
||||||
|
for i := 0; i < workers; i++ {
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
hash, err := StoreFile(dir, data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("StoreFile failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(hash) != 40 {
|
||||||
|
t.Errorf("invalid hash length: %d", len(hash))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.AddInt64(&success, 1)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if success != workers {
|
||||||
|
t.Errorf("expected %d successes, got %d", workers, success)
|
||||||
|
}
|
||||||
|
|
||||||
|
// All goroutines must produce the same hash for identical data.
|
||||||
|
h := sha1.Sum(data)
|
||||||
|
hash := hex.EncodeToString(h[:])
|
||||||
|
if !FileExists(dir, hash) {
|
||||||
|
t.Error("file not found after concurrent writes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFileExists(t *testing.T) {
|
func TestFileExists(t *testing.T) {
|
||||||
dir := t.TempDir()
|
dir := t.TempDir()
|
||||||
data := []byte("test data")
|
data := []byte("test data")
|
||||||
|
|||||||
Reference in New Issue
Block a user