diff --git a/common/signal/once.go b/common/signal/once.go new file mode 100644 index 000000000..db30b009d --- /dev/null +++ b/common/signal/once.go @@ -0,0 +1,29 @@ +package signal + +import ( + "sync" + "sync/atomic" +) + +type Once struct { + m sync.Mutex + done uint32 +} + +func (o *Once) Do(f func()) { + if atomic.LoadUint32(&o.done) == 1 { + return + } + o.m.Lock() + defer o.m.Unlock() + if o.done == 0 { + atomic.StoreUint32(&o.done, 1) + f() + } +} + +func (o *Once) Reset() { + o.m.Lock() + defer o.m.Unlock() + atomic.StoreUint32(&o.done, 0) +} diff --git a/transport/internet/kcp/connection.go b/transport/internet/kcp/connection.go index 69ff703fd..7d465697a 100644 --- a/transport/internet/kcp/connection.go +++ b/transport/internet/kcp/connection.go @@ -9,6 +9,7 @@ import ( "github.com/v2ray/v2ray-core/common/alloc" "github.com/v2ray/v2ray-core/common/log" + "github.com/v2ray/v2ray-core/common/signal" ) var ( @@ -63,6 +64,7 @@ type Connection struct { chReadEvent chan struct{} writer io.WriteCloser since int64 + terminateOnce signal.Once } // NewConnection create a new KCP connection between local and remote. @@ -76,21 +78,7 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr, conn.since = nowMillisec() mtu := uint32(effectiveConfig.Mtu - block.HeaderSize() - headerSize) - conn.kcp = NewKCP(conv, mtu, func(buf []byte, size int) { - if size >= IKCP_OVERHEAD { - ext := alloc.NewBuffer().Clear().Append(buf[:size]) - cmd := CommandData - opt := Option(0) - if conn.state == ConnStateReadyToClose { - opt = OptionClose - } - ext.Prepend([]byte{byte(cmd), byte(opt)}) - go conn.output(ext) - } - if conn.state == ConnStateReadyToClose && conn.kcp.WaitSnd() == 0 { - go conn.NotifyTermination() - } - }) + conn.kcp = NewKCP(conv, mtu, conn.output) conn.kcp.WndSize(effectiveConfig.Sndwnd, effectiveConfig.Rcvwnd) conn.kcp.NoDelay(1, 20, 2, 1) conn.kcp.current = conn.Elapsed() @@ -199,7 +187,7 @@ func (this *Connection) NotifyTermination() { this.RUnlock() buffer := alloc.NewSmallBuffer().Clear() buffer.AppendBytes(byte(CommandTerminate), byte(OptionClose), byte(0), byte(0), byte(0), byte(0)) - this.output(buffer) + this.outputBuffer(buffer) time.Sleep(time.Second) @@ -207,6 +195,19 @@ func (this *Connection) NotifyTermination() { this.Terminate() } +func (this *Connection) ForceTimeout() { + if this == nil { + return + } + for i := 0; i < 5; i++ { + if this.state == ConnStateClosed { + return + } + time.Sleep(time.Minute) + } + go this.terminateOnce.Do(this.NotifyTermination) +} + // Close closes the connection. func (this *Connection) Close() error { if this == nil || this.state == ConnStateClosed || this.state == ConnStateReadyToClose { @@ -219,7 +220,9 @@ func (this *Connection) Close() error { if this.state == ConnStateActive { this.state = ConnStateReadyToClose if this.kcp.WaitSnd() == 0 { - go this.NotifyTermination() + go this.terminateOnce.Do(this.NotifyTermination) + } else { + go this.ForceTimeout() } } @@ -280,7 +283,7 @@ func (this *Connection) SetWriteDeadline(t time.Time) error { return nil } -func (this *Connection) output(payload *alloc.Buffer) { +func (this *Connection) outputBuffer(payload *alloc.Buffer) { defer payload.Release() if this == nil { return @@ -296,6 +299,29 @@ func (this *Connection) output(payload *alloc.Buffer) { this.writer.Write(payload.Value) } +func (this *Connection) output(payload []byte) { + if this == nil || this.state == ConnStateClosed { + return + } + + if this.state == ConnStateReadyToClose && this.kcp.WaitSnd() == 0 { + go this.terminateOnce.Do(this.NotifyTermination) + } + + if len(payload) < IKCP_OVERHEAD { + return + } + + buffer := alloc.NewBuffer().Clear().Append(payload) + cmd := CommandData + opt := Option(0) + if this.state == ConnStateReadyToClose { + opt = OptionClose + } + buffer.Prepend([]byte{byte(cmd), byte(opt)}) + this.outputBuffer(buffer) +} + // kcp update, input loop func (this *Connection) updateTask() { for this.state != ConnStateClosed { diff --git a/transport/internet/kcp/kcp.go b/transport/internet/kcp/kcp.go index 42f6e66b5..f80d2d178 100644 --- a/transport/internet/kcp/kcp.go +++ b/transport/internet/kcp/kcp.go @@ -34,7 +34,7 @@ const ( ) // Output is a closure which captures conn and calls conn.Write -type Output func(buf []byte, size int) +type Output func(buf []byte) /* encode 8 bits unsigned int */ func ikcp_encode8u(p []byte, c byte) []byte { @@ -573,7 +573,7 @@ func (kcp *KCP) flush() { for i := 0; i < count; i++ { size := len(buffer) - len(ptr) if size+IKCP_OVERHEAD > int(kcp.mtu) { - kcp.output(buffer, size) + kcp.output(buffer[:size]) ptr = buffer } seg.sn, seg.ts = kcp.ack_get(i) @@ -609,7 +609,7 @@ func (kcp *KCP) flush() { seg.cmd = IKCP_CMD_WASK size := len(buffer) - len(ptr) if size+IKCP_OVERHEAD > int(kcp.mtu) { - kcp.output(buffer, size) + kcp.output(buffer[:size]) ptr = buffer } ptr = seg.encode(ptr) @@ -620,7 +620,7 @@ func (kcp *KCP) flush() { seg.cmd = IKCP_CMD_WINS size := len(buffer) - len(ptr) if size+IKCP_OVERHEAD > int(kcp.mtu) { - kcp.output(buffer, size) + kcp.output(buffer[:size]) ptr = buffer } ptr = seg.encode(ptr) @@ -703,7 +703,7 @@ func (kcp *KCP) flush() { need := IKCP_OVERHEAD + len(segment.data) if size+need >= int(kcp.mtu) { - kcp.output(buffer, size) + kcp.output(buffer[:size]) ptr = buffer } @@ -720,7 +720,7 @@ func (kcp *KCP) flush() { // flash remain segments size := len(buffer) - len(ptr) if size > 0 { - kcp.output(buffer, size) + kcp.output(buffer[:size]) } // update ssthresh diff --git a/transport/internet/tcp/connection_cache.go b/transport/internet/tcp/connection_cache.go index 032a92df5..9baf002c7 100644 --- a/transport/internet/tcp/connection_cache.go +++ b/transport/internet/tcp/connection_cache.go @@ -3,33 +3,11 @@ package tcp import ( "net" "sync" - "sync/atomic" "time" + + "github.com/v2ray/v2ray-core/common/signal" ) -type Once struct { - m sync.Mutex - done uint32 -} - -func (o *Once) Do(f func()) { - if atomic.LoadUint32(&o.done) == 1 { - return - } - o.m.Lock() - defer o.m.Unlock() - if o.done == 0 { - atomic.StoreUint32(&o.done, 1) - f() - } -} - -func (o *Once) Reset() { - o.m.Lock() - defer o.m.Unlock() - atomic.StoreUint32(&o.done, 0) -} - type AwaitingConnection struct { conn net.Conn expire time.Time @@ -42,7 +20,7 @@ func (this *AwaitingConnection) Expired() bool { type ConnectionCache struct { sync.Mutex cache map[string][]*AwaitingConnection - cleanupOnce Once + cleanupOnce signal.Once } func NewConnectionCache() *ConnectionCache {