diff --git a/common/buf/copy.go b/common/buf/copy.go index 654bf0a42..75e6408b8 100644 --- a/common/buf/copy.go +++ b/common/buf/copy.go @@ -82,23 +82,19 @@ func CountSize(sc *SizeCounter) CopyOption { func copyInternal(reader Reader, writer Writer, handler *copyHandler) error { for { buffer, err := handler.readFrom(reader) + if !buffer.IsEmpty() { + for _, handler := range handler.onData { + handler(buffer) + } + + if werr := handler.writeTo(writer, buffer); werr != nil { + buffer.Release() + return werr + } + } if err != nil { return err } - - if buffer.IsEmpty() { - buffer.Release() - continue - } - - for _, handler := range handler.onData { - handler(buffer) - } - - if err := handler.writeTo(writer, buffer); err != nil { - buffer.Release() - return err - } } } diff --git a/common/buf/reader.go b/common/buf/reader.go index 8a327b3ee..9b80e9bfe 100644 --- a/common/buf/reader.go +++ b/common/buf/reader.go @@ -6,6 +6,11 @@ import ( "v2ray.com/core/common/errors" ) +var ( + _ Reader = (*BytesToBufferReader)(nil) + _ io.Reader = (*BytesToBufferReader)(nil) +) + // BytesToBufferReader is a Reader that adjusts its reading speed automatically. type BytesToBufferReader struct { io.Reader @@ -37,15 +42,21 @@ func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) { } nBytes, err := r.Reader.Read(r.buffer) - if err != nil { - return nil, err + if nBytes > 0 { + mb := NewMultiBufferCap(nBytes/Size + 1) + mb.Write(r.buffer[:nBytes]) + return mb, err } - - mb := NewMultiBufferCap(nBytes/Size + 1) - mb.Write(r.buffer[:nBytes]) - return mb, nil + return nil, err } +var ( + _ Reader = (*BufferedReader)(nil) + _ io.Reader = (*BufferedReader)(nil) + _ io.ByteReader = (*BufferedReader)(nil) + _ io.WriterTo = (*BufferedReader)(nil) +) + type BufferedReader struct { stream Reader legacyReader io.Reader @@ -72,6 +83,12 @@ func (r *BufferedReader) IsBuffered() bool { return r.buffered } +func (r *BufferedReader) ReadByte() (byte, error) { + var b [1]byte + _, err := r.Read(b[:]) + return b[0], err +} + func (r *BufferedReader) Read(b []byte) (int, error) { if r.leftOver != nil { nBytes, _ := r.leftOver.Read(b) @@ -87,15 +104,14 @@ func (r *BufferedReader) Read(b []byte) (int, error) { } mb, err := r.stream.ReadMultiBuffer() - if err != nil { - return 0, err + if mb != nil { + nBytes, _ := mb.Read(b) + if !mb.IsEmpty() { + r.leftOver = mb + } + return nBytes, err } - - nBytes, _ := mb.Read(b) - if !mb.IsEmpty() { - r.leftOver = mb - } - return nBytes, nil + return 0, err } func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) { @@ -120,11 +136,13 @@ func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) { for { mb, err := r.stream.ReadMultiBuffer() - if err != nil { - return totalBytes, err + if mb != nil { + totalBytes += int64(mb.Len()) + if werr := mbWriter.WriteMultiBuffer(mb); werr != nil { + return totalBytes, err + } } - totalBytes += int64(mb.Len()) - if err := mbWriter.WriteMultiBuffer(mb); err != nil { + if err != nil { return totalBytes, err } } diff --git a/common/buf/writer.go b/common/buf/writer.go index f51fa3fbc..8b10bb57d 100644 --- a/common/buf/writer.go +++ b/common/buf/writer.go @@ -6,6 +6,12 @@ import ( "v2ray.com/core/common/errors" ) +var ( + _ io.ReaderFrom = (*BufferToBytesWriter)(nil) + _ io.Writer = (*BufferToBytesWriter)(nil) + _ Writer = (*BufferToBytesWriter)(nil) +) + // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer. type BufferToBytesWriter struct { io.Writer @@ -33,6 +39,13 @@ func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) { return sc.Size, err } +var ( + _ io.ReaderFrom = (*BufferedWriter)(nil) + _ io.Writer = (*BufferedWriter)(nil) + _ Writer = (*BufferedWriter)(nil) + _ io.ByteWriter = (*BufferedWriter)(nil) +) + // BufferedWriter is a Writer with internal buffer. type BufferedWriter struct { writer Writer @@ -54,6 +67,11 @@ func NewBufferedWriter(writer Writer) *BufferedWriter { return w } +func (w *BufferedWriter) WriteByte(c byte) error { + _, err := w.Write([]byte{c}) + return err +} + // Write implements io.Writer. func (w *BufferedWriter) Write(b []byte) (int, error) { if !w.buffered && w.legacyWriter != nil { @@ -130,17 +148,12 @@ func (w *BufferedWriter) SetBuffered(f bool) error { // ReadFrom implements io.ReaderFrom. func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) { - var sc SizeCounter - if !w.buffer.IsEmpty() { - sc.Size += int64(w.buffer.Len()) - if err := w.Flush(); err != nil { - return sc.Size, err - } + if err := w.SetBuffered(false); err != nil { + return 0, err } - w.buffered = false + var sc SizeCounter err := Copy(NewReader(reader), w, CountSize(&sc)) - return sc.Size, err } diff --git a/common/buf/writer_test.go b/common/buf/writer_test.go index 3f64411dd..0e6b15db8 100644 --- a/common/buf/writer_test.go +++ b/common/buf/writer_test.go @@ -37,15 +37,17 @@ func TestBytesWriterReadFrom(t *testing.T) { assert := With(t) cache := ray.NewStream(context.Background()) - reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192)) + const size = 50000 + reader := bufio.NewReader(io.LimitReader(rand.Reader, size)) writer := NewBufferedWriter(cache) writer.SetBuffered(false) - _, err := reader.WriteTo(writer) + nBytes, err := reader.WriteTo(writer) + assert(nBytes, Equals, int64(size)) assert(err, IsNil) mb, err := cache.ReadMultiBuffer() assert(err, IsNil) - assert(mb.Len(), Equals, 8192) + assert(mb.Len(), Equals, size) } func TestDiscardBytes(t *testing.T) {