diff --git a/aws.go b/aws.go index ffc72da..1687d1a 100644 --- a/aws.go +++ b/aws.go @@ -223,14 +223,12 @@ func (a *S3) Put(key string, reader io.ReadSeeker, meta map[string]string, optio input.Expires = putOptions.expires } if a.compressor != nil { - wrapReader, i, err := GetReaderLength(input.Body) + length, err := GetReaderLength(input.Body) if err != nil { return err } - if i < a.cfg.CompressLimit { - input.Body = wrapReader - } else { - input.Body, err = a.compressor.Compress(wrapReader) + if length > a.cfg.CompressLimit { + input.Body, err = a.compressor.Compress(input.Body) if err != nil { return err } diff --git a/client.go b/client.go index 83d00c3..a199fbe 100644 --- a/client.go +++ b/client.go @@ -65,7 +65,7 @@ type Options struct { // CompressType gzip CompressType string // CompressLimit 大于该值之后才压缩 单位字节 - CompressLimit int + CompressLimit int64 } const ( diff --git a/compress.go b/compress.go index c001eb8..b9227d4 100644 --- a/compress.go +++ b/compress.go @@ -3,7 +3,6 @@ package awos import ( "bytes" "compress/gzip" - "fmt" "io" "sync" ) @@ -29,31 +28,69 @@ type GzipCompressor struct { } func (g *GzipCompressor) Compress(reader io.ReadSeeker) (gzipReader io.ReadSeeker, err error) { - // TODO buffer limit - var buffer bytes.Buffer - gzipWriter := gzip.NewWriter(&buffer) - _, err = io.Copy(gzipWriter, reader) - if err != nil { - return nil, err + return &gzipReadSeeker{ + reader: reader, + }, nil +} + +func (g *GzipCompressor) ContentEncoding() string { + return compressTypeGzip +} + +type gzipReadSeeker struct { + reader io.ReadSeeker +} + +func (crs *gzipReadSeeker) Read(p []byte) (n int, err error) { + // 读取原始数据 + n, err = crs.reader.Read(p) + if err != nil && err != io.EOF { + return n, err + } + if n == 0 { + return 0, err } - err = gzipWriter.Close() + var compressedBuffer bytes.Buffer + gw := gzip.NewWriter(&compressedBuffer) + // 压缩读取的数据 + _, err = gw.Write(p[:n]) if err != nil { - return nil, err + _ = gw.Close() + return n, err } - fmt.Println("gzipCompressSuccess length: ", buffer.Len()) - return bytes.NewReader(buffer.Bytes()), nil + if err = gw.Close(); err != nil { + return 0, err + } + // 将压缩后的数据返回给调用者 + n = copy(p, compressedBuffer.Bytes()) + compressedBuffer.Reset() + return n, err } -func (g *GzipCompressor) ContentEncoding() string { - return compressTypeGzip +func (crs *gzipReadSeeker) Seek(offset int64, whence int) (int64, error) { + // 调用原始ReadSeeker的Seek方法 + return crs.reader.Seek(offset, whence) } var DefaultGzipCompressor = &GzipCompressor{} -func GetReaderLength(reader io.ReadSeeker) (io.ReadSeeker, int, error) { - all, err := io.ReadAll(reader) +func GetReaderLength(reader io.ReadSeeker) (int64, error) { + // 保存当前的读写位置 + originalPos, err := reader.Seek(0, io.SeekCurrent) if err != nil { - return nil, 0, err + return 0, err } - return bytes.NewReader(all), len(all), nil + + // 移动到文件末尾以获取字节长度 + length, err := reader.Seek(0, io.SeekEnd) + if err != nil { + return 0, err + } + // 恢复原始读写位置 + _, err = reader.Seek(originalPos, io.SeekStart) + if err != nil { + return 0, err + } + + return length, nil } diff --git a/compress_test.go b/compress_test.go index 15d4098..01de548 100644 --- a/compress_test.go +++ b/compress_test.go @@ -16,6 +16,11 @@ func TestCompress_gzip(t *testing.T) { if err != nil { panic(err) } + length, err := GetReaderLength(source) + if err != nil { + panic(err) + } + t.Logf("length %d", length) reader, err := DefaultGzipCompressor.Compress(source) if err != nil { panic(err) diff --git a/config.go b/config.go index 49f56a6..1fe1938 100644 --- a/config.go +++ b/config.go @@ -42,7 +42,7 @@ type bucketConfig struct { // CompressType gzip CompressType string // CompressLimit 大于该值之后才压缩 单位字节 - CompressLimit int + CompressLimit int64 } // DefaultConfig 返回默认配置 diff --git a/oss.go b/oss.go index 0730429..08023b4 100644 --- a/oss.go +++ b/oss.go @@ -214,14 +214,12 @@ func (ossClient *OSS) Put(key string, reader io.ReadSeeker, meta map[string]stri ossOptions = append(ossOptions, oss.Expires(*putOptions.expires)) } if ossClient.compressor != nil { - readSeeker, l, err := GetReaderLength(reader) + length, err := GetReaderLength(reader) if err != nil { return err } - if l < ossClient.cfg.CompressLimit { - reader = readSeeker - } else { - reader, err = ossClient.compressor.Compress(readSeeker) + if length > ossClient.cfg.CompressLimit { + reader, err = ossClient.compressor.Compress(reader) if err != nil { return err }