You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

151 lines
4.2 KiB

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package main
import (
"context"
"log"
"sync"
"sync/atomic"
"time"
)
// getCredsFunc is the signature for credential retrieval functions
type getCredsFunc func(context.Context, string, int) (string, string, string, error)
// TurnCredentials stores cached TURN credentials
type TurnCredentials struct {
Username string
Password string
ServerAddr string
ExpiresAt time.Time
Link string
}
// StreamCredentialsCache holds credentials cache for a single stream
type StreamCredentialsCache struct {
creds TurnCredentials
mutex sync.RWMutex
errorCount atomic.Int32
lastErrorTime atomic.Int64
}
const (
credentialLifetime = 10 * time.Minute
cacheSafetyMargin = 60 * time.Second
maxCacheErrors = 3
errorWindow = 10 * time.Second
streamsPerCache = 4 // Number of streams sharing one credentials cache
)
// getCacheID returns the shared cache ID for a given stream ID
func getCacheID(streamID int) int {
return streamID / streamsPerCache
}
// credentialsStore manages per-stream credentials caches
var credentialsStore = struct {
mu sync.RWMutex
caches map[int]*StreamCredentialsCache
}{
caches: make(map[int]*StreamCredentialsCache),
}
// getStreamCache returns or creates a shared cache for the given stream ID
func getStreamCache(streamID int) *StreamCredentialsCache {
cacheID := getCacheID(streamID)
// Try read lock first for fast path
credentialsStore.mu.RLock()
cache, exists := credentialsStore.caches[cacheID]
credentialsStore.mu.RUnlock()
if exists {
return cache
}
// Need to create new cache
credentialsStore.mu.Lock()
defer credentialsStore.mu.Unlock()
// Double-check after acquiring write lock
if cache, exists = credentialsStore.caches[cacheID]; exists {
return cache
}
cache = &StreamCredentialsCache{}
credentialsStore.caches[cacheID] = cache
return cache
}
// invalidate invalidates the credentials cache for this stream
func (c *StreamCredentialsCache) invalidate(streamID int) {
c.mutex.Lock()
c.creds = TurnCredentials{}
c.mutex.Unlock()
// Reset auth error counter
c.errorCount.Store(0)
c.lastErrorTime.Store(0)
log.Printf("[Auth] Credentials cache invalidated for stream %d", streamID)
}
// fetchMu serializes credential fetching to avoid API rate limiting
var fetchMu sync.Mutex
// fetchFunc is the signature for credential retrieval functions (without cache logic)
type fetchFunc func(ctx context.Context, link string) (string, string, string, error)
// serializeFetch wraps a fetch call with the global fetchMu to avoid API rate limiting
func serializeFetch(ctx context.Context, link string, storeFn fetchFunc) (string, string, string, error) {
fetchMu.Lock()
defer fetchMu.Unlock()
return storeFn(ctx, link)
}
// getCredsCached checks cache before fetching credentials.
// This is the general entry point for credential retrieval with caching.
func getCredsCached(ctx context.Context, link string, streamID int, storeFn fetchFunc) (string, string, string, error) {
cache := getStreamCache(streamID)
cacheID := getCacheID(streamID)
cache.mutex.Lock()
defer cache.mutex.Unlock()
// Check cache - another stream may have populated it while waiting
if cache.creds.Link == link && time.Now().Before(cache.creds.ExpiresAt) {
expires := time.Until(cache.creds.ExpiresAt)
log.Printf("[Auth] Using cached credentials (cache=%d, expires in %v)", cacheID, expires)
return cache.creds.Username, cache.creds.Password, cache.creds.ServerAddr, nil
}
log.Printf("[Auth] Cache miss (cache=%d), starting credential fetch...", cacheID)
// Check context before long fetch
select {
case <-ctx.Done():
return "", "", "", ctx.Err()
default:
}
// Fetch credentials with global mutex to avoid API rate limiting
user, pass, addr, err := serializeFetch(ctx, link, storeFn)
if err != nil {
return "", "", "", err
}
// Store in cache
cache.creds = TurnCredentials{
Username: user,
Password: pass,
ServerAddr: addr,
ExpiresAt: time.Now().Add(credentialLifetime - cacheSafetyMargin),
Link: link,
}
log.Printf("[Auth] Success! Credentials cached until %v (cache=%d)", cache.creds.ExpiresAt, cacheID)
return user, pass, addr, nil
}