diff --git a/common/net/port.go b/common/net/port.go index fc791e025..1e665550a 100644 --- a/common/net/port.go +++ b/common/net/port.go @@ -7,7 +7,7 @@ import ( type Port serial.Uint16Literal func PortFromBytes(port []byte) Port { - return Port(serial.ParseUint16(port)) + return Port(serial.BytesLiteral(port).Uint16Value()) } func (this Port) Value() uint16 { diff --git a/common/serial/bytes.go b/common/serial/bytes.go index ebbfa7591..0b315aa0c 100644 --- a/common/serial/bytes.go +++ b/common/serial/bytes.go @@ -10,6 +10,24 @@ func (this BytesLiteral) Value() []byte { return []byte(this) } +func (this BytesLiteral) Uint8Value() uint8 { + return this.Value()[0] +} + +func (this BytesLiteral) Uint16() Uint16Literal { + return Uint16Literal(this.Uint16Value()) +} + +func (this BytesLiteral) Uint16Value() uint16 { + value := this.Value() + return uint16(value[0])<<8 + uint16(value[1]) +} + +func (this BytesLiteral) IntValue() int { + value := this.Value() + return int(value[0])<<24 + int(value[1])<<16 + int(value[2])<<8 + int(value[3]) +} + func (this BytesLiteral) Uint32Value() uint32 { value := this.Value() return uint32(value[0])<<24 + diff --git a/common/serial/numbers.go b/common/serial/numbers.go index 235dab182..9de8759cc 100644 --- a/common/serial/numbers.go +++ b/common/serial/numbers.go @@ -10,17 +10,6 @@ type Uint16 interface { type Uint16Literal uint16 -func ParseUint16(data []byte) Uint16Literal { - switch len(data) { - case 0: - return Uint16Literal(0) - case 1: - return Uint16Literal(uint16(data[0])) - default: - return Uint16Literal(uint16(data[0])<<8 + uint16(data[1])) - } -} - func (this Uint16Literal) String() string { return strconv.Itoa(int(this)) } diff --git a/proxy/vmess/command/accounts.go b/proxy/vmess/command/accounts.go index b01c78770..04516cc98 100644 --- a/proxy/vmess/command/accounts.go +++ b/proxy/vmess/command/accounts.go @@ -78,7 +78,7 @@ func (this *SwitchAccount) Unmarshal(data []byte) error { if len(data) < alterIdStart+2 { return transport.CorruptedPacket } - this.AlterIds = serial.ParseUint16(data[alterIdStart : alterIdStart+2]) + this.AlterIds = serial.BytesLiteral(data[alterIdStart : alterIdStart+2]).Uint16() levelStart := alterIdStart + 2 if len(data) < levelStart+1 { return transport.CorruptedPacket