diff --git a/proxy/vmess/encoding/encoding_test.go b/proxy/vmess/encoding/encoding_test.go index 8d0211e48..ee6b1a1d6 100644 --- a/proxy/vmess/encoding/encoding_test.go +++ b/proxy/vmess/encoding/encoding_test.go @@ -46,7 +46,7 @@ func TestRequestSerialization(t *testing.T) { buffer2.Append(buffer.Bytes()) ctx, cancel := context.WithCancel(context.Background()) - sessionHistory := NewSessionHistory(ctx) + sessionHistory := NewSessionHistory() userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash) userValidator.Add(user) diff --git a/proxy/vmess/encoding/server.go b/proxy/vmess/encoding/server.go index 0859ed822..8a81a33a9 100644 --- a/proxy/vmess/encoding/server.go +++ b/proxy/vmess/encoding/server.go @@ -1,7 +1,6 @@ package encoding import ( - "context" "crypto/aes" "crypto/cipher" "crypto/md5" @@ -32,26 +31,25 @@ type SessionHistory struct { sync.RWMutex cache map[sessionId]time.Time token *signal.Semaphore - ctx context.Context + timer *time.Timer } -func NewSessionHistory(ctx context.Context) *SessionHistory { +func NewSessionHistory() *SessionHistory { h := &SessionHistory{ cache: make(map[sessionId]time.Time, 128), token: signal.NewSemaphore(1), - ctx: ctx, } return h } func (h *SessionHistory) add(session sessionId) { h.Lock() - h.cache[session] = time.Now().Add(time.Minute * 3) - h.Unlock() + defer h.Unlock() + h.cache[session] = time.Now().Add(time.Minute * 3) select { case <-h.token.Wait(): - go h.run() + h.timer = time.AfterFunc(time.Minute*3, h.removeExpiredEntries) default: } } @@ -66,31 +64,21 @@ func (h *SessionHistory) has(session sessionId) bool { return false } -func (h *SessionHistory) run() { - defer h.token.Signal() +func (h *SessionHistory) removeExpiredEntries() { + now := time.Now() - for { - select { - case <-h.ctx.Done(): - return - case <-time.After(time.Second * 30): - } - session2Remove := make([]sessionId, 0, 16) - now := time.Now() - h.Lock() - if len(h.cache) == 0 { - h.Unlock() - return - } - for session, expire := range h.cache { - if expire.Before(now) { - session2Remove = append(session2Remove, session) - } - } - for _, session := range session2Remove { + h.Lock() + defer h.Unlock() + + for session, expire := range h.cache { + if expire.Before(now) { delete(h.cache, session) } - h.Unlock() + } + + if h.timer != nil { + h.timer.Stop() + h.timer = nil } } diff --git a/proxy/vmess/inbound/inbound.go b/proxy/vmess/inbound/inbound.go index a7b358d27..7490cbc74 100644 --- a/proxy/vmess/inbound/inbound.go +++ b/proxy/vmess/inbound/inbound.go @@ -93,7 +93,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) { clients: vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash), detours: config.Detour, usersByEmail: newUserByEmail(config.User, config.GetDefaultValue()), - sessionHistory: encoding.NewSessionHistory(ctx), + sessionHistory: encoding.NewSessionHistory(), } for _, user := range config.User {