New Upstream Release - golang-github-valyala-fasthttp
Ready changes
Summary
Merged new upstream version: 1.44.0 (was: 1.31.0).
Diff
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 5130b05..5a0541d 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -8,12 +8,41 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-go@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-go@v3
with:
- go-version: 1.17.x
+ go-version: 1.19.x
+
+ - name: Get Go cache paths
+ id: go-env
+ run: |
+ echo "::set-output name=cache::$(go env GOCACHE)"
+ echo "::set-output name=modcache::$(go env GOMODCACHE)"
+ - name: Set up Go cache
+ uses: actions/cache@v3
+ with:
+ key: golangci-lint-${{ runner.os }}-go-${{ hashFiles('go.mod') }}
+ restore-keys: golangci-lint-${{ runner.os }}-go-
+ path: |
+ ${{ steps.go-env.outputs.cache }}
+ ${{ steps.go-env.outputs.modcache }}
+
+ - name: Install golangci-lint
+ run: go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.48.0
+
+ - name: Get golangci-lint cache path
+ id: golangci-lint-cache-status
+ run: |
+ echo "::set-output name=dir::$(golangci-lint cache status | head -1 | sed 's/^Dir: //')"
+
+ - name: Set up golangci-lint cache
+ uses: actions/cache@v3
+ with:
+ key: golangci-lint-${{ runner.os }}-golangci-lint-${{ hashFiles('go.mod') }}
+ restore-keys: golangci-lint-${{ runner.os }}-golangci-lint-
+ path: ${{ steps.golangci-lint-cache-status.outputs.dir }}
+
- run: go version
- run: diff -u <(echo -n) <(gofmt -d .)
- - uses: golangci/golangci-lint-action@v2
- with:
- version: v1.28.3
+ - name: Run golangci-lint
+ run: golangci-lint run
diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml
index 13b8fa7..1b4d2b2 100644
--- a/.github/workflows/security.yml
+++ b/.github/workflows/security.yml
@@ -8,15 +8,14 @@ jobs:
test:
strategy:
matrix:
- go-version: [1.17.x]
+ go-version: [1.19.x]
platform: [ubuntu-latest]
runs-on: ${{ matrix.platform }}
+ env:
+ GO111MODULE: on
steps:
- - name: Install Go
- uses: actions/setup-go@v1
+ - uses: actions/checkout@v3
+ - name: Run Gosec Security Scanner
+ uses: securego/gosec@v2.12.0
with:
- go-version: ${{ matrix.go-version }}
- - name: Checkout code
- uses: actions/checkout@v2
- - name: Security
- run: go get github.com/securego/gosec/cmd/gosec; `go env GOPATH`/bin/gosec -exclude=G104,G304 ./...
+ args: '-exclude=G104,G304,G402 ./...'
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 068a6f5..8fc0843 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -8,14 +8,29 @@ jobs:
test:
strategy:
matrix:
- go-version: [1.15.x, 1.16.x, 1.17.x]
+ go-version: [1.16.x, 1.17.x, 1.18.x, 1.19.x]
os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- - uses: actions/checkout@v2
- - uses: actions/setup-go@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go-version }}
+
+ - name: Get Go cache paths
+ id: go-env
+ run: |
+ echo "::set-output name=cache::$(go env GOCACHE)"
+ echo "::set-output name=modcache::$(go env GOMODCACHE)"
+ - name: Set up Go cache
+ uses: actions/cache@v3
+ with:
+ key: golangci-lint-${{ runner.os }}-go-${{ hashFiles('go.mod') }}
+ restore-keys: golangci-lint-${{ runner.os }}-go-
+ path: |
+ ${{ steps.go-env.outputs.cache }}
+ ${{ steps.go-env.outputs.modcache }}
+
- run: go version
- run: go test ./...
- run: go test -race ./...
diff --git a/.gitignore b/.gitignore
index 489dbd2..4d5d57d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,6 +3,7 @@ tags
*.fasthttp.gz
*.fasthttp.br
.idea
+.vscode
.DS_Store
vendor/
diff --git a/README.md b/README.md
index 59d0dd0..dcc727c 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
Fast HTTP implementation for Go.
# fasthttp might not be for you!
-fasthttp was design for some high performance edge cases. **Unless** your server/client needs to handle **thousands of small to medium requests per seconds** and needs a consistent low millisecond response time fasthttp might not be for you. **For most cases `net/http` is much better** as it's easier to use and can handle more cases. For most cases you won't even notice the performance difference.
+fasthttp was designed for some high performance edge cases. **Unless** your server/client needs to handle **thousands of small to medium requests per second** and needs a consistent low millisecond response time fasthttp might not be for you. **For most cases `net/http` is much better** as it's easier to use and can handle more cases. For most cases you won't even notice the performance difference.
## General info and links
@@ -320,6 +320,23 @@ with fasthttp support:
fasthttp.ListenAndServe(":80", m)
```
+* Because creating a new channel for every request is just too expensive, so the channel returned by RequestCtx.Done() is only closed when the server is shutting down.
+
+ ```go
+ func main() {
+ fasthttp.ListenAndServe(":8080", fasthttp.TimeoutHandler(func(ctx *fasthttp.RequestCtx) {
+ select {
+ case <-ctx.Done():
+ // ctx.Done() is only closed when the server is shutting down.
+ log.Println("context cancelled")
+ return
+ case <-time.After(10 * time.Second):
+ log.Println("process finished ok")
+ }
+ }, time.Second*2, "timeout"))
+ }
+ ```
+
* net/http -> fasthttp conversion table:
* All the pseudocode below assumes w, r and ctx have these types:
@@ -485,6 +502,27 @@ statusCode, body, err := fasthttp.Get(nil, "http://google.com/")
uintBuf := fasthttp.AppendUint(nil, 1234)
```
+* String and `[]byte` buffers may converted without memory allocations
+```go
+func b2s(b []byte) string {
+ return *(*string)(unsafe.Pointer(&b))
+}
+
+func s2b(s string) (b []byte) {
+ bh := (*reflect.SliceHeader)(unsafe.Pointer(&b))
+ sh := (*reflect.StringHeader)(unsafe.Pointer(&s))
+ bh.Data = sh.Data
+ bh.Cap = sh.Len
+ bh.Len = sh.Len
+ return b
+}
+```
+
+### Warning:
+This is an **unsafe** way, the result string and `[]byte` buffer share the same bytes.
+
+**Please make sure not to modify the bytes in the `[]byte` buffer if the string still survives!**
+
## Related projects
* [fasthttp](https://github.com/fasthttp) - various useful
diff --git a/allocation_test.go b/allocation_test.go
index 56c922b..348fef1 100644
--- a/allocation_test.go
+++ b/allocation_test.go
@@ -5,7 +5,6 @@ package fasthttp
import (
"net"
-
"testing"
)
@@ -39,7 +38,7 @@ func TestAllocationServeConn(t *testing.T) {
func TestAllocationClient(t *testing.T) {
ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
- t.Fatalf("cannot listen: %s", err)
+ t.Fatalf("cannot listen: %v", err)
}
defer ln.Close()
diff --git a/args.go b/args.go
index a6db223..c8fdccd 100644
--- a/args.go
+++ b/args.go
@@ -230,7 +230,7 @@ func (a *Args) SetBytesKV(key, value []byte) {
// SetNoValue sets only 'key' as argument without the '='.
//
-// Only key in argumemt, like key1&key2
+// Only key in argument, like key1&key2
func (a *Args) SetNoValue(key string) {
a.args = setArg(a.args, key, "", argsNoValue)
}
@@ -343,7 +343,7 @@ func (a *Args) GetUfloatOrZero(key string) float64 {
// true is returned for "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes",
// otherwise false is returned.
func (a *Args) GetBool(key string) bool {
- switch b2s(a.Peek(key)) {
+ switch string(a.Peek(key)) {
// Support the same true cases as strconv.ParseBool
// See: https://github.com/golang/go/blob/4e1b11e2c9bdb0ddea1141eed487be1a626ff5be/src/strconv/atob.go#L12
// and Y and Yes versions.
@@ -361,12 +361,20 @@ func visitArgs(args []argsKV, f func(k, v []byte)) {
}
}
+func visitArgsKey(args []argsKV, f func(k []byte)) {
+ for i, n := 0, len(args); i < n; i++ {
+ kv := &args[i]
+ f(kv.key)
+ }
+}
+
func copyArgs(dst, src []argsKV) []argsKV {
if cap(dst) < len(src) {
tmp := make([]argsKV, len(src))
+ dstLen := len(dst)
dst = dst[:cap(dst)] // copy all of dst.
copy(tmp, dst)
- for i := len(dst); i < len(tmp); i++ {
+ for i := dstLen; i < len(tmp); i++ {
// Make sure nothing is nil.
tmp[i].key = []byte{}
tmp[i].value = []byte{}
@@ -537,13 +545,28 @@ func (s *argsScanner) next(kv *argsKV) bool {
}
func decodeArgAppend(dst, src []byte) []byte {
- if bytes.IndexByte(src, '%') < 0 && bytes.IndexByte(src, '+') < 0 {
+ idxPercent := bytes.IndexByte(src, '%')
+ idxPlus := bytes.IndexByte(src, '+')
+ if idxPercent == -1 && idxPlus == -1 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
}
+ idx := 0
+ if idxPercent == -1 {
+ idx = idxPlus
+ } else if idxPlus == -1 {
+ idx = idxPercent
+ } else if idxPercent > idxPlus {
+ idx = idxPlus
+ } else {
+ idx = idxPercent
+ }
+
+ dst = append(dst, src[:idx]...)
+
// slow path
- for i := 0; i < len(src); i++ {
+ for i := idx; i < len(src); i++ {
c := src[i]
if c == '%' {
if i+2 >= len(src) {
@@ -572,13 +595,16 @@ func decodeArgAppend(dst, src []byte) []byte {
// The function is copy-pasted from decodeArgAppend due to the performance
// reasons only.
func decodeArgAppendNoPlus(dst, src []byte) []byte {
- if bytes.IndexByte(src, '%') < 0 {
+ idx := bytes.IndexByte(src, '%')
+ if idx < 0 {
// fast path: src doesn't contain encoded chars
return append(dst, src...)
+ } else {
+ dst = append(dst, src[:idx]...)
}
// slow path
- for i := 0; i < len(src); i++ {
+ for i := idx; i < len(src); i++ {
c := src[i]
if c == '%' {
if i+2 >= len(src) {
@@ -598,3 +624,21 @@ func decodeArgAppendNoPlus(dst, src []byte) []byte {
}
return dst
}
+
+func peekAllArgBytesToDst(dst [][]byte, h []argsKV, k []byte) [][]byte {
+ for i, n := 0, len(h); i < n; i++ {
+ kv := &h[i]
+ if bytes.Equal(kv.key, k) {
+ dst = append(dst, kv.value)
+ }
+ }
+ return dst
+}
+
+func peekArgsKeys(dst [][]byte, h []argsKV) [][]byte {
+ for i, n := 0, len(h); i < n; i++ {
+ kv := &h[i]
+ dst = append(dst, kv.key)
+ }
+ return dst
+}
diff --git a/args_test.go b/args_test.go
index c780c3b..ed32dd9 100644
--- a/args_test.go
+++ b/args_test.go
@@ -237,7 +237,7 @@ func TestArgsWriteTo(t *testing.T) {
var w bytebufferpool.ByteBuffer
n, err := a.WriteTo(&w)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected n: %d. Expecting %d", n, len(s))
@@ -329,7 +329,7 @@ func TestArgsCopyTo(t *testing.T) {
func testCopyTo(t *testing.T, a *Args) {
keys := make(map[string]struct{})
- a.VisitAll(func(k, v []byte) {
+ a.VisitAll(func(k, _ []byte) {
keys[string(k)] = struct{}{}
})
@@ -340,7 +340,7 @@ func testCopyTo(t *testing.T, a *Args) {
t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint
}
- b.VisitAll(func(k, v []byte) {
+ b.VisitAll(func(k, _ []byte) {
if _, ok := keys[string(k)]; !ok {
t.Fatalf("unexpected key %q after copying from %q", k, a.String())
}
@@ -594,7 +594,7 @@ func TestArgsDeleteAll(t *testing.T) {
a.Add("q2", "1234")
a.Del("q1")
if a.Len() != 1 || a.Has("q1") {
- t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %s", a.String())
+ t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %q", a.String())
}
}
diff --git a/brotli.go b/brotli.go
index a88fdce..815e4b3 100644
--- a/brotli.go
+++ b/brotli.go
@@ -92,10 +92,10 @@ var (
//
// Supported compression levels are:
//
-// * CompressBrotliNoCompression
-// * CompressBrotliBestSpeed
-// * CompressBrotliBestCompression
-// * CompressBrotliDefaultCompression
+// - CompressBrotliNoCompression
+// - CompressBrotliBestSpeed
+// - CompressBrotliBestCompression
+// - CompressBrotliDefaultCompression
func AppendBrotliBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteBrotliLevel(w, src, level) //nolint:errcheck
@@ -107,10 +107,10 @@ func AppendBrotliBytesLevel(dst, src []byte, level int) []byte {
//
// Supported compression levels are:
//
-// * CompressBrotliNoCompression
-// * CompressBrotliBestSpeed
-// * CompressBrotliBestCompression
-// * CompressBrotliDefaultCompression
+// - CompressBrotliNoCompression
+// - CompressBrotliBestSpeed
+// - CompressBrotliBestCompression
+// - CompressBrotliDefaultCompression
func WriteBrotliLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
@@ -140,7 +140,7 @@ func nonblockingWriteBrotli(ctxv interface{}) {
_, err := zw.Write(ctx.p)
if err != nil {
- panic(fmt.Sprintf("BUG: brotli.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
+ panic(fmt.Sprintf("BUG: brotli.Writer.Write for len(p)=%d returned unexpected error: %v", len(ctx.p), err))
}
releaseRealBrotliWriter(zw, ctx.level)
diff --git a/brotli_test.go b/brotli_test.go
index 1ac94dd..4872c6b 100644
--- a/brotli_test.go
+++ b/brotli_test.go
@@ -4,7 +4,7 @@ import (
"bufio"
"bytes"
"fmt"
- "io/ioutil"
+ "io"
"testing"
)
@@ -42,7 +42,7 @@ func testBrotliBytesSingleCase(s string) error {
unbrotliedS, err := AppendUnbrotliBytes(prefix, brotlipedS[len(prefix):])
if err != nil {
- return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
+ return fmt.Errorf("unexpected error when uncompressing %q: %w", s, err)
}
if !bytes.Equal(unbrotliedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, unbrotliedS[:len(prefix)], prefix)
@@ -83,17 +83,17 @@ func testBrotliCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessBrotliWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
releaseStacklessBrotliWriter(zw, CompressDefaultCompression)
zr, err := acquireBrotliReader(&buf)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
- body, err := ioutil.ReadAll(zr)
+ body, err := io.ReadAll(zr)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
@@ -118,9 +118,9 @@ func TestCompressHandlerBrotliLevel(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if string(ce) != "" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
@@ -138,15 +138,15 @@ func TestCompressHandlerBrotliLevel(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
@@ -161,15 +161,15 @@ func TestCompressHandlerBrotliLevel(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "br" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "br")
}
body, err = resp.BodyUnbrotli()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
diff --git a/bytesconv.go b/bytesconv.go
index 81f8e5b..274082f 100644
--- a/bytesconv.go
+++ b/bytesconv.go
@@ -11,7 +11,6 @@ import (
"math"
"net"
"reflect"
- "strings"
"sync"
"time"
"unsafe"
@@ -19,29 +18,24 @@ import (
// AppendHTMLEscape appends html-escaped s to dst and returns the extended dst.
func AppendHTMLEscape(dst []byte, s string) []byte {
- if strings.IndexByte(s, '<') < 0 &&
- strings.IndexByte(s, '>') < 0 &&
- strings.IndexByte(s, '"') < 0 &&
- strings.IndexByte(s, '\'') < 0 {
+ var (
+ prev int
+ sub string
+ )
- // fast path - nothing to escape
- return append(dst, s...)
- }
-
- // slow path
- var prev int
- var sub string
for i, n := 0, len(s); i < n; i++ {
sub = ""
switch s[i] {
+ case '&':
+ sub = "&"
case '<':
sub = "<"
case '>':
sub = ">"
case '"':
- sub = """
+ sub = """ // """ is shorter than """.
case '\'':
- sub = "'"
+ sub = "'" // "'" is shorter than "'" and apos was not in HTML until HTML5.
}
if len(sub) > 0 {
dst = append(dst, s[prev:i]...)
@@ -98,7 +92,7 @@ func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) {
}
v, err := ParseUint(b[:n])
if err != nil {
- return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err)
+ return dst, fmt.Errorf("cannot parse ipStr %q: %w", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
@@ -108,7 +102,7 @@ func ParseIPv4(dst net.IP, ipStr []byte) (net.IP, error) {
}
v, err := ParseUint(b)
if err != nil {
- return dst, fmt.Errorf("cannot parse ipStr %q: %s", ipStr, err)
+ return dst, fmt.Errorf("cannot parse ipStr %q: %w", ipStr, err)
}
if v > 255 {
return dst, fmt.Errorf("cannot parse ipStr %q: ip part cannot exceed 255: parsed %d", ipStr, v)
@@ -240,7 +234,7 @@ func ParseUfloat(buf []byte) (float64, error) {
if err != nil {
return -1, errInvalidFloatExponent
}
- return float64(v) * offset * math.Pow10(minus*int(vv)), nil
+ return float64(v) * offset * math.Pow10(minus*vv), nil
}
return -1, errUnexpectedFloatChar
}
@@ -258,9 +252,7 @@ var (
)
func readHexInt(r *bufio.Reader) (int, error) {
- n := 0
- i := 0
- var k int
+ var k, i, n int
for {
c, err := r.ReadByte()
if err != nil {
@@ -380,7 +372,7 @@ func appendQuotedPath(dst, src []byte) []byte {
for _, c := range src {
if quotedPathShouldEscapeTable[int(c)] != 0 {
- dst = append(dst, '%', upperhex[c>>4], upperhex[c&15])
+ dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf])
} else {
dst = append(dst, c)
}
diff --git a/bytesconv_32.go b/bytesconv_32.go
index 6a6fec2..b574883 100644
--- a/bytesconv_32.go
+++ b/bytesconv_32.go
@@ -1,5 +1,5 @@
-//go:build !amd64 && !arm64 && !ppc64 && !ppc64le
-// +build !amd64,!arm64,!ppc64,!ppc64le
+//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x
+// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x
package fasthttp
diff --git a/bytesconv_32_test.go b/bytesconv_32_test.go
index cec5aa9..3f5d5de 100644
--- a/bytesconv_32_test.go
+++ b/bytesconv_32_test.go
@@ -1,5 +1,5 @@
-//go:build !amd64 && !arm64 && !ppc64 && !ppc64le
-// +build !amd64,!arm64,!ppc64,!ppc64le
+//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x
+// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x
package fasthttp
diff --git a/bytesconv_64.go b/bytesconv_64.go
index 1300d5a..94d0ec6 100644
--- a/bytesconv_64.go
+++ b/bytesconv_64.go
@@ -1,5 +1,5 @@
-//go:build amd64 || arm64 || ppc64 || ppc64le
-// +build amd64 arm64 ppc64 ppc64le
+//go:build amd64 || arm64 || ppc64 || ppc64le || s390x
+// +build amd64 arm64 ppc64 ppc64le s390x
package fasthttp
diff --git a/bytesconv_64_test.go b/bytesconv_64_test.go
index 5351591..0689809 100644
--- a/bytesconv_64_test.go
+++ b/bytesconv_64_test.go
@@ -1,5 +1,5 @@
-//go:build amd64 || arm64 || ppc64 || ppc64le
-// +build amd64 arm64 ppc64 ppc64le
+//go:build amd64 || arm64 || ppc64 || ppc64le || s390x
+// +build amd64 arm64 ppc64 ppc64le s390x
package fasthttp
diff --git a/bytesconv_table_gen.go b/bytesconv_table_gen.go
index 134f9dd..abf69d4 100644
--- a/bytesconv_table_gen.go
+++ b/bytesconv_table_gen.go
@@ -6,8 +6,8 @@ package main
import (
"bytes"
"fmt"
- "io/ioutil"
"log"
+ "os"
)
const (
@@ -107,7 +107,7 @@ func main() {
fmt.Fprintf(w, "const quotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable)
fmt.Fprintf(w, "const quotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable)
- if err := ioutil.WriteFile("bytesconv_table.go", w.Bytes(), 0660); err != nil {
+ if err := os.WriteFile("bytesconv_table.go", w.Bytes(), 0660); err != nil {
log.Fatal(err)
}
}
diff --git a/bytesconv_test.go b/bytesconv_test.go
index 4c35371..67c41f9 100644
--- a/bytesconv_test.go
+++ b/bytesconv_test.go
@@ -4,20 +4,48 @@ import (
"bufio"
"bytes"
"fmt"
+ "html"
"net"
+ "net/url"
"testing"
"time"
"github.com/valyala/bytebufferpool"
)
+func TestAppendQuotedArg(t *testing.T) {
+ t.Parallel()
+
+ // Sync with url.QueryEscape
+ allcases := make([]byte, 256)
+ for i := 0; i < 256; i++ {
+ allcases[i] = byte(i)
+ }
+ res := string(AppendQuotedArg(nil, allcases))
+ expect := url.QueryEscape(string(allcases))
+ if res != expect {
+ t.Fatalf("unexpected string %q. Expecting %q.", res, expect)
+ }
+}
+
func TestAppendHTMLEscape(t *testing.T) {
t.Parallel()
+ // Sync with html.EscapeString
+ allcases := make([]byte, 256)
+ for i := 0; i < 256; i++ {
+ allcases[i] = byte(i)
+ }
+ res := string(AppendHTMLEscape(nil, string(allcases)))
+ expect := string(html.EscapeString(string(allcases)))
+ if res != expect {
+ t.Fatalf("unexpected string %q. Expecting %q.", res, expect)
+ }
+
testAppendHTMLEscape(t, "", "")
testAppendHTMLEscape(t, "<", "<")
testAppendHTMLEscape(t, "a", "a")
- testAppendHTMLEscape(t, `><"''`, "><"''")
+ testAppendHTMLEscape(t, `><"''`, "><"''")
testAppendHTMLEscape(t, "fo<b x='ss'>a</b>xxx", "fo<b x='ss'>a</b>xxx")
}
@@ -48,7 +76,7 @@ func testParseIPv4(t *testing.T, ipStr string, isValid bool) {
ip, err := ParseIPv4(nil, []byte(ipStr))
if isValid {
if err != nil {
- t.Fatalf("unexpected error when parsing ip %q: %s", ipStr, err)
+ t.Fatalf("unexpected error when parsing ip %q: %v", ipStr, err)
}
s := string(AppendIPv4(nil, ip))
if s != ipStr {
@@ -103,10 +131,10 @@ func testWriteHexInt(t *testing.T, n int, expectedS string) {
var w bytebufferpool.ByteBuffer
bw := bufio.NewWriter(&w)
if err := writeHexInt(bw, n); err != nil {
- t.Fatalf("unexpected error when writing hex %x: %s", n, err)
+ t.Fatalf("unexpected error when writing hex %x: %v", n, err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error when flushing hex %x: %s", n, err)
+ t.Fatalf("unexpected error when flushing hex %x: %v", n, err)
}
s := string(w.B)
if s != expectedS {
@@ -140,7 +168,7 @@ func testReadHexIntSuccess(t *testing.T, s string, expectedN int) {
br := bufio.NewReader(r)
n, err := readHexInt(br)
if err != nil {
- t.Fatalf("unexpected error: %s. s=%q", err, s)
+ t.Fatalf("unexpected error: %v. s=%q", err, s)
}
if n != expectedN {
t.Fatalf("unexpected hex int %d. Expected %d. s=%q", n, expectedN, s)
@@ -246,7 +274,7 @@ func testParseUfloatError(t *testing.T, s string) {
func testParseUfloatSuccess(t *testing.T, s string, expectedF float64) {
f, err := ParseUfloat([]byte(s))
if err != nil {
- t.Fatalf("Unexpected error when parsing %q: %s", s, err)
+ t.Fatalf("Unexpected error when parsing %q: %v", s, err)
}
delta := f - expectedF
if delta < 0 {
@@ -270,7 +298,7 @@ func testParseUintError(t *testing.T, s string) {
func testParseUintSuccess(t *testing.T, s string, expectedN int) {
n, err := ParseUint([]byte(s))
if err != nil {
- t.Fatalf("Unexpected error when parsing %q: %s", s, err)
+ t.Fatalf("Unexpected error when parsing %q: %v", s, err)
}
if n != expectedN {
t.Fatalf("Unexpected value %d. Expected %d. num=%q", n, expectedN, s)
diff --git a/bytesconv_timing_test.go b/bytesconv_timing_test.go
index 34795df..334376a 100644
--- a/bytesconv_timing_test.go
+++ b/bytesconv_timing_test.go
@@ -18,7 +18,7 @@ func BenchmarkAppendHTMLEscape(b *testing.B) {
for i := 0; i < 10; i++ {
buf = AppendHTMLEscape(buf[:0], sOrig)
if string(buf) != sExpected {
- b.Fatalf("unexpected escaped string: %s. Expecting %s", buf, sExpected)
+ b.Fatalf("unexpected escaped string: %q. Expecting %q", buf, sExpected)
}
}
}
@@ -34,7 +34,7 @@ func BenchmarkHTMLEscapeString(b *testing.B) {
for i := 0; i < 10; i++ {
s = html.EscapeString(sOrig)
if s != sExpected {
- b.Fatalf("unexpected escaped string: %s. Expecting %s", s, sExpected)
+ b.Fatalf("unexpected escaped string: %q. Expecting %q", s, sExpected)
}
}
}
@@ -49,7 +49,7 @@ func BenchmarkParseIPv4(b *testing.B) {
for pb.Next() {
ip, err = ParseIPv4(ip, ipStr)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
}
})
@@ -88,10 +88,10 @@ func BenchmarkParseUint(b *testing.B) {
for pb.Next() {
n, err := ParseUint(buf)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if n != 1234567 {
- b.Fatalf("unexpected result: %d. Expecting %s", n, buf)
+ b.Fatalf("unexpected result: %d. Expecting %q", n, buf)
}
}
})
diff --git a/client.go b/client.go
index 47bbd58..23bfa57 100644
--- a/client.go
+++ b/client.go
@@ -1,14 +1,14 @@
+// go:build !windows || !race
+
package fasthttp
import (
"bufio"
- "bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
- "strconv"
"strings"
"sync"
"sync/atomic"
@@ -296,6 +296,12 @@ type Client struct {
// By default will use isIdempotent function
RetryIf RetryIfFunc
+ // Connection pool strategy. Can be either LIFO or FIFO (default).
+ ConnPoolStrategy ConnPoolStrategyType
+
+ // ConfigureClient configures the fasthttp.HostClient.
+ ConfigureClient func(hc *HostClient) error
+
mLock sync.Mutex
m map[string]*HostClient
ms map[string]*HostClient
@@ -380,7 +386,8 @@ func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, b
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
- return clientDoTimeout(req, resp, timeout, c)
+ req.timeout = timeout
+ return c.Do(req, resp)
}
// DoDeadline performs the given request and waits for response until
@@ -407,7 +414,8 @@ func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration)
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
- return clientDoDeadline(req, resp, deadline, c)
+ req.timeout = time.Until(deadline)
+ return c.Do(req, resp)
}
// DoRedirects performs the given http request and fills the given http response,
@@ -462,11 +470,10 @@ func (c *Client) Do(req *Request, resp *Response) error {
host := uri.Host()
isTLS := false
- scheme := uri.Scheme()
- if bytes.Equal(scheme, strHTTPS) {
+ if uri.isHttps() {
isTLS = true
- } else if !bytes.Equal(scheme, strHTTP) {
- return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme)
+ } else if !uri.isHttp() {
+ return fmt.Errorf("unsupported protocol %q. http and https are supported", uri.Scheme())
}
startCleaner := false
@@ -487,7 +494,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
hc := m[string(host)]
if hc == nil {
hc = &HostClient{
- Addr: addMissingPort(string(host), isTLS),
+ Addr: AddMissingPort(string(host), isTLS),
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
@@ -507,14 +514,26 @@ func (c *Client) Do(req *Request, resp *Response) error {
DisablePathNormalizing: c.DisablePathNormalizing,
MaxConnWaitTimeout: c.MaxConnWaitTimeout,
RetryIf: c.RetryIf,
+ ConnPoolStrategy: c.ConnPoolStrategy,
clientReaderPool: &c.readerPool,
clientWriterPool: &c.writerPool,
}
+
+ if c.ConfigureClient != nil {
+ if err := c.ConfigureClient(hc); err != nil {
+ return err
+ }
+ }
+
m[string(host)] = hc
if len(m) == 1 {
startCleaner = true
}
}
+
+ atomic.AddInt32(&hc.pendingClientRequests, 1)
+ defer atomic.AddInt32(&hc.pendingClientRequests, -1)
+
c.mLock.Unlock()
if startCleaner {
@@ -550,15 +569,14 @@ func (c *Client) mCleaner(m map[string]*HostClient) {
}
for {
+ time.Sleep(sleep)
c.mLock.Lock()
for k, v := range m {
v.connsLock.Lock()
- shouldRemove := v.connsCount == 0
- v.connsLock.Unlock()
-
- if shouldRemove {
+ if v.connsCount == 0 && atomic.LoadInt32(&v.pendingClientRequests) == 0 {
delete(m, k)
}
+ v.connsLock.Unlock()
}
if len(m) == 0 {
mustStop = true
@@ -568,7 +586,6 @@ func (c *Client) mCleaner(m map[string]*HostClient) {
if mustStop {
break
}
- time.Sleep(sleep)
}
}
@@ -606,6 +623,14 @@ type RetryIfFunc func(request *Request) bool
// TransportFunc wraps every request/response.
type TransportFunc func(*Request, *Response) error
+// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
+type ConnPoolStrategyType int
+
+const (
+ FIFO ConnPoolStrategyType = iota
+ LIFO
+)
+
// HostClient balances http requests among hosts listed in Addr.
//
// HostClient may be used for balancing load among multiple upstream hosts.
@@ -760,7 +785,9 @@ type HostClient struct {
// Transport defines a transport-like mechanism that wraps every request/response.
Transport TransportFunc
- clientName atomic.Value
+ // Connection pool strategy. Can be either LIFO or FIFO (default).
+ ConnPoolStrategy ConnPoolStrategyType
+
lastUseTime uint32
connsLock sync.Mutex
@@ -783,6 +810,10 @@ type HostClient struct {
pendingRequests int32
+ // pendingClientRequests counts the number of requests that a Client is currently running using this HostClient.
+ // It will be incremented ealier than pendingRequests and will be used by Client to see if the HostClient is still in use.
+ pendingClientRequests int32
+
connsCleanerRun bool
}
@@ -874,7 +905,7 @@ type clientURLResponse struct {
}
func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) {
- timeout := -time.Since(deadline)
+ timeout := time.Until(deadline)
if timeout <= 0 {
return 0, dst, ErrTimeout
}
@@ -950,6 +981,8 @@ var clientURLResponseChPool sync.Pool
func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) {
req := AcquireRequest()
+ defer ReleaseRequest(req)
+
req.Header.SetMethod(MethodPost)
req.Header.SetContentTypeBytes(strPostArgsContentType)
if postArgs != nil {
@@ -960,7 +993,6 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (status
statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c)
- ReleaseRequest(req)
return statusCode, body, err
}
@@ -1119,7 +1151,8 @@ func ReleaseResponse(resp *Response) {
// If requests take too long and the connection pool gets filled up please
// try setting a ReadTimeout.
func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error {
- return clientDoTimeout(req, resp, timeout, c)
+ req.timeout = timeout
+ return c.Do(req, resp)
}
// DoDeadline performs the given request and waits for response until
@@ -1141,7 +1174,8 @@ func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Durati
// It is recommended obtaining req and resp via AcquireRequest
// and AcquireResponse in performance-critical code.
func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
- return clientDoDeadline(req, resp, deadline, c)
+ req.timeout = time.Until(deadline)
+ return c.Do(req, resp)
}
// DoRedirects performs the given http request and fills the given http response,
@@ -1168,93 +1202,6 @@ func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount
return err
}
-func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error {
- deadline := time.Now().Add(timeout)
- return clientDoDeadline(req, resp, deadline, c)
-}
-
-func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error {
- timeout := -time.Since(deadline)
- if timeout <= 0 {
- return ErrTimeout
- }
-
- var ch chan error
- chv := errorChPool.Get()
- if chv == nil {
- chv = make(chan error, 1)
- }
- ch = chv.(chan error)
-
- // Make req and resp copies, since on timeout they no longer
- // may be accessed.
- reqCopy := AcquireRequest()
- req.copyToSkipBody(reqCopy)
- swapRequestBody(req, reqCopy)
- respCopy := AcquireResponse()
- if resp != nil {
- // Not calling resp.copyToSkipBody(respCopy) here to avoid
- // unexpected messing with headers
- respCopy.SkipBody = resp.SkipBody
- }
-
- // Note that the request continues execution on ErrTimeout until
- // client-specific ReadTimeout exceeds. This helps limiting load
- // on slow hosts by MaxConns* concurrent requests.
- //
- // Without this 'hack' the load on slow host could exceed MaxConns*
- // concurrent requests, since timed out requests on client side
- // usually continue execution on the host.
-
- var mu sync.Mutex
- var timedout, responded bool
-
- go func() {
- reqCopy.timeout = timeout
- errDo := c.Do(reqCopy, respCopy)
- mu.Lock()
- {
- if !timedout {
- if resp != nil {
- respCopy.copyToSkipBody(resp)
- swapResponseBody(resp, respCopy)
- }
- swapRequestBody(reqCopy, req)
- ch <- errDo
- responded = true
- }
- }
- mu.Unlock()
-
- ReleaseResponse(respCopy)
- ReleaseRequest(reqCopy)
- }()
-
- tc := AcquireTimer(timeout)
- var err error
- select {
- case err = <-ch:
- case <-tc.C:
- mu.Lock()
- {
- if responded {
- err = <-ch
- } else {
- timedout = true
- err = ErrTimeout
- }
- }
- mu.Unlock()
- }
- ReleaseTimer(tc)
-
- errorChPool.Put(chv)
-
- return err
-}
-
-var errorChPool sync.Pool
-
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
@@ -1361,7 +1308,7 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.secureErrorLogMessage = c.SecureErrorLogMessage
req.Header.secureErrorLogMessage = c.SecureErrorLogMessage
- if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) {
+ if c.IsTLS != req.URI().isHttps() {
return false, ErrHostClientRedirectToDifferentScheme
}
@@ -1379,14 +1326,24 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
- req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
+ userAgent := c.Name
+ if userAgent == "" && !c.NoDefaultUserAgentHeader {
+ userAgent = defaultUserAgent
+ }
+ if userAgent != "" {
+ req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
+ }
}
-
if c.Transport != nil {
err := c.Transport(req, resp)
return err == nil, err
}
+ var deadline time.Time
+ if req.timeout > 0 {
+ deadline = time.Now().Add(req.timeout)
+ }
+
cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
if err != nil {
return false, err
@@ -1395,11 +1352,17 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
resp.parseNetConn(conn)
+ writeDeadline := deadline
if c.WriteTimeout > 0 {
+ tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
+ if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
+ writeDeadline = tmpWriteDeadline
+ }
+ }
+ if !writeDeadline.IsZero() {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
- currentTime := time.Now()
- if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil {
+ if err = conn.SetWriteDeadline(writeDeadline); err != nil {
c.closeConn(cc)
return true, err
}
@@ -1421,18 +1384,30 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
if err == nil {
err = bw.Flush()
}
- if err != nil {
- c.releaseWriter(bw)
+ c.releaseWriter(bw)
+
+ // Return ErrTimeout on any timeout.
+ if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+ err = ErrTimeout
+ }
+
+ isConnRST := isConnectionReset(err)
+ if err != nil && !isConnRST {
c.closeConn(cc)
return true, err
}
- c.releaseWriter(bw)
+ readDeadline := deadline
if c.ReadTimeout > 0 {
+ tmpReadDeadline := time.Now().Add(c.ReadTimeout)
+ if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
+ readDeadline = tmpReadDeadline
+ }
+ }
+ if !readDeadline.IsZero() {
// Set Deadline every time, since golang has fixed the performance issue
// See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details
- currentTime := time.Now()
- if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil {
+ if err = conn.SetReadDeadline(readDeadline); err != nil {
c.closeConn(cc)
return true, err
}
@@ -1446,22 +1421,22 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
}
br := c.acquireReader(conn)
- if err = resp.ReadLimitBody(br, c.MaxResponseBodySize); err != nil {
- c.releaseReader(br)
+ err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
+ c.releaseReader(br)
+ if err != nil {
c.closeConn(cc)
// Don't retry in case of ErrBodyTooLarge since we will just get the same again.
retry := err != ErrBodyTooLarge
return retry, err
}
- c.releaseReader(br)
- if resetConnection || req.ConnectionClose() || resp.ConnectionClose() {
+ if resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST {
c.closeConn(cc)
} else {
c.releaseConn(cc)
}
- return false, err
+ return false, nil
}
var (
@@ -1481,6 +1456,10 @@ var (
// to broken server.
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
"Make sure the server returns 'Connection: close' response header before closing the connection")
+
+ // ErrConnPoolStrategyNotImpl is returned when HostClient.ConnPoolStrategy is not implement yet.
+ // If you see this error, then you need to check your HostClient configuration.
+ ErrConnPoolStrategyNotImpl = errors.New("connection pool strategy is not implement")
)
type timeoutError struct{}
@@ -1492,7 +1471,7 @@ func (e *timeoutError) Error() string {
// Only implement the Timeout() function of the net.Error interface.
// This allows for checks like:
//
-// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
func (e *timeoutError) Timeout() bool {
return true
}
@@ -1528,10 +1507,20 @@ func (c *HostClient) acquireConn(reqTimeout time.Duration, connectionClose bool)
}
}
} else {
- n--
- cc = c.conns[n]
- c.conns[n] = nil
- c.conns = c.conns[:n]
+ switch c.ConnPoolStrategy {
+ case LIFO:
+ n--
+ cc = c.conns[n]
+ c.conns[n] = nil
+ c.conns = c.conns[:n]
+ case FIFO:
+ cc = c.conns[0]
+ copy(c.conns, c.conns[1:])
+ c.conns[n-1] = nil
+ c.conns = c.conns[:n-1]
+ default:
+ return nil, ErrConnPoolStrategyNotImpl
+ }
}
c.connsLock.Unlock()
@@ -1859,10 +1848,6 @@ func newClientTLSConfig(c *tls.Config, addr string) *tls.Config {
c = c.Clone()
}
- if c.ClientSessionCache == nil {
- c.ClientSessionCache = tls.NewLRUClientSessionCache(0)
- }
-
if len(c.ServerName) == 0 {
serverName := tlsServerName(addr)
if serverName == "*" {
@@ -1953,48 +1938,40 @@ func (c *HostClient) cachedTLSConfig(addr string) *tls.Config {
// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake.
var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out")
-var timeoutErrorChPool sync.Pool
-
-func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
- tc := AcquireTimer(timeout)
- defer ReleaseTimer(tc)
-
- var ch chan error
- chv := timeoutErrorChPool.Get()
- if chv == nil {
- chv = make(chan error)
- }
- ch = chv.(chan error)
- defer timeoutErrorChPool.Put(chv)
-
- conn := tls.Client(rawConn, tlsConfig)
-
- go func() {
- ch <- conn.Handshake()
- }()
-
- select {
- case <-tc.C:
- rawConn.Close()
- <-ch
- return nil, ErrTLSHandshakeTimeout
- case err := <-ch:
- if err != nil {
+func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.Time) (_ net.Conn, retErr error) {
+ defer func() {
+ if retErr != nil {
rawConn.Close()
- return nil, err
}
- return conn, nil
+ }()
+ conn := tls.Client(rawConn, tlsConfig)
+ err := conn.SetDeadline(deadline)
+ if err != nil {
+ return nil, err
+ }
+ err = conn.Handshake()
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ return nil, ErrTLSHandshakeTimeout
+ }
+ if err != nil {
+ return nil, err
+ }
+ err = conn.SetDeadline(time.Time{})
+ if err != nil {
+ return nil, err
}
+ return conn, nil
}
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
+ deadline := time.Now().Add(timeout)
if dial == nil {
if dialDualStack {
dial = DialDualStack
} else {
dial = Dial
}
- addr = addMissingPort(addr, isTLS)
+ addr = AddMissingPort(addr, isTLS)
}
conn, err := dial(addr)
if err != nil {
@@ -2003,41 +1980,47 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
if conn == nil {
panic("BUG: DialFunc returned (nil, nil)")
}
- _, isTLSAlready := conn.(*tls.Conn)
+
+ // We assume that any conn that has the Handshake() method is a TLS conn already.
+ // This doesn't cover just tls.Conn but also other TLS implementations.
+ _, isTLSAlready := conn.(interface{ Handshake() error })
+
if isTLS && !isTLSAlready {
if timeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
- return tlsClientHandshake(conn, tlsConfig, timeout)
+ return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}
-func (c *HostClient) getClientName() []byte {
- v := c.clientName.Load()
- var clientName []byte
- if v == nil {
- clientName = []byte(c.Name)
- if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
- clientName = defaultUserAgent
- }
- c.clientName.Store(clientName)
- } else {
- clientName = v.([]byte)
+// AddMissingPort adds a port to a host if it is missing.
+// A literal IPv6 address in hostport must be enclosed in square
+// brackets, as in "[::1]:80", "[::1%lo0]:80".
+func AddMissingPort(addr string, isTLS bool) string {
+ addrLen := len(addr)
+ if addrLen == 0 {
+ return addr
}
- return clientName
-}
-func addMissingPort(addr string, isTLS bool) string {
- n := strings.Index(addr, ":")
- if n >= 0 {
- return addr
+ isIp6 := addr[0] == '['
+ if isIp6 {
+ // if the IPv6 has opening bracket but closing bracket is the last char then it doesn't have a port
+ isIp6WithoutPort := addr[addrLen-1] == ']'
+ if !isIp6WithoutPort {
+ return addr
+ }
+ } else { // IPv4
+ columnPos := strings.LastIndexByte(addr, ':')
+ if columnPos > 0 {
+ return addr
+ }
}
- port := 80
+ port := ":80"
if isTLS {
- port = 443
+ port = ":443"
}
- return net.JoinHostPort(addr, strconv.Itoa(port))
+ return addr + port
}
// A wantConn records state about a wanted connection
@@ -2324,7 +2307,6 @@ type pipelineConnClient struct {
tlsConfigLock sync.Mutex
tlsConfig *tls.Config
- clientName atomic.Value
}
type pipelineWork struct {
@@ -2384,7 +2366,7 @@ func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.
func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error {
c.init()
- timeout := -time.Since(deadline)
+ timeout := time.Until(deadline)
if timeout < 0 {
return ErrTimeout
}
@@ -2395,10 +2377,16 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
- req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
+ userAgent := c.Name
+ if userAgent == "" && !c.NoDefaultUserAgentHeader {
+ userAgent = defaultUserAgent
+ }
+ if userAgent != "" {
+ req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
+ }
}
- w := acquirePipelineWork(&c.workPool, timeout)
+ w := c.acquirePipelineWork(timeout)
w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
w.req = &w.reqCopy
w.resp = &w.respCopy
@@ -2416,7 +2404,7 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
select {
case c.chW <- w:
case <-w.t.C:
- releasePipelineWork(&c.workPool, w)
+ c.releasePipelineWork(w)
return ErrTimeout
}
}
@@ -2430,7 +2418,7 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
swapResponseBody(resp, &w.respCopy)
}
err = w.err
- releasePipelineWork(&c.workPool, w)
+ c.releasePipelineWork(w)
case <-w.t.C:
err = ErrTimeout
}
@@ -2438,6 +2426,40 @@ func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline t
return err
}
+func (c *pipelineConnClient) acquirePipelineWork(timeout time.Duration) (w *pipelineWork) {
+ v := c.workPool.Get()
+ if v != nil {
+ w = v.(*pipelineWork)
+ } else {
+ w = &pipelineWork{
+ done: make(chan struct{}, 1),
+ }
+ }
+ if timeout > 0 {
+ if w.t == nil {
+ w.t = time.NewTimer(timeout)
+ } else {
+ w.t.Reset(timeout)
+ }
+ w.deadline = time.Now().Add(timeout)
+ } else {
+ w.deadline = zeroTime
+ }
+ return w
+}
+
+func (c *pipelineConnClient) releasePipelineWork(w *pipelineWork) {
+ if w.t != nil {
+ w.t.Stop()
+ }
+ w.reqCopy.Reset()
+ w.respCopy.Reset()
+ w.req = nil
+ w.resp = nil
+ w.err = nil
+ c.workPool.Put(w)
+}
+
// Do performs the given http request and sets the corresponding response.
//
// Request must contain at least non-zero RequestURI with full url (including
@@ -2462,10 +2484,16 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
userAgentOld := req.Header.UserAgent()
if len(userAgentOld) == 0 {
- req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...)
+ userAgent := c.Name
+ if userAgent == "" && !c.NoDefaultUserAgentHeader {
+ userAgent = defaultUserAgent
+ }
+ if userAgent != "" {
+ req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
+ }
}
- w := acquirePipelineWork(&c.workPool, 0)
+ w := c.acquirePipelineWork(0)
w.req = req
if resp != nil {
resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing
@@ -2488,7 +2516,7 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
select {
case c.chW <- w:
default:
- releasePipelineWork(&c.workPool, w)
+ c.releasePipelineWork(w)
return ErrPipelineOverflow
}
}
@@ -2497,7 +2525,7 @@ func (c *pipelineConnClient) Do(req *Request, resp *Response) error {
<-w.done
err := w.err
- releasePipelineWork(&c.workPool, w)
+ c.releasePipelineWork(w)
return err
}
@@ -2589,9 +2617,9 @@ func (c *pipelineConnClient) init() {
// Keep restarting the worker if it fails (connection errors for example).
for {
if err := c.worker(); err != nil {
- c.logger().Printf("error in PipelineClient(%q): %s", c.Addr, err)
- if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
- // Throttle client reconnections on temporary errors
+ c.logger().Printf("error in PipelineClient(%q): %v", c.Addr, err)
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ // Throttle client reconnections on timeout errors
time.Sleep(time.Second)
}
} else {
@@ -2858,52 +2886,4 @@ func (c *pipelineConnClient) PendingRequests() int {
return n
}
-func (c *pipelineConnClient) getClientName() []byte {
- v := c.clientName.Load()
- var clientName []byte
- if v == nil {
- clientName = []byte(c.Name)
- if len(clientName) == 0 && !c.NoDefaultUserAgentHeader {
- clientName = defaultUserAgent
- }
- c.clientName.Store(clientName)
- } else {
- clientName = v.([]byte)
- }
- return clientName
-}
-
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
-
-func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork {
- v := pool.Get()
- if v == nil {
- v = &pipelineWork{
- done: make(chan struct{}, 1),
- }
- }
- w := v.(*pipelineWork)
- if timeout > 0 {
- if w.t == nil {
- w.t = time.NewTimer(timeout)
- } else {
- w.t.Reset(timeout)
- }
- w.deadline = time.Now().Add(timeout)
- } else {
- w.deadline = zeroTime
- }
- return w
-}
-
-func releasePipelineWork(pool *sync.Pool, w *pipelineWork) {
- if w.t != nil {
- w.t.Stop()
- }
- w.reqCopy.Reset()
- w.respCopy.Reset()
- w.req = nil
- w.resp = nil
- w.err = nil
- pool.Put(w)
-}
diff --git a/client_example_test.go b/client_example_test.go
index 9a1ad6a..c2366e4 100644
--- a/client_example_test.go
+++ b/client_example_test.go
@@ -16,7 +16,7 @@ func ExampleHostClient() {
// Fetch google page via local proxy.
statusCode, body, err := c.Get(nil, "http://google.com/foo/bar")
if err != nil {
- log.Fatalf("Error when loading google page through local proxy: %s", err)
+ log.Fatalf("Error when loading google page through local proxy: %v", err)
}
if statusCode != fasthttp.StatusOK {
log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK)
@@ -26,7 +26,7 @@ func ExampleHostClient() {
// Fetch foobar page via local proxy. Reuse body buffer.
statusCode, body, err = c.Get(body, "http://foobar.com/google/com")
if err != nil {
- log.Fatalf("Error when loading foobar page through local proxy: %s", err)
+ log.Fatalf("Error when loading foobar page through local proxy: %v", err)
}
if statusCode != fasthttp.StatusOK {
log.Fatalf("Unexpected status code: %d. Expecting %d", statusCode, fasthttp.StatusOK)
diff --git a/client_test.go b/client_test.go
index e960745..aa53461 100644
--- a/client_test.go
+++ b/client_test.go
@@ -181,7 +181,7 @@ func TestClientInvalidURI(t *testing.T) {
ln := fasthttputil.NewInmemoryListener()
requests := int64(0)
s := &Server{
- Handler: func(ctx *RequestCtx) {
+ Handler: func(_ *RequestCtx) {
atomic.AddInt64(&requests, 1)
},
}
@@ -275,7 +275,7 @@ func TestClientURLAuth(t *testing.T) {
val := <-ch
if val != expected {
- t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected)
+ t.Fatalf("wrong %q header: %q expected %q", HeaderAuthorization, val, expected)
}
}
}
@@ -357,14 +357,14 @@ func TestClientParseConn(t *testing.T) {
}
if res.RemoteAddr().Network() != network {
- t.Fatalf("req RemoteAddr parse network fail: %s, hope: %s", res.RemoteAddr().Network(), network)
+ t.Fatalf("req RemoteAddr parse network fail: %q, hope: %q", res.RemoteAddr().Network(), network)
}
if host != res.RemoteAddr().String() {
- t.Fatalf("req RemoteAddr parse addr fail: %s, hope: %s", res.RemoteAddr().String(), host)
+ t.Fatalf("req RemoteAddr parse addr fail: %q, hope: %q", res.RemoteAddr().String(), host)
}
if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) {
- t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
+ t.Fatalf("res LocalAddr addr match fail: %q, hope match: %q", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$")
}
}
@@ -439,7 +439,7 @@ func TestClientRedirectSameSchema(t *testing.T) {
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
- t.Fatalf("HostClient error: %s", err)
+ t.Fatalf("HostClient error: %v", err)
return
}
@@ -474,7 +474,7 @@ func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) {
statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond)
if err != nil {
- t.Fatalf("HostClient error: %s", err)
+ t.Fatalf("HostClient error: %v", err)
return
}
@@ -545,7 +545,7 @@ func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener {
}
if err != nil {
- t.Fatalf("cannot listen isTLS %v: %s", isTLS, err)
+ t.Fatalf("cannot listen isTLS %v: %v", isTLS, err)
}
return ln
@@ -573,7 +573,7 @@ func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listen
go func() {
err := s.Serve(ln)
if err != nil {
- t.Errorf("unexpected error returned from Serve(): %s", err)
+ t.Errorf("unexpected error returned from Serve(): %v", err)
}
close(ch)
}()
@@ -626,13 +626,17 @@ func TestClientHeaderCase(t *testing.T) {
}
func TestClientReadTimeout(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.SkipNow()
+ }
+
t.Parallel()
ln := fasthttputil.NewInmemoryListener()
timeout := false
s := &Server{
- Handler: func(ctx *RequestCtx) {
+ Handler: func(_ *RequestCtx) {
if timeout {
time.Sleep(time.Second)
} else {
@@ -720,7 +724,7 @@ func TestClientDefaultUserAgent(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if userAgentSeen != string(defaultUserAgent) {
+ if userAgentSeen != defaultUserAgent {
t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent)
}
}
@@ -804,7 +808,7 @@ func TestClientDoWithCustomHeaders(t *testing.T) {
uri := "/foo/bar/baz?a=b&cd=12"
headers := map[string]string{
"Foo": "bar",
- "Host": "xxx.com",
+ "Host": "example.com",
"Content-Type": "asdfsdf",
"a-b-c-d-f": "",
}
@@ -814,14 +818,14 @@ func TestClientDoWithCustomHeaders(t *testing.T) {
go func() {
conn, err := ln.Accept()
if err != nil {
- ch <- fmt.Errorf("cannot accept client connection: %s", err)
+ ch <- fmt.Errorf("cannot accept client connection: %w", err)
return
}
br := bufio.NewReader(conn)
var req Request
if err = req.Read(br); err != nil {
- ch <- fmt.Errorf("cannot read client request: %s", err)
+ ch <- fmt.Errorf("cannot read client request: %w", err)
return
}
if string(req.Header.Method()) != MethodPost {
@@ -854,11 +858,11 @@ func TestClientDoWithCustomHeaders(t *testing.T) {
var resp Response
bw := bufio.NewWriter(conn)
if err = resp.Write(bw); err != nil {
- ch <- fmt.Errorf("cannot send response: %s", err)
+ ch <- fmt.Errorf("cannot send response: %w", err)
return
}
if err = bw.Flush(); err != nil {
- ch <- fmt.Errorf("cannot flush response: %s", err)
+ ch <- fmt.Errorf("cannot flush response: %w", err)
return
}
@@ -877,7 +881,7 @@ func TestClientDoWithCustomHeaders(t *testing.T) {
err := c.DoTimeout(&req, &resp, time.Second)
if err != nil {
- t.Fatalf("error when doing request: %s", err)
+ t.Fatalf("error when doing request: %v", err)
}
select {
@@ -923,7 +927,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -959,7 +963,7 @@ func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -984,14 +988,14 @@ func testPipelineClientDo(t *testing.T, c *PipelineClient) {
time.Sleep(10 * time.Millisecond)
continue
}
- t.Fatalf("unexpected error on iteration %d: %s", i, err)
+ t.Errorf("unexpected error on iteration %d: %v", i, err)
}
if resp.StatusCode() != StatusOK {
- t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
+ t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
body := string(resp.Body())
if body != "OK" {
- t.Fatalf("unexpected body: %q. Expecting %q", body, "OK")
+ t.Errorf("unexpected body: %q. Expecting %q", body, "OK")
}
// sleep for a while, so the connection to the host may expire.
@@ -1028,7 +1032,7 @@ func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1046,11 +1050,11 @@ func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.
for i := 0; i < 5; i++ {
if timeout > 0 {
if err := c.DoTimeout(&req, &resp, timeout); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
} else {
if err := c.Do(&req, &resp); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
hv := resp.Header.Peek("foo-BAR")
@@ -1064,7 +1068,7 @@ func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1088,7 +1092,7 @@ func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1105,7 +1109,7 @@ func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek("foo-BAR")
if string(hv) != "baz" {
@@ -1118,7 +1122,7 @@ func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1143,7 +1147,7 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1162,7 +1166,7 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
var resp Response
for i := 0; i < 5; i++ {
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek("received-uri")
if string(hv) != urlWithEncodedPath {
@@ -1171,7 +1175,7 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1187,7 +1191,7 @@ func TestHostClientPendingRequests(t *testing.T) {
doneCh := make(chan struct{})
readyCh := make(chan struct{}, concurrency)
s := &Server{
- Handler: func(ctx *RequestCtx) {
+ Handler: func(_ *RequestCtx) {
readyCh <- struct{}{}
<-doneCh
},
@@ -1196,7 +1200,7 @@ func TestHostClientPendingRequests(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1221,7 +1225,7 @@ func TestHostClientPendingRequests(t *testing.T) {
resp := AcquireResponse()
if err := c.DoTimeout(req, resp, 10*time.Second); err != nil {
- resultCh <- fmt.Errorf("unexpected error: %s", err)
+ resultCh <- fmt.Errorf("unexpected error: %w", err)
return
}
@@ -1253,7 +1257,7 @@ func TestHostClientPendingRequests(t *testing.T) {
select {
case err := <-resultCh:
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
@@ -1267,7 +1271,7 @@ func TestHostClientPendingRequests(t *testing.T) {
// stop the server
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1298,7 +1302,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1328,7 +1332,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) {
time.Sleep(time.Millisecond)
continue
}
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
break
}
@@ -1346,7 +1350,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) {
wg.Wait()
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1376,7 +1380,7 @@ func TestHostClientMaxConnDuration(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1392,7 +1396,7 @@ func TestHostClientMaxConnDuration(t *testing.T) {
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
@@ -1404,7 +1408,7 @@ func TestHostClientMaxConnDuration(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1431,7 +1435,7 @@ func TestHostClientMultipleAddrs(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1448,7 +1452,7 @@ func TestHostClientMultipleAddrs(t *testing.T) {
for i := 0; i < 9; i++ {
statusCode, body, err := c.Get(nil, "http://foobar/baz/aaa?bbb=ddd")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
@@ -1459,7 +1463,7 @@ func TestHostClientMultipleAddrs(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -1501,7 +1505,7 @@ func TestClientFollowRedirects(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1516,7 +1520,7 @@ func TestClientFollowRedirects(t *testing.T) {
for i := 0; i < 10; i++ {
statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
@@ -1529,7 +1533,7 @@ func TestClientFollowRedirects(t *testing.T) {
for i := 0; i < 10; i++ {
statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code: %d", statusCode)
@@ -1547,7 +1551,31 @@ func TestClientFollowRedirects(t *testing.T) {
err := c.DoRedirects(req, resp, 16)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if statusCode := resp.StatusCode(); statusCode != StatusOK {
+ t.Fatalf("unexpected status code: %d", statusCode)
+ }
+
+ if body := string(resp.Body()); body != "/bar" {
+ t.Fatalf("unexpected response %q. Expecting %q", body, "/bar")
+ }
+
+ ReleaseRequest(req)
+ ReleaseResponse(resp)
+ }
+
+ for i := 0; i < 10; i++ {
+ req := AcquireRequest()
+ resp := AcquireResponse()
+
+ req.SetRequestURI("http://xxx/foo")
+
+ req.SetTimeout(time.Second)
+ err := c.DoRedirects(req, resp, 16)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode := resp.StatusCode(); statusCode != StatusOK {
@@ -1562,6 +1590,32 @@ func TestClientFollowRedirects(t *testing.T) {
ReleaseResponse(resp)
}
+ for i := 0; i < 10; i++ {
+ req := AcquireRequest()
+ resp := AcquireResponse()
+
+ req.SetRequestURI("http://xxx/foo")
+
+ testConn, _ := net.Dial("tcp", ln.Addr().String())
+ timeoutConn := &Client{
+ Dial: func(addr string) (net.Conn, error) {
+ return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
+ },
+ }
+
+ req.SetTimeout(time.Millisecond)
+ err := timeoutConn.DoRedirects(req, resp, 16)
+ if err == nil {
+ t.Errorf("expecting error")
+ }
+ if err != ErrTimeout {
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
+ }
+
+ ReleaseRequest(req)
+ ReleaseResponse(resp)
+ }
+
req := AcquireRequest()
resp := AcquireResponse()
@@ -1609,6 +1663,7 @@ func TestClientDoTimeoutSuccess(t *testing.T) {
defer s.Stop()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
+ testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}
func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
@@ -1623,6 +1678,7 @@ func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
go func() {
defer wg.Done()
testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
+ testClientRequestSetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100)
}()
}
wg.Wait()
@@ -1631,9 +1687,13 @@ func TestClientDoTimeoutSuccessConcurrent(t *testing.T) {
func TestClientGetTimeoutError(t *testing.T) {
t.Parallel()
+ s := startEchoServer(t, "tcp", "127.0.0.1:")
+ defer s.Stop()
+
+ testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
- return &readTimeoutConn{t: time.Second}, nil
+ return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
}
@@ -1643,9 +1703,13 @@ func TestClientGetTimeoutError(t *testing.T) {
func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
+ s := startEchoServer(t, "tcp", "127.0.0.1:")
+ defer s.Stop()
+
+ testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
- return &readTimeoutConn{t: time.Second}, nil
+ return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
@@ -1664,21 +1728,30 @@ func TestClientGetTimeoutErrorConcurrent(t *testing.T) {
func TestClientDoTimeoutError(t *testing.T) {
t.Parallel()
+ s := startEchoServer(t, "tcp", "127.0.0.1:")
+ defer s.Stop()
+
+ testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
- return &readTimeoutConn{t: time.Second}, nil
+ return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
}
testClientDoTimeoutError(t, c, 100)
+ testClientRequestSetTimeoutError(t, c, 100)
}
func TestClientDoTimeoutErrorConcurrent(t *testing.T) {
t.Parallel()
+ s := startEchoServer(t, "tcp", "127.0.0.1:")
+ defer s.Stop()
+
+ testConn, _ := net.Dial("tcp", s.ln.Addr().String())
c := &Client{
Dial: func(addr string) (net.Conn, error) {
- return &readTimeoutConn{t: time.Second}, nil
+ return &readTimeoutConn{Conn: testConn, t: time.Second}, nil
},
MaxConnsPerHost: 1000,
}
@@ -1701,10 +1774,10 @@ func testClientDoTimeoutError(t *testing.T, c *Client, n int) {
for i := 0; i < n; i++ {
err := c.DoTimeout(&req, &resp, time.Millisecond)
if err == nil {
- t.Fatalf("expecting error")
+ t.Errorf("expecting error")
}
if err != ErrTimeout {
- t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
}
@@ -1714,32 +1787,51 @@ func testClientGetTimeoutError(t *testing.T, c *Client, n int) {
for i := 0; i < n; i++ {
statusCode, body, err := c.GetTimeout(buf, "http://foobar.com/baz", time.Millisecond)
if err == nil {
- t.Fatalf("expecting error")
+ t.Errorf("expecting error")
}
if err != ErrTimeout {
- t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
if statusCode != 0 {
- t.Fatalf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
+ t.Errorf("unexpected statusCode=%d. Expecting %d", statusCode, 0)
}
if body == nil {
- t.Fatalf("body must be non-nil")
+ t.Errorf("body must be non-nil")
+ }
+ }
+}
+
+func testClientRequestSetTimeoutError(t *testing.T, c *Client, n int) {
+ var req Request
+ var resp Response
+ req.SetRequestURI("http://foobar.com/baz")
+ for i := 0; i < n; i++ {
+ req.SetTimeout(time.Millisecond)
+ err := c.Do(&req, &resp)
+ if err == nil {
+ t.Errorf("expecting error")
+ }
+ if err != ErrTimeout {
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
}
type readTimeoutConn struct {
net.Conn
- t time.Duration
+ t time.Duration
+ wc chan struct{}
+ rc chan struct{}
}
func (r *readTimeoutConn) Read(p []byte) (int, error) {
- time.Sleep(r.t)
- return 0, io.EOF
+ <-r.rc
+ return 0, os.ErrDeadlineExceeded
}
func (r *readTimeoutConn) Write(p []byte) (int, error) {
- return len(p), nil
+ <-r.wc
+ return 0, os.ErrDeadlineExceeded
}
func (r *readTimeoutConn) Close() error {
@@ -1754,12 +1846,30 @@ func (r *readTimeoutConn) RemoteAddr() net.Addr {
return nil
}
+func (r *readTimeoutConn) SetReadDeadline(d time.Time) error {
+ r.rc = make(chan struct{}, 1)
+ go func() {
+ time.Sleep(time.Until(d))
+ r.rc <- struct{}{}
+ }()
+ return nil
+}
+
+func (r *readTimeoutConn) SetWriteDeadline(d time.Time) error {
+ r.wc = make(chan struct{}, 1)
+ go func() {
+ time.Sleep(time.Until(d))
+ r.wc <- struct{}{}
+ }()
+ return nil
+}
+
func TestClientNonIdempotentRetry(t *testing.T) {
t.Parallel()
dialsCount := 0
c := &Client{
- Dial: func(addr string) (net.Conn, error) {
+ Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
@@ -1781,7 +1891,7 @@ func TestClientNonIdempotentRetry(t *testing.T) {
dialsCount = 0
statusCode, body, err := c.Post(nil, "http://foobar/a/b", nil)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
@@ -1794,7 +1904,7 @@ func TestClientNonIdempotentRetry(t *testing.T) {
dialsCount = 0
statusCode, body, err = c.Get(nil, "http://foobar/a/b")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
@@ -1809,7 +1919,7 @@ func TestClientNonIdempotentRetry_BodyStream(t *testing.T) {
dialsCount := 0
c := &Client{
- Dial: func(addr string) (net.Conn, error) {
+ Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1, 2:
@@ -1846,7 +1956,7 @@ func TestClientIdempotentRequest(t *testing.T) {
dialsCount := 0
c := &Client{
- Dial: func(addr string) (net.Conn, error) {
+ Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
@@ -1871,7 +1981,7 @@ func TestClientIdempotentRequest(t *testing.T) {
// idempotent GET must succeed.
statusCode, body, err := c.Get(nil, "http://foobar/a/b")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
@@ -1902,7 +2012,7 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
dialsCount := 0
c := &Client{
- Dial: func(addr string) (net.Conn, error) {
+ Dial: func(_ string) (net.Conn, error) {
dialsCount++
switch dialsCount {
case 1:
@@ -1932,7 +2042,7 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
// Post must succeed for http://foobar/a/b uri.
statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", statusCode)
@@ -1962,7 +2072,7 @@ func TestHostClientTransport(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -1992,7 +2102,7 @@ func TestHostClientTransport(t *testing.T) {
for i := 0; i < 5; i++ {
statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK)
@@ -2003,7 +2113,7 @@ func TestHostClientTransport(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -2142,7 +2252,7 @@ func TestSingleEchoConn(t *testing.T) {
err := c.Do(&req, &res)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if res.StatusCode() != 345 {
t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode())
@@ -2322,14 +2432,14 @@ func testClientGet(t *testing.T, c clientGetter, addr string, n int) {
statusCode, body, err := c.Get(buf, uri)
buf = body
if err != nil {
- t.Fatalf("unexpected error when doing http request: %s", err)
+ t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
- t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
+ t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if resultURI != uri {
- t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
+ t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
@@ -2342,17 +2452,41 @@ func testClientDoTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
req.SetRequestURI(uri)
if err := c.DoTimeout(&req, &resp, time.Second); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
+ }
+ if resp.StatusCode() != StatusOK {
+ t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
+ }
+ resultURI := string(resp.Body())
+ if strings.HasPrefix(uri, "https") {
+ resultURI = uri[:5] + resultURI[4:]
+ }
+ if resultURI != uri {
+ t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
+ }
+ }
+}
+
+func testClientRequestSetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
+ var req Request
+ var resp Response
+
+ for i := 0; i < n; i++ {
+ uri := fmt.Sprintf("%s/foo/%d?bar=baz", addr, i)
+ req.SetRequestURI(uri)
+ req.SetTimeout(time.Second)
+ if err := c.Do(&req, &resp); err != nil {
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
- t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
+ t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
resultURI := string(resp.Body())
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
- t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
+ t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
@@ -2364,17 +2498,17 @@ func testClientGetTimeoutSuccess(t *testing.T, c *Client, addr string, n int) {
statusCode, body, err := c.GetTimeout(buf, uri, time.Second)
buf = body
if err != nil {
- t.Fatalf("unexpected error when doing http request: %s", err)
+ t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
- t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
+ t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
resultURI := string(body)
if strings.HasPrefix(uri, "https") {
resultURI = uri[:5] + resultURI[4:]
}
if resultURI != uri {
- t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri)
+ t.Errorf("unexpected uri %q. Expecting %q", resultURI, uri)
}
}
}
@@ -2390,14 +2524,14 @@ func testClientPost(t *testing.T, c clientPoster, addr string, n int) {
statusCode, body, err := c.Post(buf, uri, &args)
buf = body
if err != nil {
- t.Fatalf("unexpected error when doing http request: %s", err)
+ t.Errorf("unexpected error when doing http request: %v", err)
}
if statusCode != StatusOK {
- t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
+ t.Errorf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
}
s := string(body)
if s != argsS {
- t.Fatalf("unexpected response %q. Expecting %q", s, argsS)
+ t.Errorf("unexpected response %q. Expecting %q", s, argsS)
}
}
}
@@ -2480,7 +2614,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
ln, err = net.Listen(network, addr)
}
if err != nil {
- t.Fatalf("cannot listen %q: %s", addr, err)
+ t.Fatalf("cannot listen %q: %v", addr, err)
}
s := &Server{
@@ -2497,7 +2631,7 @@ func startEchoServerExt(t *testing.T, network, addr string, isTLS bool) *testEch
go func() {
err := s.Serve(ln)
if err != nil {
- t.Errorf("unexpected error returned from Serve(): %s", err)
+ t.Errorf("unexpected error returned from Serve(): %v", err)
}
close(ch)
}()
@@ -2569,7 +2703,7 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -2595,7 +2729,7 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
resp := AcquireResponse()
if err := c.Do(req, resp); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
@@ -2614,7 +2748,7 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) {
t.Errorf("connsWait has %v items remaining", c.connsWait.len())
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -2648,7 +2782,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -2676,7 +2810,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
if err := c.Do(req, resp); err != nil {
if err != ErrNoFreeConns {
- t.Errorf("unexpected error: %s. Expecting %s", err, ErrNoFreeConns)
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrNoFreeConns)
}
atomic.AddUint32(&errNoFreeConnsCount, 1)
} else {
@@ -2704,7 +2838,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) {
t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount)
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -2738,11 +2872,12 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
time.Sleep(sleep)
ctx.WriteString("foo") //nolint:errcheck
},
+ Logger: &testLogger{}, // Don't print connection closed errors.
}
serverStopCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverStopCh)
}()
@@ -2770,7 +2905,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil {
if err != ErrTimeout {
- t.Errorf("unexpected error: %s. Expecting %s", err, ErrTimeout)
+ t.Errorf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
atomic.AddUint32(&errTimeoutCount, 1)
} else {
@@ -2795,7 +2930,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
}
w.mu.Lock()
if w.err != nil && w.err != ErrTimeout {
- t.Errorf("unexpected error: %s. Expecting %s", w.err, ErrTimeout)
+ t.Errorf("unexpected error: %v. Expecting %v", w.err, ErrTimeout)
}
w.mu.Unlock()
}
@@ -2804,7 +2939,7 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount)
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-serverStopCh:
@@ -2816,3 +2951,83 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
t.Fatalf("at least one request body was empty")
}
}
+
+func TestHttpsRequestWithoutParsedURL(t *testing.T) {
+ t.Parallel()
+
+ client := HostClient{
+ IsTLS: true,
+ Transport: func(r1 *Request, r2 *Response) error {
+ return nil
+ },
+ }
+
+ req := &Request{}
+
+ req.SetRequestURI("https://foo.com/bar")
+
+ _, err := client.doNonNilReqResp(req, &Response{})
+ if err != nil {
+ t.Fatal("https requests with IsTLS client must succeed")
+ }
+}
+
+func Test_AddMissingPort(t *testing.T) {
+ type args struct {
+ addr string
+ isTLS bool
+ }
+ tests := []struct {
+ name string
+ args args
+ want string
+ }{
+ {
+ args: args{"127.1", false}, // 127.1 is a short form of 127.0.0.1
+ want: "127.1:80",
+ },
+ {
+ args: args{"127.0.0.1", false},
+ want: "127.0.0.1:80",
+ },
+ {
+ args: args{"127.0.0.1", true},
+ want: "127.0.0.1:443",
+ },
+ {
+ args: args{"[::1]", false},
+ want: "[::1]:80",
+ },
+ {
+ args: args{"::1", false},
+ want: "::1", // keep as is
+ },
+ {
+ args: args{"[::1]", true},
+ want: "[::1]:443",
+ },
+ {
+ args: args{"127.0.0.1:8080", false},
+ want: "127.0.0.1:8080",
+ },
+ {
+ args: args{"127.0.0.1:8443", true},
+ want: "127.0.0.1:8443",
+ },
+ {
+ args: args{"[::1]:8080", false},
+ want: "[::1]:8080",
+ },
+ {
+ args: args{"[::1]:8443", true},
+ want: "[::1]:8443",
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.want, func(t *testing.T) {
+ if got := AddMissingPort(tt.args.addr, tt.args.isTLS); got != tt.want {
+ t.Errorf("AddMissingPort() = %v, want %v", got, tt.want)
+ }
+ })
+ }
+}
diff --git a/client_timing_test.go b/client_timing_test.go
index 7beaaba..0755155 100644
--- a/client_timing_test.go
+++ b/client_timing_test.go
@@ -3,7 +3,7 @@ package fasthttp
import (
"bytes"
"fmt"
- "io/ioutil"
+ "io"
"net"
"net/http"
"runtime"
@@ -102,7 +102,7 @@ func BenchmarkClientGetTimeoutFastServer(b *testing.B) {
for pb.Next() {
statusCode, bodyBuf, err = c.GetTimeout(bodyBuf[:0], url, time.Second)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d", statusCode)
@@ -131,7 +131,7 @@ func BenchmarkClientDoFastServer(b *testing.B) {
req.Header.SetRequestURI(fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)))
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.Header.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d", resp.Header.StatusCode())
@@ -159,20 +159,20 @@ func BenchmarkNetHTTPClientDoFastServer(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
req, err := http.NewRequest(MethodGet, fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
for pb.Next() {
resp, err := c.Do(req)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d", resp.StatusCode)
}
- respBody, err := ioutil.ReadAll(resp.Body)
+ respBody, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
- b.Fatalf("unexpected error when reading response body: %s", err)
+ b.Fatalf("unexpected error when reading response body: %v", err)
}
if !bytes.Equal(respBody, body) {
b.Fatalf("unexpected response body: %q. Expected %q", respBody, body)
@@ -207,13 +207,13 @@ func benchmarkClientGetEndToEndTCP(b *testing.B, parallelism int) {
ln, err := net.Listen("tcp4", addr)
if err != nil {
- b.Fatalf("cannot listen %q: %s", addr, err)
+ b.Fatalf("cannot listen %q: %v", addr, err)
}
ch := make(chan struct{})
go func() {
if err := Serve(ln, fasthttpEchoHandler); err != nil {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -230,7 +230,7 @@ func benchmarkClientGetEndToEndTCP(b *testing.B, parallelism int) {
for pb.Next() {
statusCode, body, err := c.Get(buf, url)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
@@ -267,14 +267,14 @@ func benchmarkNetHTTPClientGetEndToEndTCP(b *testing.B, parallelism int) {
ln, err := net.Listen("tcp4", addr)
if err != nil {
- b.Fatalf("cannot listen %q: %s", addr, err)
+ b.Fatalf("cannot listen %q: %v", addr, err)
}
ch := make(chan struct{})
go func() {
if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -292,15 +292,15 @@ func benchmarkNetHTTPClientGetEndToEndTCP(b *testing.B, parallelism int) {
for pb.Next() {
resp, err := c.Get(url)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
- body, err := ioutil.ReadAll(resp.Body)
+ body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
- b.Fatalf("unexpected error when reading response body: %s", err)
+ b.Fatalf("unexpected error when reading response body: %v", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
@@ -342,7 +342,7 @@ func benchmarkClientGetEndToEndInmemory(b *testing.B, parallelism int) {
ch := make(chan struct{})
go func() {
if err := Serve(ln, fasthttpEchoHandler); err != nil {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -360,7 +360,7 @@ func benchmarkClientGetEndToEndInmemory(b *testing.B, parallelism int) {
for pb.Next() {
statusCode, body, err := c.Get(buf, url)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if statusCode != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK)
@@ -403,7 +403,7 @@ func benchmarkNetHTTPClientGetEndToEndInmemory(b *testing.B, parallelism int) {
go func() {
if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -422,15 +422,15 @@ func benchmarkNetHTTPClientGetEndToEndInmemory(b *testing.B, parallelism int) {
for pb.Next() {
resp, err := c.Get(url)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
- body, err := ioutil.ReadAll(resp.Body)
+ body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
- b.Fatalf("unexpected error when reading response body: %s", err)
+ b.Fatalf("unexpected error when reading response body: %v", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
@@ -466,7 +466,7 @@ func benchmarkClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
ch := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -485,7 +485,7 @@ func benchmarkClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
var resp Response
for pb.Next() {
if err := c.DoTimeout(&req, &resp, 5*time.Second); err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -515,7 +515,7 @@ func BenchmarkNetHTTPClientEndToEndBigResponse10Inmemory(b *testing.B) {
func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) {
bigResponse := createFixedBody(1024 * 1024)
- h := func(w http.ResponseWriter, r *http.Request) {
+ h := func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set(HeaderContentType, "text/plain")
w.Write(bigResponse) //nolint:errcheck
}
@@ -525,7 +525,7 @@ func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism
go func() {
if err := http.Serve(ln, http.HandlerFunc(h)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -544,20 +544,20 @@ func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism
b.RunParallel(func(pb *testing.PB) {
req, err := http.NewRequest(MethodGet, url, nil)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
for pb.Next() {
resp, err := c.Do(req)
if err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
- body, err := ioutil.ReadAll(resp.Body)
+ body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
- b.Fatalf("unexpected error when reading response body: %s", err)
+ b.Fatalf("unexpected error when reading response body: %v", err)
}
if !bytes.Equal(bigResponse, body) {
b.Fatalf("unexpected response %q. Expecting %q", body, bigResponse)
@@ -598,7 +598,7 @@ func benchmarkPipelineClient(b *testing.B, parallelism int) {
ch := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -621,7 +621,7 @@ func benchmarkPipelineClient(b *testing.B, parallelism int) {
var resp Response
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
diff --git a/client_timing_wait_test.go b/client_timing_wait_test.go
index caaacac..906ec7d 100644
--- a/client_timing_wait_test.go
+++ b/client_timing_wait_test.go
@@ -1,10 +1,7 @@
-//go:build go1.11
-// +build go1.11
-
package fasthttp
import (
- "io/ioutil"
+ "io"
"net"
"net/http"
"strings"
@@ -45,7 +42,7 @@ func benchmarkClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) {
go func() {
if err := Serve(ln, newFasthttpSleepEchoHandler(sleepDuration)); err != nil {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -65,7 +62,7 @@ func benchmarkClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) {
statusCode, body, err := c.Get(buf, url)
if err != nil {
if err != ErrNoFreeConns {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
} else {
if statusCode != StatusOK {
@@ -119,7 +116,7 @@ func benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism
go func() {
if err := http.Serve(ln, newNethttpSleepEchoHandler(sleep)); err != nil && !strings.Contains(
err.Error(), "use of closed network connection") {
- b.Errorf("error when serving requests: %s", err)
+ b.Errorf("error when serving requests: %v", err)
}
close(ch)
}()
@@ -140,16 +137,16 @@ func benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism
resp, err := c.Get(url)
if err != nil {
if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
} else {
if resp.StatusCode != http.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK)
}
- body, err := ioutil.ReadAll(resp.Body)
+ body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
- b.Fatalf("unexpected error when reading response body: %s", err)
+ b.Fatalf("unexpected error when reading response body: %v", err)
}
if string(body) != requestURI {
b.Fatalf("unexpected response %q. Expecting %q", body, requestURI)
diff --git a/client_unix_test.go b/client_unix_test.go
new file mode 100644
index 0000000..380f909
--- /dev/null
+++ b/client_unix_test.go
@@ -0,0 +1,135 @@
+//go:build !windows
+// +build !windows
+
+package fasthttp
+
+import (
+ "io"
+ "net"
+ "net/http"
+ "strings"
+ "testing"
+)
+
+// See issue #1232
+func TestRstConnResponseWhileSending(t *testing.T) {
+ const expectedStatus = http.StatusTeapot
+ const payload = "payload"
+
+ srv, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+
+ go func() {
+ for {
+ conn, err := srv.Accept()
+ if err != nil {
+ return
+ }
+
+ // Read at least one byte of the header
+ // Otherwise we would have an unsolicited response
+ _, err = io.ReadAll(io.LimitReader(conn, 1))
+ if err != nil {
+ t.Error(err)
+ }
+
+ // Respond
+ _, err = conn.Write([]byte("HTTP/1.1 418 Teapot\r\n\r\n"))
+ if err != nil {
+ t.Error(err)
+ }
+
+ // Forcefully close connection
+ err = conn.(*net.TCPConn).SetLinger(0)
+ if err != nil {
+ t.Error(err)
+ }
+ conn.Close()
+ }
+ }()
+
+ svrUrl := "http://" + srv.Addr().String()
+ client := HostClient{Addr: srv.Addr().String()}
+
+ for i := 0; i < 100; i++ {
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ resp := AcquireResponse()
+ defer ReleaseResponse(resp)
+
+ req.Header.SetMethod("POST")
+ req.SetBodyStream(strings.NewReader(payload), len(payload))
+ req.SetRequestURI(svrUrl)
+
+ err = client.Do(req, resp)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if expectedStatus != resp.StatusCode() {
+ t.Fatalf("Expected %d status code, but got %d", expectedStatus, resp.StatusCode())
+ }
+ }
+}
+
+// See issue #1232
+func TestRstConnClosedWithoutResponse(t *testing.T) {
+ const payload = "payload"
+
+ srv, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+
+ go func() {
+ for {
+ conn, err := srv.Accept()
+ if err != nil {
+ return
+ }
+
+ // Read at least one byte of the header
+ // Otherwise we would have an unsolicited response
+ _, err = io.ReadAll(io.LimitReader(conn, 1))
+ if err != nil {
+ t.Error(err)
+ }
+
+ // Respond with incomplete header
+ _, err = conn.Write([]byte("Http"))
+ if err != nil {
+ t.Error(err)
+ }
+
+ // Forcefully close connection
+ err = conn.(*net.TCPConn).SetLinger(0)
+ if err != nil {
+ t.Error(err)
+ }
+ conn.Close()
+ }
+ }()
+
+ svrUrl := "http://" + srv.Addr().String()
+ client := HostClient{Addr: srv.Addr().String()}
+
+ for i := 0; i < 100; i++ {
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ resp := AcquireResponse()
+ defer ReleaseResponse(resp)
+
+ req.Header.SetMethod("POST")
+ req.SetBodyStream(strings.NewReader(payload), len(payload))
+ req.SetRequestURI(svrUrl)
+
+ err = client.Do(req, resp)
+
+ if !isConnectionReset(err) {
+ t.Fatal("Expected connection reset error")
+ }
+ }
+}
diff --git a/compress.go b/compress.go
index f590d28..49fbf3c 100644
--- a/compress.go
+++ b/compress.go
@@ -101,7 +101,7 @@ func acquireRealGzipWriter(w io.Writer, level int) *gzip.Writer {
if v == nil {
zw, err := gzip.NewWriterLevel(w, level)
if err != nil {
- panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err))
+ panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %v", level, err))
}
return zw
}
@@ -127,11 +127,11 @@ var (
//
// Supported compression levels are:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func AppendGzipBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteGzipLevel(w, src, level) //nolint:errcheck
@@ -143,11 +143,11 @@ func AppendGzipBytesLevel(dst, src []byte, level int) []byte {
//
// Supported compression levels are:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
@@ -177,7 +177,7 @@ func nonblockingWriteGzip(ctxv interface{}) {
_, err := zw.Write(ctx.p)
if err != nil {
- panic(fmt.Sprintf("BUG: gzip.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
+ panic(fmt.Sprintf("BUG: gzip.Writer.Write for len(p)=%d returned unexpected error: %v", len(ctx.p), err))
}
releaseRealGzipWriter(zw, ctx.level)
@@ -223,11 +223,11 @@ func AppendGunzipBytes(dst, src []byte) ([]byte, error) {
//
// Supported compression levels are:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func AppendDeflateBytesLevel(dst, src []byte, level int) []byte {
w := &byteSliceWriter{dst}
WriteDeflateLevel(w, src, level) //nolint:errcheck
@@ -239,11 +239,11 @@ func AppendDeflateBytesLevel(dst, src []byte, level int) []byte {
//
// Supported compression levels are:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func WriteDeflateLevel(w io.Writer, p []byte, level int) (int, error) {
switch w.(type) {
case *byteSliceWriter,
@@ -273,7 +273,7 @@ func nonblockingWriteDeflate(ctxv interface{}) {
_, err := zw.Write(ctx.p)
if err != nil {
- panic(fmt.Sprintf("BUG: zlib.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err))
+ panic(fmt.Sprintf("BUG: zlib.Writer.Write for len(p)=%d returned unexpected error: %v", len(ctx.p), err))
}
releaseRealDeflateWriter(zw, ctx.level)
@@ -379,7 +379,7 @@ func acquireRealDeflateWriter(w io.Writer, level int) *zlib.Writer {
if v == nil {
zw, err := zlib.NewWriterLevel(w, level)
if err != nil {
- panic(fmt.Sprintf("BUG: unexpected error from zlib.NewWriterLevel(%d): %s", level, err))
+ panic(fmt.Sprintf("BUG: unexpected error from zlib.NewWriterLevel(%d): %v", level, err))
}
return zw
}
diff --git a/compress_test.go b/compress_test.go
index 572641a..329f0f5 100644
--- a/compress_test.go
+++ b/compress_test.go
@@ -3,7 +3,7 @@ package fasthttp
import (
"bytes"
"fmt"
- "io/ioutil"
+ "io"
"testing"
"time"
)
@@ -78,7 +78,7 @@ func testGzipBytesSingleCase(s string) error {
gunzippedS, err := AppendGunzipBytes(prefix, gzippedS[len(prefix):])
if err != nil {
- return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
+ return fmt.Errorf("unexpected error when uncompressing %q: %w", s, err)
}
if !bytes.Equal(gunzippedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, gunzippedS[:len(prefix)], prefix)
@@ -99,7 +99,7 @@ func testDeflateBytesSingleCase(s string) error {
inflatedS, err := AppendInflateBytes(prefix, deflatedS[len(prefix):])
if err != nil {
- return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err)
+ return fmt.Errorf("unexpected error when uncompressing %q: %w", s, err)
}
if !bytes.Equal(inflatedS[:len(prefix)], prefix) {
return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, inflatedS[:len(prefix)], prefix)
@@ -165,17 +165,17 @@ func testGzipCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessGzipWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
zr, err := acquireGzipReader(&buf)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
- body, err := ioutil.ReadAll(zr)
+ body, err := io.ReadAll(zr)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
@@ -188,17 +188,17 @@ func testFlateCompressSingleCase(s string) error {
var buf bytes.Buffer
zw := acquireStacklessDeflateWriter(&buf, CompressDefaultCompression)
if _, err := zw.Write([]byte(s)); err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
releaseStacklessDeflateWriter(zw, CompressDefaultCompression)
zr, err := acquireFlateReader(&buf)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
- body, err := ioutil.ReadAll(zr)
+ body, err := io.ReadAll(zr)
if err != nil {
- return fmt.Errorf("unexpected error: %s. s=%q", err, s)
+ return fmt.Errorf("unexpected error: %w. s=%q", err, s)
}
if string(body) != s {
return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s)
@@ -213,7 +213,7 @@ func testConcurrent(concurrency int, f func() error) error {
go func(idx int) {
err := f()
if err != nil {
- ch <- fmt.Errorf("error in goroutine %d: %s", idx, err)
+ ch <- fmt.Errorf("error in goroutine %d: %w", idx, err)
}
ch <- nil
}(i)
diff --git a/cookie_test.go b/cookie_test.go
index 1a12b44..b4b81ac 100644
--- a/cookie_test.go
+++ b/cookie_test.go
@@ -33,7 +33,7 @@ func testCookieValueWithEqualAndSpaceChars(t *testing.T, expectedName, expectedP
var c1 Cookie
if err := c1.Parse(s); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
name := c1.Key()
if string(name) != expectedName {
@@ -55,7 +55,7 @@ func TestCookieSecureHttpOnly(t *testing.T) {
var c Cookie
if err := c.Parse("foo=bar; HttpOnly; secure"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !c.Secure() {
t.Fatalf("secure must be set")
@@ -78,7 +78,7 @@ func TestCookieSecure(t *testing.T) {
var c Cookie
if err := c.Parse("foo=bar; secure"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !c.Secure() {
t.Fatalf("secure must be set")
@@ -89,7 +89,7 @@ func TestCookieSecure(t *testing.T) {
}
if err := c.Parse("foo=bar"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.Secure() {
t.Fatalf("Unexpected secure flag set")
@@ -106,7 +106,7 @@ func TestCookieSameSite(t *testing.T) {
var c Cookie
if err := c.Parse("foo=bar; samesite"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.SameSite() != CookieSameSiteDefaultMode {
t.Fatalf("SameSite must be set")
@@ -117,7 +117,7 @@ func TestCookieSameSite(t *testing.T) {
}
if err := c.Parse("foo=bar; samesite=lax"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.SameSite() != CookieSameSiteLaxMode {
t.Fatalf("SameSite Lax Mode must be set")
@@ -128,7 +128,7 @@ func TestCookieSameSite(t *testing.T) {
}
if err := c.Parse("foo=bar; samesite=strict"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.SameSite() != CookieSameSiteStrictMode {
t.Fatalf("SameSite Strict Mode must be set")
@@ -139,7 +139,7 @@ func TestCookieSameSite(t *testing.T) {
}
if err := c.Parse("foo=bar; samesite=none"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.SameSite() != CookieSameSiteNoneMode {
t.Fatalf("SameSite None Mode must be set")
@@ -150,7 +150,7 @@ func TestCookieSameSite(t *testing.T) {
}
if err := c.Parse("foo=bar"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
c.SetSameSite(CookieSameSiteNoneMode)
s = c.String()
@@ -162,7 +162,7 @@ func TestCookieSameSite(t *testing.T) {
}
if err := c.Parse("foo=bar"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.SameSite() != CookieSameSiteDisabled {
t.Fatalf("Unexpected SameSite flag set")
@@ -180,7 +180,7 @@ func TestCookieMaxAge(t *testing.T) {
maxAge := 100
if err := c.Parse("foo=bar; max-age=100"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if maxAge != c.MaxAge() {
t.Fatalf("max-age must be set")
@@ -191,7 +191,7 @@ func TestCookieMaxAge(t *testing.T) {
}
if err := c.Parse("foo=bar; expires=Tue, 10 Nov 2009 23:00:00 GMT; max-age=100;"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if maxAge != c.MaxAge() {
t.Fatalf("max-age ignored")
@@ -221,7 +221,7 @@ func TestCookieHttpOnly(t *testing.T) {
var c Cookie
if err := c.Parse("foo=bar; HttpOnly"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !c.HTTPOnly() {
t.Fatalf("HTTPOnly must be set")
@@ -232,7 +232,7 @@ func TestCookieHttpOnly(t *testing.T) {
}
if err := c.Parse("foo=bar"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if c.HTTPOnly() {
t.Fatalf("Unexpected HTTPOnly flag set")
@@ -286,7 +286,7 @@ func testCookieAcquireRelease(t *testing.T) {
s := c.String()
c.Reset()
if err := c.Parse(s); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(c.Key()) != key {
@@ -322,7 +322,7 @@ func TestCookieParse(t *testing.T) {
func testCookieParse(t *testing.T, s, expectedS string) {
var c Cookie
if err := c.Parse(s); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
result := string(c.Cookie())
if result != expectedS {
diff --git a/cookie_timing_test.go b/cookie_timing_test.go
index bcf958c..1af2687 100644
--- a/cookie_timing_test.go
+++ b/cookie_timing_test.go
@@ -9,7 +9,7 @@ func BenchmarkCookieParseMin(b *testing.B) {
s := []byte("xxx=yyy")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
- b.Fatalf("unexpected error when parsing cookies: %s", err)
+ b.Fatalf("unexpected error when parsing cookies: %v", err)
}
}
}
@@ -19,7 +19,7 @@ func BenchmarkCookieParseNoExpires(b *testing.B) {
s := []byte("xxx=yyy; domain=foobar.com; path=/a/b")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
- b.Fatalf("unexpected error when parsing cookies: %s", err)
+ b.Fatalf("unexpected error when parsing cookies: %v", err)
}
}
}
@@ -29,7 +29,7 @@ func BenchmarkCookieParseFull(b *testing.B) {
s := []byte("xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b")
for i := 0; i < b.N; i++ {
if err := c.ParseBytes(s); err != nil {
- b.Fatalf("unexpected error when parsing cookies: %s", err)
+ b.Fatalf("unexpected error when parsing cookies: %v", err)
}
}
}
diff --git a/debian/changelog b/debian/changelog
index 0ade4c1..d952856 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,11 @@
+golang-github-valyala-fasthttp (1:1.44.0-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+ * Drop patch 0001-bytesconv-add-appropriate-build-tags-for-s390x.patch,
+ present upstream.
+
+ -- Debian Janitor <janitor@jelmer.uk> Sun, 26 Feb 2023 06:27:43 -0000
+
golang-github-valyala-fasthttp (1:1.31.0-4) unstable; urgency=medium
[ Guillem Jover ]
diff --git a/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch b/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch
deleted file mode 100644
index 177a3b0..0000000
--- a/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch
+++ /dev/null
@@ -1,73 +0,0 @@
-Description: Add appropriate build tags for s390x
- The bytesconv 32-bit tests fail on s390x, because it is a 64-bit
- architecture. Add the appropriate build flags so that 32-bit tests do
- not run on this architecture.
-Author: Nick Rosbrook <nick.rosbrook@canonical.com>
-Forwarded: https://github.com/valyala/fasthttp/pull/1250
-Last-Update: 2022-03-16
----
-From d6c6e4a7cc9c17158dc2c93090e5b7d26ca42e15 Mon Sep 17 00:00:00 2001
-From: Nick Rosbrook <nr@enr0n.net>
-Date: Wed, 16 Mar 2022 09:41:03 -0400
-Subject: [PATCH] bytesconv: add appropriate build tags for s390x
-
-The bytesconv 32-bit tests fail on s390x, because it is a 64-bit
-architecture. Add the appropriate build flags so that 32-bit tests do
-not run on this architecture.
----
- bytesconv_32.go | 4 ++--
- bytesconv_32_test.go | 4 ++--
- bytesconv_64.go | 4 ++--
- bytesconv_64_test.go | 4 ++--
- 4 files changed, 8 insertions(+), 8 deletions(-)
-diff --git a/bytesconv_32.go b/bytesconv_32.go
-index 6a6fec2..b574883 100644
---- a/bytesconv_32.go
-+++ b/bytesconv_32.go
-@@ -1,5 +1,5 @@
--//go:build !amd64 && !arm64 && !ppc64 && !ppc64le
--// +build !amd64,!arm64,!ppc64,!ppc64le
-+//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x
-+// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x
-
- package fasthttp
-
-diff --git a/bytesconv_32_test.go b/bytesconv_32_test.go
-index cec5aa9..3f5d5de 100644
---- a/bytesconv_32_test.go
-+++ b/bytesconv_32_test.go
-@@ -1,5 +1,5 @@
--//go:build !amd64 && !arm64 && !ppc64 && !ppc64le
--// +build !amd64,!arm64,!ppc64,!ppc64le
-+//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x
-+// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x
-
- package fasthttp
-
-diff --git a/bytesconv_64.go b/bytesconv_64.go
-index 1300d5a..94d0ec6 100644
---- a/bytesconv_64.go
-+++ b/bytesconv_64.go
-@@ -1,5 +1,5 @@
--//go:build amd64 || arm64 || ppc64 || ppc64le
--// +build amd64 arm64 ppc64 ppc64le
-+//go:build amd64 || arm64 || ppc64 || ppc64le || s390x
-+// +build amd64 arm64 ppc64 ppc64le s390x
-
- package fasthttp
-
-diff --git a/bytesconv_64_test.go b/bytesconv_64_test.go
-index 5351591..0689809 100644
---- a/bytesconv_64_test.go
-+++ b/bytesconv_64_test.go
-@@ -1,5 +1,5 @@
--//go:build amd64 || arm64 || ppc64 || ppc64le
--// +build amd64 arm64 ppc64 ppc64le
-+//go:build amd64 || arm64 || ppc64 || ppc64le || s390x
-+// +build amd64 arm64 ppc64 ppc64le s390x
-
- package fasthttp
-
---
-2.32.0
-
diff --git a/debian/patches/series b/debian/patches/series
index bcf16dc..e69de29 100644
--- a/debian/patches/series
+++ b/debian/patches/series
@@ -1 +0,0 @@
-0001-bytesconv-add-appropriate-build-tags-for-s390x.patch
diff --git a/doc.go b/doc.go
index efcd4a0..f2bf58d 100644
--- a/doc.go
+++ b/doc.go
@@ -3,35 +3,53 @@ Package fasthttp provides fast HTTP server and client API.
Fasthttp provides the following features:
- * Optimized for speed. Easily handles more than 100K qps and more than 1M
- concurrent keep-alive connections on modern hardware.
- * Optimized for low memory usage.
- * Easy 'Connection: Upgrade' support via RequestCtx.Hijack.
- * Server provides the following anti-DoS limits:
-
- * The number of concurrent connections.
- * The number of concurrent connections per client IP.
- * The number of requests per connection.
- * Request read timeout.
- * Response write timeout.
- * Maximum request header size.
- * Maximum request body size.
- * Maximum request execution time.
- * Maximum keep-alive connection lifetime.
- * Early filtering out non-GET requests.
-
- * A lot of additional useful info is exposed to request handler:
-
- * Server and client address.
- * Per-request logger.
- * Unique request id.
- * Request start time.
- * Connection start time.
- * Request sequence number for the current connection.
-
- * Client supports automatic retry on idempotent requests' failure.
- * Fasthttp API is designed with the ability to extend existing client
- and server implementations or to write custom client and server
- implementations from scratch.
+ 1. Optimized for speed. Easily handles more than 100K qps and more than 1M
+ concurrent keep-alive connections on modern hardware.
+
+ 2. Optimized for low memory usage.
+
+ 3. Easy 'Connection: Upgrade' support via RequestCtx.Hijack.
+
+ 4. Server provides the following anti-DoS limits:
+
+ - The number of concurrent connections.
+
+ - The number of concurrent connections per client IP.
+
+ - The number of requests per connection.
+
+ - Request read timeout.
+
+ - Response write timeout.
+
+ - Maximum request header size.
+
+ - Maximum request body size.
+
+ - Maximum request execution time.
+
+ - Maximum keep-alive connection lifetime.
+
+ - Early filtering out non-GET requests.
+
+ 5. A lot of additional useful info is exposed to request handler:
+
+ - Server and client address.
+
+ - Per-request logger.
+
+ - Unique request id.
+
+ - Request start time.
+
+ - Connection start time.
+
+ - Request sequence number for the current connection.
+
+ 6. Client supports automatic retry on idempotent requests' failure.
+
+ 7. Fasthttp API is designed with the ability to extend existing client
+ and server implementations or to write custom client and server
+ implementations from scratch.
*/
package fasthttp
diff --git a/examples/client/.gitignore b/examples/client/.gitignore
new file mode 100644
index 0000000..b051c6c
--- /dev/null
+++ b/examples/client/.gitignore
@@ -0,0 +1 @@
+client
diff --git a/examples/client/Makefile b/examples/client/Makefile
new file mode 100644
index 0000000..d2844fb
--- /dev/null
+++ b/examples/client/Makefile
@@ -0,0 +1,6 @@
+client: clean
+ go get -u github.com/valyala/fasthttp
+ go build
+
+clean:
+ rm -f client
diff --git a/examples/client/README.md b/examples/client/README.md
new file mode 100644
index 0000000..cef2d35
--- /dev/null
+++ b/examples/client/README.md
@@ -0,0 +1,21 @@
+# Client Example
+
+The Client is useful when working with multiple hostnames.
+
+See the simplest `sendGetRequest()` for GET and more advanced `sendPostRequest()` for a POST request.
+
+The `sendPostRequest()` also shows:
+* Per-request timeout with `DoTimeout()`
+* Send a body as bytes slice with `SetBodyRaw()`. This is useful if you generated a request body. Otherwise, prefer `SetBody()` which copies it.
+* Parse JSON from response
+* Gracefully show error messages i.e. timeouts as warnings and other errors as a failures with detailed error messages.
+
+## How to build and run
+Start a web server on localhost:8080 then execute:
+
+ make
+ ./client
+
+## Client vs HostClient
+Internally the Client creates a dedicated HostClient for each domain/IP address and cleans unused after period of time.
+So if you have a single heavily loaded API endpoint it's better to use HostClient. See an example in the [examples/host_client](../host_client/)
diff --git a/examples/client/client.go b/examples/client/client.go
new file mode 100644
index 0000000..66881c9
--- /dev/null
+++ b/examples/client/client.go
@@ -0,0 +1,125 @@
+package main
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+ "net/http"
+ "os"
+ "reflect"
+ "time"
+
+ "github.com/valyala/fasthttp"
+)
+
+var headerContentTypeJson = []byte("application/json")
+
+var client *fasthttp.Client
+
+type Entity struct {
+ Id int
+ Name string
+}
+
+func main() {
+ // You may read the timeouts from some config
+ readTimeout, _ := time.ParseDuration("500ms")
+ writeTimeout, _ := time.ParseDuration("500ms")
+ maxIdleConnDuration, _ := time.ParseDuration("1h")
+ client = &fasthttp.Client{
+ ReadTimeout: readTimeout,
+ WriteTimeout: writeTimeout,
+ MaxIdleConnDuration: maxIdleConnDuration,
+ NoDefaultUserAgentHeader: true, // Don't send: User-Agent: fasthttp
+ DisableHeaderNamesNormalizing: true, // If you set the case on your headers correctly you can enable this
+ DisablePathNormalizing: true,
+ // increase DNS cache time to an hour instead of default minute
+ Dial: (&fasthttp.TCPDialer{
+ Concurrency: 4096,
+ DNSCacheDuration: time.Hour,
+ }).Dial,
+ }
+ sendGetRequest()
+ sendPostRequest()
+}
+
+func sendGetRequest() {
+ req := fasthttp.AcquireRequest()
+ req.SetRequestURI("http://localhost:8080/")
+ req.Header.SetMethod(fasthttp.MethodGet)
+ resp := fasthttp.AcquireResponse()
+ err := client.Do(req, resp)
+ fasthttp.ReleaseRequest(req)
+ if err == nil {
+ fmt.Printf("DEBUG Response: %s\n", resp.Body())
+ } else {
+ fmt.Fprintf(os.Stderr, "ERR Connection error: %v\n", err)
+ }
+ fasthttp.ReleaseResponse(resp)
+}
+
+func sendPostRequest() {
+ // per-request timeout
+ reqTimeout := time.Duration(100) * time.Millisecond
+
+ reqEntity := &Entity{
+ Name: "New entity",
+ }
+ reqEntityBytes, _ := json.Marshal(reqEntity)
+
+ req := fasthttp.AcquireRequest()
+ req.SetRequestURI("http://localhost:8080/")
+ req.Header.SetMethod(fasthttp.MethodPost)
+ req.Header.SetContentTypeBytes(headerContentTypeJson)
+ req.SetBodyRaw(reqEntityBytes)
+ resp := fasthttp.AcquireResponse()
+ err := client.DoTimeout(req, resp, reqTimeout)
+ fasthttp.ReleaseRequest(req)
+ if err == nil {
+ statusCode := resp.StatusCode()
+ respBody := resp.Body()
+ fmt.Printf("DEBUG Response: %s\n", respBody)
+ if statusCode == http.StatusOK {
+ respEntity := &Entity{}
+ err = json.Unmarshal(respBody, respEntity)
+ if err == io.EOF || err == nil {
+ fmt.Printf("DEBUG Parsed Response: %v\n", respEntity)
+ } else {
+ fmt.Fprintf(os.Stderr, "ERR failed to parse reponse: %v\n", err)
+ }
+ } else {
+ fmt.Fprintf(os.Stderr, "ERR invalid HTTP response code: %d\n", statusCode)
+ }
+ } else {
+ errName, known := httpConnError(err)
+ if known {
+ fmt.Fprintf(os.Stderr, "WARN conn error: %v\n", errName)
+ } else {
+ fmt.Fprintf(os.Stderr, "ERR conn failure: %v %v\n", errName, err)
+ }
+ }
+ fasthttp.ReleaseResponse(resp)
+}
+
+func httpConnError(err error) (string, bool) {
+ errName := ""
+ known := false
+ if err == fasthttp.ErrTimeout {
+ errName = "timeout"
+ known = true
+ } else if err == fasthttp.ErrNoFreeConns {
+ errName = "conn_limit"
+ known = true
+ } else if err == fasthttp.ErrConnectionClosed {
+ errName = "conn_close"
+ known = true
+ } else {
+ errName = reflect.TypeOf(err).String()
+ if errName == "*net.OpError" {
+ // Write and Read errors are not so often and in fact they just mean timeout problems
+ errName = "timeout"
+ known = true
+ }
+ }
+ return errName, known
+}
diff --git a/examples/fileserver/fileserver.go b/examples/fileserver/fileserver.go
index f6fbd4c..61cc457 100644
--- a/examples/fileserver/fileserver.go
+++ b/examples/fileserver/fileserver.go
@@ -63,7 +63,7 @@ func main() {
log.Printf("Starting HTTP server on %q", *addr)
go func() {
if err := fasthttp.ListenAndServe(*addr, requestHandler); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}()
}
@@ -73,7 +73,7 @@ func main() {
log.Printf("Starting HTTPS server on %q", *addrTLS)
go func() {
if err := fasthttp.ListenAndServeTLS(*addrTLS, *certFile, *keyFile, requestHandler); err != nil {
- log.Fatalf("error in ListenAndServeTLS: %s", err)
+ log.Fatalf("error in ListenAndServeTLS: %v", err)
}
}()
}
diff --git a/examples/helloworldserver/helloworldserver.go b/examples/helloworldserver/helloworldserver.go
index 22b518a..a22e0b7 100644
--- a/examples/helloworldserver/helloworldserver.go
+++ b/examples/helloworldserver/helloworldserver.go
@@ -22,7 +22,7 @@ func main() {
}
if err := fasthttp.ListenAndServe(*addr, h); err != nil {
- log.Fatalf("Error in ListenAndServe: %s", err)
+ log.Fatalf("Error in ListenAndServe: %v", err)
}
}
diff --git a/examples/host_client/.gitignore b/examples/host_client/.gitignore
new file mode 100644
index 0000000..097652f
--- /dev/null
+++ b/examples/host_client/.gitignore
@@ -0,0 +1 @@
+hostclient
diff --git a/examples/host_client/Makefile b/examples/host_client/Makefile
new file mode 100644
index 0000000..161ab44
--- /dev/null
+++ b/examples/host_client/Makefile
@@ -0,0 +1,6 @@
+host_client: clean
+ go get -u github.com/valyala/fasthttp
+ go build
+
+clean:
+ rm -f host_client
diff --git a/examples/host_client/README.md b/examples/host_client/README.md
new file mode 100644
index 0000000..e40b397
--- /dev/null
+++ b/examples/host_client/README.md
@@ -0,0 +1,13 @@
+# Host Client Example
+
+The HostClient is useful when calling an API from a single host.
+The example also shows how to use URI.
+You may create the parsed URI once and reuse it in many requests.
+The URI has a username and password for Basic Auth but you may also set other parts i.e. `SetPath()`, `SetQueryString()`.
+
+# How to build and run
+Start a web server on localhost:8080 then execute:
+
+ make
+ ./host_client
+
diff --git a/examples/host_client/hostclient.go b/examples/host_client/hostclient.go
new file mode 100644
index 0000000..997abd4
--- /dev/null
+++ b/examples/host_client/hostclient.go
@@ -0,0 +1,35 @@
+package main
+
+import (
+ "fmt"
+ "os"
+
+ "github.com/valyala/fasthttp"
+)
+
+func main() {
+ // Get URI from a pool
+ url := fasthttp.AcquireURI()
+ url.Parse(nil, []byte("http://localhost:8080/"))
+ url.SetUsername("Aladdin")
+ url.SetPassword("Open Sesame")
+
+ hc := &fasthttp.HostClient{
+ Addr: "localhost:8080", // The host address and port must be set explicitly
+ }
+
+ req := fasthttp.AcquireRequest()
+ req.SetURI(url) // copy url into request
+ fasthttp.ReleaseURI(url) // now you may release the URI
+
+ req.Header.SetMethod(fasthttp.MethodGet)
+ resp := fasthttp.AcquireResponse()
+ err := hc.Do(req, resp)
+ fasthttp.ReleaseRequest(req)
+ if err == nil {
+ fmt.Printf("Response: %s\n", resp.Body())
+ } else {
+ fmt.Fprintf(os.Stderr, "Connection error: %v\n", err)
+ }
+ fasthttp.ReleaseResponse(resp)
+}
diff --git a/expvarhandler/expvar.go b/expvarhandler/expvar.go
index 6254baa..d9e17bf 100644
--- a/expvarhandler/expvar.go
+++ b/expvarhandler/expvar.go
@@ -30,7 +30,7 @@ func ExpvarHandler(ctx *fasthttp.RequestCtx) {
r, err := getExpvarRegexp(ctx)
if err != nil {
expvarRegexpErrors.Add(1)
- fmt.Fprintf(ctx, "Error when obtaining expvar regexp: %s", err)
+ fmt.Fprintf(ctx, "Error when obtaining expvar regexp: %v", err)
ctx.SetStatusCode(fasthttp.StatusBadRequest)
return
}
@@ -58,7 +58,7 @@ func getExpvarRegexp(ctx *fasthttp.RequestCtx) (*regexp.Regexp, error) {
}
rr, err := regexp.Compile(r)
if err != nil {
- return nil, fmt.Errorf("cannot parse r=%q: %s", r, err)
+ return nil, fmt.Errorf("cannot parse r=%q: %w", r, err)
}
return rr, nil
}
diff --git a/expvarhandler/expvar_test.go b/expvarhandler/expvar_test.go
index 550a65e..6f9f286 100644
--- a/expvarhandler/expvar_test.go
+++ b/expvarhandler/expvar_test.go
@@ -26,7 +26,7 @@ func TestExpvarHandlerBasic(t *testing.T) {
var m map[string]interface{}
if err := json.Unmarshal(body, &m); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, ok := m["cmdline"]; !ok {
diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go
index 51b9c7a..dcd43e4 100644
--- a/fasthttpadaptor/adaptor.go
+++ b/fasthttpadaptor/adaptor.go
@@ -3,6 +3,7 @@
package fasthttpadaptor
import (
+ "io"
"net/http"
"github.com/valyala/fasthttp"
@@ -15,11 +16,11 @@ import (
// it has the following drawbacks comparing to using manually written fasthttp
// request handler:
//
-// * A lot of useful functionality provided by fasthttp is missing
-// from net/http handler.
-// * net/http -> fasthttp handler conversion has some overhead,
-// so the returned handler will be always slower than manually written
-// fasthttp handler.
+// - A lot of useful functionality provided by fasthttp is missing
+// from net/http handler.
+// - net/http -> fasthttp handler conversion has some overhead,
+// so the returned handler will be always slower than manually written
+// fasthttp handler.
//
// So it is advisable using this function only for quick net/http -> fasthttp
// switching. Then manually convert net/http handlers to fasthttp handlers
@@ -35,11 +36,11 @@ func NewFastHTTPHandlerFunc(h http.HandlerFunc) fasthttp.RequestHandler {
// it has the following drawbacks comparing to using manually written fasthttp
// request handler:
//
-// * A lot of useful functionality provided by fasthttp is missing
-// from net/http handler.
-// * net/http -> fasthttp handler conversion has some overhead,
-// so the returned handler will be always slower than manually written
-// fasthttp handler.
+// - A lot of useful functionality provided by fasthttp is missing
+// from net/http handler.
+// - net/http -> fasthttp handler conversion has some overhead,
+// so the returned handler will be always slower than manually written
+// fasthttp handler.
//
// So it is advisable using this function only for quick net/http -> fasthttp
// switching. Then manually convert net/http handlers to fasthttp handlers
@@ -48,12 +49,12 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
return func(ctx *fasthttp.RequestCtx) {
var r http.Request
if err := ConvertRequest(ctx, &r, true); err != nil {
- ctx.Logger().Printf("cannot parse requestURI %q: %s", r.RequestURI, err)
+ ctx.Logger().Printf("cannot parse requestURI %q: %v", r.RequestURI, err)
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
- var w netHTTPResponseWriter
+ w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
h.ServeHTTP(&w, r.WithContext(ctx))
ctx.SetStatusCode(w.StatusCode())
@@ -72,19 +73,19 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
// If the Header does not contain a Content-Type line, Write adds a Content-Type set
// to the result of passing the initial 512 bytes of written data to DetectContentType.
l := 512
- if len(w.body) < 512 {
- l = len(w.body)
+ b := ctx.Response.Body()
+ if len(b) < 512 {
+ l = len(b)
}
- ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(w.body[:l]))
+ ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(b[:l]))
}
- ctx.Write(w.body) //nolint:errcheck
}
}
type netHTTPResponseWriter struct {
statusCode int
h http.Header
- body []byte
+ w io.Writer
}
func (w *netHTTPResponseWriter) StatusCode() int {
@@ -106,6 +107,7 @@ func (w *netHTTPResponseWriter) WriteHeader(statusCode int) {
}
func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
- w.body = append(w.body, p...)
- return len(p), nil
+ return w.w.Write(p)
}
+
+func (w *netHTTPResponseWriter) Flush() {}
diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go
index aa50fb6..23e2801 100644
--- a/fasthttpadaptor/adaptor_test.go
+++ b/fasthttpadaptor/adaptor_test.go
@@ -1,8 +1,7 @@
package fasthttpadaptor
import (
- "fmt"
- "io/ioutil"
+ "io"
"net"
"net/http"
"net/url"
@@ -20,7 +19,7 @@ func TestNewFastHTTPHandler(t *testing.T) {
expectedProtoMajor := 1
expectedProtoMinor := 1
expectedRequestURI := "/foo/bar?baz=123"
- expectedBody := "body 123 foo bar baz"
+ expectedBody := "<!doctype html><html>"
expectedContentLength := len(expectedBody)
expectedHost := "foobar.com"
expectedRemoteAddr := "1.2.3.4:6789"
@@ -31,10 +30,11 @@ func TestNewFastHTTPHandler(t *testing.T) {
}
expectedURL, err := url.ParseRequestURI(expectedRequestURI)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
expectedContextKey := "contextKey"
expectedContextValue := "contextValue"
+ expectedContentType := "text/html; charset=utf-8"
callsCount := 0
nethttpH := func(w http.ResponseWriter, r *http.Request) {
@@ -66,10 +66,10 @@ func TestNewFastHTTPHandler(t *testing.T) {
if r.RemoteAddr != expectedRemoteAddr {
t.Fatalf("unexpected remoteAddr %q. Expecting %q", r.RemoteAddr, expectedRemoteAddr)
}
- body, err := ioutil.ReadAll(r.Body)
+ body, err := io.ReadAll(r.Body)
r.Body.Close()
if err != nil {
- t.Fatalf("unexpected error when reading request body: %s", err)
+ t.Fatalf("unexpected error when reading request body: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
@@ -91,7 +91,7 @@ func TestNewFastHTTPHandler(t *testing.T) {
w.Header().Set("Header1", "value1")
w.Header().Set("Header2", "value2")
w.WriteHeader(http.StatusBadRequest)
- fmt.Fprintf(w, "request body is %q", body)
+ w.Write(body) //nolint:errcheck
}
fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH))
fasthttpH = setContextValueMiddleware(fasthttpH, expectedContextKey, expectedContextValue)
@@ -109,7 +109,7 @@ func TestNewFastHTTPHandler(t *testing.T) {
remoteAddr, err := net.ResolveTCPAddr("tcp", expectedRemoteAddr)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
ctx.Init(&req, remoteAddr, nil)
@@ -129,9 +129,11 @@ func TestNewFastHTTPHandler(t *testing.T) {
if string(resp.Header.Peek("Header2")) != "value2" {
t.Fatalf("unexpected header value: %q. Expecting %q", resp.Header.Peek("Header2"), "value2")
}
- expectedResponseBody := fmt.Sprintf("request body is %q", expectedBody)
- if string(resp.Body()) != expectedResponseBody {
- t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody)
+ if string(resp.Body()) != expectedBody {
+ t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedBody)
+ }
+ if string(resp.Header.Peek("Content-Type")) != expectedContentType {
+ t.Fatalf("unexpected response content-type %q. Expecting %q", string(resp.Header.Peek("Content-Type")), expectedBody)
}
}
@@ -141,32 +143,3 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value i
next(ctx)
}
}
-
-func TestContentType(t *testing.T) {
- t.Parallel()
-
- nethttpH := func(w http.ResponseWriter, r *http.Request) {
- w.Write([]byte("<!doctype html><html>")) //nolint:errcheck
- }
- fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH))
-
- var ctx fasthttp.RequestCtx
- var req fasthttp.Request
-
- req.SetRequestURI("http://example.com")
-
- remoteAddr, err := net.ResolveTCPAddr("tcp", "1.2.3.4:80")
- if err != nil {
- t.Fatalf("unexpected error: %s", err)
- }
- ctx.Init(&req, remoteAddr, nil)
-
- fasthttpH(&ctx)
-
- resp := &ctx.Response
- got := string(resp.Header.Peek("Content-Type"))
- expected := "text/html; charset=utf-8"
- if got != expected {
- t.Errorf("expected %q got %q", expected, got)
- }
-}
diff --git a/fasthttpadaptor/request.go b/fasthttpadaptor/request.go
index 7a49bfd..d763a98 100644
--- a/fasthttpadaptor/request.go
+++ b/fasthttpadaptor/request.go
@@ -2,33 +2,41 @@ package fasthttpadaptor
import (
"bytes"
- "io/ioutil"
+ "io"
"net/http"
"net/url"
+ "unsafe"
"github.com/valyala/fasthttp"
)
// ConvertRequest convert a fasthttp.Request to an http.Request
// forServer should be set to true when the http.Request is going to passed to a http.Handler.
+//
+// The http.Request must not be used after the fasthttp handler has returned!
+// Memory in use by the http.Request will be reused after your handler has returned!
func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) error {
body := ctx.PostBody()
- strRequestURI := string(ctx.RequestURI())
+ strRequestURI := b2s(ctx.RequestURI())
rURL, err := url.ParseRequestURI(strRequestURI)
if err != nil {
return err
}
- r.Method = string(ctx.Method())
- r.Proto = "HTTP/1.1"
- r.ProtoMajor = 1
+ r.Method = b2s(ctx.Method())
+ r.Proto = b2s(ctx.Request.Header.Protocol())
+ if r.Proto == "HTTP/2" {
+ r.ProtoMajor = 2
+ } else {
+ r.ProtoMajor = 1
+ }
r.ProtoMinor = 1
r.ContentLength = int64(len(body))
r.RemoteAddr = ctx.RemoteAddr().String()
- r.Host = string(ctx.Host())
+ r.Host = b2s(ctx.Host())
r.TLS = ctx.TLSConnectionState()
- r.Body = ioutil.NopCloser(bytes.NewReader(body))
+ r.Body = io.NopCloser(bytes.NewReader(body))
r.URL = rURL
if forServer {
@@ -44,8 +52,8 @@ func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) e
}
ctx.Request.Header.VisitAll(func(k, v []byte) {
- sk := string(k)
- sv := string(v)
+ sk := b2s(k)
+ sv := b2s(v)
switch sk {
case "Transfer-Encoding":
@@ -57,3 +65,8 @@ func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) e
return nil
}
+
+func b2s(b []byte) string {
+ /* #nosec G103 */
+ return *(*string)(unsafe.Pointer(&b))
+}
diff --git a/fasthttpadaptor/request_test.go b/fasthttpadaptor/request_test.go
new file mode 100644
index 0000000..1f214c2
--- /dev/null
+++ b/fasthttpadaptor/request_test.go
@@ -0,0 +1,29 @@
+package fasthttpadaptor
+
+import (
+ "net/http"
+ "testing"
+
+ "github.com/valyala/fasthttp"
+)
+
+func BenchmarkConvertRequest(b *testing.B) {
+ var httpReq http.Request
+
+ ctx := &fasthttp.RequestCtx{
+ Request: fasthttp.Request{
+ Header: fasthttp.RequestHeader{},
+ UseHostHeader: false,
+ },
+ }
+ ctx.Request.Header.SetMethod("GET")
+ ctx.Request.Header.Set("x", "test")
+ ctx.Request.Header.Set("y", "test")
+ ctx.Request.SetRequestURI("/test")
+ ctx.Request.SetHost("test")
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ _ = ConvertRequest(ctx, &httpReq, true)
+ }
+}
diff --git a/fasthttpproxy/http.go b/fasthttpproxy/http.go
index 6bc9c1e..9dd5a25 100644
--- a/fasthttpproxy/http.go
+++ b/fasthttpproxy/http.go
@@ -15,6 +15,7 @@ import (
// the provided HTTP proxy.
//
// Example usage:
+//
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpHTTPDialer("username:password@localhost:9050"),
// }
@@ -26,15 +27,16 @@ func FasthttpHTTPDialer(proxy string) fasthttp.DialFunc {
// the provided HTTP proxy using the given timeout.
//
// Example usage:
+//
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpHTTPDialerTimeout("username:password@localhost:9050", time.Second * 2),
// }
func FasthttpHTTPDialerTimeout(proxy string, timeout time.Duration) fasthttp.DialFunc {
var auth string
if strings.Contains(proxy, "@") {
- split := strings.Split(proxy, "@")
- auth = base64.StdEncoding.EncodeToString([]byte(split[0]))
- proxy = split[1]
+ index := strings.LastIndex(proxy, "@")
+ auth = base64.StdEncoding.EncodeToString([]byte(proxy[:index]))
+ proxy = proxy[index+1:]
}
return func(addr string) (net.Conn, error) {
@@ -49,7 +51,7 @@ func FasthttpHTTPDialerTimeout(proxy string, timeout time.Duration) fasthttp.Dia
return nil, err
}
- req := "CONNECT " + addr + " HTTP/1.1\r\n"
+ req := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", addr, addr)
if auth != "" {
req += "Proxy-Authorization: Basic " + auth + "\r\n"
}
diff --git a/fasthttpproxy/proxy_env.go b/fasthttpproxy/proxy_env.go
index 2457ad5..3f91440 100644
--- a/fasthttpproxy/proxy_env.go
+++ b/fasthttpproxy/proxy_env.go
@@ -9,9 +9,8 @@ import (
"sync/atomic"
"time"
- "golang.org/x/net/http/httpproxy"
-
"github.com/valyala/fasthttp"
+ "golang.org/x/net/http/httpproxy"
)
const (
@@ -24,6 +23,7 @@ const (
// the the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy.
//
// Example usage:
+//
// c := &fasthttp.Client{
// Dial: FasthttpProxyHTTPDialer(),
// }
@@ -31,10 +31,11 @@ func FasthttpProxyHTTPDialer() fasthttp.DialFunc {
return FasthttpProxyHTTPDialerTimeout(0)
}
-// FasthttpProxyHTTPDialer returns a fasthttp.DialFunc that dials using
+// FasthttpProxyHTTPDialerTimeout returns a fasthttp.DialFunc that dials using
// the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy using the given timeout.
//
// Example usage:
+//
// c := &fasthttp.Client{
// Dial: FasthttpProxyHTTPDialerTimeout(time.Second * 2),
// }
@@ -49,7 +50,7 @@ func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc {
port, _, err := net.SplitHostPort(addr)
if err != nil {
- return nil, fmt.Errorf("unexpected addr format: %v", err)
+ return nil, fmt.Errorf("unexpected addr format: %w", err)
}
reqURL := &url.URL{Host: addr, Scheme: httpScheme}
@@ -108,17 +109,17 @@ func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc {
if err := res.Read(bufio.NewReader(conn)); err != nil {
if connErr := conn.Close(); connErr != nil {
- return nil, fmt.Errorf("conn close err %v followed by read conn err %v", connErr, err)
+ return nil, fmt.Errorf("conn close err %v precede by read conn err %w", connErr, err)
}
return nil, err
}
if res.Header.StatusCode() != 200 {
if connErr := conn.Close(); connErr != nil {
return nil, fmt.Errorf(
- "conn close err %v followed by connect to proxy: code: %d body %s",
+ "conn close err %w precede by connect to proxy: code: %d body %q",
connErr, res.StatusCode(), string(res.Body()))
}
- return nil, fmt.Errorf("could not connect to proxy: code: %d body %s", res.StatusCode(), string(res.Body()))
+ return nil, fmt.Errorf("could not connect to proxy: code: %d body %q", res.StatusCode(), string(res.Body()))
}
return conn, nil
}
diff --git a/fasthttpproxy/socks5.go b/fasthttpproxy/socks5.go
index a334837..a01c204 100644
--- a/fasthttpproxy/socks5.go
+++ b/fasthttpproxy/socks5.go
@@ -12,6 +12,7 @@ import (
// the provided SOCKS5 proxy.
//
// Example usage:
+//
// c := &fasthttp.Client{
// Dial: fasthttpproxy.FasthttpSocksDialer("socks5://localhost:9050"),
// }
diff --git a/fasthttputil/ecdsa.key b/fasthttputil/ecdsa.key
deleted file mode 100644
index 7e201fc..0000000
--- a/fasthttputil/ecdsa.key
+++ /dev/null
@@ -1,5 +0,0 @@
------BEGIN EC PRIVATE KEY-----
-MHcCAQEEIBpQbZ6a5jL1Yh4wdP6yZk4MKjYWArD/QOLENFw8vbELoAoGCCqGSM49
-AwEHoUQDQgAEKQCZWgE2IBhb47ot8MIs1D4KSisHYlZ41IWyeutpjb0fjwwIhimh
-pl1Qld1/d2j3Z3vVyfa5yD+ncV7qCFZuSg==
------END EC PRIVATE KEY-----
diff --git a/fasthttputil/ecdsa.pem b/fasthttputil/ecdsa.pem
deleted file mode 100644
index ca1a7f2..0000000
--- a/fasthttputil/ecdsa.pem
+++ /dev/null
@@ -1,10 +0,0 @@
------BEGIN CERTIFICATE-----
-MIIBbTCCAROgAwIBAgIQPo718S+K+G7hc1SgTEU4QDAKBggqhkjOPQQDAjASMRAw
-DgYDVQQKEwdBY21lIENvMB4XDTE3MDQyMDIxMDExNFoXDTE4MDQyMDIxMDExNFow
-EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABCkA
-mVoBNiAYW+O6LfDCLNQ+CkorB2JWeNSFsnrraY29H48MCIYpoaZdUJXdf3do92d7
-1cn2ucg/p3Fe6ghWbkqjSzBJMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr
-BgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxvY2FsaG9zdDAKBggq
-hkjOPQQDAgNIADBFAiEAoLAIQkvSuIcHUqyWroA6yWYw2fznlRH/uO9/hMCxUCEC
-IClRYb/5O9eD/Eq/ozPnwNpsQHOeYefEhadJ/P82y0lG
------END CERTIFICATE-----
diff --git a/fasthttputil/inmemory_listener.go b/fasthttputil/inmemory_listener.go
index 87f8b62..b0ad189 100644
--- a/fasthttputil/inmemory_listener.go
+++ b/fasthttputil/inmemory_listener.go
@@ -14,9 +14,11 @@ var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed:
// It may be used either for fast in-process client<->server communications
// without network stack overhead or for client<->server tests.
type InmemoryListener struct {
- lock sync.Mutex
- closed bool
- conns chan acceptConn
+ lock sync.Mutex
+ closed bool
+ conns chan acceptConn
+ listenerAddr net.Addr
+ addrLock sync.RWMutex
}
type acceptConn struct {
@@ -31,6 +33,14 @@ func NewInmemoryListener() *InmemoryListener {
}
}
+// SetLocalAddr sets the (simulated) local address for the listener.
+func (ln *InmemoryListener) SetLocalAddr(localAddr net.Addr) {
+ ln.addrLock.Lock()
+ defer ln.addrLock.Unlock()
+
+ ln.listenerAddr = localAddr
+}
+
// Accept implements net.Listener's Accept.
//
// It is safe calling Accept from concurrently running goroutines.
@@ -60,12 +70,26 @@ func (ln *InmemoryListener) Close() error {
return err
}
+type inmemoryAddr int
+
+func (inmemoryAddr) Network() string {
+ return "inmemory"
+}
+
+func (inmemoryAddr) String() string {
+ return "InmemoryListener"
+}
+
// Addr implements net.Listener's Addr.
func (ln *InmemoryListener) Addr() net.Addr {
- return &net.UnixAddr{
- Name: "InmemoryListener",
- Net: "memory",
+ ln.addrLock.RLock()
+ defer ln.addrLock.RUnlock()
+
+ if ln.listenerAddr != nil {
+ return ln.listenerAddr
}
+
+ return inmemoryAddr(0)
}
// Dial creates new client<->server connection.
@@ -74,7 +98,20 @@ func (ln *InmemoryListener) Addr() net.Addr {
//
// It is safe calling Dial from concurrently running goroutines.
func (ln *InmemoryListener) Dial() (net.Conn, error) {
+ return ln.DialWithLocalAddr(nil)
+}
+
+// DialWithLocalAddr creates new client<->server connection.
+// Just like a real Dial it only returns once the server
+// has accepted the connection. The local address of the
+// client connection can be set with local.
+//
+// It is safe calling Dial from concurrently running goroutines.
+func (ln *InmemoryListener) DialWithLocalAddr(local net.Addr) (net.Conn, error) {
pc := NewPipeConns()
+
+ pc.SetAddresses(local, ln.Addr(), ln.Addr(), local)
+
cConn := pc.Conn1()
sConn := pc.Conn2()
ln.lock.Lock()
diff --git a/fasthttputil/inmemory_listener_test.go b/fasthttputil/inmemory_listener_test.go
index cdd7763..e9d5125 100644
--- a/fasthttputil/inmemory_listener_test.go
+++ b/fasthttputil/inmemory_listener_test.go
@@ -5,7 +5,6 @@ import (
"context"
"fmt"
"io"
- "io/ioutil"
"net"
"net/http"
"sync"
@@ -23,13 +22,13 @@ func TestInmemoryListener(t *testing.T) {
go func(n int) {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
defer conn.Close()
req := fmt.Sprintf("request_%d", n)
nn, err := conn.Write([]byte(req))
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if nn != len(req) {
t.Errorf("unexpected number of bytes written: %d. Expecting %d", nn, len(req))
@@ -37,7 +36,7 @@ func TestInmemoryListener(t *testing.T) {
buf := make([]byte, 30)
nn, err = conn.Read(buf)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
buf = buf[:nn]
resp := fmt.Sprintf("response_%d", n)
@@ -63,7 +62,7 @@ func TestInmemoryListener(t *testing.T) {
buf := make([]byte, 30)
n, err := conn.Read(buf)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
buf = buf[:n]
if !bytes.HasPrefix(buf, []byte("request_")) {
@@ -72,7 +71,7 @@ func TestInmemoryListener(t *testing.T) {
resp := fmt.Sprintf("response_%s", buf[len("request_"):])
n, err = conn.Write([]byte(resp))
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if n != len(resp) {
t.Errorf("unexpected number of bytes written: %d. Expecting %d", n, len(resp))
@@ -89,7 +88,7 @@ func TestInmemoryListener(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -108,7 +107,7 @@ func (s *echoServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
time.Sleep(time.Millisecond * 100)
if _, err := io.Copy(w, r.Body); err != nil {
- s.t.Fatalf("unexpected error: %s", err)
+ s.t.Fatalf("unexpected error: %v", err)
}
}
@@ -131,7 +130,7 @@ func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Cl
go func() {
if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
@@ -145,15 +144,15 @@ func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Cl
func testInmemoryListenerHTTPSingle(t *testing.T, client *http.Client, content string) {
res, err := client.Post("http://...", "text/plain", bytes.NewBufferString(content))
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- b, err := ioutil.ReadAll(res.Body)
+ b, err := io.ReadAll(res.Body)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
s := string(b)
if string(b) != content {
- t.Fatalf("unexpected response %s, expecting %s", s, content)
+ t.Fatalf("unexpected response %q, expecting %q", s, content)
}
}
@@ -190,3 +189,96 @@ func TestInmemoryListenerHTTPConcurrent(t *testing.T) {
wg.Wait()
})
}
+
+func acceptLoop(ln net.Listener) {
+ for {
+ conn, err := ln.Accept()
+ if err != nil {
+ panic(err)
+ }
+
+ conn.Close()
+ }
+}
+
+func TestInmemoryListenerAddrDefault(t *testing.T) {
+ t.Parallel()
+
+ ln := NewInmemoryListener()
+
+ verifyAddr(t, ln.Addr(), inmemoryAddr(0))
+
+ go func() {
+ c, err := ln.Dial()
+ if err != nil {
+ panic(err)
+ }
+
+ c.Close()
+ }()
+
+ lc, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ verifyAddr(t, lc.LocalAddr(), inmemoryAddr(0))
+ verifyAddr(t, lc.RemoteAddr(), pipeAddr(0))
+
+ go acceptLoop(ln)
+
+ c, err := ln.Dial()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ verifyAddr(t, c.LocalAddr(), pipeAddr(0))
+ verifyAddr(t, c.RemoteAddr(), inmemoryAddr(0))
+}
+
+func verifyAddr(t *testing.T, got, expected net.Addr) {
+ if got != expected {
+ t.Fatalf("unexpected addr: %v. Expecting %v", got, expected)
+ }
+}
+
+func TestInmemoryListenerAddrCustom(t *testing.T) {
+ t.Parallel()
+
+ ln := NewInmemoryListener()
+
+ listenerAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 12345}
+
+ ln.SetLocalAddr(listenerAddr)
+
+ verifyAddr(t, ln.Addr(), listenerAddr)
+
+ go func() {
+ c, err := ln.Dial()
+ if err != nil {
+ panic(err)
+ }
+
+ c.Close()
+ }()
+
+ lc, err := ln.Accept()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ verifyAddr(t, lc.LocalAddr(), listenerAddr)
+ verifyAddr(t, lc.RemoteAddr(), pipeAddr(0))
+
+ go acceptLoop(ln)
+
+ clientAddr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 2), Port: 65432}
+
+ c, err := ln.DialWithLocalAddr(clientAddr)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ verifyAddr(t, c.LocalAddr(), clientAddr)
+ verifyAddr(t, c.RemoteAddr(), listenerAddr)
+}
diff --git a/fasthttputil/inmemory_listener_timing_test.go b/fasthttputil/inmemory_listener_timing_test.go
index af0ef62..7aa5ac7 100644
--- a/fasthttputil/inmemory_listener_timing_test.go
+++ b/fasthttputil/inmemory_listener_timing_test.go
@@ -33,107 +33,23 @@ func BenchmarkTLSStreaming(b *testing.B) {
benchmark(b, streamingHandler, true)
}
-// BenchmarkTLSHandshake measures end-to-end TLS handshake performance
-// for fasthttp client and server.
-//
-// It re-establishes new TLS connection per each http request.
-func BenchmarkTLSHandshakeRSAWithClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: false,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
-func BenchmarkTLSHandshakeRSAWithoutClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: true,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
-func BenchmarkTLSHandshakeECDSAWithClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: false,
- UseECDSA: true,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
-func BenchmarkTLSHandshakeECDSAWithoutClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: true,
- UseECDSA: true,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
-func BenchmarkTLSHandshakeECDSAWithCurvesWithClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: false,
- UseCurves: true,
- UseECDSA: true,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
-func BenchmarkTLSHandshakeECDSAWithCurvesWithoutClientSessionCache(b *testing.B) {
- bc := &benchConfig{
- IsTLS: true,
- DisableClientSessionCache: true,
- UseCurves: true,
- UseECDSA: true,
- }
- benchmarkExt(b, handshakeHandler, bc)
-}
-
func benchmark(b *testing.B, h fasthttp.RequestHandler, isTLS bool) {
- bc := &benchConfig{
- IsTLS: isTLS,
- }
- benchmarkExt(b, h, bc)
-}
-
-type benchConfig struct {
- IsTLS bool
- DisableClientSessionCache bool
- UseCurves bool
- UseECDSA bool
-}
-
-func benchmarkExt(b *testing.B, h fasthttp.RequestHandler, bc *benchConfig) {
var serverTLSConfig, clientTLSConfig *tls.Config
- if bc.IsTLS {
+ if isTLS {
certFile := "rsa.pem"
keyFile := "rsa.key"
- if bc.UseECDSA {
- certFile = "ecdsa.pem"
- keyFile = "ecdsa.key"
- }
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
- b.Fatalf("cannot load TLS certificate from certFile=%q, keyFile=%q: %s", certFile, keyFile, err)
+ b.Fatalf("cannot load TLS certificate from certFile=%q, keyFile=%q: %v", certFile, keyFile, err)
}
serverTLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
PreferServerCipherSuites: true,
}
serverTLSConfig.CurvePreferences = []tls.CurveID{}
- if bc.UseCurves {
- serverTLSConfig.CurvePreferences = []tls.CurveID{
- tls.CurveP256,
- }
- }
clientTLSConfig = &tls.Config{
InsecureSkipVerify: true,
}
- if bc.DisableClientSessionCache {
- clientTLSConfig.ClientSessionCache = fakeSessionCache{}
- }
}
ln := fasthttputil.NewInmemoryListener()
serverStopCh := make(chan struct{})
@@ -143,7 +59,7 @@ func benchmarkExt(b *testing.B, h fasthttp.RequestHandler, bc *benchConfig) {
serverLn = tls.NewListener(serverLn, serverTLSConfig)
}
if err := fasthttp.Serve(serverLn, h); err != nil {
- b.Errorf("unexpected error in server: %s", err)
+ b.Errorf("unexpected error in server: %v", err)
}
close(serverStopCh)
}()
@@ -151,12 +67,12 @@ func benchmarkExt(b *testing.B, h fasthttp.RequestHandler, bc *benchConfig) {
Dial: func(addr string) (net.Conn, error) {
return ln.Dial()
},
- IsTLS: clientTLSConfig != nil,
+ IsTLS: isTLS,
TLSConfig: clientTLSConfig,
}
b.RunParallel(func(pb *testing.PB) {
- runRequests(b, pb, c)
+ runRequests(b, pb, c, isTLS)
})
ln.Close()
<-serverStopCh
@@ -173,26 +89,20 @@ func handshakeHandler(ctx *fasthttp.RequestCtx) {
ctx.SetConnectionClose()
}
-func runRequests(b *testing.B, pb *testing.PB, c *fasthttp.HostClient) {
+func runRequests(b *testing.B, pb *testing.PB, c *fasthttp.HostClient, isTLS bool) {
var req fasthttp.Request
- req.SetRequestURI("http://foo.bar/baz")
+ if isTLS {
+ req.SetRequestURI("https://foo.bar/baz")
+ } else {
+ req.SetRequestURI("http://foo.bar/baz")
+ }
var resp fasthttp.Response
for pb.Next() {
if err := c.Do(&req, &resp); err != nil {
- b.Fatalf("unexpected error: %s", err)
+ b.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != fasthttp.StatusOK {
b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK)
}
}
}
-
-type fakeSessionCache struct{}
-
-func (fakeSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
- return nil, false
-}
-
-func (fakeSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
- // no-op
-}
diff --git a/fasthttputil/pipeconns.go b/fasthttputil/pipeconns.go
index c992da3..8a338e1 100644
--- a/fasthttputil/pipeconns.go
+++ b/fasthttputil/pipeconns.go
@@ -35,10 +35,10 @@ func NewPipeConns() *PipeConns {
// PipeConns has the following additional features comparing to connections
// returned from net.Pipe():
//
-// * It is faster.
-// * It buffers Write calls, so there is no need to have concurrent goroutine
+// - It is faster.
+// - It buffers Write calls, so there is no need to have concurrent goroutine
// calling Read in order to unblock each Write call.
-// * It supports read and write deadlines.
+// - It supports read and write deadlines.
//
// PipeConns is NOT safe for concurrent use by multiple goroutines!
type PipeConns struct {
@@ -48,6 +48,21 @@ type PipeConns struct {
stopChLock sync.Mutex
}
+// SetAddresses sets the local and remote addresses for the connection.
+func (pc *PipeConns) SetAddresses(localAddr1, remoteAddr1, localAddr2, remoteAddr2 net.Addr) {
+ pc.c1.addrLock.Lock()
+ defer pc.c1.addrLock.Unlock()
+
+ pc.c2.addrLock.Lock()
+ defer pc.c2.addrLock.Unlock()
+
+ pc.c1.localAddr = localAddr1
+ pc.c1.remoteAddr = remoteAddr1
+
+ pc.c2.localAddr = localAddr2
+ pc.c2.remoteAddr = remoteAddr2
+}
+
// Conn1 returns the first end of bi-directional pipe.
//
// Data written to Conn1 may be read from Conn2.
@@ -92,6 +107,10 @@ type pipeConn struct {
writeDeadlineCh <-chan time.Time
readDeadlineChLock sync.Mutex
+
+ localAddr net.Addr
+ remoteAddr net.Addr
+ addrLock sync.RWMutex
}
func (c *pipeConn) Write(p []byte) (int, error) {
@@ -209,7 +228,7 @@ func (e *timeoutError) Error() string {
// Only implement the Timeout() function of the net.Error interface.
// This allows for checks like:
//
-// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
func (e *timeoutError) Timeout() bool {
return true
}
@@ -224,10 +243,24 @@ func (c *pipeConn) Close() error {
}
func (c *pipeConn) LocalAddr() net.Addr {
+ c.addrLock.RLock()
+ defer c.addrLock.RUnlock()
+
+ if c.localAddr != nil {
+ return c.localAddr
+ }
+
return pipeAddr(0)
}
func (c *pipeConn) RemoteAddr() net.Addr {
+ c.addrLock.RLock()
+ defer c.addrLock.RUnlock()
+
+ if c.remoteAddr != nil {
+ return c.remoteAddr
+ }
+
return pipeAddr(0)
}
@@ -266,7 +299,7 @@ func updateTimer(t *time.Timer, deadline time.Time) <-chan time.Time {
if deadline.IsZero() {
return nil
}
- d := -time.Since(deadline)
+ d := time.Until(deadline)
if d <= 0 {
return closedDeadlineCh
}
diff --git a/fasthttputil/pipeconns_test.go b/fasthttputil/pipeconns_test.go
index d61e5e5..9ac7ee1 100644
--- a/fasthttputil/pipeconns_test.go
+++ b/fasthttputil/pipeconns_test.go
@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
- "io/ioutil"
"net"
"testing"
"time"
@@ -18,7 +17,7 @@ func TestPipeConnsWriteTimeout(t *testing.T) {
deadline := time.Now().Add(time.Millisecond)
if err := c1.SetWriteDeadline(deadline); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
data := []byte("foobar")
@@ -28,7 +27,7 @@ func TestPipeConnsWriteTimeout(t *testing.T) {
if err == ErrTimeout {
break
}
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -38,14 +37,14 @@ func TestPipeConnsWriteTimeout(t *testing.T) {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
- t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
// read the written data
c2 := pc.Conn2()
if err := c2.SetReadDeadline(time.Now().Add(10 * time.Millisecond)); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
for {
_, err := c2.Read(data)
@@ -53,7 +52,7 @@ func TestPipeConnsWriteTimeout(t *testing.T) {
if err == ErrTimeout {
break
}
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -63,7 +62,7 @@ func TestPipeConnsWriteTimeout(t *testing.T) {
t.Fatalf("expecting error")
}
if err != ErrTimeout {
- t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, ErrTimeout)
}
}
}
@@ -88,7 +87,7 @@ func testPipeConnsReadTimeout(t *testing.T, timeout time.Duration) {
deadline := time.Now().Add(timeout)
if err := c1.SetReadDeadline(deadline); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var buf [1]byte
@@ -98,23 +97,23 @@ func testPipeConnsReadTimeout(t *testing.T, timeout time.Duration) {
t.Fatalf("expecting error on iteration %d", i)
}
if err != ErrTimeout {
- t.Fatalf("unexpected error on iteration %d: %s. Expecting %s", i, err, ErrTimeout)
+ t.Fatalf("unexpected error on iteration %d: %v. Expecting %v", i, err, ErrTimeout)
}
}
// disable deadline and send data from c2 to c1
if err := c1.SetReadDeadline(zeroTime); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
data := []byte("foobar")
c2 := pc.Conn2()
if _, err := c2.Write(data); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
dataBuf := make([]byte, len(data))
if _, err := io.ReadFull(c1, dataBuf); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(data, dataBuf) {
t.Fatalf("unexpected data received: %q. Expecting %q", dataBuf, data)
@@ -162,9 +161,9 @@ func testPipeConnsCloseWhileReadWrite(t *testing.T) {
readCh := make(chan error)
go func() {
var err error
- if _, err = io.Copy(ioutil.Discard, c1); err != nil {
+ if _, err = io.Copy(io.Discard, c1); err != nil {
if err != errConnectionClosed {
- err = fmt.Errorf("unexpected error: %s", err)
+ err = fmt.Errorf("unexpected error: %w", err)
} else {
err = nil
}
@@ -178,7 +177,7 @@ func testPipeConnsCloseWhileReadWrite(t *testing.T) {
for {
if _, err = c2.Write([]byte("foobar")); err != nil {
if err != errConnectionClosed {
- err = fmt.Errorf("unexpected error: %s", err)
+ err = fmt.Errorf("unexpected error: %w", err)
} else {
err = nil
}
@@ -190,16 +189,16 @@ func testPipeConnsCloseWhileReadWrite(t *testing.T) {
time.Sleep(10 * time.Millisecond)
if err := c1.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := c2.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case err := <-readCh:
if err != nil {
- t.Fatalf("unexpected error in reader: %s", err)
+ t.Fatalf("unexpected error in reader: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
@@ -207,7 +206,7 @@ func testPipeConnsCloseWhileReadWrite(t *testing.T) {
select {
case err := <-writeCh:
if err != nil {
- t.Fatalf("unexpected error in writer: %s", err)
+ t.Fatalf("unexpected error in writer: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
@@ -244,7 +243,7 @@ func testPipeConnsReadWrite(t *testing.T, c1, c2 net.Conn) {
s1 := fmt.Sprintf("foo_%d", i)
n, err := c1.Write([]byte(s1))
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != len(s1) {
t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s1))
@@ -254,7 +253,7 @@ func testPipeConnsReadWrite(t *testing.T, c1, c2 net.Conn) {
s2 := fmt.Sprintf("bar_%d", i)
n, err = c1.Write([]byte(s2))
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != len(s2) {
t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(s2))
@@ -264,7 +263,7 @@ func testPipeConnsReadWrite(t *testing.T, c1, c2 net.Conn) {
s := s1 + s2
n, err = c2.Read(buf[:])
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != len(s) {
t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(s))
@@ -297,7 +296,7 @@ func testPipeConnsCloseSerial(t *testing.T) {
func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) {
if err := c1.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var buf [10]byte
@@ -319,7 +318,7 @@ func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
if n != 0 {
t.Fatalf("unexpected number of bytes read: %d. Expecting 0", n)
@@ -327,16 +326,16 @@ func testPipeConnsClose(t *testing.T, c1, c2 net.Conn) {
}
if err := c2.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// attempt closing already closed conns
for i := 0; i < 10; i++ {
if err := c1.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := c2.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
}
@@ -358,3 +357,51 @@ func testConcurrency(t *testing.T, concurrency int, f func(*testing.T)) {
}
}
}
+
+func TestPipeConnsAddrDefault(t *testing.T) {
+ t.Parallel()
+
+ pc := NewPipeConns()
+ c1 := pc.Conn1()
+
+ if c1.LocalAddr() != pipeAddr(0) {
+ t.Fatalf("unexpected local address: %v", c1.LocalAddr())
+ }
+
+ if c1.RemoteAddr() != pipeAddr(0) {
+ t.Fatalf("unexpected remote address: %v", c1.RemoteAddr())
+ }
+}
+
+func TestPipeConnsAddrCustom(t *testing.T) {
+ t.Parallel()
+
+ pc := NewPipeConns()
+
+ addr1 := &net.TCPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}
+ addr2 := &net.TCPAddr{IP: net.IPv4(5, 6, 7, 8), Port: 5678}
+ addr3 := &net.TCPAddr{IP: net.IPv4(9, 10, 11, 12), Port: 9012}
+ addr4 := &net.TCPAddr{IP: net.IPv4(13, 14, 15, 16), Port: 3456}
+
+ pc.SetAddresses(addr1, addr2, addr3, addr4)
+
+ c1 := pc.Conn1()
+
+ if c1.LocalAddr() != addr1 {
+ t.Fatalf("unexpected local address: %v", c1.LocalAddr())
+ }
+
+ if c1.RemoteAddr() != addr2 {
+ t.Fatalf("unexpected remote address: %v", c1.RemoteAddr())
+ }
+
+ c2 := pc.Conn1()
+
+ if c2.LocalAddr() != addr1 {
+ t.Fatalf("unexpected local address: %v", c2.LocalAddr())
+ }
+
+ if c2.RemoteAddr() != addr2 {
+ t.Fatalf("unexpected remote address: %v", c2.RemoteAddr())
+ }
+}
diff --git a/fs.go b/fs.go
index f8d4add..a8bcd13 100644
--- a/fs.go
+++ b/fs.go
@@ -6,7 +6,6 @@ import (
"fmt"
"html"
"io"
- "io/ioutil"
"mime"
"net/http"
"os"
@@ -30,6 +29,10 @@ import (
// with good compression ratio.
//
// See also RequestCtx.SendFileBytes.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileBytesUncompressed(ctx *RequestCtx, path []byte) {
ServeFileUncompressed(ctx, b2s(path))
}
@@ -43,6 +46,10 @@ func ServeFileBytesUncompressed(ctx *RequestCtx, path []byte) {
// with good compression ratio.
//
// See also RequestCtx.SendFile.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileUncompressed(ctx *RequestCtx, path string) {
ctx.Request.Header.DelBytes(strAcceptEncoding)
ServeFile(ctx, path)
@@ -53,8 +60,8 @@ func ServeFileUncompressed(ctx *RequestCtx, path string) {
//
// HTTP response may contain uncompressed file contents in the following cases:
//
-// * Missing 'Accept-Encoding: gzip' request header.
-// * No write access to directory containing the file.
+// - Missing 'Accept-Encoding: gzip' request header.
+// - No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
@@ -62,6 +69,10 @@ func ServeFileUncompressed(ctx *RequestCtx, path string) {
// file contents.
//
// See also RequestCtx.SendFileBytes.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFileBytes(ctx *RequestCtx, path []byte) {
ServeFile(ctx, b2s(path))
}
@@ -71,24 +82,31 @@ func ServeFileBytes(ctx *RequestCtx, path []byte) {
//
// HTTP response may contain uncompressed file contents in the following cases:
//
-// * Missing 'Accept-Encoding: gzip' request header.
-// * No write access to directory containing the file.
+// - Missing 'Accept-Encoding: gzip' request header.
+// - No write access to directory containing the file.
//
// Directory contents is returned if path points to directory.
//
// Use ServeFileUncompressed is you don't need serving compressed file contents.
//
// See also RequestCtx.SendFile.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func ServeFile(ctx *RequestCtx, path string) {
rootFSOnce.Do(func() {
rootFSHandler = rootFS.NewRequestHandler()
})
- if len(path) == 0 || path[0] != '/' {
+
+ if len(path) == 0 || !filepath.IsAbs(path) {
// extend relative path to absolute path
- hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/'
+ hasTrailingSlash := len(path) > 0 && (path[len(path)-1] == '/' || path[len(path)-1] == '\\')
+
var err error
+ path = filepath.FromSlash(path)
if path, err = filepath.Abs(path); err != nil {
- ctx.Logger().Printf("cannot resolve path %q to absolute file path: %s", path, err)
+ ctx.Logger().Printf("cannot resolve path %q to absolute file path: %v", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
@@ -96,6 +114,11 @@ func ServeFile(ctx *RequestCtx, path string) {
path += "/"
}
}
+
+ // convert the path to forward slashes regardless the OS in order to set the URI properly
+ // the handler will convert back to OS path separator before opening the file
+ path = filepath.ToSlash(path)
+
ctx.Request.SetRequestURI(path)
rootFSHandler(ctx)
}
@@ -103,7 +126,8 @@ func ServeFile(ctx *RequestCtx, path string) {
var (
rootFSOnce sync.Once
rootFS = &FS{
- Root: "/",
+ Root: "",
+ AllowEmptyRoot: true,
GenerateIndexPages: true,
Compress: true,
CompressBrotli: true,
@@ -130,12 +154,11 @@ type PathRewriteFunc func(ctx *RequestCtx) []byte
//
// Examples:
//
-// * host=foobar.com, slashesCount=0, original path="/foo/bar".
+// - host=foobar.com, slashesCount=0, original path="/foo/bar".
// Resulting path: "/foobar.com/foo/bar"
//
-// * host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg"
+// - host=img.aaa.com, slashesCount=1, original path="/images/123/456.jpg"
// Resulting path: "/img.aaa.com/123/456.jpg"
-//
func NewVHostPathRewriter(slashesCount int) PathRewriteFunc {
return func(ctx *RequestCtx) []byte {
path := stripLeadingSlashes(ctx.Path(), slashesCount)
@@ -164,9 +187,9 @@ var strInvalidHost = []byte("invalid-host")
//
// Examples:
//
-// * slashesCount = 0, original path: "/foo/bar", result: "/foo/bar"
-// * slashesCount = 1, original path: "/foo/bar", result: "/bar"
-// * slashesCount = 2, original path: "/foo/bar", result: ""
+// - slashesCount = 0, original path: "/foo/bar", result: "/foo/bar"
+// - slashesCount = 1, original path: "/foo/bar", result: "/bar"
+// - slashesCount = 2, original path: "/foo/bar", result: ""
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathSlashesStripper(slashesCount int) PathRewriteFunc {
@@ -180,9 +203,9 @@ func NewPathSlashesStripper(slashesCount int) PathRewriteFunc {
//
// Examples:
//
-// * prefixSize = 0, original path: "/foo/bar", result: "/foo/bar"
-// * prefixSize = 3, original path: "/foo/bar", result: "o/bar"
-// * prefixSize = 7, original path: "/foo/bar", result: "r"
+// - prefixSize = 0, original path: "/foo/bar", result: "/foo/bar"
+// - prefixSize = 3, original path: "/foo/bar", result: "o/bar"
+// - prefixSize = 7, original path: "/foo/bar", result: "r"
//
// The returned path rewriter may be used as FS.PathRewrite .
func NewPathPrefixStripper(prefixSize int) PathRewriteFunc {
@@ -205,6 +228,12 @@ type FS struct {
// Path to the root directory to serve files from.
Root string
+ // AllowEmptyRoot controls what happens when Root is empty. When false (default) it will default to the
+ // current working directory. An empty root is mostly useful when you want to use absolute paths
+ // on windows that are on different filesystems. On linux setting your Root to "/" already allows you to use
+ // absolute paths on any filesystem.
+ AllowEmptyRoot bool
+
// List of index file names to try opening during directory access.
//
// For example:
@@ -245,6 +274,10 @@ type FS struct {
// Brotli encoding is disabled by default.
CompressBrotli bool
+ // Path to the compressed root directory to serve files from. If this value
+ // is empty, Root is used.
+ CompressRoot string
+
// Enables byte range requests if set to true.
//
// Byte range requests are disabled by default.
@@ -315,9 +348,9 @@ const FSHandlerCacheDuration = 10 * time.Second
// from requested path before searching requested file in the root folder.
// Examples:
//
-// * stripSlashes = 0, original path: "/foo/bar", result: "/foo/bar"
-// * stripSlashes = 1, original path: "/foo/bar", result: "/bar"
-// * stripSlashes = 2, original path: "/foo/bar", result: ""
+// - stripSlashes = 0, original path: "/foo/bar", result: "/foo/bar"
+// - stripSlashes = 1, original path: "/foo/bar", result: "/bar"
+// - stripSlashes = 2, original path: "/foo/bar", result: ""
//
// The returned request handler automatically generates index pages
// for directories without index.html.
@@ -357,18 +390,34 @@ func (fs *FS) NewRequestHandler() RequestHandler {
return fs.h
}
-func (fs *FS) initRequestHandler() {
- root := fs.Root
-
- // serve files from the current working directory if root is empty
- if len(root) == 0 {
- root = "."
+func (fs *FS) normalizeRoot(root string) string {
+ // Serve files from the current working directory if Root is empty or if Root is a relative path.
+ if (!fs.AllowEmptyRoot && len(root) == 0) || (len(root) > 0 && !filepath.IsAbs(root)) {
+ path, err := os.Getwd()
+ if err != nil {
+ path = "."
+ }
+ root = path + "/" + root
}
+ // convert the root directory slashes to the native format
+ root = filepath.FromSlash(root)
// strip trailing slashes from the root path
- for len(root) > 0 && root[len(root)-1] == '/' {
+ for len(root) > 0 && root[len(root)-1] == os.PathSeparator {
root = root[:len(root)-1]
}
+ return root
+}
+
+func (fs *FS) initRequestHandler() {
+ root := fs.normalizeRoot(fs.Root)
+
+ compressRoot := fs.CompressRoot
+ if len(compressRoot) == 0 {
+ compressRoot = root
+ } else {
+ compressRoot = fs.normalizeRoot(compressRoot)
+ }
cacheDuration := fs.CacheDuration
if cacheDuration <= 0 {
@@ -393,6 +442,7 @@ func (fs *FS) initRequestHandler() {
generateIndexPages: fs.GenerateIndexPages,
compress: fs.Compress,
compressBrotli: fs.CompressBrotli,
+ compressRoot: compressRoot,
pathNotFound: fs.PathNotFound,
acceptByteRange: fs.AcceptByteRange,
cacheDuration: cacheDuration,
@@ -441,6 +491,7 @@ type fsHandler struct {
generateIndexPages bool
compress bool
compressBrotli bool
+ compressRoot string
acceptByteRange bool
cacheDuration time.Duration
compressedFileSuffixes map[string]string
@@ -479,10 +530,10 @@ func (ff *fsFile) NewReader() (io.Reader, error) {
}
return r, err
}
- return ff.smallFileReader(), nil
+ return ff.smallFileReader()
}
-func (ff *fsFile) smallFileReader() io.Reader {
+func (ff *fsFile) smallFileReader() (io.Reader, error) {
v := ff.h.smallFileReaderPool.Get()
if v == nil {
v = &fsSmallFileReader{}
@@ -491,9 +542,9 @@ func (ff *fsFile) smallFileReader() io.Reader {
r.ff = ff
r.endPos = ff.contentLength
if r.startPos > 0 {
- panic("BUG: fsSmallFileReader with non-nil startPos found in the pool")
+ return nil, errors.New("bug: fsSmallFileReader with non-nil startPos found in the pool")
}
- return r
+ return r, nil
}
// files bigger than this size are sent with sendfile
@@ -505,7 +556,7 @@ func (ff *fsFile) isBig() bool {
func (ff *fsFile) bigFileReader() (io.Reader, error) {
if ff.f == nil {
- panic("BUG: ff.f must be non-nil in bigFileReader")
+ return nil, errors.New("bug: ff.f must be non-nil in bigFileReader")
}
var r io.Reader
@@ -524,7 +575,7 @@ func (ff *fsFile) bigFileReader() (io.Reader, error) {
f, err := os.Open(ff.f.Name())
if err != nil {
- return nil, fmt.Errorf("cannot open already opened file: %s", err)
+ return nil, fmt.Errorf("cannot open already opened file: %w", err)
}
return &bigFileReader{
f: f,
@@ -535,12 +586,12 @@ func (ff *fsFile) bigFileReader() (io.Reader, error) {
func (ff *fsFile) Release() {
if ff.f != nil {
- ff.f.Close()
+ _ = ff.f.Close()
if ff.isBig() {
ff.bigFilesLock.Lock()
for _, r := range ff.bigFiles {
- r.f.Close()
+ _ = r.f.Close()
}
ff.bigFilesLock.Unlock()
}
@@ -581,7 +632,7 @@ func (r *bigFileReader) Read(p []byte) (int, error) {
func (r *bigFileReader) WriteTo(w io.Writer) (int64, error) {
if rf, ok := w.(io.ReaderFrom); ok {
- // fast path. Senfile must be triggered
+ // fast path. Send file must be triggered
return rf.ReadFrom(r.r)
}
@@ -593,16 +644,17 @@ func (r *bigFileReader) Close() error {
r.r = r.f
n, err := r.f.Seek(0, 0)
if err == nil {
- if n != 0 {
- panic("BUG: File.Seek(0,0) returned (non-zero, nil)")
+ if n == 0 {
+ ff := r.ff
+ ff.bigFilesLock.Lock()
+ ff.bigFiles = append(ff.bigFiles, r)
+ ff.bigFilesLock.Unlock()
+ } else {
+ _ = r.f.Close()
+ err = errors.New("bug: File.Seek(0,0) returned (non-zero, nil)")
}
-
- ff := r.ff
- ff.bigFilesLock.Lock()
- ff.bigFiles = append(ff.bigFiles, r)
- ff.bigFilesLock.Unlock()
} else {
- r.f.Close()
+ _ = r.f.Close()
}
r.ff.decReadersCount()
return err
@@ -680,7 +732,7 @@ func (r *fsSmallFileReader) WriteTo(w io.Writer) (int64, error) {
nw, errw := w.Write(buf[:n])
curPos += nw
if errw == nil && nw != n {
- panic("BUG: Write(p) returned (n, nil), where n != len(p)")
+ errw = errors.New("bug: Write(p) returned (n, nil), where n != len(p)")
}
if err == nil {
err = errw
@@ -742,6 +794,20 @@ func cleanCacheNolock(cache map[string]*fsFile, pendingFiles, filesToRelease []*
return pendingFiles, filesToRelease
}
+func (h *fsHandler) pathToFilePath(path string) string {
+ return filepath.FromSlash(h.root + path)
+}
+
+func (h *fsHandler) filePathToCompressed(filePath string) string {
+ if h.root == h.compressRoot {
+ return filePath
+ }
+ if !strings.HasPrefix(filePath, h.root) {
+ return filePath
+ }
+ return filepath.FromSlash(h.compressRoot + filePath[len(h.root):])
+}
+
func (h *fsHandler) handleRequest(ctx *RequestCtx) {
var path []byte
if h.pathRewrite != nil {
@@ -793,7 +859,8 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
if !ok {
pathStr := string(path)
- filePath := h.root + pathStr
+ filePath := h.pathToFilePath(pathStr)
+
var err error
ff, err = h.openFSFile(filePath, mustCompress, fileEncoding)
if mustCompress && err == errNoCreatePermission {
@@ -809,12 +876,12 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
}
ff, err = h.openIndexFile(ctx, filePath, mustCompress, fileEncoding)
if err != nil {
- ctx.Logger().Printf("cannot open dir index %q: %s", filePath, err)
+ ctx.Logger().Printf("cannot open dir index %q: %v", filePath, err)
ctx.Error("Directory index is forbidden", StatusForbidden)
return
}
} else if err != nil {
- ctx.Logger().Printf("cannot open file %q: %s", filePath, err)
+ ctx.Logger().Printf("cannot open file %q: %v", filePath, err)
if h.pathNotFound == nil {
ctx.Error("Cannot open requested path", StatusNotFound)
} else {
@@ -851,7 +918,7 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
r, err := ff.NewReader()
if err != nil {
- ctx.Logger().Printf("cannot obtain file reader for path=%q: %s", path, err)
+ ctx.Logger().Printf("cannot obtain file reader for path=%q: %v", path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
@@ -859,28 +926,28 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
hdr := &ctx.Response.Header
if ff.compressed {
if fileEncoding == "br" {
- hdr.SetCanonical(strContentEncoding, strBr)
+ hdr.SetContentEncodingBytes(strBr)
} else if fileEncoding == "gzip" {
- hdr.SetCanonical(strContentEncoding, strGzip)
+ hdr.SetContentEncodingBytes(strGzip)
}
}
statusCode := StatusOK
contentLength := ff.contentLength
if h.acceptByteRange {
- hdr.SetCanonical(strAcceptRanges, strBytes)
+ hdr.setNonSpecial(strAcceptRanges, strBytes)
if len(byteRange) > 0 {
startPos, endPos, err := ParseByteRange(byteRange, contentLength)
if err != nil {
- r.(io.Closer).Close()
- ctx.Logger().Printf("cannot parse byte range %q for path=%q: %s", byteRange, path, err)
+ _ = r.(io.Closer).Close()
+ ctx.Logger().Printf("cannot parse byte range %q for path=%q: %v", byteRange, path, err)
ctx.Error("Range Not Satisfiable", StatusRequestedRangeNotSatisfiable)
return
}
if err = r.(byteRangeUpdater).UpdateByteRange(startPos, endPos); err != nil {
- r.(io.Closer).Close()
- ctx.Logger().Printf("cannot seek byte range %q for path=%q: %s", byteRange, path, err)
+ _ = r.(io.Closer).Close()
+ ctx.Logger().Printf("cannot seek byte range %q for path=%q: %v", byteRange, path, err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
@@ -891,7 +958,7 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
}
}
- hdr.SetCanonical(strLastModified, ff.lastModifiedStr)
+ hdr.setNonSpecial(strLastModified, ff.lastModifiedStr)
if !ctx.IsHead() {
ctx.SetBodyStream(r, contentLength)
} else {
@@ -900,7 +967,7 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) {
ctx.Response.Header.SetContentLength(contentLength)
if rc, ok := r.(io.Closer); ok {
if err := rc.Close(); err != nil {
- ctx.Logger().Printf("cannot close file reader: %s", err)
+ ctx.Logger().Printf("cannot close file reader: %v", err)
ctx.Error("Internal Server Error", StatusInternalServerError)
return
}
@@ -981,7 +1048,7 @@ func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress
return ff, nil
}
if !os.IsNotExist(err) {
- return nil, fmt.Errorf("cannot open file %q: %s", indexFilePath, err)
+ return nil, fmt.Errorf("cannot open file %q: %w", indexFilePath, err)
}
}
@@ -1001,16 +1068,16 @@ func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool,
w := &bytebufferpool.ByteBuffer{}
basePathEscaped := html.EscapeString(string(base.Path()))
- fmt.Fprintf(w, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
- fmt.Fprintf(w, "<h1>%s</h1>", basePathEscaped)
- fmt.Fprintf(w, "<ul>")
+ _, _ = fmt.Fprintf(w, "<html><head><title>%s</title><style>.dir { font-weight: bold }</style></head><body>", basePathEscaped)
+ _, _ = fmt.Fprintf(w, "<h1>%s</h1>", basePathEscaped)
+ _, _ = fmt.Fprintf(w, "<ul>")
if len(basePathEscaped) > 1 {
var parentURI URI
base.CopyTo(&parentURI)
parentURI.Update(string(base.Path()) + "/..")
parentPathEscaped := html.EscapeString(string(parentURI.Path()))
- fmt.Fprintf(w, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
+ _, _ = fmt.Fprintf(w, `<li><a href="%s" class="dir">..</a></li>`, parentPathEscaped)
}
f, err := os.Open(dirPath)
@@ -1019,7 +1086,7 @@ func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool,
}
fileinfos, err := f.Readdir(0)
- f.Close()
+ _ = f.Close()
if err != nil {
return nil, err
}
@@ -1054,11 +1121,11 @@ nestedContinue:
auxStr = fmt.Sprintf("file, %d bytes", fi.Size())
className = "file"
}
- fmt.Fprintf(w, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
+ _, _ = fmt.Fprintf(w, `<li><a href="%s" class="%s">%s</a>, %s, last modified %s</li>`,
pathEscaped, className, html.EscapeString(name), auxStr, fsModTime(fi.ModTime()))
}
- fmt.Fprintf(w, "</ul></body></html>")
+ _, _ = fmt.Fprintf(w, "</ul></body></html>")
if mustCompress {
var zbuf bytebufferpool.ByteBuffer
@@ -1099,12 +1166,12 @@ func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string)
fileInfo, err := f.Stat()
if err != nil {
- f.Close()
- return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err)
+ _ = f.Close()
+ return nil, fmt.Errorf("cannot obtain info for file %q: %w", filePath, err)
}
if fileInfo.IsDir() {
- f.Close()
+ _ = f.Close()
return nil, errDirIndexRequired
}
@@ -1114,11 +1181,18 @@ func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string)
return h.newFSFile(f, fileInfo, false, "")
}
- compressedFilePath := filePath + h.compressedFileSuffixes[fileEncoding]
+ compressedFilePath := h.filePathToCompressed(filePath)
+ if compressedFilePath != filePath {
+ if err := os.MkdirAll(filepath.Dir(compressedFilePath), os.ModePerm); err != nil {
+ return nil, err
+ }
+ }
+ compressedFilePath += h.compressedFileSuffixes[fileEncoding]
+
absPath, err := filepath.Abs(compressedFilePath)
if err != nil {
- f.Close()
- return nil, fmt.Errorf("cannot determine absolute path for %q: %s", compressedFilePath, err)
+ _ = f.Close()
+ return nil, fmt.Errorf("cannot determine absolute path for %q: %v", compressedFilePath, err)
}
flock := getFileLock(absPath)
@@ -1135,7 +1209,7 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat
// It is safe opening such a file, since the file creation
// is guarded by file mutex - see getFileLock call.
if _, err := os.Stat(compressedFilePath); err == nil {
- f.Close()
+ _ = f.Close()
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
@@ -1144,9 +1218,9 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat
tmpFilePath := compressedFilePath + ".tmp"
zf, err := os.Create(tmpFilePath)
if err != nil {
- f.Close()
+ _ = f.Close()
if !os.IsPermission(err) {
- return nil, fmt.Errorf("cannot create temporary file %q: %s", tmpFilePath, err)
+ return nil, fmt.Errorf("cannot create temporary file %q: %w", tmpFilePath, err)
}
return nil, errNoCreatePermission
}
@@ -1165,17 +1239,17 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat
}
releaseStacklessGzipWriter(zw, CompressDefaultCompression)
}
- zf.Close()
- f.Close()
+ _ = zf.Close()
+ _ = f.Close()
if err != nil {
- return nil, fmt.Errorf("error when compressing file %q to %q: %s", filePath, tmpFilePath, err)
+ return nil, fmt.Errorf("error when compressing file %q to %q: %w", filePath, tmpFilePath, err)
}
if err = os.Chtimes(tmpFilePath, time.Now(), fileInfo.ModTime()); err != nil {
- return nil, fmt.Errorf("cannot change modification time to %s for tmp file %q: %s",
+ return nil, fmt.Errorf("cannot change modification time to %v for tmp file %q: %v",
fileInfo.ModTime(), tmpFilePath, err)
}
if err = os.Rename(tmpFilePath, compressedFilePath); err != nil {
- return nil, fmt.Errorf("cannot move compressed file from %q to %q: %s", tmpFilePath, compressedFilePath, err)
+ return nil, fmt.Errorf("cannot move compressed file from %q to %q: %w", tmpFilePath, compressedFilePath, err)
}
return h.newCompressedFSFile(compressedFilePath, fileEncoding)
}
@@ -1183,12 +1257,12 @@ func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePat
func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (*fsFile, error) {
f, err := os.Open(filePath)
if err != nil {
- return nil, fmt.Errorf("cannot open compressed file %q: %s", filePath, err)
+ return nil, fmt.Errorf("cannot open compressed file %q: %w", filePath, err)
}
fileInfo, err := f.Stat()
if err != nil {
- f.Close()
- return nil, fmt.Errorf("cannot obtain info for compressed file %q: %s", filePath, err)
+ _ = f.Close()
+ return nil, fmt.Errorf("cannot obtain info for compressed file %q: %w", filePath, err)
}
return h.newFSFile(f, fileInfo, true, fileEncoding)
}
@@ -1209,12 +1283,12 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding
fileInfo, err := f.Stat()
if err != nil {
- f.Close()
- return nil, fmt.Errorf("cannot obtain info for file %q: %s", filePath, err)
+ _ = f.Close()
+ return nil, fmt.Errorf("cannot obtain info for file %q: %w", filePath, err)
}
if fileInfo.IsDir() {
- f.Close()
+ _ = f.Close()
if mustCompress {
return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q",
filePath, h.compressedFileSuffixes[fileEncoding])
@@ -1225,8 +1299,8 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding
if mustCompress {
fileInfoOriginal, err := os.Stat(filePathOriginal)
if err != nil {
- f.Close()
- return nil, fmt.Errorf("cannot obtain info for original file %q: %s", filePathOriginal, err)
+ _ = f.Close()
+ return nil, fmt.Errorf("cannot obtain info for original file %q: %w", filePathOriginal, err)
}
// Only re-create the compressed file if there was more than a second between the mod times.
@@ -1234,8 +1308,8 @@ func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding
// to look newer than the gzipped file.
if fileInfoOriginal.ModTime().Sub(fileInfo.ModTime()) >= time.Second {
// The compressed file became stale. Re-create it.
- f.Close()
- os.Remove(filePath)
+ _ = f.Close()
+ _ = os.Remove(filePath)
return h.compressAndOpenFSFile(filePathOriginal, fileEncoding)
}
}
@@ -1247,7 +1321,7 @@ func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool,
n := fileInfo.Size()
contentLength := int(n)
if n != int64(contentLength) {
- f.Close()
+ _ = f.Close()
return nil, fmt.Errorf("too big file: %d bytes", n)
}
@@ -1257,7 +1331,7 @@ func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool,
if len(contentType) == 0 {
data, err := readFileHeader(f, compressed, fileEncoding)
if err != nil {
- return nil, fmt.Errorf("cannot read header of the file %q: %s", f.Name(), err)
+ return nil, fmt.Errorf("cannot read header of the file %q: %w", f.Name(), err)
}
contentType = http.DetectContentType(data)
}
@@ -1302,7 +1376,7 @@ func readFileHeader(f *os.File, compressed bool, fileEncoding string) ([]byte, e
R: r,
N: 512,
}
- data, err := ioutil.ReadAll(lr)
+ data, err := io.ReadAll(lr)
if _, err := f.Seek(0, 0); err != nil {
return nil, err
}
@@ -1359,7 +1433,7 @@ func FileLastModified(path string) (time.Time, error) {
return zeroTime, err
}
fileInfo, err := f.Stat()
- f.Close()
+ _ = f.Close()
if err != nil {
return zeroTime, err
}
@@ -1370,18 +1444,10 @@ func fsModTime(t time.Time) time.Time {
return t.In(time.UTC).Truncate(time.Second)
}
-var (
- filesLockMap = make(map[string]*sync.Mutex)
- filesLockMapLock sync.Mutex
-)
+var filesLockMap sync.Map
func getFileLock(absPath string) *sync.Mutex {
- filesLockMapLock.Lock()
- flock := filesLockMap[absPath]
- if flock == nil {
- flock = &sync.Mutex{}
- filesLockMap[absPath] = flock
- }
- filesLockMapLock.Unlock()
- return flock
+ v, _ := filesLockMap.LoadOrStore(absPath, &sync.Mutex{})
+ filelock := v.(*sync.Mutex)
+ return filelock
}
diff --git a/fs_example_test.go b/fs_example_test.go
index 9073cb1..724e986 100644
--- a/fs_example_test.go
+++ b/fs_example_test.go
@@ -23,6 +23,6 @@ func ExampleFS() {
// Start the server.
if err := fasthttp.ListenAndServe(":8080", h); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
diff --git a/fs_handler_example_test.go b/fs_handler_example_test.go
index dba4670..3831327 100644
--- a/fs_handler_example_test.go
+++ b/fs_handler_example_test.go
@@ -42,6 +42,6 @@ func requestHandler(ctx *fasthttp.RequestCtx) {
func ExampleFSHandler() {
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
- log.Fatalf("Error in server: %s", err)
+ log.Fatalf("Error in server: %v", err)
}
}
diff --git a/fs_test.go b/fs_test.go
index c6e125c..22a9b33 100644
--- a/fs_test.go
+++ b/fs_test.go
@@ -5,7 +5,6 @@ import (
"bytes"
"fmt"
"io"
- "io/ioutil"
"math/rand"
"os"
"path"
@@ -127,10 +126,10 @@ func TestServeFileHead(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if len(ce) > 0 {
t.Fatalf("Unexpected 'Content-Encoding' %q", ce)
}
@@ -142,7 +141,7 @@ func TestServeFileHead(t *testing.T) {
expectedBody, err := getFileContents("/fs.go")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
contentLength := resp.Header.ContentLength()
if contentLength != len(expectedBody) {
@@ -155,13 +154,13 @@ func TestServeFileSmallNoReadFrom(t *testing.T) {
teststr := "hello, world!"
- tempdir, err := ioutil.TempDir("", "httpexpect")
+ tempdir, err := os.MkdirTemp("", "httpexpect")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempdir)
- if err := ioutil.WriteFile(
+ if err := os.WriteFile(
path.Join(tempdir, "hello"), []byte(teststr), 0666); err != nil {
t.Fatal(err)
}
@@ -191,7 +190,7 @@ func TestServeFileSmallNoReadFrom(t *testing.T) {
body := buf.String()
if body != teststr {
- t.Fatalf("expected '%s'", teststr)
+ t.Fatalf("expected '%q'", teststr)
}
}
@@ -219,21 +218,21 @@ func TestServeFileCompressed(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
expectedBody, err := getFileContents("/fs.go")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
@@ -248,21 +247,21 @@ func TestServeFileCompressed(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err = resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "br" {
t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "br")
}
body, err = resp.BodyUnbrotli()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
expectedBody, err = getFileContents("/fs.go")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
@@ -284,10 +283,10 @@ func TestServeFileUncompressed(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if len(ce) > 0 {
t.Fatalf("Unexpected 'Content-Encoding' %q", ce)
}
@@ -295,7 +294,7 @@ func TestServeFileUncompressed(t *testing.T) {
body := resp.Body()
expectedBody, err := getFileContents("/fs.go")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(body, expectedBody) {
t.Fatalf("unexpected body %q. expecting %q", body, expectedBody)
@@ -359,7 +358,7 @@ func testFSByteRange(t *testing.T, h RequestHandler, filePath string) {
expectedBody, err := getFileContents(filePath)
if err != nil {
- t.Fatalf("cannot read file %q: %s", filePath, err)
+ t.Fatalf("cannot read file %q: %v", filePath, err)
}
fileSize := len(expectedBody)
@@ -377,7 +376,7 @@ func testFSByteRange(t *testing.T, h RequestHandler, filePath string) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. filePath=%q", err, filePath)
+ t.Fatalf("unexpected error: %v. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusPartialContent {
t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath)
@@ -409,7 +408,7 @@ func getFileContents(path string) ([]byte, error) {
return nil, err
}
defer f.Close()
- return ioutil.ReadAll(f)
+ return io.ReadAll(f)
}
func TestParseByteRangeSuccess(t *testing.T) {
@@ -435,7 +434,7 @@ func TestParseByteRangeSuccess(t *testing.T) {
func testParseByteRangeSuccess(t *testing.T, v string, contentLength, startPos, endPos int) {
startPos1, endPos1, err := ParseByteRange([]byte(v), contentLength)
if err != nil {
- t.Fatalf("unexpected error: %s. v=%q, contentLength=%d", err, v, contentLength)
+ t.Fatalf("unexpected error: %v. v=%q, contentLength=%d", err, v, contentLength)
}
if startPos1 != startPos {
t.Fatalf("unexpected startPos=%d. Expecting %d. v=%q, contentLength=%d", startPos1, startPos, v, contentLength)
@@ -480,6 +479,11 @@ func testParseByteRangeError(t *testing.T, v string, contentLength int) {
}
func TestFSCompressConcurrent(t *testing.T) {
+ // Don't run this test on Windows, the Windows Github actions are to slow and timeout too often.
+ if runtime.GOOS == "windows" {
+ t.SkipNow()
+ }
+
// This test can't run parallel as files in / might be changed by other tests.
stop := make(chan struct{})
@@ -510,7 +514,7 @@ func TestFSCompressConcurrent(t *testing.T) {
for i := 0; i < concurrency; i++ {
select {
case <-ch:
- case <-time.After(time.Second * 3):
+ case <-time.After(time.Second * 2):
t.Fatalf("timeout")
}
}
@@ -537,6 +541,11 @@ func TestFSCompressSingleThread(t *testing.T) {
}
func testFSCompress(t *testing.T, h RequestHandler, filePath string) {
+ // File locking is flaky on Windows.
+ if runtime.GOOS == "windows" {
+ t.SkipNow()
+ }
+
var ctx RequestCtx
ctx.Init(&Request{}, nil, nil)
@@ -549,12 +558,12 @@ func testFSCompress(t *testing.T, h RequestHandler, filePath string) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
+ t.Errorf("unexpected error: %v. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if string(ce) != "" {
t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath)
}
@@ -568,18 +577,18 @@ func testFSCompress(t *testing.T, h RequestHandler, filePath string) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
+ t.Errorf("unexpected error: %v. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath)
}
zbody, err := resp.BodyGunzip()
if err != nil {
- t.Errorf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath)
+ t.Errorf("unexpected error when gunzipping response body: %v. filePath=%q", err, filePath)
}
if string(zbody) != body {
t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath)
@@ -593,18 +602,18 @@ func testFSCompress(t *testing.T, h RequestHandler, filePath string) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s. filePath=%q", err, filePath)
+ t.Errorf("unexpected error: %v. filePath=%q", err, filePath)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "br" {
t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "br", filePath)
}
zbody, err = resp.BodyUnbrotli()
if err != nil {
- t.Errorf("unexpected error when unbrotling response body: %s. filePath=%q", err, filePath)
+ t.Errorf("unexpected error when unbrotling response body: %v. filePath=%q", err, filePath)
}
if string(zbody) != body {
t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath)
@@ -618,13 +627,13 @@ func TestFSHandlerSingleThread(t *testing.T) {
f, err := os.Open(".")
if err != nil {
- t.Fatalf("cannot open cwd: %s", err)
+ t.Fatalf("cannot open cwd: %v", err)
}
filenames, err := f.Readdirnames(0)
f.Close()
if err != nil {
- t.Fatalf("cannot read dirnames in cwd: %s", err)
+ t.Fatalf("cannot read dirnames in cwd: %v", err)
}
sort.Strings(filenames)
@@ -640,13 +649,13 @@ func TestFSHandlerConcurrent(t *testing.T) {
f, err := os.Open(".")
if err != nil {
- t.Fatalf("cannot open cwd: %s", err)
+ t.Fatalf("cannot open cwd: %v", err)
}
filenames, err := f.Readdirnames(0)
f.Close()
if err != nil {
- t.Fatalf("cannot read dirnames in cwd: %s", err)
+ t.Fatalf("cannot read dirnames in cwd: %v", err)
}
sort.Strings(filenames)
@@ -680,20 +689,20 @@ func fsHandlerTest(t *testing.T, requestHandler RequestHandler, filenames []stri
for _, name := range filenames {
f, err := os.Open(name)
if err != nil {
- t.Fatalf("cannot open file %q: %s", name, err)
+ t.Fatalf("cannot open file %q: %v", name, err)
}
stat, err := f.Stat()
if err != nil {
- t.Fatalf("cannot get file stat %q: %s", name, err)
+ t.Fatalf("cannot get file stat %q: %v", name, err)
}
if stat.IsDir() {
f.Close()
continue
}
- data, err := ioutil.ReadAll(f)
+ data, err := io.ReadAll(f)
f.Close()
if err != nil {
- t.Fatalf("cannot read file contents %q: %s", name, err)
+ t.Fatalf("cannot read file contents %q: %v", name, err)
}
ctx.URI().Update(name)
@@ -701,9 +710,9 @@ func fsHandlerTest(t *testing.T, requestHandler RequestHandler, filenames []stri
if ctx.Response.bodyStream == nil {
t.Fatalf("response body stream must be non-empty")
}
- body, err := ioutil.ReadAll(ctx.Response.bodyStream)
+ body, err := io.ReadAll(ctx.Response.bodyStream)
if err != nil {
- t.Fatalf("error when reading response body stream: %s", err)
+ t.Fatalf("error when reading response body stream: %v", err)
}
if !bytes.Equal(body, data) {
t.Fatalf("unexpected body returned: %q. Expecting %q", body, data)
@@ -720,9 +729,9 @@ func fsHandlerTest(t *testing.T, requestHandler RequestHandler, filenames []stri
if ctx.Response.bodyStream == nil {
t.Fatalf("response body stream must be non-empty")
}
- body, err := ioutil.ReadAll(ctx.Response.bodyStream)
+ body, err := io.ReadAll(ctx.Response.bodyStream)
if err != nil {
- t.Fatalf("error when reading response body stream: %s", err)
+ t.Fatalf("error when reading response body stream: %v", err)
}
if len(body) == 0 {
t.Fatalf("index page must be non-empty")
@@ -796,7 +805,7 @@ func TestServeFileContentType(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
expected := []byte("image/png")
diff --git a/go.mod b/go.mod
index 1711091..0f02f59 100644
--- a/go.mod
+++ b/go.mod
@@ -1,13 +1,13 @@
module github.com/valyala/fasthttp
-go 1.12
+go 1.16
require (
- github.com/andybalholm/brotli v1.0.2
- github.com/klauspost/compress v1.13.4
+ github.com/andybalholm/brotli v1.0.4
+ github.com/klauspost/compress v1.15.9
github.com/valyala/bytebufferpool v1.0.0
github.com/valyala/tcplisten v1.0.0
- golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a
- golang.org/x/net v0.0.0-20210510120150-4163338589ed
- golang.org/x/sys v0.0.0-20210514084401-e8d321eab015
+ golang.org/x/crypto v0.0.0-20220214200702-86341886e292
+ golang.org/x/net v0.0.0-20220906165146-f3363e06e74c
+ golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10
)
diff --git a/go.sum b/go.sum
index 5c59aef..1f0f0cd 100644
--- a/go.sum
+++ b/go.sum
@@ -1,23 +1,24 @@
-github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E=
-github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
-github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
-github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s=
-github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
+github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
+github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
+github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY=
+github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8=
github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc=
-golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc=
-golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
-golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
-golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I=
-golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/crypto v0.0.0-20220214200702-86341886e292 h1:f+lwQ+GtmgoY+A2YaQxlSOnDjXcQ7ZRLWOHbC6HtRqE=
+golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
+golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
+golang.org/x/net v0.0.0-20220906165146-f3363e06e74c h1:yKufUcDwucU5urd+50/Opbt4AYpqthk7wHpHok8f1lo=
+golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E=
-golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10 h1:WIoqL4EROvwiPdUtaip4VcDdpZ4kha7wBWZrbVKCIZg=
+golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
-golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
diff --git a/header.go b/header.go
index 9fa7b4b..7c12308 100644
--- a/header.go
+++ b/header.go
@@ -33,15 +33,20 @@ type ResponseHeader struct {
noDefaultDate bool
statusCode int
+ statusMessage []byte
+ protocol []byte
contentLength int
contentLengthBytes []byte
secureErrorLogMessage bool
- contentType []byte
- server []byte
+ contentType []byte
+ contentEncoding []byte
+ server []byte
+ mulHeader [][]byte
- h []argsKV
- bufKV argsKV
+ h []argsKV
+ trailer []argsKV
+ bufKV argsKV
cookies []argsKV
}
@@ -56,9 +61,10 @@ type ResponseHeader struct {
type RequestHeader struct {
noCopy noCopy //nolint:unused,structcheck
- disableNormalizing bool
- noHTTP11 bool
- connectionClose bool
+ disableNormalizing bool
+ noHTTP11 bool
+ connectionClose bool
+ noDefaultContentType bool
// These two fields have been moved close to other bool fields
// for reducing RequestHeader object size.
@@ -74,9 +80,11 @@ type RequestHeader struct {
host []byte
contentType []byte
userAgent []byte
+ mulHeader [][]byte
- h []argsKV
- bufKV argsKV
+ h []argsKV
+ trailer []argsKV
+ bufKV argsKV
cookies []argsKV
@@ -98,13 +106,13 @@ func (h *ResponseHeader) SetContentRange(startPos, endPos, contentLength int) {
b = AppendUint(b, contentLength)
h.bufKV.value = b
- h.SetCanonical(strContentRange, h.bufKV.value)
+ h.setNonSpecial(strContentRange, h.bufKV.value)
}
-// SetByteRange sets 'Range: bytes=startPos-endPos' header.
+// SetByteRanges sets 'Range: bytes=startPos-endPos' header.
//
-// * If startPos is negative, then 'bytes=-startPos' value is set.
-// * If endPos is negative, then 'bytes=startPos-' value is set.
+// - If startPos is negative, then 'bytes=-startPos' value is set.
+// - If endPos is negative, then 'bytes=startPos-' value is set.
func (h *RequestHeader) SetByteRange(startPos, endPos int) {
b := h.bufKV.value[:0]
b = append(b, strBytes...)
@@ -120,7 +128,7 @@ func (h *RequestHeader) SetByteRange(startPos, endPos int) {
}
h.bufKV.value = b
- h.SetCanonical(strRange, h.bufKV.value)
+ h.setNonSpecial(strRange, h.bufKV.value)
}
// StatusCode returns response status code.
@@ -136,10 +144,33 @@ func (h *ResponseHeader) SetStatusCode(statusCode int) {
h.statusCode = statusCode
}
+// StatusMessage returns response status message.
+func (h *ResponseHeader) StatusMessage() []byte {
+ return h.statusMessage
+}
+
+// SetStatusMessage sets response status message bytes.
+func (h *ResponseHeader) SetStatusMessage(statusMessage []byte) {
+ h.statusMessage = append(h.statusMessage[:0], statusMessage...)
+}
+
+// Protocol returns response protocol bytes.
+func (h *ResponseHeader) Protocol() []byte {
+ if len(h.protocol) > 0 {
+ return h.protocol
+ }
+ return strHTTP11
+}
+
+// SetProtocol sets response protocol bytes.
+func (h *ResponseHeader) SetProtocol(protocol []byte) {
+ h.protocol = append(h.protocol[:0], protocol...)
+}
+
// SetLastModified sets 'Last-Modified' header to the given value.
func (h *ResponseHeader) SetLastModified(t time.Time) {
h.bufKV.value = AppendHTTPDate(h.bufKV.value[:0], t)
- h.SetCanonical(strLastModified, h.bufKV.value)
+ h.setNonSpecial(strLastModified, h.bufKV.value)
}
// ConnectionClose returns true if 'Connection: close' header is set.
@@ -297,6 +328,21 @@ func (h *ResponseHeader) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
}
+// ContentEncoding returns Content-Encoding header value.
+func (h *ResponseHeader) ContentEncoding() []byte {
+ return h.contentEncoding
+}
+
+// SetContentEncoding sets Content-Encoding header value.
+func (h *ResponseHeader) SetContentEncoding(contentEncoding string) {
+ h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...)
+}
+
+// SetContentEncodingBytes sets Content-Encoding header value.
+func (h *ResponseHeader) SetContentEncodingBytes(contentEncoding []byte) {
+ h.contentEncoding = append(h.contentEncoding[:0], contentEncoding...)
+}
+
// Server returns Server header value.
func (h *ResponseHeader) Server() []byte {
return h.server
@@ -327,6 +373,21 @@ func (h *RequestHeader) SetContentTypeBytes(contentType []byte) {
h.contentType = append(h.contentType[:0], contentType...)
}
+// ContentEncoding returns Content-Encoding header value.
+func (h *RequestHeader) ContentEncoding() []byte {
+ return peekArgBytes(h.h, strContentEncoding)
+}
+
+// SetContentEncoding sets Content-Encoding header value.
+func (h *RequestHeader) SetContentEncoding(contentEncoding string) {
+ h.SetBytesK(strContentEncoding, contentEncoding)
+}
+
+// SetContentEncodingBytes sets Content-Encoding header value.
+func (h *RequestHeader) SetContentEncodingBytes(contentEncoding []byte) {
+ h.setNonSpecial(strContentEncoding, contentEncoding)
+}
+
// SetMultipartFormBoundary sets the following Content-Type:
// 'multipart/form-data; boundary=...'
// where ... is substituted by the given boundary.
@@ -357,6 +418,117 @@ func (h *RequestHeader) SetMultipartFormBoundaryBytes(boundary []byte) {
h.SetContentTypeBytes(h.bufKV.value)
}
+// SetTrailer sets header Trailer value for chunked response
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *ResponseHeader) SetTrailer(trailer string) error {
+ return h.SetTrailerBytes(s2b(trailer))
+}
+
+// SetTrailerBytes sets Trailer header value for chunked response
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *ResponseHeader) SetTrailerBytes(trailer []byte) error {
+ h.trailer = h.trailer[:0]
+ return h.AddTrailerBytes(trailer)
+}
+
+// AddTrailer add Trailer header value for chunked response
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *ResponseHeader) AddTrailer(trailer string) error {
+ return h.AddTrailerBytes(s2b(trailer))
+}
+
+var ErrBadTrailer = errors.New("contain forbidden trailer")
+
+// AddTrailerBytes add Trailer header value for chunked response
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *ResponseHeader) AddTrailerBytes(trailer []byte) error {
+ var err error
+ for i := -1; i+1 < len(trailer); {
+ trailer = trailer[i+1:]
+ i = bytes.IndexByte(trailer, ',')
+ if i < 0 {
+ i = len(trailer)
+ }
+ key := trailer[:i]
+ for len(key) > 0 && key[0] == ' ' {
+ key = key[1:]
+ }
+ for len(key) > 0 && key[len(key)-1] == ' ' {
+ key = key[:len(key)-1]
+ }
+ // Forbidden by RFC 7230, section 4.1.2
+ if isBadTrailer(key) {
+ err = ErrBadTrailer
+ continue
+ }
+ h.bufKV.key = append(h.bufKV.key[:0], key...)
+ normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
+ h.trailer = appendArgBytes(h.trailer, h.bufKV.key, nil, argsNoValue)
+ }
+
+ return err
+}
+
// MultipartFormBoundary returns boundary part
// from 'multipart/form-data; boundary=...' Content-Type.
func (h *RequestHeader) MultipartFormBoundary() []byte {
@@ -431,7 +603,7 @@ func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) {
// Referer returns Referer header value.
func (h *RequestHeader) Referer() []byte {
- return h.PeekBytes(strReferer)
+ return peekArgBytes(h.h, strReferer)
}
// SetReferer sets Referer header value.
@@ -441,7 +613,7 @@ func (h *RequestHeader) SetReferer(referer string) {
// SetRefererBytes sets Referer header value.
func (h *RequestHeader) SetRefererBytes(referer []byte) {
- h.SetCanonical(strReferer, referer)
+ h.setNonSpecial(strReferer, referer)
}
// Method returns HTTP request method.
@@ -505,6 +677,115 @@ func (h *RequestHeader) SetRequestURIBytes(requestURI []byte) {
h.requestURI = append(h.requestURI[:0], requestURI...)
}
+// SetTrailer sets Trailer header value for chunked request
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *RequestHeader) SetTrailer(trailer string) error {
+ return h.SetTrailerBytes(s2b(trailer))
+}
+
+// SetTrailerBytes sets Trailer header value for chunked request
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *RequestHeader) SetTrailerBytes(trailer []byte) error {
+ h.trailer = h.trailer[:0]
+ return h.AddTrailerBytes(trailer)
+}
+
+// AddTrailer add Trailer header value for chunked request
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *RequestHeader) AddTrailer(trailer string) error {
+ return h.AddTrailerBytes(s2b(trailer))
+}
+
+// AddTrailerBytes add Trailer header value for chunked request
+// to indicate which headers will be sent after the body.
+//
+// Use Set to set the trailer header later.
+//
+// Trailers are only supported with chunked transfer.
+// Trailers allow the sender to include additional headers at the end of chunked messages.
+//
+// The following trailers are forbidden:
+// 1. necessary for message framing (e.g., Transfer-Encoding and Content-Length),
+// 2. routing (e.g., Host),
+// 3. request modifiers (e.g., controls and conditionals in Section 5 of [RFC7231]),
+// 4. authentication (e.g., see [RFC7235] and [RFC6265]),
+// 5. response control data (e.g., see Section 7.1 of [RFC7231]),
+// 6. determining how to process the payload (e.g., Content-Encoding, Content-Type, Content-Range, and Trailer)
+//
+// Return ErrBadTrailer if contain any forbidden trailers.
+func (h *RequestHeader) AddTrailerBytes(trailer []byte) error {
+ var err error
+ for i := -1; i+1 < len(trailer); {
+ trailer = trailer[i+1:]
+ i = bytes.IndexByte(trailer, ',')
+ if i < 0 {
+ i = len(trailer)
+ }
+ key := trailer[:i]
+ for len(key) > 0 && key[0] == ' ' {
+ key = key[1:]
+ }
+ for len(key) > 0 && key[len(key)-1] == ' ' {
+ key = key[:len(key)-1]
+ }
+ // Forbidden by RFC 7230, section 4.1.2
+ if isBadTrailer(key) {
+ err = ErrBadTrailer
+ continue
+ }
+ h.bufKV.key = append(h.bufKV.key[:0], key...)
+ normalizeHeaderKey(h.bufKV.key, h.disableNormalizing)
+ h.trailer = appendArgBytes(h.trailer, h.bufKV.key, nil, argsNoValue)
+ }
+
+ return err
+}
+
// IsGet returns true if request method is GET.
func (h *RequestHeader) IsGet() bool {
return string(h.Method()) == MethodGet
@@ -589,7 +870,7 @@ func (h *RequestHeader) HasAcceptEncodingBytes(acceptEncoding []byte) bool {
// i.e. the number of times f is called in VisitAll.
func (h *ResponseHeader) Len() int {
n := 0
- h.VisitAll(func(k, v []byte) { n++ })
+ h.VisitAll(func(_, _ []byte) { n++ })
return n
}
@@ -597,7 +878,7 @@ func (h *ResponseHeader) Len() int {
// i.e. the number of times f is called in VisitAll.
func (h *RequestHeader) Len() int {
n := 0
- h.VisitAll(func(k, v []byte) { n++ })
+ h.VisitAll(func(_, _ []byte) { n++ })
return n
}
@@ -608,9 +889,9 @@ func (h *RequestHeader) Len() int {
// while lowercasing all the other letters.
// Examples:
//
-// * CONNECTION -> Connection
-// * conteNT-tYPE -> Content-Type
-// * foo-bar-baz -> Foo-Bar-Baz
+// - CONNECTION -> Connection
+// - conteNT-tYPE -> Content-Type
+// - foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *RequestHeader) DisableNormalizing() {
@@ -624,9 +905,9 @@ func (h *RequestHeader) DisableNormalizing() {
// the other letters.
// Examples:
//
-// * CONNECTION -> Connection
-// * conteNT-tYPE -> Content-Type
-// * foo-bar-baz -> Foo-Bar-Baz
+// - CONNECTION -> Connection
+// - conteNT-tYPE -> Content-Type
+// - foo-bar-baz -> Foo-Bar-Baz
//
// This is enabled by default unless disabled using DisableNormalizing()
func (h *RequestHeader) EnableNormalizing() {
@@ -640,9 +921,9 @@ func (h *RequestHeader) EnableNormalizing() {
// while lowercasing all the other letters.
// Examples:
//
-// * CONNECTION -> Connection
-// * conteNT-tYPE -> Content-Type
-// * foo-bar-baz -> Foo-Bar-Baz
+// - CONNECTION -> Connection
+// - conteNT-tYPE -> Content-Type
+// - foo-bar-baz -> Foo-Bar-Baz
//
// Disable header names' normalization only if know what are you doing.
func (h *ResponseHeader) DisableNormalizing() {
@@ -656,9 +937,9 @@ func (h *ResponseHeader) DisableNormalizing() {
// the other letters.
// Examples:
//
-// * CONNECTION -> Connection
-// * conteNT-tYPE -> Content-Type
-// * foo-bar-baz -> Foo-Bar-Baz
+// - CONNECTION -> Connection
+// - conteNT-tYPE -> Content-Type
+// - foo-bar-baz -> Foo-Bar-Baz
//
// This is enabled by default unless disabled using DisableNormalizing()
func (h *ResponseHeader) EnableNormalizing() {
@@ -683,19 +964,30 @@ func (h *ResponseHeader) resetSkipNormalize() {
h.connectionClose = false
h.statusCode = 0
+ h.statusMessage = h.statusMessage[:0]
+ h.protocol = h.protocol[:0]
h.contentLength = 0
h.contentLengthBytes = h.contentLengthBytes[:0]
h.contentType = h.contentType[:0]
+ h.contentEncoding = h.contentEncoding[:0]
h.server = h.server[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
+ h.trailer = h.trailer[:0]
+ h.mulHeader = h.mulHeader[:0]
+}
+
+// SetNoDefaultContentType allows you to control if a default Content-Type header will be set (false) or not (true).
+func (h *RequestHeader) SetNoDefaultContentType(noDefaultContentType bool) {
+ h.noDefaultContentType = noDefaultContentType
}
// Reset clears request header.
func (h *RequestHeader) Reset() {
h.disableNormalizing = false
+ h.SetNoDefaultContentType(false)
h.resetSkipNormalize()
}
@@ -712,6 +1004,8 @@ func (h *RequestHeader) resetSkipNormalize() {
h.host = h.host[:0]
h.contentType = h.contentType[:0]
h.userAgent = h.userAgent[:0]
+ h.trailer = h.trailer[:0]
+ h.mulHeader = h.mulHeader[:0]
h.h = h.h[:0]
h.cookies = h.cookies[:0]
@@ -731,12 +1025,16 @@ func (h *ResponseHeader) CopyTo(dst *ResponseHeader) {
dst.noDefaultDate = h.noDefaultDate
dst.statusCode = h.statusCode
+ dst.statusMessage = append(dst.statusMessage, h.statusMessage...)
+ dst.protocol = append(dst.protocol, h.protocol...)
dst.contentLength = h.contentLength
- dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
- dst.contentType = append(dst.contentType[:0], h.contentType...)
- dst.server = append(dst.server[:0], h.server...)
+ dst.contentLengthBytes = append(dst.contentLengthBytes, h.contentLengthBytes...)
+ dst.contentType = append(dst.contentType, h.contentType...)
+ dst.contentEncoding = append(dst.contentEncoding, h.contentEncoding...)
+ dst.server = append(dst.server, h.server...)
dst.h = copyArgs(dst.h, h.h)
dst.cookies = copyArgs(dst.cookies, h.cookies)
+ dst.trailer = copyArgs(dst.trailer, h.trailer)
}
// CopyTo copies all the headers to dst.
@@ -748,17 +1046,18 @@ func (h *RequestHeader) CopyTo(dst *RequestHeader) {
dst.connectionClose = h.connectionClose
dst.contentLength = h.contentLength
- dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...)
- dst.method = append(dst.method[:0], h.method...)
- dst.proto = append(dst.proto[:0], h.proto...)
- dst.requestURI = append(dst.requestURI[:0], h.requestURI...)
- dst.host = append(dst.host[:0], h.host...)
- dst.contentType = append(dst.contentType[:0], h.contentType...)
- dst.userAgent = append(dst.userAgent[:0], h.userAgent...)
+ dst.contentLengthBytes = append(dst.contentLengthBytes, h.contentLengthBytes...)
+ dst.method = append(dst.method, h.method...)
+ dst.proto = append(dst.proto, h.proto...)
+ dst.requestURI = append(dst.requestURI, h.requestURI...)
+ dst.host = append(dst.host, h.host...)
+ dst.contentType = append(dst.contentType, h.contentType...)
+ dst.userAgent = append(dst.userAgent, h.userAgent...)
+ dst.trailer = append(dst.trailer, h.trailer...)
dst.h = copyArgs(dst.h, h.h)
dst.cookies = copyArgs(dst.cookies, h.cookies)
dst.cookiesCollected = h.cookiesCollected
- dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...)
+ dst.rawHeaders = append(dst.rawHeaders, h.rawHeaders...)
}
// VisitAll calls f for each header.
@@ -773,21 +1072,42 @@ func (h *ResponseHeader) VisitAll(f func(key, value []byte)) {
if len(contentType) > 0 {
f(strContentType, contentType)
}
+ contentEncoding := h.ContentEncoding()
+ if len(contentEncoding) > 0 {
+ f(strContentEncoding, contentEncoding)
+ }
server := h.Server()
if len(server) > 0 {
f(strServer, server)
}
if len(h.cookies) > 0 {
- visitArgs(h.cookies, func(k, v []byte) {
+ visitArgs(h.cookies, func(_, v []byte) {
f(strSetCookie, v)
})
}
+ if len(h.trailer) > 0 {
+ f(strTrailer, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
+ }
visitArgs(h.h, f)
if h.ConnectionClose() {
f(strConnection, strClose)
}
}
+// VisitAllTrailer calls f for each response Trailer.
+//
+// f must not retain references to value after returning.
+func (h *ResponseHeader) VisitAllTrailer(f func(value []byte)) {
+ visitArgsKey(h.trailer, f)
+}
+
+// VisitAllTrailer calls f for each request Trailer.
+//
+// f must not retain references to value after returning.
+func (h *RequestHeader) VisitAllTrailer(f func(value []byte)) {
+ visitArgsKey(h.trailer, f)
+}
+
// VisitAllCookie calls f for each response cookie.
//
// Cookie name is passed in key and the whole Set-Cookie header value
@@ -829,6 +1149,9 @@ func (h *RequestHeader) VisitAll(f func(key, value []byte)) {
if len(userAgent) > 0 {
f(strUserAgent, userAgent)
}
+ if len(h.trailer) > 0 {
+ f(strTrailer, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
+ }
h.collectCookies()
if len(h.cookies) > 0 {
@@ -876,6 +1199,8 @@ func (h *ResponseHeader) del(key []byte) {
switch string(key) {
case HeaderContentType:
h.contentType = h.contentType[:0]
+ case HeaderContentEncoding:
+ h.contentEncoding = h.contentEncoding[:0]
case HeaderServer:
h.server = h.server[:0]
case HeaderSetCookie:
@@ -885,6 +1210,8 @@ func (h *ResponseHeader) del(key []byte) {
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
+ case HeaderTrailer:
+ h.trailer = h.trailer[:0]
}
h.h = delAllArgsBytes(h.h, key)
}
@@ -917,6 +1244,8 @@ func (h *RequestHeader) del(key []byte) {
h.contentLengthBytes = h.contentLengthBytes[:0]
case HeaderConnection:
h.connectionClose = false
+ case HeaderTrailer:
+ h.trailer = h.trailer[:0]
}
h.h = delAllArgsBytes(h.h, key)
}
@@ -938,12 +1267,15 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
h.contentLengthBytes = append(h.contentLengthBytes[:0], value...)
}
return true
+ } else if caseInsensitiveCompare(strContentEncoding, key) {
+ h.SetContentEncodingBytes(value)
+ return true
} else if caseInsensitiveCompare(strConnection, key) {
if bytes.Equal(strClose, value) {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
- h.h = setArgBytes(h.h, key, value, argsHasValue)
+ h.setNonSpecial(key, value)
}
return true
}
@@ -962,6 +1294,9 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
+ } else if caseInsensitiveCompare(strTrailer, key) {
+ _ = h.SetTrailerBytes(value)
+ return true
}
case 'd':
if caseInsensitiveCompare(strDate, key) {
@@ -973,6 +1308,11 @@ func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool {
return false
}
+// setNonSpecial directly put into map i.e. not a basic header
+func (h *ResponseHeader) setNonSpecial(key []byte, value []byte) {
+ h.h = setArgBytes(h.h, key, value, argsHasValue)
+}
+
// setSpecialHeader handles special headers and return true when a header is processed.
func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
if len(key) == 0 {
@@ -995,7 +1335,7 @@ func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
h.SetConnectionClose()
} else {
h.ResetConnectionClose()
- h.h = setArgBytes(h.h, key, value, argsHasValue)
+ h.setNonSpecial(key, value)
}
return true
} else if caseInsensitiveCompare(strCookie, key) {
@@ -1007,6 +1347,9 @@ func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
if caseInsensitiveCompare(strTransferEncoding, key) {
// Transfer-Encoding is managed automatically.
return true
+ } else if caseInsensitiveCompare(strTrailer, key) {
+ _ = h.SetTrailerBytes(value)
+ return true
}
case 'h':
if caseInsensitiveCompare(strHost, key) {
@@ -1023,6 +1366,11 @@ func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
return false
}
+// setNonSpecial directly put into map i.e. not a basic header
+func (h *RequestHeader) setNonSpecial(key []byte, value []byte) {
+ h.h = setArgBytes(h.h, key, value, argsHasValue)
+}
+
// Add adds the given 'key: value' header.
//
// Multiple headers with the same key may be added with this function.
@@ -1031,6 +1379,9 @@ func (h *RequestHeader) setSpecialHeader(key, value []byte) bool {
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked response body.
func (h *ResponseHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
@@ -1043,6 +1394,9 @@ func (h *ResponseHeader) Add(key, value string) {
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
@@ -1055,6 +1409,9 @@ func (h *ResponseHeader) AddBytesK(key []byte, value string) {
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
@@ -1067,6 +1424,9 @@ func (h *ResponseHeader) AddBytesV(key string, value []byte) {
// the Content-Type, Content-Length, Connection, Server, Set-Cookie,
// Transfer-Encoding and Date headers can only be set once and will
// overwrite the previous value.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked response body.
func (h *ResponseHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
@@ -1078,6 +1438,9 @@ func (h *ResponseHeader) AddBytesKV(key, value []byte) {
// Set sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked response body.
+//
// Use Add for setting multiple header values under the same key.
func (h *ResponseHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
@@ -1086,6 +1449,9 @@ func (h *ResponseHeader) Set(key, value string) {
// SetBytesK sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked response body.
+//
// Use AddBytesK for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesK(key []byte, value string) {
h.bufKV.value = append(h.bufKV.value[:0], value...)
@@ -1094,6 +1460,9 @@ func (h *ResponseHeader) SetBytesK(key []byte, value string) {
// SetBytesV sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked response body.
+//
// Use AddBytesV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
@@ -1102,6 +1471,9 @@ func (h *ResponseHeader) SetBytesV(key string, value []byte) {
// SetBytesKV sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked response body.
+//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *ResponseHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
@@ -1111,12 +1483,14 @@ func (h *ResponseHeader) SetBytesKV(key, value []byte) {
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked response body.
func (h *ResponseHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
-
- h.h = setArgBytes(h.h, key, value, argsHasValue)
+ h.setNonSpecial(key, value)
}
// SetCookie sets the given response cookie.
@@ -1146,13 +1520,13 @@ func (h *RequestHeader) SetCookieBytesKV(key, value []byte) {
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
-// c := AcquireCookie()
-// c.SetKey(key)
-// c.SetDomain("example.com")
-// c.SetPath("/path")
-// c.SetExpire(CookieExpireDelete)
-// h.SetCookie(c)
-// ReleaseCookie(c)
+// c := AcquireCookie()
+// c.SetKey(key)
+// c.SetDomain("example.com")
+// c.SetPath("/path")
+// c.SetExpire(CookieExpireDelete)
+// h.SetCookie(c)
+// ReleaseCookie(c)
//
// Use DelCookie if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookie(key string) {
@@ -1169,13 +1543,13 @@ func (h *ResponseHeader) DelClientCookie(key string) {
// This doesn't work for a cookie with specific domain or path,
// you should delete it manually like:
//
-// c := AcquireCookie()
-// c.SetKey(key)
-// c.SetDomain("example.com")
-// c.SetPath("/path")
-// c.SetExpire(CookieExpireDelete)
-// h.SetCookie(c)
-// ReleaseCookie(c)
+// c := AcquireCookie()
+// c.SetKey(key)
+// c.SetDomain("example.com")
+// c.SetPath("/path")
+// c.SetExpire(CookieExpireDelete)
+// h.SetCookie(c)
+// ReleaseCookie(c)
//
// Use DelCookieBytes if you want just removing the cookie from response header.
func (h *ResponseHeader) DelClientCookieBytes(key []byte) {
@@ -1224,6 +1598,9 @@ func (h *RequestHeader) DelAllCookies() {
//
// Multiple headers with the same key may be added with this function.
// Use Set for setting a single header for the given key.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked request body.
func (h *RequestHeader) Add(key, value string) {
h.AddBytesKV(s2b(key), s2b(value))
}
@@ -1232,6 +1609,9 @@ func (h *RequestHeader) Add(key, value string) {
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesK for setting a single header for the given key.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesK(key []byte, value string) {
h.AddBytesKV(key, s2b(value))
}
@@ -1240,6 +1620,9 @@ func (h *RequestHeader) AddBytesK(key []byte, value string) {
//
// Multiple headers with the same key may be added with this function.
// Use SetBytesV for setting a single header for the given key.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesV(key string, value []byte) {
h.AddBytesKV(s2b(key), value)
}
@@ -1252,6 +1635,9 @@ func (h *RequestHeader) AddBytesV(key string, value []byte) {
// the Content-Type, Content-Length, Connection, Cookie,
// Transfer-Encoding, Host and User-Agent headers can only be set once
// and will overwrite the previous value.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see AddTrailer for more details),
+// it will be sent after the chunked request body.
func (h *RequestHeader) AddBytesKV(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
@@ -1263,6 +1649,9 @@ func (h *RequestHeader) AddBytesKV(key, value []byte) {
// Set sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked request body.
+//
// Use Add for setting multiple header values under the same key.
func (h *RequestHeader) Set(key, value string) {
initHeaderKV(&h.bufKV, key, value, h.disableNormalizing)
@@ -1271,6 +1660,9 @@ func (h *RequestHeader) Set(key, value string) {
// SetBytesK sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked request body.
+//
// Use AddBytesK for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesK(key []byte, value string) {
h.bufKV.value = append(h.bufKV.value[:0], value...)
@@ -1279,6 +1671,9 @@ func (h *RequestHeader) SetBytesK(key []byte, value string) {
// SetBytesV sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked request body.
+//
// Use AddBytesV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesV(key string, value []byte) {
k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
@@ -1287,6 +1682,9 @@ func (h *RequestHeader) SetBytesV(key string, value []byte) {
// SetBytesKV sets the given 'key: value' header.
//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked request body.
+//
// Use AddBytesKV for setting multiple header values under the same key.
func (h *RequestHeader) SetBytesKV(key, value []byte) {
h.bufKV.key = append(h.bufKV.key[:0], key...)
@@ -1296,12 +1694,14 @@ func (h *RequestHeader) SetBytesKV(key, value []byte) {
// SetCanonical sets the given 'key: value' header assuming that
// key is in canonical form.
+//
+// If the header is set as a Trailer (forbidden trailers will not be set, see SetTrailer for more details),
+// it will be sent after the chunked request body.
func (h *RequestHeader) SetCanonical(key, value []byte) {
if h.setSpecialHeader(key, value) {
return
}
-
- h.h = setArgBytes(h.h, key, value, argsHasValue)
+ h.setNonSpecial(key, value)
}
// Peek returns header value for the given key.
@@ -1350,6 +1750,8 @@ func (h *ResponseHeader) peek(key []byte) []byte {
switch string(key) {
case HeaderContentType:
return h.ContentType()
+ case HeaderContentEncoding:
+ return h.ContentEncoding()
case HeaderServer:
return h.Server()
case HeaderConnection:
@@ -1361,6 +1763,8 @@ func (h *ResponseHeader) peek(key []byte) []byte {
return h.contentLengthBytes
case HeaderSetCookie:
return appendResponseCookieBytes(nil, h.cookies)
+ case HeaderTrailer:
+ return appendArgsKeyBytes(nil, h.trailer, strCommaSpace)
default:
return peekArgBytes(h.h, key)
}
@@ -1386,11 +1790,153 @@ func (h *RequestHeader) peek(key []byte) []byte {
return appendRequestCookieBytes(nil, h.cookies)
}
return peekArgBytes(h.h, key)
+ case HeaderTrailer:
+ return appendArgsKeyBytes(nil, h.trailer, strCommaSpace)
default:
return peekArgBytes(h.h, key)
}
}
+// PeekAll returns all header value for the given key.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseRequest or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *RequestHeader) PeekAll(key string) [][]byte {
+ k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
+ return h.peekAll(k)
+}
+
+func (h *RequestHeader) peekAll(key []byte) [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ switch string(key) {
+ case HeaderHost:
+ if host := h.Host(); len(host) > 0 {
+ h.mulHeader = append(h.mulHeader, host)
+ }
+ case HeaderContentType:
+ if contentType := h.ContentType(); len(contentType) > 0 {
+ h.mulHeader = append(h.mulHeader, contentType)
+ }
+ case HeaderUserAgent:
+ if ua := h.UserAgent(); len(ua) > 0 {
+ h.mulHeader = append(h.mulHeader, ua)
+ }
+ case HeaderConnection:
+ if h.ConnectionClose() {
+ h.mulHeader = append(h.mulHeader, strClose)
+ } else {
+ h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
+ }
+ case HeaderContentLength:
+ h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
+ case HeaderCookie:
+ if h.cookiesCollected {
+ h.mulHeader = append(h.mulHeader, appendRequestCookieBytes(nil, h.cookies))
+ } else {
+ h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
+ }
+ case HeaderTrailer:
+ h.mulHeader = append(h.mulHeader, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
+ default:
+ h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
+ }
+ return h.mulHeader
+}
+
+// PeekAll returns all header value for the given key.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseResponse or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *ResponseHeader) PeekAll(key string) [][]byte {
+ k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing)
+ return h.peekAll(k)
+}
+
+func (h *ResponseHeader) peekAll(key []byte) [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ switch string(key) {
+ case HeaderContentType:
+ if contentType := h.ContentType(); len(contentType) > 0 {
+ h.mulHeader = append(h.mulHeader, contentType)
+ }
+ case HeaderContentEncoding:
+ if contentEncoding := h.ContentEncoding(); len(contentEncoding) > 0 {
+ h.mulHeader = append(h.mulHeader, contentEncoding)
+ }
+ case HeaderServer:
+ if server := h.Server(); len(server) > 0 {
+ h.mulHeader = append(h.mulHeader, server)
+ }
+ case HeaderConnection:
+ if h.ConnectionClose() {
+ h.mulHeader = append(h.mulHeader, strClose)
+ } else {
+ h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
+ }
+ case HeaderContentLength:
+ h.mulHeader = append(h.mulHeader, h.contentLengthBytes)
+ case HeaderSetCookie:
+ h.mulHeader = append(h.mulHeader, appendResponseCookieBytes(nil, h.cookies))
+ case HeaderTrailer:
+ h.mulHeader = append(h.mulHeader, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
+ default:
+ h.mulHeader = peekAllArgBytesToDst(h.mulHeader, h.h, key)
+ }
+ return h.mulHeader
+}
+
+// PeekKeys return all header keys.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseRequest or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *RequestHeader) PeekKeys() [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ h.mulHeader = peekArgsKeys(h.mulHeader, h.h)
+ return h.mulHeader
+}
+
+// PeekTrailerKeys return all trailer keys.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseRequest or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *RequestHeader) PeekTrailerKeys() [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ h.mulHeader = peekArgsKeys(h.mulHeader, h.trailer)
+ return h.mulHeader
+}
+
+// PeekKeys return all header keys.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseResponse or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *ResponseHeader) PeekKeys() [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ h.mulHeader = peekArgsKeys(h.mulHeader, h.h)
+ return h.mulHeader
+}
+
+// PeekTrailerKeys return all trailer keys.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseResponse or your request handler returning.
+// Any future calls to the Peek* will modify the returned value.
+// Do not store references to returned value. Make copies instead.
+func (h *ResponseHeader) PeekTrailerKeys() [][]byte {
+ h.mulHeader = h.mulHeader[:0]
+ h.mulHeader = peekArgsKeys(h.mulHeader, h.trailer)
+ return h.mulHeader
+}
+
// Cookie returns cookie for the given key.
func (h *RequestHeader) Cookie(key string) []byte {
h.collectCookies()
@@ -1454,11 +2000,11 @@ func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error {
}
}
return &ErrSmallBuffer{
- error: fmt.Errorf("error when reading response headers: %s", errSmallBuffer),
+ error: fmt.Errorf("error when reading response headers: %w", errSmallBuffer),
}
}
- return fmt.Errorf("error when reading response headers: %s", err)
+ return fmt.Errorf("error when reading response headers: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
@@ -1469,6 +2015,61 @@ func (h *ResponseHeader) tryRead(r *bufio.Reader, n int) error {
return nil
}
+// ReadTrailer reads response trailer header from r.
+//
+// io.EOF is returned if r is closed before reading the first byte.
+func (h *ResponseHeader) ReadTrailer(r *bufio.Reader) error {
+ n := 1
+ for {
+ err := h.tryReadTrailer(r, n)
+ if err == nil {
+ return nil
+ }
+ if err != errNeedMore {
+ return err
+ }
+ n = r.Buffered() + 1
+ }
+}
+
+func (h *ResponseHeader) tryReadTrailer(r *bufio.Reader, n int) error {
+ b, err := r.Peek(n)
+ if len(b) == 0 {
+ // Return ErrTimeout on any timeout.
+ if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+ return ErrTimeout
+ }
+
+ if n == 1 || err == io.EOF {
+ return io.EOF
+ }
+
+ // This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
+ if err == bufio.ErrBufferFull {
+ if h.secureErrorLogMessage {
+ return &ErrSmallBuffer{
+ error: fmt.Errorf("error when reading response trailer"),
+ }
+ }
+ return &ErrSmallBuffer{
+ error: fmt.Errorf("error when reading response trailer: %w", errSmallBuffer),
+ }
+ }
+
+ return fmt.Errorf("error when reading response trailer: %w", err)
+ }
+ b = mustPeekBuffered(r)
+ headersLen, errParse := h.parseTrailer(b)
+ if errParse != nil {
+ if err == io.EOF {
+ return err
+ }
+ return headerError("response", err, errParse, b, h.secureErrorLogMessage)
+ }
+ mustDiscard(r, headersLen)
+ return nil
+}
+
func headerError(typ string, err, errParse error, b []byte, secureErrorLogMessage bool) error {
if errParse != errNeedMore {
return headerErrorMsg(typ, errParse, b, secureErrorLogMessage)
@@ -1493,9 +2094,9 @@ func headerError(typ string, err, errParse error, b []byte, secureErrorLogMessag
func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) error {
if secureErrorLogMessage {
- return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d", typ, err, len(b))
+ return fmt.Errorf("error when reading %s headers: %w. Buffer size=%d", typ, err, len(b))
}
- return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d, contents: %s", typ, err, len(b), bufferSnippet(b))
+ return fmt.Errorf("error when reading %s headers: %w. Buffer size=%d, contents: %s", typ, err, len(b), bufferSnippet(b))
}
// Read reads request header from r.
@@ -1523,6 +2124,61 @@ func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error {
}
}
+// ReadTrailer reads request trailer header from r.
+//
+// io.EOF is returned if r is closed before reading the first byte.
+func (h *RequestHeader) ReadTrailer(r *bufio.Reader) error {
+ n := 1
+ for {
+ err := h.tryReadTrailer(r, n)
+ if err == nil {
+ return nil
+ }
+ if err != errNeedMore {
+ return err
+ }
+ n = r.Buffered() + 1
+ }
+}
+
+func (h *RequestHeader) tryReadTrailer(r *bufio.Reader, n int) error {
+ b, err := r.Peek(n)
+ if len(b) == 0 {
+ // Return ErrTimeout on any timeout.
+ if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+ return ErrTimeout
+ }
+
+ if n == 1 || err == io.EOF {
+ return io.EOF
+ }
+
+ // This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
+ if err == bufio.ErrBufferFull {
+ if h.secureErrorLogMessage {
+ return &ErrSmallBuffer{
+ error: fmt.Errorf("error when reading request trailer"),
+ }
+ }
+ return &ErrSmallBuffer{
+ error: fmt.Errorf("error when reading request trailer: %w", errSmallBuffer),
+ }
+ }
+
+ return fmt.Errorf("error when reading request trailer: %w", err)
+ }
+ b = mustPeekBuffered(r)
+ headersLen, errParse := h.parseTrailer(b)
+ if errParse != nil {
+ if err == io.EOF {
+ return err
+ }
+ return headerError("request", err, errParse, b, h.secureErrorLogMessage)
+ }
+ mustDiscard(r, headersLen)
+ return nil
+}
+
func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
h.resetSkipNormalize()
b, err := r.Peek(n)
@@ -1538,7 +2194,7 @@ func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
// This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 .
if err == bufio.ErrBufferFull {
return &ErrSmallBuffer{
- error: fmt.Errorf("error when reading request headers: %s", errSmallBuffer),
+ error: fmt.Errorf("error when reading request headers: %w (n=%d, r.Buffered()=%d)", errSmallBuffer, n, r.Buffered()),
}
}
@@ -1548,7 +2204,7 @@ func (h *RequestHeader) tryRead(r *bufio.Reader, n int) error {
return ErrNothingRead{err}
}
- return fmt.Errorf("error when reading request headers: %s", err)
+ return fmt.Errorf("error when reading request headers: %w", err)
}
b = mustPeekBuffered(r)
headersLen, errParse := h.parse(b)
@@ -1619,6 +2275,8 @@ func (h *ResponseHeader) WriteTo(w io.Writer) (int64, error) {
// Header returns response header representation.
//
+// Headers that set as Trailer will not represent. Use TrailerHeader for trailers.
+//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
@@ -1627,19 +2285,48 @@ func (h *ResponseHeader) Header() []byte {
return h.bufKV.value
}
+// writeTrailer writes response trailer to w.
+func (h *ResponseHeader) writeTrailer(w *bufio.Writer) error {
+ _, err := w.Write(h.TrailerHeader())
+ return err
+}
+
+// TrailerHeader returns response trailer header representation.
+//
+// Trailers will only be received with chunked transfer.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseRequest or your request handler returning.
+// Do not store references to returned value. Make copies instead.
+func (h *ResponseHeader) TrailerHeader() []byte {
+ h.bufKV.value = h.bufKV.value[:0]
+ for _, t := range h.trailer {
+ value := h.peek(t.key)
+ h.bufKV.value = appendHeaderLine(h.bufKV.value, t.key, value)
+ }
+ h.bufKV.value = append(h.bufKV.value, strCRLF...)
+ return h.bufKV.value
+}
+
// String returns response header representation.
func (h *ResponseHeader) String() string {
return string(h.Header())
}
-// AppendBytes appends response header representation to dst and returns
+// appendStatusLine appends the response status line to dst and returns
// the extended dst.
-func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
+func (h *ResponseHeader) appendStatusLine(dst []byte) []byte {
statusCode := h.StatusCode()
if statusCode < 0 {
statusCode = StatusOK
}
- dst = append(dst, statusLine(statusCode)...)
+ return formatStatusLine(dst, h.Protocol(), statusCode, h.StatusMessage())
+}
+
+// AppendBytes appends response header representation to dst and returns
+// the extended dst.
+func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
+ dst = h.appendStatusLine(dst[:0])
server := h.Server()
if len(server) != 0 {
@@ -1660,6 +2347,10 @@ func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
dst = appendHeaderLine(dst, strContentType, contentType)
}
}
+ contentEncoding := h.ContentEncoding()
+ if len(contentEncoding) > 0 {
+ dst = appendHeaderLine(dst, strContentEncoding, contentEncoding)
+ }
if len(h.contentLengthBytes) > 0 {
dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes)
@@ -1667,11 +2358,24 @@ func (h *ResponseHeader) AppendBytes(dst []byte) []byte {
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
- if h.noDefaultDate || !bytes.Equal(kv.key, strDate) {
+
+ // Exclude trailer from header
+ exclude := false
+ for _, t := range h.trailer {
+ if bytes.Equal(kv.key, t.key) {
+ exclude = true
+ break
+ }
+ }
+ if !exclude && (h.noDefaultDate || !bytes.Equal(kv.key, strDate)) {
dst = appendHeaderLine(dst, kv.key, kv.value)
}
}
+ if len(h.trailer) > 0 {
+ dst = appendHeaderLine(dst, strTrailer, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
+ }
+
n := len(h.cookies)
if n > 0 {
for i := 0; i < n; i++ {
@@ -1703,6 +2407,8 @@ func (h *RequestHeader) WriteTo(w io.Writer) (int64, error) {
// Header returns request header representation.
//
+// Headers that set as Trailer will not represent. Use TrailerHeader for trailers.
+//
// The returned value is valid until the request is released,
// either though ReleaseRequest or your request handler returning.
// Do not store references to returned value. Make copies instead.
@@ -1711,6 +2417,29 @@ func (h *RequestHeader) Header() []byte {
return h.bufKV.value
}
+// writeTrailer writes request trailer to w.
+func (h *RequestHeader) writeTrailer(w *bufio.Writer) error {
+ _, err := w.Write(h.TrailerHeader())
+ return err
+}
+
+// TrailerHeader returns request trailer header representation.
+//
+// Trailers will only be received with chunked transfer.
+//
+// The returned value is valid until the request is released,
+// either though ReleaseRequest or your request handler returning.
+// Do not store references to returned value. Make copies instead.
+func (h *RequestHeader) TrailerHeader() []byte {
+ h.bufKV.value = h.bufKV.value[:0]
+ for _, t := range h.trailer {
+ value := h.peek(t.key)
+ h.bufKV.value = appendHeaderLine(h.bufKV.value, t.key, value)
+ }
+ h.bufKV.value = append(h.bufKV.value, strCRLF...)
+ return h.bufKV.value
+}
+
// RawHeaders returns raw header key/value bytes.
//
// Depending on server configuration, header keys may be normalized to
@@ -1751,7 +2480,7 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte {
}
contentType := h.ContentType()
- if len(contentType) == 0 && !h.ignoreBody() {
+ if !h.noDefaultContentType && len(contentType) == 0 && !h.ignoreBody() {
contentType = strDefaultContentType
}
if len(contentType) > 0 {
@@ -1763,7 +2492,21 @@ func (h *RequestHeader) AppendBytes(dst []byte) []byte {
for i, n := 0, len(h.h); i < n; i++ {
kv := &h.h[i]
- dst = appendHeaderLine(dst, kv.key, kv.value)
+ // Exclude trailer from header
+ exclude := false
+ for _, t := range h.trailer {
+ if bytes.Equal(kv.key, t.key) {
+ exclude = true
+ break
+ }
+ }
+ if !exclude {
+ dst = appendHeaderLine(dst, kv.key, kv.value)
+ }
+ }
+
+ if len(h.trailer) > 0 {
+ dst = appendHeaderLine(dst, strTrailer, appendArgsKeyBytes(nil, h.trailer, strCommaSpace))
}
// there is no need in h.collectCookies() here, since if cookies aren't collected yet,
@@ -1802,6 +2545,43 @@ func (h *ResponseHeader) parse(buf []byte) (int, error) {
return m + n, nil
}
+func (h *ResponseHeader) parseTrailer(buf []byte) (int, error) {
+ // Skip any 0 length chunk.
+ if buf[0] == '0' {
+ skip := len(strCRLF) + 1
+ if len(buf) < skip {
+ return 0, io.EOF
+ }
+ buf = buf[skip:]
+ }
+
+ var s headerScanner
+ s.b = buf
+ s.disableNormalizing = h.disableNormalizing
+ var err error
+ for s.next() {
+ if len(s.key) > 0 {
+ if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
+ err = fmt.Errorf("invalid trailer key %q", s.key)
+ continue
+ }
+ // Forbidden by RFC 7230, section 4.1.2
+ if isBadTrailer(s.key) {
+ err = fmt.Errorf("forbidden trailer key %q", s.key)
+ continue
+ }
+ h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
+ }
+ }
+ if s.err != nil {
+ return 0, s.err
+ }
+ if err != nil {
+ return 0, err
+ }
+ return s.hLen, nil
+}
+
func (h *RequestHeader) ignoreBody() bool {
return h.IsGet() || h.IsHead()
}
@@ -1824,6 +2604,87 @@ func (h *RequestHeader) parse(buf []byte) (int, error) {
return m + n, nil
}
+func (h *RequestHeader) parseTrailer(buf []byte) (int, error) {
+ // Skip any 0 length chunk.
+ if buf[0] == '0' {
+ skip := len(strCRLF) + 1
+ if len(buf) < skip {
+ return 0, io.EOF
+ }
+ buf = buf[skip:]
+ }
+
+ var s headerScanner
+ s.b = buf
+ s.disableNormalizing = h.disableNormalizing
+ var err error
+ for s.next() {
+ if len(s.key) > 0 {
+ if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 {
+ err = fmt.Errorf("invalid trailer key %q", s.key)
+ continue
+ }
+ // Forbidden by RFC 7230, section 4.1.2
+ if isBadTrailer(s.key) {
+ err = fmt.Errorf("forbidden trailer key %q", s.key)
+ continue
+ }
+ h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
+ }
+ }
+ if s.err != nil {
+ return 0, s.err
+ }
+ if err != nil {
+ return 0, err
+ }
+ return s.hLen, nil
+}
+
+func isBadTrailer(key []byte) bool {
+ if len(key) == 0 {
+ return true
+ }
+
+ switch key[0] | 0x20 {
+ case 'a':
+ return caseInsensitiveCompare(key, strAuthorization)
+ case 'c':
+ if len(key) > len(HeaderContentType) && caseInsensitiveCompare(key[:8], strContentType[:8]) {
+ // skip compare prefix 'Content-'
+ return caseInsensitiveCompare(key[8:], strContentEncoding[8:]) ||
+ caseInsensitiveCompare(key[8:], strContentLength[8:]) ||
+ caseInsensitiveCompare(key[8:], strContentType[8:]) ||
+ caseInsensitiveCompare(key[8:], strContentRange[8:])
+ }
+ return caseInsensitiveCompare(key, strConnection)
+ case 'e':
+ return caseInsensitiveCompare(key, strExpect)
+ case 'h':
+ return caseInsensitiveCompare(key, strHost)
+ case 'k':
+ return caseInsensitiveCompare(key, strKeepAlive)
+ case 'm':
+ return caseInsensitiveCompare(key, strMaxForwards)
+ case 'p':
+ if len(key) > len(HeaderProxyConnection) && caseInsensitiveCompare(key[:6], strProxyConnection[:6]) {
+ // skip compare prefix 'Proxy-'
+ return caseInsensitiveCompare(key[6:], strProxyConnection[6:]) ||
+ caseInsensitiveCompare(key[6:], strProxyAuthenticate[6:]) ||
+ caseInsensitiveCompare(key[6:], strProxyAuthorization[6:])
+ }
+ case 'r':
+ return caseInsensitiveCompare(key, strRange)
+ case 't':
+ return caseInsensitiveCompare(key, strTE) ||
+ caseInsensitiveCompare(key, strTrailer) ||
+ caseInsensitiveCompare(key, strTransferEncoding)
+ case 'w':
+ return caseInsensitiveCompare(key, strWWWAuthenticate)
+ }
+ return false
+}
+
func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) {
bNext := buf
var b []byte
@@ -1849,9 +2710,9 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) {
h.statusCode, n, err = parseUintBuf(b)
if err != nil {
if h.secureErrorLogMessage {
- return 0, fmt.Errorf("cannot parse response status code: %s", err)
+ return 0, fmt.Errorf("cannot parse response status code: %w", err)
}
- return 0, fmt.Errorf("cannot parse response status code: %s. Response %q", err, buf)
+ return 0, fmt.Errorf("cannot parse response status code: %w. Response %q", err, buf)
}
if len(b) > n && b[n] != ' ' {
if h.secureErrorLogMessage {
@@ -1859,6 +2720,9 @@ func (h *ResponseHeader) parseFirstLine(buf []byte) (int, error) {
}
return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf)
}
+ if len(b) > n+1 {
+ h.SetStatusMessage(b[n+1:])
+ }
return len(buf) - len(bNext), nil
}
@@ -1952,6 +2816,10 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
h.contentType = append(h.contentType[:0], s.value...)
continue
}
+ if caseInsensitiveCompare(s.key, strContentEncoding) {
+ h.contentEncoding = append(h.contentEncoding[:0], s.value...)
+ continue
+ }
if caseInsensitiveCompare(s.key, strContentLength) {
if h.contentLength != -1 {
if h.contentLength, err = parseContentLength(s.value); err != nil {
@@ -1990,6 +2858,10 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
}
continue
}
+ if caseInsensitiveCompare(s.key, strTrailer) {
+ err = h.SetTrailerBytes(s.value)
+ continue
+ }
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
}
@@ -2012,7 +2884,7 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) {
h.connectionClose = !hasHeaderValue(v, strKeepAlive)
}
- return len(buf) - len(s.b), nil
+ return len(buf) - len(s.b), err
}
func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
@@ -2078,6 +2950,14 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) {
}
continue
}
+ if caseInsensitiveCompare(s.key, strTrailer) {
+ if nerr := h.SetTrailerBytes(s.value); nerr != nil {
+ if err == nil {
+ err = nerr
+ }
+ }
+ continue
+ }
}
}
h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue)
@@ -2121,13 +3001,15 @@ func (h *RequestHeader) collectCookies() {
h.cookiesCollected = true
}
+var errNonNumericChars = errors.New("non-numeric chars found")
+
func parseContentLength(b []byte) (int, error) {
v, n, err := parseUintBuf(b)
if err != nil {
- return -1, err
+ return -1, fmt.Errorf("cannot parse Content-Length: %w", err)
}
if n != len(b) {
- return -1, fmt.Errorf("non-numeric chars at the end of Content-Length")
+ return -1, fmt.Errorf("cannot parse Content-Length: %w", errNonNumericChars)
}
return v, nil
}
@@ -2438,9 +3320,9 @@ func removeNewLines(raw []byte) []byte {
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
-// * coNTENT-TYPe -> Content-Type
-// * HOST -> Host
-// * foo-bar-baz -> Foo-Bar-Baz
+// - coNTENT-TYPe -> Content-Type
+// - HOST -> Host
+// - foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKey(dst []byte, key string) []byte {
dst = append(dst, key...)
normalizeHeaderKey(dst[len(dst)-len(key):], false)
@@ -2454,13 +3336,24 @@ func AppendNormalizedHeaderKey(dst []byte, key string) []byte {
// after dashes are also uppercased. All the other letters are lowercased.
// Examples:
//
-// * coNTENT-TYPe -> Content-Type
-// * HOST -> Host
-// * foo-bar-baz -> Foo-Bar-Baz
+// - coNTENT-TYPe -> Content-Type
+// - HOST -> Host
+// - foo-bar-baz -> Foo-Bar-Baz
func AppendNormalizedHeaderKeyBytes(dst, key []byte) []byte {
return AppendNormalizedHeaderKey(dst, b2s(key))
}
+func appendArgsKeyBytes(dst []byte, args []argsKV, sep []byte) []byte {
+ for i, n := 0, len(args); i < n; i++ {
+ kv := &args[i]
+ dst = append(dst, kv.key...)
+ if i+1 < n {
+ dst = append(dst, sep...)
+ }
+ }
+ return dst
+}
+
var (
errNeedMore = errors.New("need more data: cannot find trailing lf")
errInvalidName = errors.New("invalid header name")
@@ -2492,6 +3385,6 @@ func mustPeekBuffered(r *bufio.Reader) []byte {
func mustDiscard(r *bufio.Reader, n int) {
if _, err := r.Discard(n); err != nil {
- panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %s", n, err))
+ panic(fmt.Sprintf("bufio.Reader.Discard(%d) failed: %v", n, err))
}
}
diff --git a/header_regression_test.go b/header_regression_test.go
index 2a9187d..42f0344 100644
--- a/header_regression_test.go
+++ b/header_regression_test.go
@@ -69,7 +69,7 @@ func testIssue6RequestHeaderSetContentType(t *testing.T, method string) {
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
issue6VerifyRequestHeader(t, &h1, contentType, contentLength, method)
}
diff --git a/header_test.go b/header_test.go
index 575fae4..7d29425 100644
--- a/header_test.go
+++ b/header_test.go
@@ -4,9 +4,9 @@ import (
"bufio"
"bytes"
"encoding/base64"
+ "errors"
"fmt"
"io"
- "io/ioutil"
"net/http"
"reflect"
"strings"
@@ -33,10 +33,30 @@ func TestResponseHeaderAddContentType(t *testing.T) {
}
}
+func TestResponseHeaderAddContentEncoding(t *testing.T) {
+ t.Parallel()
+
+ var h ResponseHeader
+ h.Add("Content-Encoding", "test")
+
+ got := string(h.Peek("Content-Encoding"))
+ expected := "test"
+ if got != expected {
+ t.Errorf("expected %q got %q", expected, got)
+ }
+
+ var buf bytes.Buffer
+ h.WriteTo(&buf) //nolint:errcheck
+
+ if n := strings.Count(buf.String(), "Content-Encoding: "); n != 1 {
+ t.Errorf("Content-Encoding occurred %d times", n)
+ }
+}
+
func TestResponseHeaderMultiLineValue(t *testing.T) {
t.Parallel()
- s := "HTTP/1.1 200 OK\r\n" +
+ s := "HTTP/1.1 200 SuperOK\r\n" +
"EmptyValue1:\r\n" +
"Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" +
"Foo: Bar\r\n" +
@@ -45,11 +65,36 @@ func TestResponseHeaderMultiLineValue(t *testing.T) {
"\r\n"
header := new(ResponseHeader)
if _, err := header.parse([]byte(s)); err != nil {
- t.Fatalf("parse headers with multi-line values failed, %s", err)
+ t.Fatalf("parse headers with multi-line values failed, %v", err)
}
response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil)
if err != nil {
- t.Fatalf("parse response using net/http failed, %s", err)
+ t.Fatalf("parse response using net/http failed, %v", err)
+ }
+
+ if !bytes.Equal(header.StatusMessage(), []byte("SuperOK")) {
+ t.Errorf("parse status line with non-default value failed, got: '%q' want: 'SuperOK'", header.StatusMessage())
+ }
+
+ header.SetProtocol([]byte("HTTP/3.3"))
+ if !bytes.Equal(header.Protocol(), []byte("HTTP/3.3")) {
+ t.Errorf("parse protocol with non-default value failed, got: '%q' want: 'HTTP/3.3'", header.Protocol())
+ }
+
+ if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 SuperOK\r\n")) {
+ t.Errorf("parse status line with non-default value failed, got: '%q' want: 'HTTP/3.3 200 SuperOK'", header.Protocol())
+ }
+
+ header.SetStatusMessage(nil)
+
+ if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 OK\r\n")) {
+ t.Errorf("parse status line with default protocol value failed, got: '%q' want: 'HTTP/3.3 200 OK'", header.appendStatusLine(nil))
+ }
+
+ header.SetStatusMessage(s2b(StatusMessage(200)))
+
+ if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/3.3 200 OK\r\n")) {
+ t.Errorf("parse status line with default protocol value failed, got: '%q' want: 'HTTP/3.3 200 OK'", header.appendStatusLine(nil))
}
for name, vals := range response.Header {
@@ -57,7 +102,7 @@ func TestResponseHeaderMultiLineValue(t *testing.T) {
want := vals[0]
if got != want {
- t.Errorf("unexpected %s got: %q want: %q", name, got, want)
+ t.Errorf("unexpected %q got: %q want: %q", name, got, want)
}
}
}
@@ -78,6 +123,18 @@ func TestResponseHeaderMultiLineName(t *testing.T) {
})
t.Errorf("expected error, got %q (%v)", m, err)
}
+
+ if !bytes.Equal(header.StatusMessage(), []byte("OK")) {
+ t.Errorf("expected default status line, got: %q", header.StatusMessage())
+ }
+
+ if !bytes.Equal(header.Protocol(), []byte("HTTP/1.1")) {
+ t.Errorf("expected default protocol, got: %q", header.Protocol())
+ }
+
+ if !bytes.Equal(header.appendStatusLine(nil), []byte("HTTP/1.1 200 OK\r\n")) {
+ t.Errorf("parse status line with non-default value failed, got: %q want: HTTP/1.1 200 OK", header.Protocol())
+ }
}
func TestResponseHeaderMultiLinePaniced(t *testing.T) {
@@ -101,7 +158,7 @@ func TestResponseHeaderEmptyValueFromHeader(t *testing.T) {
var h ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.ContentType()) != string(h1.ContentType()) {
t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), h1.ContentType())
@@ -128,7 +185,7 @@ func TestResponseHeaderEmptyValueFromString(t *testing.T) {
var h ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.ContentType()) != "foo/bar" {
t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), "foo/bar")
@@ -156,7 +213,7 @@ func TestRequestHeaderEmptyValueFromHeader(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != string(h1.Host()) {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), h1.Host())
@@ -182,7 +239,7 @@ func TestRequestHeaderEmptyValueFromString(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
@@ -209,7 +266,7 @@ func TestRequestRawHeaders(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
@@ -235,7 +292,7 @@ func TestRequestRawHeaders(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
@@ -255,7 +312,7 @@ func TestRequestRawHeaders(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != "foobar" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar")
@@ -275,7 +332,7 @@ func TestRequestRawHeaders(t *testing.T) {
h.DisableNormalizing()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h.Host()) != "" {
t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "")
@@ -298,13 +355,13 @@ func TestRequestHeaderSetCookieWithSpecialChars(t *testing.T) {
s := h.String()
if !strings.Contains(s, "Cookie: ID&14") {
- t.Fatalf("Missing cookie in request header: [%s]", s)
+ t.Fatalf("Missing cookie in request header: %q", s)
}
var h1 RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
cookie := h1.Peek(HeaderCookie)
if string(cookie) != "ID&14" {
@@ -343,7 +400,7 @@ func TestResponseHeaderDelClientCookie(t *testing.T) {
t.Fatalf("expecting cookie %q", c.Key())
}
if !c.Expire().Equal(CookieExpireDelete) {
- t.Fatalf("unexpected cookie expiration time: %s. Expecting %s", c.Expire(), CookieExpireDelete)
+ t.Fatalf("unexpected cookie expiration time: %q. Expecting %q", c.Expire(), CookieExpireDelete)
}
if len(c.Value()) > 0 {
t.Fatalf("unexpected cookie value: %q. Expecting empty value", c.Value())
@@ -388,7 +445,7 @@ func TestResponseHeaderAdd(t *testing.T) {
br := bufio.NewReader(bytes.NewBufferString(s))
var h1 ResponseHeader
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
h.VisitAll(func(k, v []byte) {
@@ -441,7 +498,7 @@ func TestRequestHeaderAdd(t *testing.T) {
br := bufio.NewReader(bytes.NewBufferString(s))
var h1 RequestHeader
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
h.VisitAll(func(k, v []byte) {
@@ -498,6 +555,7 @@ func TestRequestHeaderDel(t *testing.T) {
h.Set("User-Agent", "asdfas")
h.Set("Content-Length", "1123")
h.Set("Cookie", "foobar=baz")
+ h.Set(HeaderTrailer, "foo, bar")
h.Del("foo-bar")
h.Del("connection")
@@ -506,6 +564,7 @@ func TestRequestHeaderDel(t *testing.T) {
h.Del("user-agent")
h.Del("content-length")
h.Del("cookie")
+ h.Del("trailer")
hv := h.Peek("aaa")
if string(hv) != "bbb" {
@@ -539,6 +598,10 @@ func TestRequestHeaderDel(t *testing.T) {
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
+ hv = h.Peek(HeaderTrailer)
+ if len(hv) > 0 {
+ t.Fatalf("non-zero value: %q", hv)
+ }
cv := h.Cookie("foobar")
if len(cv) > 0 {
@@ -557,8 +620,10 @@ func TestResponseHeaderDel(t *testing.T) {
h.Set("aaa", "bbb")
h.Set(HeaderConnection, "keep-alive")
h.Set(HeaderContentType, "aaa")
+ h.Set(HeaderContentEncoding, "gzip")
h.Set(HeaderServer, "aaabbb")
h.Set(HeaderContentLength, "1123")
+ h.Set(HeaderTrailer, "foo, bar")
var c Cookie
c.SetKey("foo")
@@ -571,6 +636,7 @@ func TestResponseHeaderDel(t *testing.T) {
h.Del(HeaderServer)
h.Del("content-length")
h.Del("set-cookie")
+ h.Del("trailer")
hv := h.Peek("aaa")
if string(hv) != "bbb" {
@@ -588,6 +654,10 @@ func TestResponseHeaderDel(t *testing.T) {
if string(hv) != string(defaultContentType) {
t.Fatalf("unexpected content-type: %q. Expecting %q", hv, defaultContentType)
}
+ hv = h.Peek(HeaderContentEncoding)
+ if string(hv) != ("gzip") {
+ t.Fatalf("unexpected content-encoding: %q. Expecting %q", hv, "gzip")
+ }
hv = h.Peek(HeaderServer)
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
@@ -596,6 +666,10 @@ func TestResponseHeaderDel(t *testing.T) {
if len(hv) > 0 {
t.Fatalf("non-zero value: %q", hv)
}
+ hv = h.Peek(HeaderTrailer)
+ if len(hv) > 0 {
+ t.Fatalf("non-zero value: %q", hv)
+ }
if h.Cookie(&c) {
t.Fatalf("unexpected cookie obtianed: %q", &c)
@@ -605,6 +679,51 @@ func TestResponseHeaderDel(t *testing.T) {
}
}
+func TestResponseHeaderSetTrailerGetBytes(t *testing.T) {
+ t.Parallel()
+
+ h := &ResponseHeader{}
+ h.noDefaultDate = true
+ h.Set("Foo", "bar")
+ h.Set(HeaderTrailer, "Baz")
+ h.Set("Baz", "test")
+
+ headerBytes := h.Header()
+ n, err := h.parseFirstLine(headerBytes)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if string(headerBytes[n:]) != "Foo: bar\r\nTrailer: Baz\r\n\r\n" {
+ t.Fatalf("Unexpected header: %q. Expected %q", headerBytes[n:], "Foo: bar\nTrailer: Baz\n\n")
+ }
+ if string(h.TrailerHeader()) != "Baz: test\r\n\r\n" {
+ t.Fatalf("Unexpected trailer header: %q. Expected %q", h.TrailerHeader(), "Baz: test\r\n\r\n")
+ }
+}
+
+func TestRequestHeaderSetTrailerGetBytes(t *testing.T) {
+ t.Parallel()
+
+ h := &RequestHeader{}
+ h.Set("Foo", "bar")
+ h.Set(HeaderTrailer, "Baz")
+ h.Set("Baz", "test")
+
+ headerBytes := h.Header()
+ n, err := h.parseFirstLine(headerBytes)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if string(headerBytes[n:]) != "Foo: bar\r\nTrailer: Baz\r\n\r\n" {
+ t.Fatalf("Unexpected header: %q. Expected %q", headerBytes[n:], "Foo: bar\nTrailer: Baz\n\n")
+ }
+ if string(h.TrailerHeader()) != "Baz: test\r\n\r\n" {
+ t.Fatalf("Unexpected trailer header: %q. Expected %q", h.TrailerHeader(), "Baz: test\r\n\r\n")
+ }
+}
+
func TestAppendNormalizedHeaderKeyBytes(t *testing.T) {
t.Parallel()
@@ -629,7 +748,7 @@ func TestRequestHeaderHTTP10ConnectionClose(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !h.ConnectionClose() {
@@ -644,7 +763,7 @@ func TestRequestHeaderHTTP10ConnectionKeepAlive(t *testing.T) {
var h RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if h.ConnectionClose() {
@@ -680,7 +799,7 @@ func TestBufferSnippet(t *testing.T) {
func testBufferSnippet(t *testing.T, buf, expectedSnippet string) {
snippet := bufferSnippet([]byte(buf))
if snippet != expectedSnippet {
- t.Fatalf("unexpected snippet %s. Expecting %s", snippet, expectedSnippet)
+ t.Fatalf("unexpected snippet %q. Expecting %q", snippet, expectedSnippet)
}
}
@@ -693,7 +812,7 @@ func TestResponseHeaderTrailingCRLFSuccess(t *testing.T) {
var r ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// try reading the trailing CRLF. It must return EOF
@@ -702,7 +821,7 @@ func TestResponseHeaderTrailingCRLFSuccess(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
}
@@ -715,7 +834,7 @@ func TestResponseHeaderTrailingCRLFError(t *testing.T) {
var r ResponseHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// try reading the trailing CRLF. It must return EOF
@@ -724,7 +843,7 @@ func TestResponseHeaderTrailingCRLFError(t *testing.T) {
t.Fatalf("expecting error")
}
if err == io.EOF {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -737,7 +856,7 @@ func TestRequestHeaderTrailingCRLFSuccess(t *testing.T) {
var r RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// try reading the trailing CRLF. It must return EOF
@@ -746,7 +865,7 @@ func TestRequestHeaderTrailingCRLFSuccess(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
}
@@ -759,7 +878,7 @@ func TestRequestHeaderTrailingCRLFError(t *testing.T) {
var r RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// try reading the trailing CRLF. It must return EOF
@@ -768,7 +887,7 @@ func TestRequestHeaderTrailingCRLFError(t *testing.T) {
t.Fatalf("expecting error")
}
if err == io.EOF {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -783,7 +902,7 @@ func TestRequestHeaderReadEOF(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
// incomplete request header mustn't return io.EOF
@@ -808,7 +927,7 @@ func TestResponseHeaderReadEOF(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
// incomplete response header mustn't return io.EOF
@@ -831,14 +950,14 @@ func TestResponseHeaderOldVersion(t *testing.T) {
s += "HTTP/1.0 200 OK\r\nContent-Length: 2\r\nContent-Type: ass\r\nConnection: keep-alive\r\n\r\n42"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !h.ConnectionClose() {
t.Fatalf("expecting 'Connection: close' for the response with old http protocol")
}
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if h.ConnectionClose() {
t.Fatalf("unexpected 'Connection: close' for keep-alive response with old http protocol")
@@ -939,7 +1058,7 @@ func testRequestMultipartFormBoundary(t *testing.T, s, boundary string) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. s=%q, boundary=%q", err, s, boundary)
+ t.Fatalf("unexpected error: %v. s=%q, boundary=%q", err, s, boundary)
}
b := h.MultipartFormBoundary()
@@ -977,7 +1096,7 @@ func testResponseHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKe
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. Response header %q", err, s)
+ t.Fatalf("unexpected error: %v. Response header %q", err, s)
}
upgrade := h.ConnectionUpgrade()
if upgrade != isUpgrade {
@@ -1024,7 +1143,7 @@ func testRequestHeaderConnectionUpgrade(t *testing.T, s string, isUpgrade, isKee
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. Request header %q", err, s)
+ t.Fatalf("unexpected error: %v. Request header %q", err, s)
}
upgrade := h.ConnectionUpgrade()
if upgrade != isUpgrade {
@@ -1046,21 +1165,21 @@ func TestRequestHeaderProxyWithCookie(t *testing.T) {
r := bytes.NewBufferString("GET /foo HTTP/1.1\r\nFoo: bar\r\nHost: aaa.com\r\nCookie: foo=bar; bazzz=aaaaaaa; x=y\r\nCookie: aqqqqq=123\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var h1 RequestHeader
br.Reset(w)
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(h1.RequestURI()) != "/foo" {
t.Fatalf("unexpected requestURI: %q. Expecting %q", h1.RequestURI(), "/foo")
@@ -1097,7 +1216,7 @@ func TestResponseHeaderFirstByteReadEOF(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error %v. Expecting %v", err, io.EOF)
}
}
@@ -1153,7 +1272,7 @@ func testResponseHeaderHTTPVer(t *testing.T, s string, connectionClose bool) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. response=%q", err, s)
+ t.Fatalf("unexpected error: %v. response=%q", err, s)
}
if h.ConnectionClose() != connectionClose {
t.Fatalf("unexpected connectionClose %v. Expecting %v. response=%q", h.ConnectionClose(), connectionClose, s)
@@ -1166,7 +1285,7 @@ func testRequestHeaderHTTPVer(t *testing.T, s string, connectionClose bool) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("unexpected error: %s. request=%q", err, s)
+ t.Fatalf("unexpected error: %v. request=%q", err, s)
}
if h.ConnectionClose() != connectionClose {
t.Fatalf("unexpected connectionClose %v. Expecting %v. request=%q", h.ConnectionClose(), connectionClose, s)
@@ -1180,7 +1299,9 @@ func TestResponseHeaderCopyTo(t *testing.T) {
h.Set(HeaderSetCookie, "foo=bar")
h.Set(HeaderContentType, "foobar")
+ h.Set(HeaderContentEncoding, "gzip")
h.Set("AAA-BBB", "aaaa")
+ h.Set(HeaderTrailer, "foo, bar")
var h1 ResponseHeader
h.CopyTo(&h1)
@@ -1190,9 +1311,15 @@ func TestResponseHeaderCopyTo(t *testing.T) {
if !bytes.Equal(h1.Peek(HeaderContentType), h.Peek(HeaderContentType)) {
t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type"))
}
+ if !bytes.Equal(h1.Peek(HeaderContentEncoding), h.Peek(HeaderContentEncoding)) {
+ t.Fatalf("unexpected content-encoding %q. Expected %q", h1.Peek("content-encoding"), h.Peek("content-encoding"))
+ }
if !bytes.Equal(h1.Peek("aaa-bbb"), h.Peek("AAA-BBB")) {
t.Fatalf("unexpected aaa-bbb %q. Expected %q", h1.Peek("aaa-bbb"), h.Peek("aaa-bbb"))
}
+ if !bytes.Equal(h1.Peek(HeaderTrailer), h.Peek(HeaderTrailer)) {
+ t.Fatalf("unexpected trailer %q. Expected %q", h1.Peek(HeaderTrailer), h.Peek(HeaderTrailer))
+ }
// flush buf
h.bufKV = argsKV{}
@@ -1210,8 +1337,10 @@ func TestRequestHeaderCopyTo(t *testing.T) {
h.Set(HeaderCookie, "aa=bb; cc=dd")
h.Set(HeaderContentType, "foobar")
+ h.Set(HeaderContentEncoding, "gzip")
h.Set(HeaderHost, "aaaa")
h.Set("aaaxxx", "123")
+ h.Set(HeaderTrailer, "foo, bar")
var h1 RequestHeader
h.CopyTo(&h1)
@@ -1221,12 +1350,18 @@ func TestRequestHeaderCopyTo(t *testing.T) {
if !bytes.Equal(h1.Peek("content-type"), h.Peek(HeaderContentType)) {
t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type"))
}
+ if !bytes.Equal(h1.Peek("content-encoding"), h.Peek(HeaderContentEncoding)) {
+ t.Fatalf("unexpected content-encoding %q. Expected %q", h1.Peek("content-encoding"), h.Peek("content-encoding"))
+ }
if !bytes.Equal(h1.Peek("host"), h.Peek("host")) {
t.Fatalf("unexpected host %q. Expected %q", h1.Peek("host"), h.Peek("host"))
}
if !bytes.Equal(h1.Peek("aaaxxx"), h.Peek("aaaxxx")) {
t.Fatalf("unexpected aaaxxx %q. Expected %q", h1.Peek("aaaxxx"), h.Peek("aaaxxx"))
}
+ if !bytes.Equal(h1.Peek(HeaderTrailer), h.Peek(HeaderTrailer)) {
+ t.Fatalf("unexpected trailer %q. Expected %q", h1.Peek(HeaderTrailer), h.Peek(HeaderTrailer))
+ }
// flush buf
h.bufKV = argsKV{}
@@ -1262,16 +1397,16 @@ func TestRequestContentTypeDefaultNotEmpty(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
if string(h1.contentType) != "application/octet-stream" {
@@ -1279,6 +1414,33 @@ func TestRequestContentTypeDefaultNotEmpty(t *testing.T) {
}
}
+func TestRequestContentTypeNoDefault(t *testing.T) {
+ t.Parallel()
+
+ var h RequestHeader
+ h.SetMethod(MethodDelete)
+ h.SetNoDefaultContentType(true)
+
+ w := &bytes.Buffer{}
+ bw := bufio.NewWriter(w)
+ if err := h.Write(bw); err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ var h1 RequestHeader
+ br := bufio.NewReader(w)
+ if err := h1.Read(br); err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ if string(h1.contentType) != "" {
+ t.Fatalf("unexpected Content-Type %q. Expecting %q", h1.contentType, "")
+ }
+}
+
func TestResponseDateNoDefaultNotEmpty(t *testing.T) {
t.Parallel()
@@ -1307,16 +1469,16 @@ func TestRequestHeaderConnectionClose(t *testing.T) {
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := h.Write(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var h1 RequestHeader
br := bufio.NewReader(&w)
if err := h1.Read(br); err != nil {
- t.Fatalf("error when reading request header: %s", err)
+ t.Fatalf("error when reading request header: %v", err)
}
if !h1.ConnectionClose() {
@@ -1384,17 +1546,18 @@ func TestResponseHeaderVisitAll(t *testing.T) {
var h ResponseHeader
- r := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\nSet-Cookie: aa=bb; path=/foo/bar\r\nSet-Cookie: ccc\r\n\r\n")
+ r := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Encoding: gzip\r\nContent-Length: 123\r\nSet-Cookie: aa=bb; path=/foo/bar\r\nSet-Cookie: ccc\r\nTrailer: Foo, Bar\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
- if h.Len() != 4 {
- t.Fatalf("Unexpected number of headers: %d. Expected 4", h.Len())
+ if h.Len() != 6 {
+ t.Fatalf("Unexpected number of headers: %d. Expected 6", h.Len())
}
contentLengthCount := 0
contentTypeCount := 0
+ contentEncodingCount := 0
cookieCount := 0
h.VisitAll(func(key, value []byte) {
k := string(key)
@@ -1410,6 +1573,11 @@ func TestResponseHeaderVisitAll(t *testing.T) {
t.Fatalf("Unexpected content-type: %q. Expected %q", v, h.Peek(k))
}
contentTypeCount++
+ case HeaderContentEncoding:
+ if v != string(h.Peek(k)) {
+ t.Fatalf("Unexpected content-encoding: %q. Expected %q", v, h.Peek(k))
+ }
+ contentEncodingCount++
case HeaderSetCookie:
if cookieCount == 0 && v != "aa=bb; path=/foo/bar" {
t.Fatalf("unexpected cookie header: %q. Expected %q", v, "aa=bb; path=/foo/bar")
@@ -1418,6 +1586,10 @@ func TestResponseHeaderVisitAll(t *testing.T) {
t.Fatalf("unexpected cookie header: %q. Expected %q", v, "ccc")
}
cookieCount++
+ case HeaderTrailer:
+ if v != "Foo, Bar" {
+ t.Fatalf("Unexpected trailer header %q. Expected %q", v, "Foo, Bar")
+ }
default:
t.Fatalf("unexpected header %q=%q", k, v)
}
@@ -1428,6 +1600,9 @@ func TestResponseHeaderVisitAll(t *testing.T) {
if contentTypeCount != 1 {
t.Fatalf("unexpected number of content-type headers: %d. Expected 1", contentTypeCount)
}
+ if contentEncodingCount != 1 {
+ t.Fatalf("unexpected number of content-encoding headers: %d. Expected 1", contentEncodingCount)
+ }
if cookieCount != 2 {
t.Fatalf("unexpected number of cookie header: %d. Expected 2", cookieCount)
}
@@ -1438,14 +1613,14 @@ func TestRequestHeaderVisitAll(t *testing.T) {
var h RequestHeader
- r := bytes.NewBufferString("GET / HTTP/1.1\r\nHost: aa.com\r\nXX: YYY\r\nXX: ZZ\r\nCookie: a=b; c=d\r\n\r\n")
+ r := bytes.NewBufferString("GET / HTTP/1.1\r\nHost: aa.com\r\nXX: YYY\r\nXX: ZZ\r\nCookie: a=b; c=d\r\nTrailer: Foo, Bar\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
- if h.Len() != 4 {
- t.Fatalf("Unexpected number of header: %d. Expected 4", h.Len())
+ if h.Len() != 5 {
+ t.Fatalf("Unexpected number of header: %d. Expected 5", h.Len())
}
hostCount := 0
xxCount := 0
@@ -1472,6 +1647,10 @@ func TestRequestHeaderVisitAll(t *testing.T) {
t.Fatalf("Unexpected cookie %q. Expected %q", v, "a=b; c=d")
}
cookieCount++
+ case HeaderTrailer:
+ if v != "Foo, Bar" {
+ t.Fatalf("Unexpected trailer header %q. Expected %q", v, "Foo, Bar")
+ }
default:
t.Fatalf("Unexpected header %q=%q", k, v)
}
@@ -1487,7 +1666,7 @@ func TestRequestHeaderVisitAll(t *testing.T) {
}
}
-func TestResponseHeaderVisitAllInOrder(t *testing.T) {
+func TestRequestHeaderVisitAllInOrder(t *testing.T) {
t.Parallel()
var h RequestHeader
@@ -1495,7 +1674,7 @@ func TestResponseHeaderVisitAllInOrder(t *testing.T) {
r := bytes.NewBufferString("GET / HTTP/1.1\r\nContent-Type: aa\r\nCookie: a=b\r\nHost: example.com\r\nUser-Agent: xxx\r\n\r\n")
br := bufio.NewReader(r)
if err := h.Read(br); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
if h.Len() != 4 {
@@ -1530,6 +1709,38 @@ func TestResponseHeaderVisitAllInOrder(t *testing.T) {
})
}
+func TestResponseHeaderAddTrailerError(t *testing.T) {
+ t.Parallel()
+
+ var h ResponseHeader
+ err := h.AddTrailer("Foo, Content-Length , Bar,Transfer-Encoding,")
+ expectedTrailer := "Foo, Bar"
+
+ if !errors.Is(err, ErrBadTrailer) {
+ t.Fatalf("unexpected err %q. Expected %q", err, ErrBadTrailer)
+ }
+ if trailer := string(h.Peek(HeaderTrailer)); trailer != expectedTrailer {
+ t.Fatalf("unexpected trailer %q. Expected %q", trailer, expectedTrailer)
+ }
+
+}
+
+func TestRequestHeaderAddTrailerError(t *testing.T) {
+ t.Parallel()
+
+ var h RequestHeader
+ err := h.AddTrailer("Foo, Content-Length , Bar,Transfer-Encoding,")
+ expectedTrailer := "Foo, Bar"
+
+ if !errors.Is(err, ErrBadTrailer) {
+ t.Fatalf("unexpected err %q. Expected %q", err, ErrBadTrailer)
+ }
+ if trailer := string(h.Peek(HeaderTrailer)); trailer != expectedTrailer {
+ t.Fatalf("unexpected trailer %q. Expected %q", trailer, expectedTrailer)
+ }
+
+}
+
func TestResponseHeaderCookie(t *testing.T) {
t.Parallel()
@@ -1595,10 +1806,10 @@ func TestResponseHeaderCookie(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
h.DelAllCookies()
@@ -1606,7 +1817,7 @@ func TestResponseHeaderCookie(t *testing.T) {
var h1 ResponseHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
c.SetKey("foobar")
@@ -1681,16 +1892,16 @@ func TestRequestHeaderCookie(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := h.Write(bw); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err := h1.Read(br); err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
if !bytes.Equal(h1.Cookie("foo"), h.Cookie("foo")) {
@@ -1731,7 +1942,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen := false
- h.VisitAll(func(key, value []byte) {
+ h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
@@ -1752,7 +1963,7 @@ func TestResponseHeaderCookieIssue4(t *testing.T) {
t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar")
}
cookieSeen = false
- h.VisitAll(func(key, value []byte) {
+ h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderSetCookie:
cookieSeen = true
@@ -1776,7 +1987,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen := false
- h.VisitAll(func(key, value []byte) {
+ h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
@@ -1794,7 +2005,7 @@ func TestRequestHeaderCookieIssue313(t *testing.T) {
t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar")
}
cookieSeen = false
- h.VisitAll(func(key, value []byte) {
+ h.VisitAll(func(key, _ []byte) {
switch string(key) {
case HeaderCookie:
cookieSeen = true
@@ -1831,7 +2042,7 @@ func testRequestHeaderMethod(t *testing.T, expectedMethod string) {
var h1 RequestHeader
br := bufio.NewReader(bytes.NewBufferString(s))
if err := h1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
m1 := h1.Method()
if string(m) != string(m1) {
@@ -1876,16 +2087,16 @@ func TestRequestHeaderSetGet(t *testing.T) {
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
- t.Fatalf("Unexpected error when writing request header: %s", err)
+ t.Fatalf("Unexpected error when writing request header: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("Unexpected error when flushing request header: %s", err)
+ t.Fatalf("Unexpected error when flushing request header: %v", err)
}
var h1 RequestHeader
br := bufio.NewReader(w)
if err = h1.Read(br); err != nil {
- t.Fatalf("Unexpected error when reading request header: %s", err)
+ t.Fatalf("Unexpected error when reading request header: %v", err)
}
if h1.ContentLength() != h.ContentLength() {
@@ -1912,6 +2123,7 @@ func TestResponseHeaderSetGet(t *testing.T) {
h := &ResponseHeader{}
h.Set("foo", "bar")
h.Set("content-type", "aaa/bbb")
+ h.Set("content-encoding", "gzip")
h.Set("connection", "close")
h.Set("content-length", "1234")
h.Set(HeaderServer, "aaaa")
@@ -1920,6 +2132,7 @@ func TestResponseHeaderSetGet(t *testing.T) {
expectResponseHeaderGet(t, h, "Foo", "bar")
expectResponseHeaderGet(t, h, HeaderContentType, "aaa/bbb")
+ expectResponseHeaderGet(t, h, HeaderContentEncoding, "gzip")
expectResponseHeaderGet(t, h, HeaderConnection, "close")
expectResponseHeaderGet(t, h, HeaderContentLength, "1234")
expectResponseHeaderGet(t, h, "seRVer", "aaaa")
@@ -1937,16 +2150,16 @@ func TestResponseHeaderSetGet(t *testing.T) {
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
- t.Fatalf("Unexpected error when writing response header: %s", err)
+ t.Fatalf("Unexpected error when writing response header: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("Unexpected error when flushing response header: %s", err)
+ t.Fatalf("Unexpected error when flushing response header: %v", err)
}
var h1 ResponseHeader
br := bufio.NewReader(w)
if err = h1.Read(br); err != nil {
- t.Fatalf("Unexpected error when reading response header: %s", err)
+ t.Fatalf("Unexpected error when reading response header: %v", err)
}
if h1.ContentLength() != h.ContentLength() {
@@ -1958,6 +2171,7 @@ func TestResponseHeaderSetGet(t *testing.T) {
expectResponseHeaderGet(t, &h1, "Foo", "bar")
expectResponseHeaderGet(t, &h1, HeaderContentType, "aaa/bbb")
+ expectResponseHeaderGet(t, &h1, HeaderContentEncoding, "gzip")
expectResponseHeaderGet(t, &h1, HeaderConnection, "close")
expectResponseHeaderGet(t, &h1, "seRVer", "aaaa")
expectResponseHeaderGet(t, &h1, "baz", "xxxxx")
@@ -1993,17 +2207,17 @@ func testResponseHeaderConnectionClose(t *testing.T, connectionClose bool) {
bw := bufio.NewWriter(w)
err := h.Write(bw)
if err != nil {
- t.Fatalf("Unexpected error when writing response header: %s", err)
+ t.Fatalf("Unexpected error when writing response header: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("Unexpected error when flushing response header: %s", err)
+ t.Fatalf("Unexpected error when flushing response header: %v", err)
}
var h1 ResponseHeader
br := bufio.NewReader(w)
err = h1.Read(br)
if err != nil {
- t.Fatalf("Unexpected error when reading response header: %s", err)
+ t.Fatalf("Unexpected error when reading response header: %v", err)
}
if h1.ConnectionClose() != h.ConnectionClose() {
t.Fatalf("Unexpected value for ConnectionClose: %v. Expected %v", h1.ConnectionClose(), h.ConnectionClose())
@@ -2066,25 +2280,23 @@ func TestRequestHeaderBufioPeek(t *testing.T) {
br := bufio.NewReaderSize(r, 4096)
h := &RequestHeader{}
if err := h.Read(br); err != nil {
- t.Fatalf("Unexpected error when reading request: %s", err)
+ t.Fatalf("Unexpected error when reading request: %v", err)
}
verifyRequestHeader(t, h, -2, "/", "foobar.com", "", "")
- verifyTrailer(t, br, "aaaa")
}
func TestResponseHeaderBufioPeek(t *testing.T) {
t.Parallel()
r := &bufioPeekReader{
- s: "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: aaa\r\n" + getHeaders(10) + "\r\n0123456789",
+ s: "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: text/plain\r\nContent-Encoding: gzip\r\n" + getHeaders(10) + "\r\n0123456789",
}
br := bufio.NewReaderSize(r, 4096)
h := &ResponseHeader{}
if err := h.Read(br); err != nil {
- t.Fatalf("Unexpected error when reading response: %s", err)
+ t.Fatalf("Unexpected error when reading response: %v", err)
}
- verifyResponseHeader(t, h, 200, 10, "aaa")
- verifyTrailer(t, br, "0123456789")
+ verifyResponseHeader(t, h, 200, 10, "text/plain", "gzip")
}
func getHeaders(n int) string {
@@ -2102,143 +2314,127 @@ func TestResponseHeaderReadSuccess(t *testing.T) {
// straight order of content-length and content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n",
- 200, 123, "text/html", "")
+ 200, 123, "text/html")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// reverse order of content-length and content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 202 OK\r\nContent-Type: text/plain; encoding=utf-8\r\nContent-Length: 543\r\nConnection: close\r\n\r\n",
- 202, 543, "text/plain; encoding=utf-8", "")
+ 202, 543, "text/plain; encoding=utf-8")
if !h.ConnectionClose() {
t.Fatalf("expecting connection: close")
}
// tranfer-encoding: chunked
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 505 Internal error\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n",
- 505, -1, "text/html", "")
+ 505, -1, "text/html")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// reverse order of content-type and tranfer-encoding
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 343 foobar\r\nTransfer-Encoding: chunked\r\nContent-Type: text/json\r\n\r\n",
- 343, -1, "text/json", "")
+ 343, -1, "text/json")
// additional headers
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 100 Continue\r\nFoobar: baz\r\nContent-Type: aaa/bbb\r\nUser-Agent: x\r\nContent-Length: 123\r\nZZZ: werer\r\n\r\n",
- 100, 123, "aaa/bbb", "")
-
- // trailer (aka body)
- testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 32245\r\n\r\nqwert aaa",
- 200, 32245, "text/plain", "qwert aaa")
+ 100, 123, "aaa/bbb")
// ancient http protocol
testResponseHeaderReadSuccess(t, h, "HTTP/0.9 300 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\nqqqq",
- 300, 123, "text/html", "qqqq")
+ 300, 123, "text/html")
// lf instead of crlf
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length: 123\nContent-Type: text/html\n\n",
- 200, 123, "text/html", "")
+ 200, 123, "text/html")
// Zero-length headers with mixed crlf and lf
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nContent-Length: 345\nZero-Value: \r\nContent-Type: aaa\n: zero-key\r\n\r\nooa",
- 400, 345, "aaa", "ooa")
+ 400, 345, "aaa")
// No space after colon
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\nContent-Length:34\nContent-Type: sss\n\naaaa",
- 200, 34, "sss", "aaaa")
+ 200, 34, "sss")
// invalid case
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\nconTEnt-leNGTH: 123\nConTENT-TYPE: ass\n\n",
- 400, 123, "ass", "")
+ 400, 123, "ass")
// duplicate content-length
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 456\r\nContent-Type: foo/bar\r\nContent-Length: 321\r\n\r\n",
- 200, 321, "foo/bar", "")
+ 200, 321, "foo/bar")
// duplicate content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 234\r\nContent-Type: foo/bar\r\nContent-Type: baz/bar\r\n\r\n",
- 200, 234, "baz/bar", "")
-
- // both transfer-encoding: chunked and content-length
- testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 123\r\nTransfer-Encoding: chunked\r\n\r\n",
- 200, -1, "foo/bar", "")
+ 200, 234, "baz/bar")
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 300 OK\r\nContent-Type: foo/barr\r\nTransfer-Encoding: chunked\r\nContent-Length: 354\r\n\r\n",
- 300, -1, "foo/barr", "")
+ 300, -1, "foo/barr")
// duplicate transfer-encoding: chunked
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\nTransfer-Encoding: chunked\r\n\r\n",
- 200, -1, "text/html", "")
+ 200, -1, "text/html")
// no reason string in the first line
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 456\r\nContent-Type: xxx/yyy\r\nContent-Length: 134\r\n\r\naaaxxx",
- 456, 134, "xxx/yyy", "aaaxxx")
+ 456, 134, "xxx/yyy")
// blank lines before the first line
testResponseHeaderReadSuccess(t, h, "\r\nHTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 0\r\n\r\nsss",
- 200, 0, "aa", "sss")
+ 200, 0, "aa")
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close")
}
// no content-length (informational responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 101 OK\r\n\r\n",
- 101, -2, "text/plain; charset=utf-8", "")
+ 101, -2, "text/plain; charset=utf-8")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for informational response")
}
// no content-length (no-content responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 204 OK\r\n\r\n",
- 204, -2, "text/plain; charset=utf-8", "")
+ 204, -2, "text/plain; charset=utf-8")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for no-content response")
}
// no content-length (not-modified responses)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 304 OK\r\n\r\n",
- 304, -2, "text/plain; charset=utf-8", "")
+ 304, -2, "text/plain; charset=utf-8")
if h.ConnectionClose() {
t.Fatalf("expecting connection: keep-alive for not-modified response")
}
// no content-length (identity transfer-encoding)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\n\r\nabcdefg",
- 200, -2, "foo/bar", "abcdefg")
+ 200, -2, "foo/bar")
if !h.ConnectionClose() {
t.Fatalf("expecting connection: close for identity response")
}
- // non-numeric content-length
- testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: faaa\r\nContent-Type: text/html\r\n\r\nfoobar",
- 200, -2, "text/html", "foobar")
- testResponseHeaderReadSuccess(t, h, "HTTP/1.1 201 OK\r\nContent-Length: 123aa\r\nContent-Type: text/ht\r\n\r\naaa",
- 201, -2, "text/ht", "aaa")
- testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Length: aa124\r\nContent-Type: html\r\n\r\nxx",
- 200, -2, "html", "xx")
-
// no content-type
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa",
- 400, 123, string(defaultContentType), "foiaaa")
+ 400, 123, string(defaultContentType))
// no content-type and no default
h.SetNoDefaultContentType(true)
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa",
- 400, 123, "", "foiaaa")
+ 400, 123, "")
h.SetNoDefaultContentType(false)
// no headers
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\n\r\naaaabbb",
- 200, -2, string(defaultContentType), "aaaabbb")
+ 200, -2, string(defaultContentType))
if !h.IsHTTP11() {
t.Fatalf("expecting http/1.1 protocol")
}
// ancient http protocol
testResponseHeaderReadSuccess(t, h, "HTTP/1.0 203 OK\r\nContent-Length: 123\r\nContent-Type: foobar\r\n\r\naaa",
- 203, 123, "foobar", "aaa")
+ 203, 123, "foobar")
if h.IsHTTP11() {
t.Fatalf("ancient protocol must be non-http/1.1")
}
@@ -2248,7 +2444,7 @@ func TestResponseHeaderReadSuccess(t *testing.T) {
// ancient http protocol with 'Connection: keep-alive' header.
testResponseHeaderReadSuccess(t, h, "HTTP/1.0 403 aa\r\nContent-Length: 0\r\nContent-Type: 2\r\nConnection: Keep-Alive\r\n\r\nww",
- 403, 0, "2", "ww")
+ 403, 0, "2")
if h.IsHTTP11() {
t.Fatalf("ancient protocol must be non-http/1.1")
}
@@ -2264,21 +2460,21 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
// simple headers
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n",
- -2, "/foo/bar", "google.com", "", "", "")
+ -2, "/foo/bar", "google.com", "", "", nil)
if h.ConnectionClose() {
t.Fatalf("unexpected connection: close header")
}
// simple headers with body
testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar",
- -2, "/a/bar", "gole.com", "", "", "foobar")
+ -2, "/a/bar", "gole.com", "", "", nil)
if !h.ConnectionClose() {
t.Fatalf("connection: close unset")
}
// ancient http protocol
testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp",
- -2, "/bar", "gole", "", "", "pppp")
+ -2, "/bar", "gole", "", "", nil)
if h.IsHTTP11() {
t.Fatalf("ancient http protocol cannot be http/1.1")
}
@@ -2288,7 +2484,7 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
// ancient http protocol with 'Connection: keep-alive' header
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.0\r\nHost: bb\r\nConnection: keep-alive\r\n\r\nxxx",
- -2, "/aa", "bb", "", "", "xxx")
+ -2, "/aa", "bb", "", "", nil)
if h.IsHTTP11() {
t.Fatalf("ancient http protocol cannot be http/1.1")
}
@@ -2298,7 +2494,7 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
// complex headers with body
testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz",
- -2, "/aabar", "ole.com", "", "", "zzz")
+ -2, "/aabar", "ole.com", "", "", nil)
if !h.IsHTTP11() {
t.Fatalf("expecting http/1.1 protocol")
}
@@ -2308,103 +2504,103 @@ func TestRequestHeaderReadSuccess(t *testing.T) {
// lf instead of crlf
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\nHost: google.com\n\n",
- -2, "/foo/bar", "google.com", "", "", "")
+ -2, "/foo/bar", "google.com", "", "", nil)
// post method
testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef",
- 1235, "/aaa?bbb", "foobar.com", "", "aaa", "abcdef")
+ 1235, "/aaa?bbb", "foobar.com", "", "aaa", nil)
// zero-length headers with mixed crlf and lf
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv",
- -2, "/a", "aaa", "", "", "xccv")
+ -2, "/a", "aaa", "", "", nil)
// no space after colon
testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds",
- -2, "/a", "aaaxd", "", "", "sdfds")
+ -2, "/a", "aaaxd", "", "", nil)
// get with zero content-length
testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 0\n\n",
- 0, "/xxx", "aaa.com", "", "", "")
+ 0, "/xxx", "aaa.com", "", "", nil)
// get with non-zero content-length
testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 123\n\n",
- 123, "/xxx", "aaa.com", "", "", "")
+ 123, "/xxx", "aaa.com", "", "", nil)
// invalid case
testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\nhoST: bbb.com\n\naas",
- -2, "/aaa", "bbb.com", "", "", "aas")
+ -2, "/aaa", "bbb.com", "", "", nil)
// referer
testRequestHeaderReadSuccess(t, h, "GET /asdf HTTP/1.1\nHost: aaa.com\nReferer: bb.com\n\naaa",
- -2, "/asdf", "aaa.com", "bb.com", "", "aaa")
+ -2, "/asdf", "aaa.com", "bb.com", "", nil)
// duplicate host
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aaaaaa.com\r\nHost: bb.com\r\n\r\n",
- -2, "/aa", "bb.com", "", "", "")
+ -2, "/aa", "bb.com", "", "", nil)
// post with duplicate content-type
testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n",
- 123, "/a", "aa", "", "xx", "")
+ 123, "/a", "aa", "", "xx", nil)
// post with duplicate content-length
testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n",
- 1, "/xx", "aa", "", "s", "")
+ 1, "/xx", "aa", "", "s", nil)
// non-post with content-type
testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n",
- -2, "/aaa", "bbb.com", "", "aaab", "")
+ -2, "/aaa", "bbb.com", "", "aaab", nil)
// non-post with content-length
testRequestHeaderReadSuccess(t, h, "HEAD / HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\n\r\n",
- 123, "/", "aaa.com", "", "", "")
+ 123, "/", "aaa.com", "", "", nil)
// non-post with content-type and content-length
testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aa.com\r\nContent-Type: abd/test\r\nContent-Length: 123\r\n\r\n",
- 123, "/aa", "aa.com", "", "abd/test", "")
+ 123, "/aa", "aa.com", "", "abd/test", nil)
// request uri with hostname
testRequestHeaderReadSuccess(t, h, "GET http://gooGle.com/foO/%20bar?xxx#aaa HTTP/1.1\r\nHost: aa.cOM\r\n\r\ntrail",
- -2, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", "trail")
+ -2, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", nil)
// no protocol in the first line
testRequestHeaderReadSuccess(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD",
- -2, "/foo/bar", "google.com", "", "", "isdD")
+ -2, "/foo/bar", "google.com", "", "", nil)
// blank lines before the first line
testRequestHeaderReadSuccess(t, h, "\r\n\n\r\nGET /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nsss",
- -2, "/aaa", "aaa.com", "", "", "sss")
+ -2, "/aaa", "aaa.com", "", "", nil)
// request uri with spaces
testRequestHeaderReadSuccess(t, h, "GET /foo/ bar baz HTTP/1.1\r\nHost: aa.com\r\n\r\nxxx",
- -2, "/foo/ bar baz", "aa.com", "", "", "xxx")
+ -2, "/foo/ bar baz", "aa.com", "", "", nil)
// no host
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa",
- -2, "/foo/bar", "", "", "", "aaa")
+ -2, "/foo/bar", "", "", "", nil)
// no host, no headers
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar",
- -2, "/foo/bar", "", "", "", "foobar")
+ -2, "/foo/bar", "", "", "", nil)
// post without content-length and content-type
testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc",
- -2, "/aaa", "aaa.com", "", "", "zxc")
+ -2, "/aaa", "aaa.com", "", "", nil)
// post without content-type
testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Length: 123\r\n\r\npoiuy",
- 123, "/abc", "aa.com", "", "", "poiuy")
+ 123, "/abc", "aa.com", "", "", nil)
// post without content-length
testRequestHeaderReadSuccess(t, h, "POST /abc HTTP/1.1\r\nHost: aa.com\r\nContent-Type: adv\r\n\r\n123456",
- -2, "/abc", "aa.com", "", "adv", "123456")
+ -2, "/abc", "aa.com", "", "adv", nil)
// invalid method
testRequestHeaderReadSuccess(t, h, "POST /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\nmnbv",
- -2, "/foo/bar", "google.com", "", "", "mnbv")
+ -2, "/foo/bar", "google.com", "", "", nil)
// put request
testRequestHeaderReadSuccess(t, h, "PUT /faa HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\nContent-Type: aaa\r\n\r\nxwwere",
- 123, "/faa", "aaa.com", "", "aaa", "xwwere")
+ 123, "/faa", "aaa.com", "", "aaa", nil)
}
func TestResponseHeaderReadError(t *testing.T) {
@@ -2425,11 +2621,19 @@ func TestResponseHeaderReadError(t *testing.T) {
testResponseHeaderReadError(t, h, "HTTP/1.1 123foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
testResponseHeaderReadError(t, h, "HTTP/1.1 foobar344 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n")
+ // non-numeric content-length
+ testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: faaa\r\nContent-Type: text/html\r\n\r\nfoobar")
+ testResponseHeaderReadError(t, h, "HTTP/1.1 201 OK\r\nContent-Length: 123aa\r\nContent-Type: text/ht\r\n\r\naaa")
+ testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: aa124\r\nContent-Type: html\r\n\r\nxx")
+
// no headers
testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\n")
// no trailing crlf
testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n")
+
+ // forbidden trailer
+ testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n")
}
func TestResponseHeaderReadErrorSecureLog(t *testing.T) {
@@ -2474,6 +2678,9 @@ func TestRequestHeaderReadError(t *testing.T) {
// post with invalid content-length
testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty")
+
+ // forbidden trailer
+ testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n")
}
func TestRequestHeaderReadSecuredError(t *testing.T) {
@@ -2504,7 +2711,7 @@ func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string
}
// make sure response header works after error
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss",
- 200, 12345, "foo/bar", "sss")
+ 200, 12345, "foo/bar")
}
func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers string) {
@@ -2519,7 +2726,7 @@ func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers
}
// make sure response header works after error
testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss",
- 200, 12345, "foo/bar", "sss")
+ 200, 12345, "foo/bar")
}
func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) {
@@ -2532,7 +2739,7 @@ func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string)
// make sure request header works after error
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx",
- -2, "/foo/bar", "aaaa", "", "", "xxx")
+ -2, "/foo/bar", "aaaa", "", "", nil)
}
func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers string) {
@@ -2547,42 +2754,43 @@ func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers s
}
// make sure request header works after error
testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx",
- -2, "/foo/bar", "aaaa", "", "", "xxx")
+ -2, "/foo/bar", "aaaa", "", "", nil)
}
func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int,
- expectedContentType, expectedTrailer string) {
+ expectedContentType string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err != nil {
- t.Fatalf("Unexpected error when parsing response headers: %s. headers=%q", err, headers)
+ t.Fatalf("Unexpected error when parsing response headers: %v. headers=%q", err, headers)
}
- verifyResponseHeader(t, h, expectedStatusCode, expectedContentLength, expectedContentType)
- verifyTrailer(t, br, expectedTrailer)
+ verifyResponseHeader(t, h, expectedStatusCode, expectedContentLength, expectedContentType, "")
}
func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int,
- expectedRequestURI, expectedHost, expectedReferer, expectedContentType, expectedTrailer string) {
+ expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, expectedTrailer map[string]string) {
r := bytes.NewBufferString(headers)
br := bufio.NewReader(r)
err := h.Read(br)
if err != nil {
- t.Fatalf("Unexpected error when parsing request headers: %s. headers=%q", err, headers)
+ t.Fatalf("Unexpected error when parsing request headers: %v. headers=%q", err, headers)
}
verifyRequestHeader(t, h, expectedContentLength, expectedRequestURI, expectedHost, expectedReferer, expectedContentType)
- verifyTrailer(t, br, expectedTrailer)
}
-func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType string) {
+func verifyResponseHeader(t *testing.T, h *ResponseHeader, expectedStatusCode, expectedContentLength int, expectedContentType, expectedContentEncoding string) {
if h.StatusCode() != expectedStatusCode {
t.Fatalf("Unexpected status code %d. Expected %d", h.StatusCode(), expectedStatusCode)
}
if h.ContentLength() != expectedContentLength {
t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength)
}
- if string(h.Peek(HeaderContentType)) != expectedContentType {
- t.Fatalf("Unexpected content type %q. Expected %q", h.Peek(HeaderContentType), expectedContentType)
+ if string(h.ContentType()) != expectedContentType {
+ t.Fatalf("Unexpected content type %q. Expected %q", h.ContentType(), expectedContentType)
+ }
+ if string(h.ContentEncoding()) != expectedContentEncoding {
+ t.Fatalf("Unexpected content encoding %q. Expected %q", h.ContentEncoding(), expectedContentEncoding)
}
}
@@ -2611,12 +2819,159 @@ func verifyRequestHeader(t *testing.T, h *RequestHeader, expectedContentLength i
}
}
-func verifyTrailer(t *testing.T, r *bufio.Reader, expectedTrailer string) {
- trailer, err := ioutil.ReadAll(r)
+func verifyResponseTrailer(t *testing.T, h *ResponseHeader, expectedTrailers map[string]string) {
+ for k, v := range expectedTrailers {
+ got := h.Peek(k)
+ if !bytes.Equal(got, []byte(v)) {
+ t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got)
+ }
+ }
+}
+
+func verifyRequestTrailer(t *testing.T, h *RequestHeader, expectedTrailers map[string]string) {
+ for k, v := range expectedTrailers {
+ got := h.Peek(k)
+ if !bytes.Equal(got, []byte(v)) {
+ t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, got)
+ }
+ }
+}
+
+func verifyTrailer(t *testing.T, r *bufio.Reader, expectedTrailers map[string]string, isReq bool) {
+ if isReq {
+ req := Request{}
+ err := req.Header.ReadTrailer(r)
+ if err == io.EOF && expectedTrailers == nil {
+ return
+ }
+ if err != nil {
+ t.Fatalf("Cannot read trailer: %v", err)
+ }
+ verifyRequestTrailer(t, &req.Header, expectedTrailers)
+ return
+ }
+
+ resp := Response{}
+ err := resp.Header.ReadTrailer(r)
+ if err == io.EOF && expectedTrailers == nil {
+ return
+ }
+ if err != nil {
+ t.Fatalf("Cannot read trailer: %v", err)
+ }
+ verifyResponseTrailer(t, &resp.Header, expectedTrailers)
+}
+
+func TestRequestHeader_PeekAll(t *testing.T) {
+ t.Parallel()
+ h := &RequestHeader{}
+ h.Add(HeaderConnection, "keep-alive")
+ h.Add("Content-Type", "aaa")
+ h.Add(HeaderHost, "aaabbb")
+ h.Add("User-Agent", "asdfas")
+ h.Add("Content-Length", "1123")
+ h.Add("Cookie", "foobar=baz")
+ h.Add(HeaderTrailer, "foo, bar")
+ h.Add("aaa", "aaa")
+ h.Add("aaa", "bbb")
+
+ expectRequestHeaderAll(t, h, HeaderConnection, [][]byte{s2b("keep-alive")})
+ expectRequestHeaderAll(t, h, "Content-Type", [][]byte{s2b("aaa")})
+ expectRequestHeaderAll(t, h, HeaderHost, [][]byte{s2b("aaabbb")})
+ expectRequestHeaderAll(t, h, "User-Agent", [][]byte{s2b("asdfas")})
+ expectRequestHeaderAll(t, h, "Content-Length", [][]byte{s2b("1123")})
+ expectRequestHeaderAll(t, h, "Cookie", [][]byte{s2b("foobar=baz")})
+ expectRequestHeaderAll(t, h, HeaderTrailer, [][]byte{s2b("Foo, Bar")})
+ expectRequestHeaderAll(t, h, "aaa", [][]byte{s2b("aaa"), s2b("bbb")})
+
+ h.Del("Content-Type")
+ h.Del(HeaderHost)
+ h.Del("aaa")
+ expectRequestHeaderAll(t, h, "Content-Type", [][]byte{})
+ expectRequestHeaderAll(t, h, HeaderHost, [][]byte{})
+ expectRequestHeaderAll(t, h, "aaa", [][]byte{})
+}
+func expectRequestHeaderAll(t *testing.T, h *RequestHeader, key string, expectedValue [][]byte) {
+ if len(h.PeekAll(key)) != len(expectedValue) {
+ t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue))
+ }
+ if !reflect.DeepEqual(h.PeekAll(key), expectedValue) {
+ t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.PeekAll(key), expectedValue)
+ }
+}
+
+func TestResponseHeader_PeekAll(t *testing.T) {
+ t.Parallel()
+
+ h := &ResponseHeader{}
+ h.Add(HeaderContentType, "aaa/bbb")
+ h.Add(HeaderContentEncoding, "gzip")
+ h.Add(HeaderConnection, "close")
+ h.Add(HeaderContentLength, "1234")
+ h.Add(HeaderServer, "aaaa")
+ h.Add(HeaderSetCookie, "cccc")
+ h.Add("aaa", "aaa")
+ h.Add("aaa", "bbb")
+
+ expectResponseHeaderAll(t, h, HeaderContentType, [][]byte{s2b("aaa/bbb")})
+ expectResponseHeaderAll(t, h, HeaderContentEncoding, [][]byte{s2b("gzip")})
+ expectResponseHeaderAll(t, h, HeaderConnection, [][]byte{s2b("close")})
+ expectResponseHeaderAll(t, h, HeaderContentLength, [][]byte{s2b("1234")})
+ expectResponseHeaderAll(t, h, HeaderServer, [][]byte{s2b("aaaa")})
+ expectResponseHeaderAll(t, h, HeaderSetCookie, [][]byte{s2b("cccc")})
+ expectResponseHeaderAll(t, h, "aaa", [][]byte{s2b("aaa"), s2b("bbb")})
+
+ h.Del(HeaderContentType)
+ h.Del(HeaderContentEncoding)
+ expectResponseHeaderAll(t, h, HeaderContentType, [][]byte{defaultContentType})
+ expectResponseHeaderAll(t, h, HeaderContentEncoding, [][]byte{})
+}
+
+func expectResponseHeaderAll(t *testing.T, h *ResponseHeader, key string, expectedValue [][]byte) {
+ if len(h.PeekAll(key)) != len(expectedValue) {
+ t.Fatalf("Unexpected size for key %q: %d. Expected %d", key, len(h.PeekAll(key)), len(expectedValue))
+ }
+ if !reflect.DeepEqual(h.PeekAll(key), expectedValue) {
+ t.Fatalf("Unexpected value for key %q: %q. Expected %q", key, h.PeekAll(key), expectedValue)
+ }
+}
+
+func TestRequestHeader_Keys(t *testing.T) {
+ h := &RequestHeader{}
+ h.Add(HeaderConnection, "keep-alive")
+ h.Add("Content-Type", "aaa")
+ err := h.SetTrailer("aaa,bbb,ccc")
if err != nil {
- t.Fatalf("Cannot read trailer: %s", err)
+ t.Fatal(err)
+ }
+ actualKeys := h.PeekKeys()
+ expectedKeys := [][]byte{s2b("keep-alive"), s2b("aaa")}
+ if reflect.DeepEqual(actualKeys, expectedKeys) {
+ t.Fatalf("Unexpected value %q. Expected %q", actualKeys, expectedKeys)
+ }
+ actualTrailerKeys := h.PeekTrailerKeys()
+ expectedTrailerKeys := [][]byte{s2b("aaa"), s2b("bbb"), s2b("ccc")}
+ if reflect.DeepEqual(actualTrailerKeys, expectedTrailerKeys) {
+ t.Fatalf("Unexpected value %q. Expected %q", actualTrailerKeys, expectedTrailerKeys)
+ }
+}
+
+func TestResponseHeader_Keys(t *testing.T) {
+ h := &ResponseHeader{}
+ h.Add(HeaderConnection, "keep-alive")
+ h.Add("Content-Type", "aaa")
+ err := h.SetTrailer("aaa,bbb,ccc")
+ if err != nil {
+ t.Fatal(err)
+ }
+ actualKeys := h.PeekKeys()
+ expectedKeys := [][]byte{s2b("keep-alive"), s2b("aaa")}
+ if reflect.DeepEqual(actualKeys, expectedKeys) {
+ t.Fatalf("Unexpected value %q. Expected %q", actualKeys, expectedKeys)
}
- if !bytes.Equal(trailer, []byte(expectedTrailer)) {
- t.Fatalf("Unexpected trailer %q. Expected %q", trailer, expectedTrailer)
+ actualTrailerKeys := h.PeekTrailerKeys()
+ expectedTrailerKeys := [][]byte{s2b("aaa"), s2b("bbb"), s2b("ccc")}
+ if reflect.DeepEqual(actualTrailerKeys, expectedTrailerKeys) {
+ t.Fatalf("Unexpected value %q. Expected %q", actualTrailerKeys, expectedTrailerKeys)
}
}
diff --git a/header_timing_test.go b/header_timing_test.go
index 745b5c3..81ff7fb 100644
--- a/header_timing_test.go
+++ b/header_timing_test.go
@@ -12,6 +12,9 @@ import (
var strFoobar = []byte("foobar.com")
+// it has the same length as Content-Type
+var strNonSpecialHeader = []byte("Dontent-Type")
+
type benchReadBuf struct {
s []byte
n int
@@ -38,7 +41,7 @@ func BenchmarkRequestHeaderRead(b *testing.B) {
buf.n = 0
br.Reset(buf)
if err := h.Read(br); err != nil {
- b.Fatalf("unexpected error when reading header: %s", err)
+ b.Fatalf("unexpected error when reading header: %v", err)
}
}
})
@@ -55,7 +58,7 @@ func BenchmarkResponseHeaderRead(b *testing.B) {
buf.n = 0
br.Reset(buf)
if err := h.Read(br); err != nil {
- b.Fatalf("unexpected error when reading header: %s", err)
+ b.Fatalf("unexpected error when reading header: %v", err)
}
}
})
@@ -71,7 +74,7 @@ func BenchmarkRequestHeaderWrite(b *testing.B) {
var w bytebufferpool.ByteBuffer
for pb.Next() {
if _, err := h.WriteTo(&w); err != nil {
- b.Fatalf("unexpected error when writing header: %s", err)
+ b.Fatalf("unexpected error when writing header: %v", err)
}
w.Reset()
}
@@ -89,19 +92,20 @@ func BenchmarkResponseHeaderWrite(b *testing.B) {
var w bytebufferpool.ByteBuffer
for pb.Next() {
if _, err := h.WriteTo(&w); err != nil {
- b.Fatalf("unexpected error when writing header: %s", err)
+ b.Fatalf("unexpected error when writing header: %v", err)
}
w.Reset()
}
})
}
-func BenchmarkRequestHeaderPeekBytesCanonical(b *testing.B) {
+// Result: 2.2 ns/op
+func BenchmarkRequestHeaderPeekBytesSpecialHeader(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
- h.SetBytesV("Host", strFoobar)
+ h.SetContentTypeBytes(strFoobar)
for pb.Next() {
- v := h.PeekBytes(strHost)
+ v := h.PeekBytes(strContentType)
if !bytes.Equal(v, strFoobar) {
b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
}
@@ -109,13 +113,41 @@ func BenchmarkRequestHeaderPeekBytesCanonical(b *testing.B) {
})
}
-func BenchmarkRequestHeaderPeekBytesNonCanonical(b *testing.B) {
+// Result: 2.9 ns/op
+func BenchmarkRequestHeaderPeekBytesNonSpecialHeader(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
var h RequestHeader
- h.SetBytesV("Host", strFoobar)
- hostBytes := []byte("HOST")
+ h.SetBytesKV(strNonSpecialHeader, strFoobar)
+ for pb.Next() {
+ v := h.PeekBytes(strNonSpecialHeader)
+ if !bytes.Equal(v, strFoobar) {
+ b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
+ }
+ }
+ })
+}
+
+// Result: 2.3 ns/op
+func BenchmarkResponseHeaderPeekBytesSpecialHeader(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ var h ResponseHeader
+ h.SetContentTypeBytes(strFoobar)
+ for pb.Next() {
+ v := h.PeekBytes(strContentType)
+ if !bytes.Equal(v, strFoobar) {
+ b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
+ }
+ }
+ })
+}
+
+// Result: 2.9 ns/op
+func BenchmarkResponseHeaderPeekBytesNonSpecialHeader(b *testing.B) {
+ b.RunParallel(func(pb *testing.PB) {
+ var h ResponseHeader
+ h.SetBytesKV(strNonSpecialHeader, strFoobar)
for pb.Next() {
- v := h.PeekBytes(hostBytes)
+ v := h.PeekBytes(strNonSpecialHeader)
if !bytes.Equal(v, strFoobar) {
b.Fatalf("unexpected result: %q. Expected %q", v, strFoobar)
}
diff --git a/headers.go b/headers.go
index 378dfec..676a0da 100644
--- a/headers.go
+++ b/headers.go
@@ -36,8 +36,9 @@ const (
HeaderVary = "Vary"
// Connection management
- HeaderConnection = "Connection"
- HeaderKeepAlive = "Keep-Alive"
+ HeaderConnection = "Connection"
+ HeaderKeepAlive = "Keep-Alive"
+ HeaderProxyConnection = "Proxy-Connection"
// Content negotiation
HeaderAccept = "Accept"
diff --git a/http.go b/http.go
index 18e5f81..3d1429c 100644
--- a/http.go
+++ b/http.go
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
+ "math"
"mime/multipart"
"net"
"os"
@@ -17,6 +18,18 @@ import (
"github.com/valyala/bytebufferpool"
)
+var (
+ requestBodyPoolSizeLimit = -1
+ responseBodyPoolSizeLimit = -1
+)
+
+// SetBodySizePoolLimit set the max body size for bodies to be returned to the pool.
+// If the body size is larger it will be released instead of put back into the pool for reuse.
+func SetBodySizePoolLimit(reqBodyLimit, respBodyLimit int) {
+ requestBodyPoolSizeLimit = reqBodyLimit
+ responseBodyPoolSizeLimit = respBodyLimit
+}
+
// Request represents HTTP request.
//
// It is forbidden copying Request instances. Create new instances
@@ -56,6 +69,9 @@ type Request struct {
// Request timeout. Usually set by DoDeadline or DoTimeout
// if <= 0, means not set
timeout time.Duration
+
+ // Use Host header (request.Header.SetHost) instead of the host from SetRequestURI, SetHost, or URI().SetHost
+ UseHostHeader bool
}
// Response represents HTTP response.
@@ -239,14 +255,14 @@ func (resp *Response) IsBodyStream() bool {
//
// This function may be used in the following cases:
//
-// * if request body is too big (more than 10MB).
-// * if request body is streamed from slow external sources.
-// * if request body must be streamed to the server in chunks
+// - if request body is too big (more than 10MB).
+// - if request body is streamed from slow external sources.
+// - if request body must be streamed to the server in chunks
// (aka `http client push` or `chunked transfer-encoding`).
//
// Note that GET and HEAD requests cannot have body.
//
-/// See also SetBodyStream.
+// See also SetBodyStream.
func (req *Request) SetBodyStreamWriter(sw StreamWriter) {
sr := NewStreamReader(sw)
req.SetBodyStream(sr, -1)
@@ -256,9 +272,9 @@ func (req *Request) SetBodyStreamWriter(sw StreamWriter) {
//
// This function may be used in the following cases:
//
-// * if response body is too big (more than 10MB).
-// * if response body is streamed from slow external sources.
-// * if response body must be streamed to the client in chunks
+// - if response body is too big (more than 10MB).
+// - if response body is streamed from slow external sources.
+// - if response body must be streamed to the client in chunks
// (aka `http server push` or `chunked transfer-encoding`).
//
// See also SetBodyStream.
@@ -471,6 +487,48 @@ func inflateData(p []byte) ([]byte, error) {
return bb.B, nil
}
+var ErrContentEncodingUnsupported = errors.New("unsupported Content-Encoding")
+
+// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli.
+//
+// This method may be used if the response header contains
+// 'Content-Encoding' for reading uncompressed request body.
+// Use Body for reading the raw request body.
+func (req *Request) BodyUncompressed() ([]byte, error) {
+ switch string(req.Header.ContentEncoding()) {
+ case "":
+ return req.Body(), nil
+ case "deflate":
+ return req.BodyInflate()
+ case "gzip":
+ return req.BodyGunzip()
+ case "br":
+ return req.BodyUnbrotli()
+ default:
+ return nil, ErrContentEncodingUnsupported
+ }
+}
+
+// BodyUncompressed returns body data and if needed decompress it from gzip, deflate or Brotli.
+//
+// This method may be used if the response header contains
+// 'Content-Encoding' for reading uncompressed response body.
+// Use Body for reading the raw response body.
+func (resp *Response) BodyUncompressed() ([]byte, error) {
+ switch string(resp.Header.ContentEncoding()) {
+ case "":
+ return resp.Body(), nil
+ case "deflate":
+ return resp.BodyInflate()
+ case "gzip":
+ return resp.BodyGunzip()
+ case "br":
+ return resp.BodyUnbrotli()
+ default:
+ return nil, ErrContentEncodingUnsupported
+ }
+}
+
// BodyWriteTo writes request body to w.
func (req *Request) BodyWriteTo(w io.Writer) error {
if req.bodyStream != nil {
@@ -567,6 +625,9 @@ func (req *Request) SetBodyRaw(body []byte) {
// The majority of workloads don't need this method.
func (resp *Response) ReleaseBody(size int) {
resp.bodyRaw = nil
+ if resp.body == nil {
+ return
+ }
if cap(resp.body.B) > size {
resp.closeBodyStream() //nolint:errcheck
resp.body = nil
@@ -582,6 +643,9 @@ func (resp *Response) ReleaseBody(size int) {
// The majority of workloads don't need this method.
func (req *Request) ReleaseBody(size int) {
req.bodyRaw = nil
+ if req.body == nil {
+ return
+ }
if cap(req.body.B) > size {
req.closeBodyStream() //nolint:errcheck
req.body = nil
@@ -729,6 +793,8 @@ func (req *Request) copyToSkipBody(dst *Request) {
dst.parsedPostArgs = req.parsedPostArgs
dst.isTLS = req.isTLS
+ dst.UseHostHeader = req.UseHostHeader
+
// do not copy multipartForm - it will be automatically
// re-created on the first call to MultipartForm.
}
@@ -760,6 +826,14 @@ func swapRequestBody(a, b *Request) {
a.body, b.body = b.body, a.body
a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw
a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream
+
+ // This code assumes that if a requestStream was swapped the headers are also swapped or copied.
+ if rs, ok := a.bodyStream.(*requestStream); ok {
+ rs.header = &a.Header
+ }
+ if rs, ok := b.bodyStream.(*requestStream); ok {
+ rs.header = &b.Header
+ }
}
func swapResponseBody(a, b *Response) {
@@ -774,6 +848,20 @@ func (req *Request) URI() *URI {
return &req.uri
}
+// SetURI initializes request URI
+// Use this method if a single URI may be reused across multiple requests.
+// Otherwise, you can just use SetRequestURI() and it will be parsed as new URI.
+// The URI is copied and can be safely modified later.
+func (req *Request) SetURI(newUri *URI) {
+ if newUri != nil {
+ newUri.CopyTo(&req.uri)
+ req.parsedURI = true
+ return
+ }
+ req.uri.Reset()
+ req.parsedURI = false
+}
+
func (req *Request) parseURI() error {
if req.parsedURI {
return nil
@@ -830,7 +918,7 @@ func (req *Request) MultipartForm() (*multipart.Form, error) {
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
- return nil, fmt.Errorf("cannot gunzip request body: %s", err)
+ return nil, fmt.Errorf("cannot gunzip request body: %w", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
@@ -839,14 +927,14 @@ func (req *Request) MultipartForm() (*multipart.Form, error) {
mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
req.multipartForm, err = mr.ReadForm(8 * 1024)
if err != nil {
- return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
+ return nil, fmt.Errorf("cannot read multipart/form-data body: %w", err)
}
} else {
body := req.bodyBytes()
if bytes.Equal(ce, strGzip) {
// Do not care about memory usage here.
if body, err = AppendGunzipBytes(nil, body); err != nil {
- return nil, fmt.Errorf("cannot gunzip request body: %s", err)
+ return nil, fmt.Errorf("cannot gunzip request body: %w", err)
}
} else if len(ce) > 0 {
return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
@@ -880,14 +968,14 @@ func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error {
mw := multipart.NewWriter(w)
if err := mw.SetBoundary(boundary); err != nil {
- return fmt.Errorf("cannot use form boundary %q: %s", boundary, err)
+ return fmt.Errorf("cannot use form boundary %q: %w", boundary, err)
}
// marshal values
for k, vv := range f.Value {
for _, v := range vv {
if err := mw.WriteField(k, v); err != nil {
- return fmt.Errorf("cannot write form field %q value %q: %s", k, v, err)
+ return fmt.Errorf("cannot write form field %q value %q: %w", k, v, err)
}
}
}
@@ -897,23 +985,23 @@ func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error {
for _, fv := range fvv {
vw, err := mw.CreatePart(fv.Header)
if err != nil {
- return fmt.Errorf("cannot create form file %q (%q): %s", k, fv.Filename, err)
+ return fmt.Errorf("cannot create form file %q (%q): %w", k, fv.Filename, err)
}
fh, err := fv.Open()
if err != nil {
- return fmt.Errorf("cannot open form file %q (%q): %s", k, fv.Filename, err)
+ return fmt.Errorf("cannot open form file %q (%q): %w", k, fv.Filename, err)
}
if _, err = copyZeroAlloc(vw, fh); err != nil {
- return fmt.Errorf("error when copying form file %q (%q): %s", k, fv.Filename, err)
+ return fmt.Errorf("error when copying form file %q (%q): %w", k, fv.Filename, err)
}
if err = fh.Close(); err != nil {
- return fmt.Errorf("cannot close form file %q (%q): %s", k, fv.Filename, err)
+ return fmt.Errorf("cannot close form file %q (%q): %w", k, fv.Filename, err)
}
}
}
if err := mw.Close(); err != nil {
- return fmt.Errorf("error when closing multipart form writer: %s", err)
+ return fmt.Errorf("error when closing multipart form writer: %w", err)
}
return nil
@@ -931,16 +1019,20 @@ func readMultipartForm(r io.Reader, boundary string, size, maxInMemoryFileSize i
mr := multipart.NewReader(lr, boundary)
f, err := mr.ReadForm(int64(maxInMemoryFileSize))
if err != nil {
- return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
+ return nil, fmt.Errorf("cannot read multipart/form-data body: %w", err)
}
return f, nil
}
// Reset clears request contents.
func (req *Request) Reset() {
+ if requestBodyPoolSizeLimit >= 0 && req.body != nil {
+ req.ReleaseBody(requestBodyPoolSizeLimit)
+ }
req.Header.Reset()
req.resetSkipHeader()
req.timeout = 0
+ req.UseHostHeader = false
}
func (req *Request) resetSkipHeader() {
@@ -966,6 +1058,9 @@ func (req *Request) RemoveMultipartFormFiles() {
// Reset clears response contents.
func (resp *Response) Reset() {
+ if responseBodyPoolSizeLimit >= 0 && resp.body != nil {
+ resp.ReleaseBody(responseBodyPoolSizeLimit)
+ }
resp.Header.Reset()
resp.resetSkipHeader()
resp.SkipBody = false
@@ -986,11 +1081,11 @@ func (resp *Response) resetSkipHeader() {
//
// If MayContinue returns true, the caller must:
//
-// - Either send StatusExpectationFailed response if request headers don't
-// satisfy the caller.
-// - Or send StatusContinue response before reading request body
-// with ContinueReadBody.
-// - Or close the connection.
+// - Either send StatusExpectationFailed response if request headers don't
+// satisfy the caller.
+// - Or send StatusContinue response before reading request body
+// with ContinueReadBody.
+// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) Read(r *bufio.Reader) error {
@@ -1014,11 +1109,11 @@ var ErrGetOnly = errors.New("non-GET request received")
//
// If MayContinue returns true, the caller must:
//
-// - Either send StatusExpectationFailed response if request headers don't
-// satisfy the caller.
-// - Or send StatusContinue response before reading request body
-// with ContinueReadBody.
-// - Or close the connection.
+// - Either send StatusExpectationFailed response if request headers don't
+// satisfy the caller.
+// - Or send StatusContinue response before reading request body
+// with ContinueReadBody.
+// - Or close the connection.
//
// io.EOF is returned if r is closed before reading the first header byte.
func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
@@ -1034,7 +1129,7 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
// Do not reset the request here - the caller must reset it before
// calling this method.
- if getOnly && !req.Header.IsGet() {
+ if getOnly && !req.Header.IsGet() && !req.Header.IsHead() {
return ErrGetOnly
}
@@ -1052,7 +1147,7 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo
// Do not reset the request here - the caller must reset it before
// calling this method.
- if getOnly && !req.Header.IsGet() {
+ if getOnly && !req.Header.IsGet() && !req.Header.IsHead() {
return ErrGetOnly
}
@@ -1071,11 +1166,11 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo
//
// The caller must do one of the following actions if MayContinue returns true:
//
-// - Either send StatusExpectationFailed response if request headers don't
-// satisfy the caller.
-// - Or send StatusContinue response before reading request body
-// with ContinueReadBody.
-// - Or close the connection.
+// - Either send StatusExpectationFailed response if request headers don't
+// satisfy the caller.
+// - Or send StatusContinue response before reading request body
+// with ContinueReadBody.
+// - Or close the connection.
func (req *Request) MayContinue() bool {
return bytes.Equal(req.Header.peek(strExpect), str100Continue)
}
@@ -1122,18 +1217,46 @@ func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseM
return nil
}
+ if err = req.ReadBody(r, contentLength, maxBodySize); err != nil {
+ return err
+ }
+
+ if req.Header.ContentLength() == -1 {
+ err = req.Header.ReadTrailer(r)
+ if err != nil && err != io.EOF {
+ return err
+ }
+ }
+ return nil
+}
+
+// ReadBody reads request body from the given r, limiting the body size.
+//
+// If maxBodySize > 0 and the body size exceeds maxBodySize,
+// then ErrBodyTooLarge is returned.
+func (req *Request) ReadBody(r *bufio.Reader, contentLength int, maxBodySize int) (err error) {
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
- bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
+
+ if contentLength >= 0 {
+ bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
+
+ } else if contentLength == -1 {
+ bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
+
+ } else {
+ bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
+ req.Header.SetContentLength(len(bodyBuf.B))
+ }
+
if err != nil {
req.Reset()
return err
}
- req.Header.SetContentLength(len(bodyBuf.B))
return nil
}
-// ContinueReadBody reads request body if request header contains
+// ContinueReadBodyStream reads request body if request header contains
// 'Expect: 100-continue'.
//
// The caller must send StatusContinue response before calling this method.
@@ -1164,7 +1287,11 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
// the end of body is determined by connection close.
// So just ignore request body for requests without
// 'Content-Length' and 'Transfer-Encoding' headers.
- req.Header.SetContentLength(0)
+
+ // refer to https://tools.ietf.org/html/rfc7230#section-3.3.2
+ if !req.Header.ignoreBody() {
+ req.Header.SetContentLength(0)
+ }
return nil
}
@@ -1175,12 +1302,12 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
- req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
+ req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
- req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
+ req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
return nil
}
req.Reset()
@@ -1188,7 +1315,7 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
}
req.body = bodyBuf
- req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
+ req.bodyStream = acquireRequestStream(bodyBuf, r, &req.Header)
req.Header.SetContentLength(contentLength)
return nil
}
@@ -1200,7 +1327,10 @@ func (resp *Response) Read(r *bufio.Reader) error {
return resp.ReadLimitBody(r, 0)
}
-// ReadLimitBody reads response from the given r, limiting the body size.
+// ReadLimitBody reads response headers from the given r,
+// then reads the body using the ReadBody function and limiting the body size.
+//
+// If resp.SkipBody is true then it skips reading the response body.
//
// If maxBodySize > 0 and the body size exceeds maxBodySize,
// then ErrBodyTooLarge is returned.
@@ -1220,17 +1350,49 @@ func (resp *Response) ReadLimitBody(r *bufio.Reader, maxBodySize int) error {
}
if !resp.mustSkipBody() {
- bodyBuf := resp.bodyBuffer()
- bodyBuf.Reset()
- bodyBuf.B, err = readBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B)
+ err = resp.ReadBody(r, maxBodySize)
if err != nil {
+ if isConnectionReset(err) {
+ return nil
+ }
+ return err
+ }
+ }
+
+ if resp.Header.ContentLength() == -1 {
+ err = resp.Header.ReadTrailer(r)
+ if err != nil && err != io.EOF {
+ if isConnectionReset(err) {
+ return nil
+ }
return err
}
- resp.Header.SetContentLength(len(bodyBuf.B))
}
return nil
}
+// ReadBody reads response body from the given r, limiting the body size.
+//
+// If maxBodySize > 0 and the body size exceeds maxBodySize,
+// then ErrBodyTooLarge is returned.
+func (resp *Response) ReadBody(r *bufio.Reader, maxBodySize int) (err error) {
+ bodyBuf := resp.bodyBuffer()
+ bodyBuf.Reset()
+
+ contentLength := resp.Header.ContentLength()
+ if contentLength >= 0 {
+ bodyBuf.B, err = readBody(r, contentLength, maxBodySize, bodyBuf.B)
+
+ } else if contentLength == -1 {
+ bodyBuf.B, err = readBodyChunked(r, maxBodySize, bodyBuf.B)
+
+ } else {
+ bodyBuf.B, err = readBodyIdentity(r, maxBodySize, bodyBuf.B)
+ resp.Header.SetContentLength(len(bodyBuf.B))
+ }
+ return err
+}
+
func (resp *Response) mustSkipBody() bool {
return resp.SkipBody || resp.Header.mustSkipContentLength()
}
@@ -1323,10 +1485,15 @@ func (req *Request) Write(w *bufio.Writer) error {
if len(req.Header.Host()) == 0 || req.parsedURI {
uri := req.URI()
host := uri.Host()
- if len(host) == 0 {
- return errRequestHostRequired
+ if len(req.Header.Host()) == 0 {
+ if len(host) == 0 {
+ return errRequestHostRequired
+ } else {
+ req.Header.SetHostBytes(host)
+ }
+ } else if !req.UseHostHeader {
+ req.Header.SetHostBytes(host)
}
- req.Header.SetHostBytes(host)
req.Header.SetRequestURIBytes(uri.RequestURI())
if len(uri.username) > 0 {
@@ -1358,7 +1525,7 @@ func (req *Request) Write(w *bufio.Writer) error {
if req.onlyMultipartForm() {
body, err = marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
- return fmt.Errorf("error when marshaling multipart form: %s", err)
+ return fmt.Errorf("error when marshaling multipart form: %w", err)
}
req.Header.SetMultipartFormBoundary(req.multipartFormBoundary)
}
@@ -1399,11 +1566,11 @@ func (resp *Response) WriteGzip(w *bufio.Writer) error {
//
// Level is the desired compression level:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
//
// The method gzips response body and sets 'Content-Encoding: gzip'
// header before writing response to w.
@@ -1430,11 +1597,11 @@ func (resp *Response) WriteDeflate(w *bufio.Writer) error {
//
// Level is the desired compression level:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
//
// The method deflates response body and sets 'Content-Encoding: deflate'
// header before writing response to w.
@@ -1448,7 +1615,7 @@ func (resp *Response) WriteDeflateLevel(w *bufio.Writer, level int) error {
}
func (resp *Response) brotliBody(level int) error {
- if len(resp.Header.peek(strContentEncoding)) > 0 {
+ if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
@@ -1498,12 +1665,12 @@ func (resp *Response) brotliBody(level int) error {
resp.body = w
resp.bodyRaw = nil
}
- resp.Header.SetCanonical(strContentEncoding, strBr)
+ resp.Header.SetContentEncodingBytes(strBr)
return nil
}
func (resp *Response) gzipBody(level int) error {
- if len(resp.Header.peek(strContentEncoding)) > 0 {
+ if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
@@ -1553,12 +1720,12 @@ func (resp *Response) gzipBody(level int) error {
resp.body = w
resp.bodyRaw = nil
}
- resp.Header.SetCanonical(strContentEncoding, strGzip)
+ resp.Header.SetContentEncodingBytes(strGzip)
return nil
}
func (resp *Response) deflateBody(level int) error {
- if len(resp.Header.peek(strContentEncoding)) > 0 {
+ if len(resp.Header.ContentEncoding()) > 0 {
// It looks like the body is already compressed.
// Do not compress it again.
return nil
@@ -1608,7 +1775,7 @@ func (resp *Response) deflateBody(level int) error {
resp.body = w
resp.bodyRaw = nil
}
- resp.Header.SetCanonical(strContentEncoding, strDeflate)
+ resp.Header.SetContentEncodingBytes(strDeflate)
return nil
}
@@ -1689,9 +1856,13 @@ func (req *Request) writeBodyStream(w *bufio.Writer) error {
}
} else {
req.Header.SetContentLength(-1)
- if err = req.Header.Write(w); err == nil {
+ err = req.Header.Write(w)
+ if err == nil {
err = writeBodyChunked(w, req.bodyStream)
}
+ if err == nil {
+ err = req.Header.writeTrailer(w)
+ }
}
err1 := req.closeBodyStream()
if err == nil {
@@ -1745,6 +1916,9 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error
if err == nil && sendBody {
err = writeBodyChunked(w, resp.bodyStream)
}
+ if err == nil {
+ err = resp.Header.writeTrailer(w)
+ }
}
}
err1 := resp.closeBodyStream()
@@ -1762,6 +1936,9 @@ func (req *Request) closeBodyStream() error {
if bsc, ok := req.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
+ if rs, ok := req.bodyStream.(*requestStream); ok {
+ releaseRequestStream(rs)
+ }
req.bodyStream = nil
return err
}
@@ -1798,6 +1975,8 @@ func (resp *Response) String() string {
func getHTTPString(hw httpWriter) string {
w := bytebufferpool.Get()
+ defer bytebufferpool.Put(w)
+
bw := bufio.NewWriter(w)
if err := hw.Write(bw); err != nil {
return err.Error()
@@ -1806,7 +1985,6 @@ func getHTTPString(hw httpWriter) string {
return err.Error()
}
s := string(w.B)
- bytebufferpool.Put(w)
return s
}
@@ -1824,7 +2002,7 @@ func writeBodyChunked(w *bufio.Writer, r io.Reader) error {
n, err = r.Read(buf)
if n == 0 {
if err == nil {
- panic("BUG: io.Reader returned 0, nil")
+ continue
}
if err == io.EOF {
if err = writeChunk(w, buf[:0]); err != nil {
@@ -1893,12 +2071,13 @@ func writeChunk(w *bufio.Writer, b []byte) error {
if _, err := w.Write(b); err != nil {
return err
}
- _, err := w.Write(strCRLF)
- err1 := w.Flush()
- if err == nil {
- err = err1
+ // If is end chunk, write CRLF after writing trailer
+ if n > 0 {
+ if _, err := w.Write(strCRLF); err != nil {
+ return err
+ }
}
- return err
+ return w.Flush()
}
// ErrBodyTooLarge is returned if either request or response body exceeds
@@ -1906,17 +2085,10 @@ func writeChunk(w *bufio.Writer, b []byte) error {
var ErrBodyTooLarge = errors.New("body size exceeds the given limit")
func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) ([]byte, error) {
- dst = dst[:0]
- if contentLength >= 0 {
- if maxBodySize > 0 && contentLength > maxBodySize {
- return dst, ErrBodyTooLarge
- }
- return appendBodyFixedSize(r, dst, contentLength)
- }
- if contentLength == -1 {
- return readBodyChunked(r, maxBodySize, dst)
+ if maxBodySize > 0 && contentLength > maxBodySize {
+ return dst, ErrBodyTooLarge
}
- return readBodyIdentity(r, maxBodySize, dst)
+ return appendBodyFixedSize(r, dst, contentLength)
}
var errChunkedStream = errors.New("chunked stream")
@@ -2033,6 +2205,9 @@ func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, erro
if err != nil {
return dst, err
}
+ if chunkSize == 0 {
+ return dst, err
+ }
if maxBodySize > 0 && len(dst)+chunkSize > maxBodySize {
return dst, ErrBodyTooLarge
}
@@ -2046,9 +2221,6 @@ func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, erro
}
}
dst = dst[:len(dst)-strCRLFLen]
- if chunkSize == 0 {
- return dst, nil
- }
}
}
@@ -2061,16 +2233,17 @@ func parseChunkSize(r *bufio.Reader) (int, error) {
c, err := r.ReadByte()
if err != nil {
return -1, ErrBrokenChunk{
- error: fmt.Errorf("cannot read '\r' char at the end of chunk size: %s", err),
+ error: fmt.Errorf("cannot read '\r' char at the end of chunk size: %w", err),
}
}
- // Skip any trailing whitespace after chunk size.
- if c == ' ' {
+ // Skip chunk extension after chunk size.
+ // Add support later if anyone needs it.
+ if c != '\r' {
continue
}
if err := r.UnreadByte(); err != nil {
return -1, ErrBrokenChunk{
- error: fmt.Errorf("cannot unread '\r' char at the end of chunk size: %s", err),
+ error: fmt.Errorf("cannot unread '\r' char at the end of chunk size: %w", err),
}
}
break
@@ -2087,7 +2260,7 @@ func readCrLf(r *bufio.Reader) error {
c, err := r.ReadByte()
if err != nil {
return ErrBrokenChunk{
- error: fmt.Errorf("cannot read %q char at the end of chunk size: %s", exp, err),
+ error: fmt.Errorf("cannot read %q char at the end of chunk size: %w", exp, err),
}
}
if c != exp {
@@ -2111,5 +2284,18 @@ func round2(n int) int {
x |= x >> 8
x |= x >> 16
+ // Make sure we don't return 0 due to overflow, even on 32 bit systems
+ if x >= uint32(math.MaxInt32) {
+ return math.MaxInt32
+ }
+
return int(x + 1)
}
+
+// SetTimeout sets timeout for the request.
+//
+// req.SetTimeout(t); c.Do(&req, &resp) is equivalent to
+// c.DoTimeout(&req, &resp, t)
+func (req *Request) SetTimeout(t time.Duration) {
+ req.timeout = t
+}
diff --git a/http_test.go b/http_test.go
index 0e76a10..52d8a67 100644
--- a/http_test.go
+++ b/http_test.go
@@ -3,10 +3,14 @@ package fasthttp
import (
"bufio"
"bytes"
+ "encoding/base64"
+ "errors"
"fmt"
"io"
- "io/ioutil"
+ "math"
"mime/multipart"
+ "net/http"
+ "net/http/httptest"
"reflect"
"strconv"
"strings"
@@ -16,6 +20,28 @@ import (
"github.com/valyala/bytebufferpool"
)
+func TestInvalidTrailers(t *testing.T) {
+ t.Parallel()
+
+ if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0x20, 0x30, 0x0a, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2d, 0x45, 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x3a, 0xff, 0x0a, 0x0a, 0x30, 0x0d, 0x0a, 0x30}))); !errors.Is(err, io.EOF) {
+ t.Fatalf("%#v", err)
+ }
+ if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0xff, 0x20, 0x0a, 0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !errors.Is(err, errEmptyInt) {
+ t.Fatal(err)
+ }
+ if err := (&Response{}).Read(bufio.NewReader(bytes.NewReader([]byte{0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !strings.Contains(err.Error(), "cannot find whitespace in the first line of response") {
+ t.Fatal(err)
+ }
+ if err := (&Request{}).Read(bufio.NewReader(bytes.NewReader([]byte{0xff, 0x20, 0x0a, 0x54, 0x52, 0x61, 0x49, 0x4c, 0x65, 0x52, 0x3a, 0x2c, 0x0a, 0x0a}))); !strings.Contains(err.Error(), "contain forbidden trailer") {
+ t.Fatal(err)
+ }
+
+ b, _ := base64.StdEncoding.DecodeString("tCAKIDoKCToKICAKCToKICAKCToKIAogOgoJOgogIAoJOgovIC8vOi4KOh0KVFJhSUxlUjo9HT09HQpUUmFJTGVSOicQAApUUmFJTGVSOj0gHSAKCT09HQoKOgoKCgo=")
+ if err := (&Request{}).Read(bufio.NewReader(bytes.NewReader(b))); !strings.Contains(err.Error(), "error when reading request headers: invalid header key") {
+ t.Fatalf("%#v", err)
+ }
+}
+
func TestResponseEmptyTransferEncoding(t *testing.T) {
t.Parallel()
@@ -58,7 +84,7 @@ func TestIssue875(t *testing.T) {
expectedLocation string
}
- var testcases = []testcase{
+ testcases := []testcase{
{
uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`,
expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n",
@@ -92,7 +118,7 @@ func TestIssue875(t *testing.T) {
ctx.Response.Header.Set("Location", q)
if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) {
- subT.Errorf("invalid escaping, got\n%s", ctx.Response.String())
+ subT.Errorf("invalid escaping, got\n%q", ctx.Response.String())
}
})
}
@@ -114,10 +140,9 @@ func TestRequestCopyTo(t *testing.T) {
expectedHost, expectedContentType, len(expectedBody), expectedBody)
br := bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
testRequestCopyTo(t, &req)
-
}
func TestResponseCopyTo(t *testing.T) {
@@ -134,7 +159,6 @@ func TestResponseCopyTo(t *testing.T) {
resp.Header.SetStatusCode(200)
resp.SetBodyString("test")
testResponseCopyTo(t, &resp)
-
}
func testRequestCopyTo(t *testing.T, src *Request) {
@@ -155,6 +179,122 @@ func testResponseCopyTo(t *testing.T, src *Response) {
}
}
+func TestRequestBodyStreamWithTrailer(t *testing.T) {
+ t.Parallel()
+
+ testRequestBodyStreamWithTrailer(t, nil, false)
+
+ body := createFixedBody(1e5)
+ testRequestBodyStreamWithTrailer(t, body, false)
+ testRequestBodyStreamWithTrailer(t, body, true)
+}
+
+func testRequestBodyStreamWithTrailer(t *testing.T, body []byte, disableNormalizing bool) {
+ expectedTrailer := map[string]string{
+ "foo": "testfoo",
+ "bar": "testbar",
+ }
+
+ var req1 Request
+ req1.Header.disableNormalizing = disableNormalizing
+ req1.SetHost("google.com")
+ req1.SetBodyStream(bytes.NewBuffer(body), -1)
+ for k, v := range expectedTrailer {
+ err := req1.Header.AddTrailer(k)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ req1.Header.Set(k, v)
+ }
+
+ w := &bytes.Buffer{}
+ bw := bufio.NewWriter(w)
+ if err := req1.Write(bw); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ var req2 Request
+ req2.Header.disableNormalizing = disableNormalizing
+ br := bufio.NewReader(w)
+ if err := req2.Read(br); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ reqBody := req2.Body()
+ if !bytes.Equal(reqBody, body) {
+ t.Fatalf("unexpected body: %q. Expecting %q", reqBody, body)
+ }
+
+ for k, v := range expectedTrailer {
+ kBytes := []byte(k)
+ normalizeHeaderKey(kBytes, disableNormalizing)
+ r := req2.Header.Peek(k)
+ if string(r) != v {
+ t.Fatalf("unexpected trailer header %q: %q. Expecting %q", kBytes, r, v)
+ }
+ }
+}
+
+func TestResponseBodyStreamWithTrailer(t *testing.T) {
+ t.Parallel()
+
+ testResponseBodyStreamWithTrailer(t, nil, false)
+
+ body := createFixedBody(1e5)
+ testResponseBodyStreamWithTrailer(t, body, false)
+ testResponseBodyStreamWithTrailer(t, body, true)
+}
+
+func testResponseBodyStreamWithTrailer(t *testing.T, body []byte, disableNormalizing bool) {
+ expectedTrailer := map[string]string{
+ "foo": "testfoo",
+ "bar": "testbar",
+ }
+ var resp1 Response
+ resp1.Header.disableNormalizing = disableNormalizing
+ resp1.SetBodyStream(bytes.NewReader(body), -1)
+ for k, v := range expectedTrailer {
+ err := resp1.Header.AddTrailer(k)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ resp1.Header.Set(k, v)
+ }
+
+ w := &bytes.Buffer{}
+ bw := bufio.NewWriter(w)
+ if err := resp1.Write(bw); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ var resp2 Response
+ resp2.Header.disableNormalizing = disableNormalizing
+ br := bufio.NewReader(w)
+ if err := resp2.Read(br); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ respBody := resp2.Body()
+ if !bytes.Equal(respBody, body) {
+ t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
+ }
+
+ for k, v := range expectedTrailer {
+ kBytes := []byte(k)
+ normalizeHeaderKey(kBytes, disableNormalizing)
+ r := resp2.Header.Peek(k)
+ if string(r) != v {
+ t.Fatalf("unexpected trailer header %q: %q. Expecting %q", kBytes, r, v)
+ }
+ }
+}
+
func TestResponseBodyStreamDeflate(t *testing.T) {
t.Parallel()
@@ -188,25 +328,31 @@ func testResponseBodyStreamDeflate(t *testing.T, body []byte, bodySize int) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteDeflate(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
respBody, err := resp.BodyInflate()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(respBody, body) {
t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
}
+ // check for invalid
+ resp.SetBodyRaw([]byte("invalid"))
+ _, errDeflate := resp.BodyInflate()
+ if errDeflate == nil || errDeflate.Error() != "zlib: invalid header" {
+ t.Fatalf("expected error: 'zlib: invalid header' but was %v", errDeflate)
+ }
}
func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) {
@@ -216,25 +362,31 @@ func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteGzip(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
respBody, err := resp.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !bytes.Equal(respBody, body) {
t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
}
+ // check for invalid
+ resp.SetBodyRaw([]byte("invalid"))
+ _, errUnzip := resp.BodyGunzip()
+ if errUnzip == nil || errUnzip.Error() != "unexpected EOF" {
+ t.Fatalf("expected error: 'unexpected EOF' but was %v", errUnzip)
+ }
}
func TestResponseWriteGzipNilBody(t *testing.T) {
@@ -244,10 +396,10 @@ func TestResponseWriteGzipNilBody(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteGzip(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -258,10 +410,50 @@ func TestResponseWriteDeflateNilBody(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := r.WriteDeflate(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestResponseBodyUncompressed(t *testing.T) {
+ body := "body"
+ var r Response
+ r.SetBodyStream(bytes.NewReader([]byte(body)), len(body))
+
+ w := &bytes.Buffer{}
+ bw := bufio.NewWriter(w)
+ if err := r.WriteDeflate(bw); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ var resp Response
+ br := bufio.NewReader(w)
+ if err := resp.Read(br); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ ce := resp.Header.ContentEncoding()
+ if string(ce) != "deflate" {
+ t.Fatalf("unexpected Content-Encoding: %s", ce)
+ }
+ respBody, err := resp.BodyUncompressed()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if string(respBody) != body {
+ t.Fatalf("unexpected body: %q. Expecting %q", respBody, body)
+ }
+
+ // check for invalid encoding
+ resp.Header.SetContentEncoding("invalid")
+ _, decodeErr := resp.BodyUncompressed()
+ if decodeErr != ErrContentEncodingUnsupported {
+ t.Fatalf("unexpected error: %v", decodeErr)
}
}
@@ -409,7 +601,7 @@ func TestRequestContentTypeWithCharsetIssue100(t *testing.T) {
br := bufio.NewReader(bytes.NewBufferString(s))
var r Request
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
body := r.Body()
@@ -459,12 +651,12 @@ tailfoobar`
var r Request
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(tail) != "tailfoobar" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "tailfoobar")
@@ -472,7 +664,7 @@ tailfoobar`
f, err := r.MultipartForm()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
defer r.RemoveMultipartFormFiles()
@@ -515,6 +707,31 @@ tailfoobar`
}
}
+func TestRequestSetURI(t *testing.T) {
+ t.Parallel()
+
+ var r Request
+
+ uri := "/foo/bar?baz"
+ u := &URI{}
+ u.Parse(nil, []byte(uri)) //nolint:errcheck
+ // Set request uri via SetURI()
+ r.SetURI(u) // copies URI
+ // modifying an original URI struct doesn't affect stored URI inside of request
+ u.SetPath("newPath")
+ if string(r.RequestURI()) != uri {
+ t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri)
+ }
+
+ // Set request uri to nil just resets the URI
+ r.Reset()
+ uri = "/"
+ r.SetURI(nil)
+ if string(r.RequestURI()) != uri {
+ t.Fatalf("unexpected request uri %q. Expecting %q", r.RequestURI(), uri)
+ }
+}
+
func TestRequestRequestURI(t *testing.T) {
t.Parallel()
@@ -569,6 +786,80 @@ func TestRequestUpdateURI(t *testing.T) {
}
}
+func TestUseHostHeader(t *testing.T) {
+ t.Parallel()
+
+ var r Request
+ r.UseHostHeader = true
+ r.Header.SetHost("aaa.bbb")
+ r.SetRequestURI("/lkjkl/kjl")
+
+ // Modify request uri and host via URI() object and make sure
+ // the requestURI and Host header are properly updated
+ u := r.URI()
+ u.SetPath("/123/432.html")
+ u.SetHost("foobar.com")
+ a := u.QueryArgs()
+ a.Set("aaa", "bcse")
+
+ s := r.String()
+ if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") {
+ t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s)
+ }
+ if !strings.Contains(s, "\r\nHost: aaa.bbb\r\n") {
+ t.Fatalf("cannot find %q in %q", "\r\nHost: aaa.bbb\r\n", s)
+ }
+}
+
+func TestUseHostHeader2(t *testing.T) {
+ t.Parallel()
+ testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.Host != "SomeHost" {
+ http.Error(w, fmt.Sprintf("Expected Host header to be '%q', but got '%q'", "SomeHost", r.Host), http.StatusBadRequest)
+ } else {
+ w.WriteHeader(http.StatusOK)
+ }
+ }))
+ defer testServer.Close()
+
+ client := &Client{}
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
+ resp := AcquireResponse()
+ defer ReleaseResponse(resp)
+
+ req.SetRequestURI(testServer.URL)
+ req.UseHostHeader = true
+ req.Header.SetHost("SomeHost")
+ if err := client.DoTimeout(req, resp, 1*time.Second); err != nil {
+ t.Fatalf("DoTimeout returned an error '%v'", err)
+ } else {
+ if resp.StatusCode() != http.StatusOK {
+ t.Fatalf("DoTimeout: %v", resp.body)
+ }
+ }
+ if err := client.Do(req, resp); err != nil {
+ t.Fatalf("DoTimeout returned an error '%v'", err)
+ } else {
+ if resp.StatusCode() != http.StatusOK {
+ t.Fatalf("Do: %q", resp.body)
+ }
+ }
+}
+
+func TestUseHostHeaderAfterRelease(t *testing.T) {
+ t.Parallel()
+ req := AcquireRequest()
+ req.UseHostHeader = true
+ ReleaseRequest(req)
+
+ req = AcquireRequest()
+ defer ReleaseRequest(req)
+ if req.UseHostHeader {
+ t.Fatalf("UseHostHeader was not released in ReleaseRequest()")
+ }
+}
+
func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) {
t.Parallel()
@@ -661,7 +952,7 @@ func TestRequestBodyWriteToMultipart(t *testing.T) {
var r Request
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
testBodyWriteTo(t, &r, expectedS, true)
@@ -675,7 +966,7 @@ type bodyWriterTo interface {
func testBodyWriteTo(t *testing.T, bw bodyWriterTo, expectedS string, isRetainedBody bool) {
var buf bytebufferpool.ByteBuffer
if err := bw.BodyWriteTo(&buf); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
s := buf.B
@@ -706,7 +997,7 @@ func TestRequestReadEOF(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
// incomplete request mustn't return io.EOF
@@ -731,7 +1022,7 @@ func TestResponseReadEOF(t *testing.T) {
t.Fatalf("expecting error")
}
if err != io.EOF {
- t.Fatalf("unexpected error: %s. Expecting %s", err, io.EOF)
+ t.Fatalf("unexpected error: %v. Expecting %v", err, io.EOF)
}
// incomplete response mustn't return io.EOF
@@ -754,7 +1045,26 @@ func TestRequestReadNoBody(t *testing.T) {
err := r.Read(br)
r.SetHost("foobar")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
+ }
+ s := r.String()
+ if strings.Contains(s, "Content-Length: ") {
+ t.Fatalf("unexpected Content-Length")
+ }
+}
+
+func TestRequestReadNoBodyStreaming(t *testing.T) {
+ t.Parallel()
+
+ var r Request
+
+ r.Header.contentLength = -2
+
+ br := bufio.NewReader(bytes.NewBufferString("GET / HTTP/1.1\r\n\r\n"))
+ err := r.ContinueReadBodyStream(br, 0)
+ r.SetHost("foobar")
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
}
s := r.String()
if strings.Contains(s, "Content-Length: ") {
@@ -773,7 +1083,7 @@ func TestResponseWriteTo(t *testing.T) {
var buf bytebufferpool.ByteBuffer
n, err := r.WriteTo(&buf)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected response length %d. Expecting %d", n, len(s))
@@ -794,7 +1104,7 @@ func TestRequestWriteTo(t *testing.T) {
var buf bytebufferpool.ByteBuffer
n, err := r.WriteTo(&buf)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != int64(len(s)) {
t.Fatalf("unexpected request length %d. Expecting %d", n, len(s))
@@ -837,6 +1147,24 @@ func TestResponseSkipBody(t *testing.T) {
t.Fatalf("unexpected content-type in response %q", s)
}
+ // set StatusNoContent with statusMessage
+ r.Header.SetStatusCode(StatusNoContent)
+ r.Header.SetStatusMessage([]byte("NC"))
+ r.SetBodyString("foobar")
+ s = r.String()
+ if strings.Contains(s, "\r\n\r\nfoobar") {
+ t.Fatalf("unexpected non-zero body in response %q", s)
+ }
+ if strings.Contains(s, "Content-Length: ") {
+ t.Fatalf("unexpected content-length in response %q", s)
+ }
+ if strings.Contains(s, "Content-Type: ") {
+ t.Fatalf("unexpected content-type in response %q", s)
+ }
+ if !strings.HasPrefix(s, "HTTP/1.1 204 NC\r\n") {
+ t.Fatalf("expecting non-default status line in response %q", s)
+ }
+
// explicitly skip body
r.Header.SetStatusCode(StatusOK)
r.SkipBody = true
@@ -885,11 +1213,11 @@ func TestRequestReadGzippedBody(t *testing.T) {
len(body), body)
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- if string(r.Header.Peek(HeaderContentEncoding)) != "gzip" {
- t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek(HeaderContentEncoding), "gzip")
+ if string(r.Header.ContentEncoding()) != "gzip" {
+ t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.ContentEncoding(), "gzip")
}
if r.Header.ContentLength() != len(body) {
t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body))
@@ -900,7 +1228,7 @@ func TestRequestReadGzippedBody(t *testing.T) {
bodyGunzipped, err := AppendGunzipBytes(nil, r.Body())
if err != nil {
- t.Fatalf("unexpected error when uncompressing data: %s", err)
+ t.Fatalf("unexpected error when uncompressing data: %v", err)
}
if string(bodyGunzipped) != bodyOriginal {
t.Fatalf("unexpected uncompressed body %q. Expecting %q", bodyGunzipped, bodyOriginal)
@@ -915,7 +1243,7 @@ func TestRequestReadPostNoBody(t *testing.T) {
s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(r.Header.RequestURI()) != "/foo/bar" {
@@ -931,9 +1259,9 @@ func TestRequestReadPostNoBody(t *testing.T) {
t.Fatalf("unexpected content-length: %d. Expecting 0", r.Header.ContentLength())
}
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(tail) != "aaaa" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "aaaa")
@@ -948,23 +1276,23 @@ func TestRequestContinueReadBody(t *testing.T) {
var r Request
if err := r.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !r.MayContinue() {
t.Fatalf("MayContinue must return true")
}
if err := r.ContinueReadBody(br, 0, true); err != nil {
- t.Fatalf("error when reading request body: %s", err)
+ t.Fatalf("error when reading request body: %v", err)
}
body := r.Body()
if string(body) != "abcde" {
t.Fatalf("unexpected body %q. Expecting %q", body, "abcde")
}
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(tail) != "f4343" {
t.Fatalf("unexpected tail %q. Expecting %q", tail, "f4343")
@@ -980,12 +1308,12 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
if err := mw.WriteField(k, v); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
boundary := mw.Boundary()
if err := mw.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
formData := w.Bytes()
@@ -996,11 +1324,11 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) {
var r Request
if err := r.Header.Read(br); err != nil {
- t.Fatalf("unexpected error reading headers: %s", err)
+ t.Fatalf("unexpected error reading headers: %v", err)
}
if err := r.readLimitBody(br, 10000, false, false); err != nil {
- t.Fatalf("unexpected error reading body: %s", err)
+ t.Fatalf("unexpected error reading body: %v", err)
}
if r.multipartForm != nil {
@@ -1010,7 +1338,6 @@ func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) {
if string(formData) != string(r.Body()) {
t.Fatalf("The body given must equal the body in the Request")
}
-
}
func TestRequestMayContinue(t *testing.T) {
@@ -1048,7 +1375,7 @@ func TestResponseGzipStream(t *testing.T) {
time.Sleep(time.Millisecond)
fmt.Fprintf(w, "1234") //nolint:errcheck
if err := w.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
})
if !r.IsBodyStream() {
@@ -1071,7 +1398,7 @@ func TestResponseDeflateStream(t *testing.T) {
w.Flush() //nolint:errcheck
w.Write([]byte("1234")) //nolint:errcheck
if err := w.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
})
if !r.IsBodyStream() {
@@ -1115,19 +1442,19 @@ func testResponseDeflateExt(t *testing.T, r *Response, s string) {
var err error
bw := bufio.NewWriter(&buf)
if err = r.WriteDeflate(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err = bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var r1 Response
br := bufio.NewReader(&buf)
if err = r1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := r1.Header.Peek(HeaderContentEncoding)
+ ce := r1.Header.ContentEncoding()
var body []byte
if isCompressible {
if string(ce) != "deflate" {
@@ -1136,7 +1463,7 @@ func testResponseDeflateExt(t *testing.T, r *Response, s string) {
}
body, err = r1.BodyInflate()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
} else {
if len(ce) > 0 {
@@ -1168,19 +1495,19 @@ func testResponseGzipExt(t *testing.T, r *Response, s string) {
var err error
bw := bufio.NewWriter(&buf)
if err = r.WriteGzip(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err = bw.Flush(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var r1 Response
br := bufio.NewReader(&buf)
if err = r1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := r1.Header.Peek(HeaderContentEncoding)
+ ce := r1.Header.ContentEncoding()
var body []byte
if isCompressible {
if string(ce) != "gzip" {
@@ -1189,7 +1516,7 @@ func testResponseGzipExt(t *testing.T, r *Response, s string) {
}
body, err = r1.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
} else {
if len(ce) > 0 {
@@ -1219,12 +1546,12 @@ func TestRequestMultipartForm(t *testing.T) {
k := fmt.Sprintf("key_%d", i)
v := fmt.Sprintf("value_%d", i)
if err := mw.WriteField(k, v); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
boundary := mw.Boundary()
if err := mw.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
formData := w.Bytes()
@@ -1238,13 +1565,13 @@ func TestRequestMultipartForm(t *testing.T) {
var req Request
br := bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
s = req.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := req.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
testRequestMultipartForm(t, "foobar", req.Body(), 3)
@@ -1259,12 +1586,12 @@ func testRequestMultipartForm(t *testing.T, boundary string, formData []byte, pa
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := req.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
f, err := req.MultipartForm()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
defer req.RemoveMultipartFormFiles()
@@ -1301,17 +1628,19 @@ func TestResponseReadLimitBody(t *testing.T) {
// response with content-length
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100)
- testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9)
+ testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 9, ErrBodyTooLarge)
// chunked response
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9)
+ testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nFoo: bar\r\n\r\n", 9)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 100)
- testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 2)
+ testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nfoobar\r\n\r\n", 100)
+ testResponseReadLimitBodyError(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nTransfer-Encoding: chunked\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 2, ErrBodyTooLarge)
// identity response
testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 6)
testResponseReadLimitBodySuccess(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 106)
- testResponseReadLimitBodyError(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 5)
+ testResponseReadLimitBodyError(t, "HTTP/1.1 400 OK\r\nContent-Type: aa\r\n\r\n123456", 5, ErrBodyTooLarge)
}
func TestRequestReadLimitBody(t *testing.T) {
@@ -1320,37 +1649,39 @@ func TestRequestReadLimitBody(t *testing.T) {
// request with content-length
testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 9)
testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 92)
- testRequestReadLimitBodyError(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 5)
+ testRequestReadLimitBodyError(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 5, ErrBodyTooLarge)
// chunked request
testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 9)
+ testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\nHost: a.com\nTransfer-Encoding: chunked\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nFoo: bar\r\n\r\n", 9)
testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 999)
- testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8)
+ testRequestReadLimitBodySuccess(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\nfoobar\r\n\r\n", 999)
+ testRequestReadLimitBodyError(t, "POST /a HTTP/1.1\r\nHost: a.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa\r\n\r\n6\r\nfoobar\r\n3\r\nbaz\r\n0\r\n\r\n", 8, ErrBodyTooLarge)
}
-func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
- var req Response
+func testResponseReadLimitBodyError(t *testing.T, s string, maxBodySize int, expectedErr error) {
+ var resp Response
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
- err := req.ReadLimitBody(br, maxBodySize)
+ err := resp.ReadLimitBody(br, maxBodySize)
if err == nil {
t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize)
}
- if err != ErrBodyTooLarge {
- t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize)
+ if err != expectedErr {
+ t.Fatalf("unexpected error: %v. Expecting %v. s=%q, maxBodySize=%d", err, expectedErr, s, maxBodySize)
}
}
func testResponseReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) {
- var req Response
+ var resp Response
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
- if err := req.ReadLimitBody(br, maxBodySize); err != nil {
- t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize)
+ if err := resp.ReadLimitBody(br, maxBodySize); err != nil {
+ t.Fatalf("unexpected error: %v. s=%q, maxBodySize=%d", err, s, maxBodySize)
}
}
-func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
+func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int, expectedErr error) {
var req Request
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
@@ -1358,8 +1689,8 @@ func testRequestReadLimitBodyError(t *testing.T, s string, maxBodySize int) {
if err == nil {
t.Fatalf("expecting error. s=%q, maxBodySize=%d", s, maxBodySize)
}
- if err != ErrBodyTooLarge {
- t.Fatalf("unexpected error: %s. Expecting %s. s=%q, maxBodySize=%d", err, ErrBodyTooLarge, s, maxBodySize)
+ if err != expectedErr {
+ t.Fatalf("unexpected error: %v. Expecting %v. s=%q, maxBodySize=%d", err, expectedErr, s, maxBodySize)
}
}
@@ -1368,7 +1699,7 @@ func testRequestReadLimitBodySuccess(t *testing.T, s string, maxBodySize int) {
r := bytes.NewBufferString(s)
br := bufio.NewReader(r)
if err := req.ReadLimitBody(br, maxBodySize); err != nil {
- t.Fatalf("unexpected error: %s. s=%q, maxBodySize=%d", err, s, maxBodySize)
+ t.Fatalf("unexpected error: %v. s=%q, maxBodySize=%d", err, s, maxBodySize)
}
}
@@ -1416,16 +1747,16 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := req.Write(bw); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
var req1 Request
br := bufio.NewReader(&w)
if err := req1.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(req1.Header.Host()) != "google.com" {
t.Fatalf("unexpected host: %q. Expecting %q", req1.Header.Host(), "google.com")
@@ -1447,52 +1778,49 @@ func TestRequestWriteRequestURINoHost(t *testing.T) {
func TestSetRequestBodyStreamFixedSize(t *testing.T) {
t.Parallel()
- testSetRequestBodyStream(t, "a", false)
- testSetRequestBodyStream(t, string(createFixedBody(4097)), false)
- testSetRequestBodyStream(t, string(createFixedBody(100500)), false)
+ testSetRequestBodyStream(t, "a")
+ testSetRequestBodyStream(t, string(createFixedBody(4097)))
+ testSetRequestBodyStream(t, string(createFixedBody(100500)))
}
func TestSetResponseBodyStreamFixedSize(t *testing.T) {
t.Parallel()
- testSetResponseBodyStream(t, "a", false)
- testSetResponseBodyStream(t, string(createFixedBody(4097)), false)
- testSetResponseBodyStream(t, string(createFixedBody(100500)), false)
+ testSetResponseBodyStream(t, "a")
+ testSetResponseBodyStream(t, string(createFixedBody(4097)))
+ testSetResponseBodyStream(t, string(createFixedBody(100500)))
}
func TestSetRequestBodyStreamChunked(t *testing.T) {
t.Parallel()
- testSetRequestBodyStream(t, "", true)
+ testSetRequestBodyStreamChunked(t, "", map[string]string{"Foo": "bar"})
body := "foobar baz aaa bbb ccc"
- testSetRequestBodyStream(t, body, true)
+ testSetRequestBodyStreamChunked(t, body, nil)
body = string(createFixedBody(10001))
- testSetRequestBodyStream(t, body, true)
+ testSetRequestBodyStreamChunked(t, body, map[string]string{"Foo": "test", "Bar": "test"})
}
func TestSetResponseBodyStreamChunked(t *testing.T) {
t.Parallel()
- testSetResponseBodyStream(t, "", true)
+ testSetResponseBodyStreamChunked(t, "", map[string]string{"Foo": "bar"})
body := "foobar baz aaa bbb ccc"
- testSetResponseBodyStream(t, body, true)
+ testSetResponseBodyStreamChunked(t, body, nil)
body = string(createFixedBody(10001))
- testSetResponseBodyStream(t, body, true)
+ testSetResponseBodyStreamChunked(t, body, map[string]string{"Foo": "test", "Bar": "test"})
}
-func testSetRequestBodyStream(t *testing.T, body string, chunked bool) {
+func testSetRequestBodyStream(t *testing.T, body string) {
var req Request
req.Header.SetHost("foobar.com")
req.Header.SetMethod(MethodPost)
bodySize := len(body)
- if chunked {
- bodySize = -1
- }
if req.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
@@ -1504,28 +1832,72 @@ func testSetRequestBodyStream(t *testing.T, body string, chunked bool) {
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := req.Write(bw); err != nil {
- t.Fatalf("unexpected error when writing request: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when writing request: %v. body=%q", err, body)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error when flushing request: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when flushing request: %v. body=%q", err, body)
}
var req1 Request
br := bufio.NewReader(&w)
if err := req1.Read(br); err != nil {
- t.Fatalf("unexpected error when reading request: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when reading request: %v. body=%q", err, body)
}
if string(req1.Body()) != body {
t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body)
}
}
-func testSetResponseBodyStream(t *testing.T, body string, chunked bool) {
+func testSetRequestBodyStreamChunked(t *testing.T, body string, trailer map[string]string) {
+ var req Request
+ req.Header.SetHost("foobar.com")
+ req.Header.SetMethod(MethodPost)
+
+ if req.IsBodyStream() {
+ t.Fatalf("IsBodyStream must return false")
+ }
+ req.SetBodyStream(bytes.NewBufferString(body), -1)
+ if !req.IsBodyStream() {
+ t.Fatalf("IsBodyStream must return true")
+ }
+
+ var w bytes.Buffer
+ bw := bufio.NewWriter(&w)
+ for k := range trailer {
+ err := req.Header.AddTrailer(k)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ }
+ if err := req.Write(bw); err != nil {
+ t.Fatalf("unexpected error when writing request: %v. body=%q", err, body)
+ }
+ for k, v := range trailer {
+ req.Header.Set(k, v)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("unexpected error when flushing request: %v. body=%q", err, body)
+ }
+
+ var req1 Request
+ br := bufio.NewReader(&w)
+ if err := req1.Read(br); err != nil {
+ t.Fatalf("unexpected error when reading request: %v. body=%q", err, body)
+ }
+ if string(req1.Body()) != body {
+ t.Fatalf("unexpected body %q. Expecting %q", req1.Body(), body)
+ }
+ for k, v := range trailer {
+ r := req.Header.Peek(k)
+ if string(r) != v {
+ t.Fatalf("unexpected trailer %q. Expecting %q. Got %q", k, v, r)
+ }
+ }
+}
+
+func testSetResponseBodyStream(t *testing.T, body string) {
var resp Response
bodySize := len(body)
- if chunked {
- bodySize = -1
- }
if resp.IsBodyStream() {
t.Fatalf("IsBodyStream must return false")
}
@@ -1537,22 +1909,66 @@ func testSetResponseBodyStream(t *testing.T, body string, chunked bool) {
var w bytes.Buffer
bw := bufio.NewWriter(&w)
if err := resp.Write(bw); err != nil {
- t.Fatalf("unexpected error when writing response: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when writing response: %v. body=%q", err, body)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("unexpected error when flushing response: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when flushing response: %v. body=%q", err, body)
}
var resp1 Response
br := bufio.NewReader(&w)
if err := resp1.Read(br); err != nil {
- t.Fatalf("unexpected error when reading response: %s. body=%q", err, body)
+ t.Fatalf("unexpected error when reading response: %v. body=%q", err, body)
}
if string(resp1.Body()) != body {
t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body)
}
}
+func testSetResponseBodyStreamChunked(t *testing.T, body string, trailer map[string]string) {
+ var resp Response
+ if resp.IsBodyStream() {
+ t.Fatalf("IsBodyStream must return false")
+ }
+ resp.SetBodyStream(bytes.NewBufferString(body), -1)
+ if !resp.IsBodyStream() {
+ t.Fatalf("IsBodyStream must return true")
+ }
+
+ var w bytes.Buffer
+ bw := bufio.NewWriter(&w)
+ for k := range trailer {
+ err := resp.Header.AddTrailer(k)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ }
+ if err := resp.Write(bw); err != nil {
+ t.Fatalf("unexpected error when writing response: %v. body=%q", err, body)
+ }
+ if err := bw.Flush(); err != nil {
+ t.Fatalf("unexpected error when flushing response: %v. body=%q", err, body)
+ }
+ for k, v := range trailer {
+ resp.Header.Set(k, v)
+ }
+
+ var resp1 Response
+ br := bufio.NewReader(&w)
+ if err := resp1.Read(br); err != nil {
+ t.Fatalf("unexpected error when reading response: %v. body=%q", err, body)
+ }
+ if string(resp1.Body()) != body {
+ t.Fatalf("unexpected body %q. Expecting %q", resp1.Body(), body)
+ }
+ for k, v := range trailer {
+ r := resp.Header.Peek(k)
+ if string(r) != v {
+ t.Fatalf("unexpected trailer %q. Expecting %q. Got %q", k, v, r)
+ }
+ }
+}
+
func TestRound2(t *testing.T) {
t.Parallel()
@@ -1566,6 +1982,7 @@ func TestRound2(t *testing.T) {
testRound2(t, 8, 8)
testRound2(t, 9, 16)
testRound2(t, 0x10001, 0x20000)
+ testRound2(t, math.MaxInt32-1, math.MaxInt32)
}
func testRound2(t *testing.T, n, expectedRound2 int) {
@@ -1579,19 +1996,19 @@ func TestRequestReadChunked(t *testing.T) {
var req Request
- s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail"
+ s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\nTrail: test\r\n\r\n"
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
err := req.Read(rb)
if err != nil {
- t.Fatalf("Unexpected error when reading chunked request: %s", err)
+ t.Fatalf("Unexpected error when reading chunked request: %v", err)
}
expectedBody := "abc12345"
if string(req.Body()) != expectedBody {
t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody)
}
- verifyRequestHeader(t, &req.Header, 8, "/foo", "google.com", "", "aa/bb")
- verifyTrailer(t, rb, "trail")
+ verifyRequestHeader(t, &req.Header, -1, "/foo", "google.com", "", "aa/bb")
+ verifyTrailer(t, rb, map[string]string{"Trail": "test"}, true)
}
// See: https://github.com/erikdubbelboer/fasthttp/issues/34
@@ -1605,7 +2022,7 @@ func TestRequestChunkedWhitespace(t *testing.T) {
rb := bufio.NewReader(r)
err := req.Read(rb)
if err != nil {
- t.Fatalf("Unexpected error when reading chunked request: %s", err)
+ t.Fatalf("Unexpected error when reading chunked request: %v", err)
}
expectedBody := "abc"
if string(req.Body()) != expectedBody {
@@ -1618,42 +2035,42 @@ func TestResponseReadWithoutBody(t *testing.T) {
var resp Response
- testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\nfoobar", false,
- 304, 1235, "aa", "foobar")
+ testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\n", false,
+ 304, 1235, "aa", nil)
- testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n123\r\nss", false,
- 204, -1, "aab", "123\r\nss")
+ testResponseReadWithoutBody(t, &resp, "HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\n", false,
+ 204, -1, "aab", map[string]string{"Foo": "bar"})
- testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\naaaa", false,
- 123, 3434, "xxx", "aaaa")
+ testResponseReadWithoutBody(t, &resp, "HTTP/1.1 123 AAA\r\nContent-Type: xxx\r\nContent-Length: 3434\r\n\r\n", false,
+ 123, 3434, "xxx", nil)
- testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nxxxx", true,
- 200, 123, "text/xml", "xxxx")
+ testResponseReadWithoutBody(t, &resp, "HTTP 200 OK\r\nContent-Type: text/xml\r\nContent-Length: 123\r\n\r\nfoobar\r\n", true,
+ 200, 123, "text/xml", nil)
// '100 Continue' must be skipped.
- testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\nfoobar", true,
- 329, 894, "qwe", "foobar")
+ testResponseReadWithoutBody(t, &resp, "HTTP/1.1 100 Continue\r\nFoo-bar: baz\r\n\r\nHTTP/1.1 329 aaa\r\nContent-Type: qwe\r\nContent-Length: 894\r\n\r\n", true,
+ 329, 894, "qwe", nil)
}
func testResponseReadWithoutBody(t *testing.T, resp *Response, s string, skipBody bool,
- expectedStatusCode, expectedContentLength int, expectedContentType, expectedTrailer string) {
+ expectedStatusCode, expectedContentLength int, expectedContentType string, expectedTrailer map[string]string) {
r := bytes.NewBufferString(s)
rb := bufio.NewReader(r)
resp.SkipBody = skipBody
err := resp.Read(rb)
if err != nil {
- t.Fatalf("Unexpected error when reading response without body: %s. response=%q", err, s)
+ t.Fatalf("Unexpected error when reading response without body: %v. response=%q", err, s)
}
if len(resp.Body()) != 0 {
t.Fatalf("Unexpected response body %q. Expected %q. response=%q", resp.Body(), "", s)
}
- verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType)
- verifyTrailer(t, rb, expectedTrailer)
+ verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "")
+ verifyResponseTrailer(t, &resp.Header, expectedTrailer)
// verify that ordinal response is read after null-body response
resp.SkipBody = false
testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa",
- 300, 5, "bar", "56789", "aaa")
+ 300, 5, "bar", "56789", nil)
}
func TestRequestSuccess(t *testing.T) {
@@ -1717,16 +2134,16 @@ func testResponseSuccess(t *testing.T, statusCode int, contentType, serverName,
bw := bufio.NewWriter(w)
err := resp.Write(bw)
if err != nil {
- t.Fatalf("Unexpected error when calling Response.Write(): %s", err)
+ t.Fatalf("Unexpected error when calling Response.Write(): %v", err)
}
if err = bw.Flush(); err != nil {
- t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err)
+ t.Fatalf("Unexpected error when flushing bufio.Writer: %v", err)
}
var resp1 Response
br := bufio.NewReader(w)
if err = resp1.Read(br); err != nil {
- t.Fatalf("Unexpected error when calling Response.Read(): %s", err)
+ t.Fatalf("Unexpected error when calling Response.Read(): %v", err)
}
if resp1.StatusCode() != expectedStatusCode {
t.Fatalf("Unexpected status code: %d. Expected %d", resp1.StatusCode(), expectedStatusCode)
@@ -1787,16 +2204,16 @@ func testRequestSuccess(t *testing.T, method, requestURI, host, userAgent, body,
bw := bufio.NewWriter(w)
err := req.Write(bw)
if err != nil {
- t.Fatalf("Unexpected error when calling Request.Write(): %s", err)
+ t.Fatalf("Unexpected error when calling Request.Write(): %v", err)
}
if err = bw.Flush(); err != nil {
- t.Fatalf("Unexpected error when flushing bufio.Writer: %s", err)
+ t.Fatalf("Unexpected error when flushing bufio.Writer: %v", err)
}
var req1 Request
br := bufio.NewReader(w)
if err = req1.Read(br); err != nil {
- t.Fatalf("Unexpected error when calling Request.Read(): %s", err)
+ t.Fatalf("Unexpected error when calling Request.Read(): %v", err)
}
if string(req1.Header.Method()) != expectedMethod {
t.Fatalf("Unexpected method: %q. Expected %q", req1.Header.Method(), expectedMethod)
@@ -1829,40 +2246,54 @@ func TestResponseReadSuccess(t *testing.T) {
// usual response
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: foo/bar\r\n\r\n0123456789",
- 200, 10, "foo/bar", "0123456789", "")
+ 200, 10, "foo/bar", "0123456789", nil)
// zero response
testResponseReadSuccess(t, resp, "HTTP/1.1 500 OK\r\nContent-Length: 0\r\nContent-Type: foo/bar\r\n\r\n",
- 500, 0, "foo/bar", "", "")
+ 500, 0, "foo/bar", "", nil)
// response with trailer
- testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nContent-Length: 5\r\nContent-Type: bar\r\n\r\n56789aaa",
- 300, 5, "bar", "56789", "aaa")
+ testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n",
+ 300, -1, "bar", "56789", map[string]string{"Foo": "bar"})
+
+ // response with trailer disableNormalizing
+ resp.Header.DisableNormalizing()
+ testResponseReadSuccess(t, resp, "HTTP/1.1 300 OK\r\nTransfer-Encoding: chunked\r\nContent-Type: bar\r\n\r\n5\r\n56789\r\n0\r\nfoo: bar\r\n\r\n",
+ 300, -1, "bar", "56789", map[string]string{"foo": "bar"})
- // no conent-length ('identity' transfer-encoding)
- testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxc",
- 200, 4, "foobar", "zxxc", "")
+ // no content-length ('identity' transfer-encoding)
+ testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foobar\r\n\r\nzxxxx",
+ 200, 5, "foobar", "zxxxx", nil)
// explicitly stated 'Transfer-Encoding: identity'
testResponseReadSuccess(t, resp, "HTTP/1.1 234 ss\r\nContent-Type: xxx\r\n\r\nxag",
- 234, 3, "xxx", "xag", "")
+ 234, 3, "xxx", "xag", nil)
// big 'identity' response
body := string(createFixedBody(100500))
testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\n\r\n"+body,
- 200, 100500, "aa", body, "")
+ 200, 100500, "aa", body, nil)
// chunked response
- testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\n\r\nzzzzz",
- 200, 6, "text/html", "qwerty", "zzzzz")
+ testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n4\r\nqwer\r\n2\r\nty\r\n0\r\nFoo2: bar2\r\n\r\n",
+ 200, -1, "text/html", "qwerty", map[string]string{"Foo2": "bar2"})
// chunked response with non-chunked Transfer-Encoding.
- testResponseReadSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\n\r\nwe",
- 230, 4, "text", "erty", "we")
+ testResponseReadSuccess(t, resp, "HTTP/1.1 230 OK\r\nContent-Type: text\r\nTransfer-Encoding: aaabbb\r\n\r\n2\r\ner\r\n2\r\nty\r\n0\r\nFoo3: bar3\r\n\r\n",
+ 230, -1, "text", "erty", map[string]string{"Foo3": "bar3"})
+
+ // chunked response with content-length
+ testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 123\r\nTransfer-Encoding: chunked\r\n\r\n4\r\ntest\r\n0\r\nFoo4:bar4\r\n\r\n",
+ 200, -1, "foo/bar", "test", map[string]string{"Foo4": "bar4"})
+
+ // chunked response with empty body
+ testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo5: bar5\r\n\r\n",
+ 200, -1, "text/html", "", map[string]string{"Foo5": "bar5"})
+
+ // chunked response with chunk extension
+ testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n3;ext\r\naaa\r\n0\r\nFoo6: bar6\r\n\r\n",
+ 200, -1, "text/html", "aaa", map[string]string{"Foo6": "bar6"})
- // zero chunked response
- testResponseReadSuccess(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\nzzz",
- 200, 0, "text/html", "", "zzz")
}
func TestResponseReadError(t *testing.T) {
@@ -1879,8 +2310,13 @@ func TestResponseReadError(t *testing.T) {
// empty body
testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\n")
- // short body
+ // invalid chunked body
testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nContent-Length: 1234\r\n\r\nshort")
+
+ // chunked body without end chunk
+ testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nTransfer-Encoding: chunked\r\n\r\nfoo")
+
+ testResponseReadError(t, resp, "HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nTransfer-Encoding: chunked\r\n\r\n3\r\nfoo")
}
func testResponseReadError(t *testing.T, resp *Response, response string) {
@@ -1891,25 +2327,25 @@ func testResponseReadError(t *testing.T, resp *Response, response string) {
t.Fatalf("Expecting error for response=%q", response)
}
- testResponseReadSuccess(t, resp, "HTTP/1.1 303 Redisred sedfs sdf\r\nContent-Type: aaa\r\nContent-Length: 5\r\n\r\nHELLOaaa",
- 303, 5, "aaa", "HELLO", "aaa")
+ testResponseReadSuccess(t, resp, "HTTP/1.1 303 Redisred sedfs sdf\r\nContent-Type: aaa\r\nContent-Length: 5\r\n\r\nHELLO",
+ 303, 5, "aaa", "HELLO", nil)
}
func testResponseReadSuccess(t *testing.T, resp *Response, response string, expectedStatusCode, expectedContentLength int,
- expectedContenType, expectedBody, expectedTrailer string) {
+ expectedContentType, expectedBody string, expectedTrailer map[string]string) {
r := bytes.NewBufferString(response)
rb := bufio.NewReader(r)
err := resp.Read(rb)
if err != nil {
- t.Fatalf("Unexpected error: %s", err)
+ t.Fatalf("Unexpected error: %v", err)
}
- verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContenType)
+ verifyResponseHeader(t, &resp.Header, expectedStatusCode, expectedContentLength, expectedContentType, "")
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
- verifyTrailer(t, rb, expectedTrailer)
+ verifyResponseTrailer(t, &resp.Header, expectedTrailer)
}
func TestReadBodyFixedSize(t *testing.T) {
@@ -2033,7 +2469,7 @@ func testRequestPostArgsError(t *testing.T, req *Request, s string) {
br := bufio.NewReader(r)
err := req.Read(br)
if err != nil {
- t.Fatalf("Unexpected error when reading %q: %s", s, err)
+ t.Fatalf("Unexpected error when reading %q: %v", s, err)
}
ss := req.PostArgs().String()
if len(ss) != 0 {
@@ -2046,7 +2482,7 @@ func testRequestPostArgsSuccess(t *testing.T, req *Request, s string, expectedAr
br := bufio.NewReader(r)
err := req.Read(br)
if err != nil {
- t.Fatalf("Unexpected error when reading %q: %s", s, err)
+ t.Fatalf("Unexpected error when reading %q: %v", s, err)
}
args := req.PostArgs()
@@ -2066,37 +2502,33 @@ func testRequestPostArgsSuccess(t *testing.T, req *Request, s string, expectedAr
func testReadBodyChunked(t *testing.T, bodySize int) {
body := createFixedBody(bodySize)
- chunkedBody := createChunkedBody(body)
- expectedTrailer := []byte("chunked shit")
- chunkedBody = append(chunkedBody, expectedTrailer...)
+ expectedTrailer := map[string]string{"Foo": "bar"}
+ chunkedBody := createChunkedBody(body, expectedTrailer, true)
r := bytes.NewBuffer(chunkedBody)
br := bufio.NewReader(r)
- b, err := readBody(br, -1, 0, nil)
+ b, err := readBodyChunked(br, 0, nil)
if err != nil {
- t.Fatalf("Unexpected error for bodySize=%d: %s. body=%q, chunkedBody=%q", bodySize, err, body, chunkedBody)
+ t.Fatalf("Unexpected error for bodySize=%d: %v. body=%q, chunkedBody=%q", bodySize, err, body, chunkedBody)
}
if !bytes.Equal(b, body) {
t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q. chunkedBody=%q", bodySize, b, body, chunkedBody)
}
- verifyTrailer(t, br, string(expectedTrailer))
+ verifyTrailer(t, br, expectedTrailer, false)
}
func testReadBodyFixedSize(t *testing.T, bodySize int) {
body := createFixedBody(bodySize)
- expectedTrailer := []byte("traler aaaa")
- bodyWithTrailer := append(body, expectedTrailer...)
-
- r := bytes.NewBuffer(bodyWithTrailer)
+ r := bytes.NewBuffer(body)
br := bufio.NewReader(r)
b, err := readBody(br, bodySize, 0, nil)
if err != nil {
- t.Fatalf("Unexpected error in ReadResponseBody(%d): %s", bodySize, err)
+ t.Fatalf("Unexpected error in ReadResponseBody(%d): %v", bodySize, err)
}
if !bytes.Equal(b, body) {
t.Fatalf("Unexpected response read for bodySize=%d: %q. Expected %q", bodySize, b, body)
}
- verifyTrailer(t, br, string(expectedTrailer))
+ verifyTrailer(t, br, nil, false)
}
func createFixedBody(bodySize int) []byte {
@@ -2107,7 +2539,7 @@ func createFixedBody(bodySize int) []byte {
return b
}
-func createChunkedBody(body []byte) []byte {
+func createChunkedBody(body []byte, trailer map[string]string, withEnd bool) []byte {
var b []byte
chunkSize := 1
for len(body) > 0 {
@@ -2120,7 +2552,17 @@ func createChunkedBody(body []byte) []byte {
body = body[chunkSize:]
chunkSize++
}
- return append(b, []byte("0\r\n\r\n")...)
+ if withEnd {
+ b = append(b, "0\r\n"...)
+ for k, v := range trailer {
+ b = append(b, k...)
+ b = append(b, ": "...)
+ b = append(b, v...)
+ b = append(b, "\r\n"...)
+ }
+ b = append(b, "\r\n"...)
+ }
+ return b
}
func TestWriteMultipartForm(t *testing.T) {
@@ -2140,13 +2582,12 @@ Content-Type: application/json
`, "\n", "\r\n", -1)
mr := multipart.NewReader(strings.NewReader(s), "foo")
form, err := mr.ReadForm(1024)
-
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := WriteMultipartForm(&w, form, "foo"); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if w.String() != s {
@@ -2309,7 +2750,7 @@ func TestResponseImmediateHeaderFlushFixedLength(t *testing.T) {
go func() {
if err := bw.Write(bb); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
waitForIt <- struct{}{}
}()
@@ -2358,7 +2799,7 @@ func TestResponseImmediateHeaderFlushFixedLengthSkipBody(t *testing.T) {
bw := &r
if err := bw.Write(bb); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if !strings.Contains(headersOnClose, "Content-Length: 0") {
@@ -2390,7 +2831,7 @@ func TestResponseImmediateHeaderFlushChunked(t *testing.T) {
go func() {
if err := bw.Write(bb); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
waitForIt <- struct{}{}
@@ -2440,7 +2881,7 @@ func TestResponseImmediateHeaderFlushChunkedNoBody(t *testing.T) {
bw := &r
if err := bw.Write(bb); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if !strings.Contains(headersOnClose, "Transfer-Encoding: chunked") {
@@ -2508,3 +2949,87 @@ func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) {
t.Fatalf("unexpected error value, got: %+v.", e.Error())
}
}
+
+func TestRequestMultipartFormPipeEmptyFormField(t *testing.T) {
+ t.Parallel()
+
+ pr, pw := io.Pipe()
+ mw := multipart.NewWriter(pw)
+
+ errs := make(chan error, 1)
+
+ go func() {
+ defer func() {
+ err := mw.Close()
+ if err != nil {
+ errs <- err
+ }
+ err = pw.Close()
+ if err != nil {
+ errs <- err
+ }
+ close(errs)
+ }()
+
+ if err := mw.WriteField("emptyField", ""); err != nil {
+ errs <- err
+ }
+ }()
+
+ var b bytes.Buffer
+ bw := bufio.NewWriter(&b)
+ err := writeBodyChunked(bw, pr)
+
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ for e := range errs {
+ t.Fatalf("unexpected error in goroutine multiwriter: %v", e)
+ }
+
+ testRequestMultipartFormPipeEmptyFormField(t, mw.Boundary(), b.Bytes(), 1)
+}
+
+func testRequestMultipartFormPipeEmptyFormField(t *testing.T, boundary string, formData []byte, partsCount int) []byte {
+ s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nTransfer-Encoding: chunked\r\n\r\n%s",
+ boundary, formData)
+
+ var req Request
+
+ r := bytes.NewBufferString(s)
+ br := bufio.NewReader(r)
+ if err := req.Read(br); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ f, err := req.MultipartForm()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ defer req.RemoveMultipartFormFiles()
+
+ if len(f.File) > 0 {
+ t.Fatalf("unexpected files found in the multipart form: %d", len(f.File))
+ }
+
+ if len(f.Value) != partsCount {
+ t.Fatalf("unexpected number of values found: %d. Expecting %d", len(f.Value), partsCount)
+ }
+
+ for k, vv := range f.Value {
+ if len(vv) != 1 {
+ t.Fatalf("unexpected number of values found for key=%q: %d. Expecting 1", k, len(vv))
+ }
+ if k != "emptyField" {
+ t.Fatalf("unexpected key=%q. Expecting %q", k, "emptyField")
+ }
+
+ v := vv[0]
+ if v != "" {
+ t.Fatalf("unexpected value=%q. expecting %q", v, "")
+ }
+ }
+
+ return req.Body()
+}
diff --git a/lbclient.go b/lbclient.go
index 46d14b7..34ff719 100644
--- a/lbclient.go
+++ b/lbclient.go
@@ -50,6 +50,7 @@ type LBClient struct {
cs []*lbClient
once sync.Once
+ mu sync.RWMutex
}
// DefaultLBClientTimeout is the default request timeout used by LBClient
@@ -80,6 +81,8 @@ func (cc *LBClient) Do(req *Request, resp *Response) error {
}
func (cc *LBClient) init() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
if len(cc.Clients) == 0 {
panic("BUG: LBClient.Clients cannot be empty")
}
@@ -91,9 +94,44 @@ func (cc *LBClient) init() {
}
}
+// AddClient adds a new client to the balanced clients
+// returns the new total number of clients
+func (cc *LBClient) AddClient(c BalancingClient) int {
+ cc.mu.Lock()
+ cc.cs = append(cc.cs, &lbClient{
+ c: c,
+ healthCheck: cc.HealthCheck,
+ })
+ cc.mu.Unlock()
+ return len(cc.cs)
+}
+
+// RemoveClients removes clients using the provided callback
+// if rc returns true, the passed client will be removed
+// returns the new total number of clients
+func (cc *LBClient) RemoveClients(rc func(BalancingClient) bool) int {
+ cc.mu.Lock()
+ n := 0
+ for _, cs := range cc.cs {
+ if rc(cs.c) {
+ continue
+ }
+ cc.cs[n] = cs
+ n++
+ }
+ for i := n; i < len(cc.cs); i++ {
+ cc.cs[i] = nil
+ }
+ cc.cs = cc.cs[:n]
+
+ cc.mu.Unlock()
+ return len(cc.cs)
+}
+
func (cc *LBClient) get() *lbClient {
cc.once.Do(cc.init)
+ cc.mu.RLock()
cs := cc.cs
minC := cs[0]
@@ -108,6 +146,7 @@ func (cc *LBClient) get() *lbClient {
minT = t
}
}
+ cc.mu.RUnlock()
return minC
}
diff --git a/lbclient_example_test.go b/lbclient_example_test.go
index a9e240d..3cc7e24 100644
--- a/lbclient_example_test.go
+++ b/lbclient_example_test.go
@@ -31,7 +31,7 @@ func ExampleLBClient() {
url := fmt.Sprintf("http://abcedfg/foo/bar/%d", i)
req.SetRequestURI(url)
if err := lbc.Do(&req, &resp); err != nil {
- log.Fatalf("Error when sending request: %s", err)
+ log.Fatalf("Error when sending request: %v", err)
}
if resp.StatusCode() != fasthttp.StatusOK {
log.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK)
diff --git a/nocopy.go b/nocopy.go
index 9664cb0..5e41bdf 100644
--- a/nocopy.go
+++ b/nocopy.go
@@ -7,5 +7,5 @@ package fasthttp
// and also: https://stackoverflow.com/questions/52494458/nocopy-minimal-example
type noCopy struct{} //nolint:unused
-func (*noCopy) Lock() {}
-func (*noCopy) Unlock() {}
+func (*noCopy) Lock() {} //nolint:unused
+func (*noCopy) Unlock() {} //nolint:unused
diff --git a/peripconn.go b/peripconn.go
index afd2a92..c2d1827 100644
--- a/peripconn.go
+++ b/peripconn.go
@@ -48,8 +48,10 @@ type perIPConn struct {
func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
v := counter.pool.Get()
if v == nil {
- v = &perIPConn{
+ return &perIPConn{
perIPConnCounter: counter,
+ Conn: conn,
+ ip: ip,
}
}
c := v.(*perIPConn)
diff --git a/peripconn_test.go b/peripconn_test.go
index b67e8a9..e2137c0 100644
--- a/peripconn_test.go
+++ b/peripconn_test.go
@@ -16,7 +16,7 @@ func testIPxUint32(t *testing.T, n uint32) {
ip := uint322ip(n)
nn := ip2uint32(ip)
if n != nn {
- t.Fatalf("Unexpected value=%d for ip=%s. Expected %d", nn, ip, n)
+ t.Fatalf("Unexpected value=%d for ip=%q. Expected %d", nn, ip, n)
}
}
diff --git a/prefork/prefork_test.go b/prefork/prefork_test.go
index 19b2e47..9a9a873 100644
--- a/prefork/prefork_test.go
+++ b/prefork/prefork_test.go
@@ -48,7 +48,7 @@ func Test_New(t *testing.T) {
p := New(s)
if p.Network != defaultNetwork {
- t.Errorf("Prefork.Netork == %s, want %s", p.Network, defaultNetwork)
+ t.Errorf("Prefork.Netork == %q, want %q", p.Network, defaultNetwork)
}
if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() {
@@ -82,11 +82,11 @@ func Test_listen(t *testing.T) {
lnAddr := ln.Addr().String()
if lnAddr != addr {
- t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
+ t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.Network != defaultNetwork {
- t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork)
+ t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
procs := runtime.GOMAXPROCS(0)
@@ -119,11 +119,11 @@ func Test_setTCPListenerFiles(t *testing.T) {
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
- t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
+ t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.Network != defaultNetwork {
- t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork)
+ t.Errorf("Prefork.Network == %q, want %q", p.Network, defaultNetwork)
}
if len(p.files) != 1 {
@@ -155,7 +155,7 @@ func Test_ListenAndServe(t *testing.T) {
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
- t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
+ t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
@@ -187,7 +187,7 @@ func Test_ListenAndServeTLS(t *testing.T) {
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
- t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
+ t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
@@ -219,7 +219,7 @@ func Test_ListenAndServeTLSEmbed(t *testing.T) {
lnAddr := p.ln.Addr().String()
if lnAddr != addr {
- t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr)
+ t.Errorf("Prefork.Addr == %q, want %q", lnAddr, addr)
}
if p.ln == nil {
diff --git a/requestctx_setbodystreamwriter_example_test.go b/requestctx_setbodystreamwriter_example_test.go
index 6bdcb81..af9b09c 100644
--- a/requestctx_setbodystreamwriter_example_test.go
+++ b/requestctx_setbodystreamwriter_example_test.go
@@ -12,7 +12,7 @@ import (
func ExampleRequestCtx_SetBodyStreamWriter() {
// Start fasthttp server for streaming responses.
if err := fasthttp.ListenAndServe(":8080", responseStreamHandler); err != nil {
- log.Fatalf("unexpected error in server: %s", err)
+ log.Fatalf("unexpected error in server: %v", err)
}
}
diff --git a/reuseport/reuseport.go b/reuseport/reuseport.go
index 161f4e1..6e13acb 100644
--- a/reuseport/reuseport.go
+++ b/reuseport/reuseport.go
@@ -1,5 +1,5 @@
-//go:build !windows
-// +build !windows
+//go:build !windows && !aix
+// +build !windows,!aix
// Package reuseport provides TCP net.Listener with SO_REUSEPORT support.
//
@@ -21,8 +21,8 @@ import (
// The returned listener tries enabling the following TCP options, which usually
// have positive impact on performance:
//
-// - TCP_DEFER_ACCEPT. This option expects that the server reads from accepted
-// connections before writing to them.
+// - TCP_DEFER_ACCEPT. This option expects that the server reads from accepted
+// connections before writing to them.
//
// - TCP_FASTOPEN. See https://lwn.net/Articles/508865/ for details.
//
diff --git a/reuseport/reuseport_aix.go b/reuseport/reuseport_aix.go
new file mode 100644
index 0000000..c04c951
--- /dev/null
+++ b/reuseport/reuseport_aix.go
@@ -0,0 +1,25 @@
+package reuseport
+
+import (
+ "context"
+ "net"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+)
+
+var listenConfig = net.ListenConfig{
+ Control: func(network, address string, c syscall.RawConn) (err error) {
+ return c.Control(func(fd uintptr) {
+ err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEADDR, 1)
+ if err == nil {
+ err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_REUSEPORT, 1)
+ }
+ })
+ },
+}
+
+// Listen returns a TCP listener with the SO_REUSEADDR and SO_REUSEPORT options set.
+func Listen(network, addr string) (net.Listener, error) {
+ return listenConfig.Listen(context.Background(), network, addr)
+}
diff --git a/reuseport/reuseport_error.go b/reuseport/reuseport_error.go
index 3e29d42..062f09b 100644
--- a/reuseport/reuseport_error.go
+++ b/reuseport/reuseport_error.go
@@ -11,5 +11,5 @@ type ErrNoReusePort struct {
// Error implements error interface.
func (e *ErrNoReusePort) Error() string {
- return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %s", e.err)
+ return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %v", e.err)
}
diff --git a/reuseport/reuseport_example_test.go b/reuseport/reuseport_example_test.go
index 6bac22a..82b8250 100644
--- a/reuseport/reuseport_example_test.go
+++ b/reuseport/reuseport_example_test.go
@@ -11,11 +11,11 @@ import (
func ExampleListen() {
ln, err := reuseport.Listen("tcp4", "localhost:12345")
if err != nil {
- log.Fatalf("error in reuseport listener: %s", err)
+ log.Fatalf("error in reuseport listener: %v", err)
}
if err = fasthttp.Serve(ln, requestHandler); err != nil {
- log.Fatalf("error in fasthttp Server: %s", err)
+ log.Fatalf("error in fasthttp Server: %v", err)
}
}
diff --git a/reuseport/reuseport_test.go b/reuseport/reuseport_test.go
index a79b962..8ebf3d8 100644
--- a/reuseport/reuseport_test.go
+++ b/reuseport/reuseport_test.go
@@ -1,17 +1,14 @@
package reuseport
import (
- "fmt"
- "io/ioutil"
"net"
"testing"
- "time"
)
func TestTCP4(t *testing.T) {
t.Parallel()
- testNewListener(t, "tcp4", "localhost:10081", 20, 1000)
+ testNewListener(t, "tcp4", "localhost:10081")
}
func TestTCP6(t *testing.T) {
@@ -19,14 +16,14 @@ func TestTCP6(t *testing.T) {
// Run this test only if tcp6 interface exists.
if hasLocalIPv6(t) {
- testNewListener(t, "tcp6", "[::1]:10082", 20, 1000)
+ testNewListener(t, "tcp6", "[::1]:10082")
}
}
func hasLocalIPv6(t *testing.T) bool {
addrs, err := net.InterfaceAddrs()
if err != nil {
- t.Fatalf("cannot obtain local interfaces: %s", err)
+ t.Fatalf("cannot obtain local interfaces: %v", err)
}
for _, a := range addrs {
if a.String() == "::1/128" {
@@ -36,87 +33,17 @@ func hasLocalIPv6(t *testing.T) bool {
return false
}
-func testNewListener(t *testing.T, network, addr string, serversCount, requestsCount int) {
- var lns []net.Listener
- doneCh := make(chan struct{}, serversCount)
-
- for i := 0; i < serversCount; i++ {
- ln, err := Listen(network, addr)
- if err != nil {
- t.Fatalf("cannot create listener %d: %s", i, err)
- }
- go func() {
- serveEcho(t, ln)
- doneCh <- struct{}{}
- }()
- lns = append(lns, ln)
- }
-
- for i := 0; i < requestsCount; i++ {
- c, err := net.Dial(network, addr)
- if err != nil {
- t.Fatalf("%d. unexpected error when dialing: %s", i, err)
- }
- req := fmt.Sprintf("request number %d", i)
- if _, err = c.Write([]byte(req)); err != nil {
- t.Fatalf("%d. unexpected error when writing request: %s", i, err)
- }
- if err = c.(*net.TCPConn).CloseWrite(); err != nil {
- t.Fatalf("%d. unexpected error when closing write end of the connection: %s", i, err)
- }
-
- var resp []byte
- ch := make(chan struct{})
- go func() {
- if resp, err = ioutil.ReadAll(c); err != nil {
- t.Errorf("%d. unexpected error when reading response: %s", i, err)
- }
- close(ch)
- }()
- select {
- case <-ch:
- case <-time.After(250 * time.Millisecond):
- t.Fatalf("%d. timeout when waiting for response", i)
- }
-
- if string(resp) != req {
- t.Fatalf("%d. unexpected response %q. Expecting %q", i, resp, req)
- }
- if err = c.Close(); err != nil {
- t.Fatalf("%d. unexpected error when closing connection: %s", i, err)
- }
- }
-
- for _, ln := range lns {
- if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error when closing listener: %s", err)
- }
+func testNewListener(t *testing.T, network, addr string) {
+ ln1, err := Listen(network, addr)
+ if err != nil {
+ t.Fatalf("cannot create listener %v", err)
}
- for i := 0; i < serversCount; i++ {
- select {
- case <-doneCh:
- case <-time.After(200 * time.Millisecond):
- t.Fatalf("timeout when waiting for servers to be closed")
- }
+ ln2, err := Listen(network, addr)
+ if err != nil {
+ t.Fatalf("cannot create listener %v", err)
}
-}
-func serveEcho(t *testing.T, ln net.Listener) {
- for {
- c, err := ln.Accept()
- if err != nil {
- break
- }
- req, err := ioutil.ReadAll(c)
- if err != nil {
- t.Fatalf("unexpected error when reading request: %s", err)
- }
- if _, err = c.Write(req); err != nil {
- t.Fatalf("unexpected error when writing response: %s", err)
- }
- if err = c.Close(); err != nil {
- t.Fatalf("unexpected error when closing connection: %s", err)
- }
- }
+ _ = ln1.Close()
+ _ = ln2.Close()
}
diff --git a/server.go b/server.go
index 4fef94c..9c879ac 100644
--- a/server.go
+++ b/server.go
@@ -248,6 +248,10 @@ type Server struct {
// Deprecated: Use IdleTimeout instead.
MaxKeepaliveDuration time.Duration
+ // MaxIdleWorkerDuration is the maximum idle time of a single worker in the underlying
+ // worker pool of the Server. Idle workers beyond this time will be cleared.
+ MaxIdleWorkerDuration time.Duration
+
// Period between tcp keep-alive messages.
//
// TCP keep-alive period is determined by operation system by default.
@@ -288,7 +292,7 @@ type Server struct {
// Rejects all non-GET requests if set to true.
//
// This option is useful as anti-DoS protection for servers
- // accepting only GET requests. The request size is limited
+ // accepting only GET requests and HEAD requests. The request size is limited
// by ReadBufferSize if GetOnly is set.
//
// Server accepts all the requests by default.
@@ -391,22 +395,40 @@ type Server struct {
// By default standard logger from log package is used.
Logger Logger
- tlsConfig *tls.Config
+ // TLSConfig optionally provides a TLS configuration for use
+ // by ServeTLS, ServeTLSEmbed, ListenAndServeTLS, ListenAndServeTLSEmbed,
+ // AppendCert, AppendCertEmbed and NextProto.
+ //
+ // Note that this value is cloned by ServeTLS, ServeTLSEmbed, ListenAndServeTLS
+ // and ListenAndServeTLSEmbed, so it's not possible to modify the configuration
+ // with methods like tls.Config.SetSessionTicketKeys.
+ // To use SetSessionTicketKeys, use Server.Serve with a TLS Listener
+ // instead.
+ TLSConfig *tls.Config
+
+ // FormValueFunc, which is used by RequestCtx.FormValue and support for customising
+ // the behaviour of the RequestCtx.FormValue function.
+ //
+ // NetHttpFormValueFunc gives a FormValueFunc func implementation that is consistent with net/http.
+ FormValueFunc FormValueFunc
+
nextProtos map[string]ServeHandler
concurrency uint32
concurrencyCh chan struct{}
perIPConnCounter perIPConnCounter
- serverName atomic.Value
ctxPool sync.Pool
readerPool sync.Pool
writerPool sync.Pool
hijackConnPool sync.Pool
- // We need to know our listeners so we can close them in Shutdown().
+ // We need to know our listeners and idle connections so we can close them in Shutdown().
ln []net.Listener
+ idleConns map[net.Conn]time.Time
+ idleConnsMu sync.Mutex
+
mu sync.Mutex
open int32
stop int32
@@ -465,7 +487,7 @@ func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string,
}
}
-//RequestConfig configure the per request deadline and body limits
+// RequestConfig configure the per request deadline and body limits
type RequestConfig struct {
// ReadTimeout is the maximum duration for reading the entire
// request body.
@@ -493,11 +515,11 @@ func CompressHandler(h RequestHandler) RequestHandler {
//
// Level is the desired compression level:
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func CompressHandlerLevel(h RequestHandler, level int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
@@ -515,18 +537,18 @@ func CompressHandlerLevel(h RequestHandler, level int) RequestHandler {
//
// brotliLevel is the desired compression level for brotli.
//
-// * CompressBrotliNoCompression
-// * CompressBrotliBestSpeed
-// * CompressBrotliBestCompression
-// * CompressBrotliDefaultCompression
+// - CompressBrotliNoCompression
+// - CompressBrotliBestSpeed
+// - CompressBrotliBestCompression
+// - CompressBrotliDefaultCompression
//
// otherLevel is the desired compression level for gzip and deflate.
//
-// * CompressNoCompression
-// * CompressBestSpeed
-// * CompressBestCompression
-// * CompressDefaultCompression
-// * CompressHuffmanOnly
+// - CompressNoCompression
+// - CompressBestSpeed
+// - CompressBestCompression
+// - CompressDefaultCompression
+// - CompressHuffmanOnly
func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) RequestHandler {
return func(ctx *RequestCtx) {
h(ctx)
@@ -587,6 +609,7 @@ type RequestCtx struct {
hijackHandler HijackHandler
hijackNoResponse bool
+ formValueFunc FormValueFunc
}
// HijackHandler must process the hijacked connection c.
@@ -609,8 +632,8 @@ type HijackHandler func(c net.Conn)
//
// The server skips calling the handler in the following cases:
//
-// * 'Connection: close' header exists in either request or response.
-// * Unexpected error during response writing to the connection.
+// - 'Connection: close' header exists in either request or response.
+// - Unexpected error during response writing to the connection.
//
// The server stops processing requests from hijacked connections.
//
@@ -622,9 +645,8 @@ type HijackHandler func(c net.Conn)
// Arbitrary 'Connection: Upgrade' protocols may be implemented
// with HijackHandler. For instance,
//
-// * WebSocket ( https://en.wikipedia.org/wiki/WebSocket )
-// * HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 )
-//
+// - WebSocket ( https://en.wikipedia.org/wiki/WebSocket )
+// - HTTP/2.0 ( https://en.wikipedia.org/wiki/HTTP/2 )
func (ctx *RequestCtx) Hijack(handler HijackHandler) {
ctx.hijackHandler = handler
}
@@ -654,7 +676,7 @@ func (ctx *RequestCtx) Hijacked() bool {
// All the values are removed from ctx after returning from the top
// RequestHandler. Additionally, Close method is called on each value
// implementing io.Closer before removing the value from ctx.
-func (ctx *RequestCtx) SetUserValue(key string, value interface{}) {
+func (ctx *RequestCtx) SetUserValue(key interface{}, value interface{}) {
ctx.userValues.Set(key, value)
}
@@ -672,7 +694,7 @@ func (ctx *RequestCtx) SetUserValueBytes(key []byte, value interface{}) {
}
// UserValue returns the value stored via SetUserValue* under the given key.
-func (ctx *RequestCtx) UserValue(key string) interface{} {
+func (ctx *RequestCtx) UserValue(key interface{}) interface{} {
return ctx.userValues.Get(key)
}
@@ -682,11 +704,24 @@ func (ctx *RequestCtx) UserValueBytes(key []byte) interface{} {
return ctx.userValues.GetBytes(key)
}
-// VisitUserValues calls visitor for each existing userValue.
+// VisitUserValues calls visitor for each existing userValue with a key that is a string or []byte.
//
// visitor must not retain references to key and value after returning.
// Make key and/or value copies if you need storing them after returning.
func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) {
+ for i, n := 0, len(ctx.userValues); i < n; i++ {
+ kv := &ctx.userValues[i]
+ if _, ok := kv.key.(string); ok {
+ visitor(s2b(kv.key.(string)), kv.value)
+ }
+ }
+}
+
+// VisitUserValuesAll calls visitor for each existing userValue.
+//
+// visitor must not retain references to key and value after returning.
+// Make key and/or value copies if you need storing them after returning.
+func (ctx *RequestCtx) VisitUserValuesAll(visitor func(interface{}, interface{})) {
for i, n := 0, len(ctx.userValues); i < n; i++ {
kv := &ctx.userValues[i]
visitor(kv.key, kv.value)
@@ -699,7 +734,7 @@ func (ctx *RequestCtx) ResetUserValues() {
}
// RemoveUserValue removes the given key and the value under it in ctx.
-func (ctx *RequestCtx) RemoveUserValue(key string) {
+func (ctx *RequestCtx) RemoveUserValue(key interface{}) {
ctx.userValues.Remove(key)
}
@@ -760,12 +795,49 @@ func (ctx *RequestCtx) Conn() net.Conn {
return ctx.c
}
+func (ctx *RequestCtx) reset() {
+ ctx.userValues.Reset()
+ ctx.Request.Reset()
+ ctx.Response.Reset()
+ ctx.fbr.reset()
+
+ ctx.connID = 0
+ ctx.connRequestNum = 0
+ ctx.connTime = zeroTime
+ ctx.remoteAddr = nil
+ ctx.time = zeroTime
+ ctx.c = nil
+
+ // Don't reset ctx.s!
+ // We have a pool per server so the next time this ctx is used it
+ // will be assigned the same value again.
+ // ctx might still be in use for context.Done() and context.Err()
+ // which are safe to use as they only use ctx.s and no other value.
+
+ if ctx.timeoutResponse != nil {
+ ctx.timeoutResponse.Reset()
+ }
+
+ if ctx.timeoutTimer != nil {
+ stopTimer(ctx.timeoutTimer)
+ }
+
+ ctx.hijackHandler = nil
+ ctx.hijackNoResponse = false
+}
+
type firstByteReader struct {
c net.Conn
ch byte
byteRead bool
}
+func (r *firstByteReader) reset() {
+ r.c = nil
+ r.ch = 0
+ r.byteRead = false
+}
+
func (r *firstByteReader) Read(b []byte) (int, error) {
if len(b) == 0 {
return 0, nil
@@ -1030,35 +1102,66 @@ func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) {
//
// The value is searched in the following places:
//
-// * Query string.
-// * POST or PUT body.
+// - Query string.
+// - POST or PUT body.
//
// There are more fine-grained methods for obtaining form values:
//
-// * QueryArgs for obtaining values from query string.
-// * PostArgs for obtaining values from POST or PUT body.
-// * MultipartForm for obtaining values from multipart form.
-// * FormFile for obtaining uploaded files.
+// - QueryArgs for obtaining values from query string.
+// - PostArgs for obtaining values from POST or PUT body.
+// - MultipartForm for obtaining values from multipart form.
+// - FormFile for obtaining uploaded files.
//
// The returned value is valid until your request handler returns.
func (ctx *RequestCtx) FormValue(key string) []byte {
- v := ctx.QueryArgs().Peek(key)
- if len(v) > 0 {
- return v
+ if ctx.formValueFunc != nil {
+ return ctx.formValueFunc(ctx, key)
}
- v = ctx.PostArgs().Peek(key)
- if len(v) > 0 {
- return v
+ return defaultFormValue(ctx, key)
+}
+
+type FormValueFunc func(*RequestCtx, string) []byte
+
+var (
+ defaultFormValue = func(ctx *RequestCtx, key string) []byte {
+ v := ctx.QueryArgs().Peek(key)
+ if len(v) > 0 {
+ return v
+ }
+ v = ctx.PostArgs().Peek(key)
+ if len(v) > 0 {
+ return v
+ }
+ mf, err := ctx.MultipartForm()
+ if err == nil && mf.Value != nil {
+ vv := mf.Value[key]
+ if len(vv) > 0 {
+ return []byte(vv[0])
+ }
+ }
+ return nil
}
- mf, err := ctx.MultipartForm()
- if err == nil && mf.Value != nil {
- vv := mf.Value[key]
- if len(vv) > 0 {
- return []byte(vv[0])
+
+ // NetHttpFormValueFunc gives consistent behavior with net/http. POST and PUT body parameters take precedence over URL query string values.
+ NetHttpFormValueFunc = func(ctx *RequestCtx, key string) []byte {
+ v := ctx.PostArgs().Peek(key)
+ if len(v) > 0 {
+ return v
+ }
+ mf, err := ctx.MultipartForm()
+ if err == nil && mf.Value != nil {
+ vv := mf.Value[key]
+ if len(vv) > 0 {
+ return []byte(vv[0])
+ }
+ }
+ v = ctx.QueryArgs().Peek(key)
+ if len(v) > 0 {
+ return v
}
+ return nil
}
- return nil
-}
+)
// IsGet returns true if request method is GET.
func (ctx *RequestCtx) IsGet() bool {
@@ -1200,11 +1303,11 @@ func (ctx *RequestCtx) SuccessString(contentType, body string) {
//
// statusCode must have one of the following values:
//
-// * StatusMovedPermanently (301)
-// * StatusFound (302)
-// * StatusSeeOther (303)
-// * StatusTemporaryRedirect (307)
-// * StatusPermanentRedirect (308)
+// - StatusMovedPermanently (301)
+// - StatusFound (302)
+// - StatusSeeOther (303)
+// - StatusTemporaryRedirect (307)
+// - StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
@@ -1212,10 +1315,9 @@ func (ctx *RequestCtx) SuccessString(contentType, body string) {
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
-// strLocation = []byte("Location") // Put this with your top level var () declarations.
-// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
-// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
-//
+// strLocation = []byte("Location") // Put this with your top level var () declarations.
+// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
+// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
u := AcquireURI()
ctx.URI().CopyTo(u)
@@ -1229,11 +1331,11 @@ func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
//
// statusCode must have one of the following values:
//
-// * StatusMovedPermanently (301)
-// * StatusFound (302)
-// * StatusSeeOther (303)
-// * StatusTemporaryRedirect (307)
-// * StatusPermanentRedirect (308)
+// - StatusMovedPermanently (301)
+// - StatusFound (302)
+// - StatusSeeOther (303)
+// - StatusTemporaryRedirect (307)
+// - StatusPermanentRedirect (308)
//
// All other statusCode values are replaced by StatusFound (302).
//
@@ -1241,17 +1343,16 @@ func (ctx *RequestCtx) Redirect(uri string, statusCode int) {
// request uri. Fasthttp will always send an absolute uri back to the client.
// To send a relative uri you can use the following code:
//
-// strLocation = []byte("Location") // Put this with your top level var () declarations.
-// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
-// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
-//
+// strLocation = []byte("Location") // Put this with your top level var () declarations.
+// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri")
+// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently)
func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) {
s := b2s(uri)
ctx.Redirect(s, statusCode)
}
func (ctx *RequestCtx) redirect(uri []byte, statusCode int) {
- ctx.Response.Header.SetCanonical(strLocation, uri)
+ ctx.Response.Header.setNonSpecial(strLocation, uri)
statusCode = getRedirectStatusCode(statusCode)
ctx.Response.SetStatusCode(statusCode)
}
@@ -1289,6 +1390,10 @@ func (ctx *RequestCtx) ResetBody() {
// SendFile logs all the errors via ctx.Logger.
//
// See also ServeFile, FSHandler and FS.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func (ctx *RequestCtx) SendFile(path string) {
ServeFile(ctx, path)
}
@@ -1300,6 +1405,10 @@ func (ctx *RequestCtx) SendFile(path string) {
// SendFileBytes logs all the errors via ctx.Logger.
//
// See also ServeFileBytes, FSHandler and FS.
+//
+// WARNING: do not pass any user supplied paths to this function!
+// WARNING: if path is based on user input users will be able to request
+// any file on your filesystem! Use fasthttp.FS with a sane Root instead.
func (ctx *RequestCtx) SendFileBytes(path []byte) {
ServeFileBytes(ctx, path)
}
@@ -1375,9 +1484,9 @@ func (ctx *RequestCtx) SetBodyStream(bodyStream io.Reader, bodySize int) {
//
// This function may be used in the following cases:
//
-// * if response body is too big (more than 10MB).
-// * if response body is streamed from slow external sources.
-// * if response body must be streamed to the client in chunks.
+// - if response body is too big (more than 10MB).
+// - if response body is streamed from slow external sources.
+// - if response body must be streamed to the client in chunks.
// (aka `http server push`).
func (ctx *RequestCtx) SetBodyStreamWriter(sw StreamWriter) {
ctx.Response.SetBodyStreamWriter(sw)
@@ -1464,8 +1573,9 @@ func (s *Server) NextProto(key string, nph ServeHandler) {
if s.nextProtos == nil {
s.nextProtos = make(map[string]ServeHandler)
}
+
s.configTLS()
- s.tlsConfig.NextProtos = append(s.tlsConfig.NextProtos, key)
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, key)
s.nextProtos[key] = nph
}
@@ -1473,13 +1583,13 @@ func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
if tlsConn, ok := c.(connTLSer); ok {
if s.ReadTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
- panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
+ panic(fmt.Sprintf("BUG: error in SetReadDeadline(%v): %v", s.ReadTimeout, err))
}
}
if s.WriteTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil {
- panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err))
+ panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%v): %v", s.WriteTimeout, err))
}
}
@@ -1491,34 +1601,6 @@ func (s *Server) getNextProto(c net.Conn) (proto string, err error) {
return
}
-// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
-// connections. It's used by ListenAndServe, ListenAndServeTLS and
-// ListenAndServeTLSEmbed so dead TCP connections (e.g. closing laptop mid-download)
-// eventually go away.
-type tcpKeepaliveListener struct {
- *net.TCPListener
- keepalive bool
- keepalivePeriod time.Duration
-}
-
-func (ln tcpKeepaliveListener) Accept() (net.Conn, error) {
- tc, err := ln.AcceptTCP()
- if err != nil {
- return nil, err
- }
- if err := tc.SetKeepAlive(ln.keepalive); err != nil {
- tc.Close() //nolint:errcheck
- return nil, err
- }
- if ln.keepalivePeriod > 0 {
- if err := tc.SetKeepAlivePeriod(ln.keepalivePeriod); err != nil {
- tc.Close() //nolint:errcheck
- return nil, err
- }
- }
- return tc, nil
-}
-
// ListenAndServe serves HTTP requests from the given TCP4 addr.
//
// Pass custom listener to Serve if you need listening on non-TCP4 media
@@ -1530,13 +1612,6 @@ func (s *Server) ListenAndServe(addr string) error {
if err != nil {
return err
}
- if tcpln, ok := ln.(*net.TCPListener); ok {
- return s.Serve(tcpKeepaliveListener{
- TCPListener: tcpln,
- keepalive: s.TCPKeepalive,
- keepalivePeriod: s.TCPKeepalivePeriod,
- })
- }
return s.Serve(ln)
}
@@ -1547,14 +1622,14 @@ func (s *Server) ListenAndServe(addr string) error {
// The server sets the given file mode for the UNIX addr.
func (s *Server) ListenAndServeUNIX(addr string, mode os.FileMode) error {
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
- return fmt.Errorf("unexpected error when trying to remove unix socket file %q: %s", addr, err)
+ return fmt.Errorf("unexpected error when trying to remove unix socket file %q: %w", addr, err)
}
ln, err := net.Listen("unix", addr)
if err != nil {
return err
}
if err = os.Chmod(addr, mode); err != nil {
- return fmt.Errorf("cannot chmod %#o for %q: %s", mode, addr, err)
+ return fmt.Errorf("cannot chmod %#o for %q: %w", mode, addr, err)
}
return s.Serve(ln)
}
@@ -1575,13 +1650,6 @@ func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error {
if err != nil {
return err
}
- if tcpln, ok := ln.(*net.TCPListener); ok {
- return s.ServeTLS(tcpKeepaliveListener{
- TCPListener: tcpln,
- keepalive: s.TCPKeepalive,
- keepalivePeriod: s.TCPKeepalivePeriod,
- }, certFile, keyFile)
- }
return s.ServeTLS(ln, certFile, keyFile)
}
@@ -1601,13 +1669,6 @@ func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) e
if err != nil {
return err
}
- if tcpln, ok := ln.(*net.TCPListener); ok {
- return s.ServeTLSEmbed(tcpKeepaliveListener{
- TCPListener: tcpln,
- keepalive: s.TCPKeepalive,
- keepalivePeriod: s.TCPKeepalivePeriod,
- }, certData, keyData)
- }
return s.ServeTLSEmbed(ln, certData, keyData)
}
@@ -1624,19 +1685,19 @@ func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error {
s.mu.Unlock()
return err
}
- if s.tlsConfig == nil {
+ if s.TLSConfig == nil {
s.mu.Unlock()
return errNoCertOrKeyProvided
}
// BuildNameToCertificate has been deprecated since 1.14.
// But since we also support older versions we'll keep this here.
- s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck
+ s.TLSConfig.BuildNameToCertificate() //nolint:staticcheck
s.mu.Unlock()
return s.Serve(
- tls.NewListener(ln, s.tlsConfig),
+ tls.NewListener(ln, s.TLSConfig.Clone()),
)
}
@@ -1654,19 +1715,19 @@ func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error
s.mu.Unlock()
return err
}
- if s.tlsConfig == nil {
+ if s.TLSConfig == nil {
s.mu.Unlock()
return errNoCertOrKeyProvided
}
// BuildNameToCertificate has been deprecated since 1.14.
// But since we also support older versions we'll keep this here.
- s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck
+ s.TLSConfig.BuildNameToCertificate() //nolint:staticcheck
s.mu.Unlock()
return s.Serve(
- tls.NewListener(ln, s.tlsConfig),
+ tls.NewListener(ln, s.TLSConfig.Clone()),
)
}
@@ -1681,12 +1742,12 @@ func (s *Server) AppendCert(certFile, keyFile string) error {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
- return fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err)
+ return fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %w", certFile, keyFile, err)
}
s.configTLS()
+ s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert)
- s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert)
return nil
}
@@ -1698,21 +1759,19 @@ func (s *Server) AppendCertEmbed(certData, keyData []byte) error {
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
- return fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s",
+ return fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %w",
len(certData), len(keyData), err)
}
s.configTLS()
+ s.TLSConfig.Certificates = append(s.TLSConfig.Certificates, cert)
- s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert)
return nil
}
func (s *Server) configTLS() {
- if s.tlsConfig == nil {
- s.tlsConfig = &tls.Config{
- PreferServerCipherSuites: true,
- }
+ if s.TLSConfig == nil {
+ s.TLSConfig = &tls.Config{}
}
}
@@ -1745,11 +1804,12 @@ func (s *Server) Serve(ln net.Listener) error {
s.mu.Unlock()
wp := &workerPool{
- WorkerFunc: s.serveConn,
- MaxWorkersCount: maxWorkersCount,
- LogAllErrors: s.LogAllErrors,
- Logger: s.logger(),
- connState: s.setState,
+ WorkerFunc: s.serveConn,
+ MaxWorkersCount: maxWorkersCount,
+ LogAllErrors: s.LogAllErrors,
+ MaxIdleWorkerDuration: s.MaxIdleWorkerDuration,
+ Logger: s.logger(),
+ connState: s.setState,
}
wp.Start()
@@ -1806,6 +1866,17 @@ func (s *Server) Serve(ln net.Listener) error {
//
// Shutdown does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0.
func (s *Server) Shutdown() error {
+ return s.ShutdownWithContext(context.Background())
+}
+
+// ShutdownWithContext gracefully shuts down the server without interrupting any active connections.
+// ShutdownWithContext works by first closing all open listeners and then waiting for all connections to return to idle or context timeout and then shut down.
+//
+// When ShutdownWithContext is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil.
+// Make sure the program doesn't exit and waits instead for Shutdown to return.
+//
+// ShutdownWithContext does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0.
+func (s *Server) ShutdownWithContext(ctx context.Context) (err error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -1817,7 +1888,7 @@ func (s *Server) Shutdown() error {
}
for _, ln := range s.ln {
- if err := ln.Close(); err != nil {
+ if err = ln.Close(); err != nil {
return err
}
}
@@ -1828,20 +1899,31 @@ func (s *Server) Shutdown() error {
// Closing the listener will make Serve() call Stop on the worker pool.
// Setting .stop to 1 will make serveConn() break out of its loop.
- // Now we just have to wait until all workers are done.
+ // Now we just have to wait until all workers are done or timeout.
+ ticker := time.NewTicker(time.Millisecond * 100)
+ defer ticker.Stop()
+END:
for {
+ s.closeIdleConns()
+
if open := atomic.LoadInt32(&s.open); open == 0 {
break
}
// This is not an optimal solution but using a sync.WaitGroup
// here causes data races as it's hard to prevent Add() to be called
// while Wait() is waiting.
- time.Sleep(time.Millisecond * 100)
+ select {
+ case <-ctx.Done():
+ err = ctx.Err()
+ break END
+ case <-ticker.C:
+ continue
+ }
}
s.done = nil
s.ln = nil
- return nil
+ return err
}
func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) {
@@ -1851,13 +1933,13 @@ func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.
if c != nil {
panic("BUG: net.Listener returned non-nil conn and non-nil error")
}
- if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
- s.logger().Printf("Temporary error when accepting new connections: %s", netErr)
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ s.logger().Printf("Timeout error when accepting new connections: %v", netErr)
time.Sleep(time.Second)
continue
}
if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") {
- s.logger().Printf("Permanent error when accepting new connections: %s", err)
+ s.logger().Printf("Permanent error when accepting new connections: %v", err)
return nil, err
}
return nil, io.EOF
@@ -1865,6 +1947,20 @@ func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.
if c == nil {
panic("BUG: net.Listener returned (nil, nil)")
}
+
+ if tc, ok := c.(*net.TCPConn); ok && s.TCPKeepalive {
+ if err := tc.SetKeepAlive(s.TCPKeepalive); err != nil {
+ tc.Close() //nolint:errcheck
+ return nil, err
+ }
+ if s.TCPKeepalivePeriod > 0 {
+ if err := tc.SetKeepAlivePeriod(s.TCPKeepalivePeriod); err != nil {
+ tc.Close() //nolint:errcheck
+ return nil, err
+ }
+ }
+ }
+
if s.MaxConnsPerIP > 0 {
pic := wrapPerIPConn(s, c)
if pic == nil {
@@ -2030,17 +2126,14 @@ func (s *Server) serveConn(c net.Conn) (err error) {
// The next handler is responsible for setting its own deadlines.
if s.ReadTimeout > 0 || s.WriteTimeout > 0 {
if err := c.SetDeadline(zeroTime); err != nil {
- panic(fmt.Sprintf("BUG: error in SetDeadline(zeroTime): %s", err))
+ panic(fmt.Sprintf("BUG: error in SetDeadline(zeroTime): %v", err))
}
}
return handler(c)
}
- var serverName []byte
- if !s.NoDefaultServerHeader {
- serverName = s.getServerName()
- }
+ serverName := s.getServerName()
connRequestNum := uint64(0)
connID := nextConnID()
connTime := time.Now()
@@ -2063,10 +2156,8 @@ func (s *Server) serveConn(c net.Conn) (err error) {
hijackNoResponse bool
connectionClose bool
- isHTTP11 bool
- reqReset bool
- continueReadingRequest bool = true
+ continueReadingRequest = true
)
for {
connRequestNum++
@@ -2075,7 +2166,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if connRequestNum > 1 {
if d := s.idleTimeout(); d > 0 {
if err := c.SetReadDeadline(time.Now().Add(d)); err != nil {
- panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", d, err))
+ break
}
}
}
@@ -2115,15 +2206,17 @@ func (s *Server) serveConn(c net.Conn) (err error) {
ctx.Response.secureErrorLogMessage = s.SecureErrorLogMessage
if err == nil {
+ s.setState(c, StateActive)
+
if s.ReadTimeout > 0 {
if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil {
- panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err))
+ break
}
} else if s.IdleTimeout > 0 && connRequestNum > 1 {
// If this was an idle connection and the server has an IdleTimeout but
// no ReadTimeout then we should remove the ReadTimeout.
if err := c.SetReadDeadline(zeroTime); err != nil {
- panic(fmt.Sprintf("BUG: error in SetReadDeadline(zeroTime): %s", err))
+ break
}
}
if s.DisableHeaderNamesNormalizing {
@@ -2157,17 +2250,23 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if reqConf.ReadTimeout > 0 {
deadline := time.Now().Add(reqConf.ReadTimeout)
if err := c.SetReadDeadline(deadline); err != nil {
- panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err))
+ panic(fmt.Sprintf("BUG: error in SetReadDeadline(%v): %v", deadline, err))
}
}
if reqConf.MaxRequestBodySize > 0 {
maxRequestBodySize = reqConf.MaxRequestBodySize
+ } else if s.MaxRequestBodySize > 0 {
+ maxRequestBodySize = s.MaxRequestBodySize
+ } else {
+ maxRequestBodySize = DefaultMaxRequestBodySize
}
if reqConf.WriteTimeout > 0 {
writeTimeout = reqConf.WriteTimeout
+ } else {
+ writeTimeout = s.WriteTimeout
}
}
- //read body
+ // read body
if s.StreamRequestBody {
err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm)
} else {
@@ -2175,11 +2274,6 @@ func (s *Server) serveConn(c net.Conn) (err error) {
}
}
- if err == nil {
- // If we read any bytes off the wire, we're active.
- s.setState(c, StateActive)
- }
-
if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil {
releaseReader(s, br)
br = nil
@@ -2263,11 +2357,11 @@ func (s *Server) serveConn(c net.Conn) (err error) {
}
}
+ // store req.ConnectionClose so even if it was changed inside of handler
connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose()
- isHTTP11 = ctx.Request.Header.IsHTTP11()
- if serverName != nil {
- ctx.Response.Header.SetServerBytes(serverName)
+ if serverName != "" {
+ ctx.Response.Header.SetServer(serverName)
}
ctx.connID = connID
ctx.connRequestNum = connRequestNum
@@ -2285,47 +2379,43 @@ func (s *Server) serveConn(c net.Conn) (err error) {
timeoutResponse.CopyTo(&ctx.Response)
}
- if !ctx.IsGet() && ctx.IsHead() {
+ if ctx.IsHead() {
ctx.Response.SkipBody = true
}
- reqReset = true
- ctx.Request.Reset()
hijackHandler = ctx.hijackHandler
ctx.hijackHandler = nil
hijackNoResponse = ctx.hijackNoResponse && hijackHandler != nil
ctx.hijackNoResponse = false
- if s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn) {
- ctx.SetConnectionClose()
- }
-
if writeTimeout > 0 {
if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
- panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err))
+ panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%v): %v", writeTimeout, err))
}
previousWriteTimeout = writeTimeout
} else if previousWriteTimeout > 0 {
// We don't want a write timeout but we previously set one, remove it.
if err := c.SetWriteDeadline(zeroTime); err != nil {
- panic(fmt.Sprintf("BUG: error in SetWriteDeadline(zeroTime): %s", err))
+ panic(fmt.Sprintf("BUG: error in SetWriteDeadline(zeroTime): %v", err))
}
previousWriteTimeout = 0
}
- connectionClose = connectionClose || ctx.Response.ConnectionClose()
- connectionClose = connectionClose || ctx.Response.ConnectionClose() || (s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1)
+ connectionClose = connectionClose ||
+ (s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn)) ||
+ ctx.Response.Header.ConnectionClose() ||
+ (s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1)
if connectionClose {
- ctx.Response.Header.SetCanonical(strConnection, strClose)
- } else if !isHTTP11 {
- // Set 'Connection: keep-alive' response header for non-HTTP/1.1 request.
+ ctx.Response.Header.SetConnectionClose()
+ } else if !ctx.Request.Header.IsHTTP11() {
+ // Set 'Connection: keep-alive' response header for HTTP/1.0 request.
// There is no need in setting this header for http/1.1, since in http/1.1
// connections are keep-alive by default.
- ctx.Response.Header.SetCanonical(strConnection, strKeepAlive)
+ ctx.Response.Header.setNonSpecial(strConnection, strKeepAlive)
}
- if serverName != nil && len(ctx.Response.Header.Server()) == 0 {
- ctx.Response.Header.SetServerBytes(serverName)
+ if serverName != "" && len(ctx.Response.Header.Server()) == 0 {
+ ctx.Response.Header.SetServer(serverName)
}
if !hijackNoResponse {
@@ -2361,9 +2451,6 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if br != nil {
hjr = br
br = nil
-
- // br may point to ctx.fbr, so do not return ctx into pool below.
- ctx = nil
}
if bw != nil {
err = bw.Flush()
@@ -2377,7 +2464,7 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if err != nil {
break
}
- go hijackConnHandler(hjr, c, s, hijackHandler)
+ go hijackConnHandler(ctx, hjr, c, s, hijackHandler)
err = errHijacked
break
}
@@ -2386,10 +2473,13 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if rs, ok := ctx.Request.bodyStream.(*requestStream); ok {
releaseRequestStream(rs)
}
+ ctx.Request.bodyStream = nil
}
s.setState(c, StateIdle)
ctx.userValues.Reset()
+ ctx.Request.Reset()
+ ctx.Response.Reset()
if atomic.LoadInt32(&s.stop) == 1 {
err = nil
@@ -2403,25 +2493,21 @@ func (s *Server) serveConn(c net.Conn) (err error) {
if bw != nil {
releaseWriter(s, bw)
}
- if ctx != nil {
- // in unexpected cases the for loop will break
- // before request reset call. in such cases, call it before
- // release to fix #548
- if !reqReset {
- ctx.Request.Reset()
- }
+ if hijackHandler == nil {
s.releaseCtx(ctx)
}
+
return
}
func (s *Server) setState(nc net.Conn, state ConnState) {
+ s.trackConn(nc, state)
if hook := s.ConnState; hook != nil {
hook(nc, state)
}
}
-func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) {
+func hijackConnHandler(ctx *RequestCtx, r io.Reader, c net.Conn, s *Server, h HijackHandler) {
hjc := s.acquireHijackConn(r, c)
h(hjc)
@@ -2432,6 +2518,7 @@ func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) {
c.Close()
s.releaseHijackConn(hjc)
}
+ s.releaseCtx(ctx)
}
func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn {
@@ -2495,7 +2582,7 @@ func writeResponse(ctx *RequestCtx, w *bufio.Writer) error {
panic("BUG: cannot write timed out response")
}
err := ctx.Response.Write(w)
- ctx.Response.Reset()
+
return err
}
@@ -2576,17 +2663,21 @@ func releaseWriter(s *Server, w *bufio.Writer) {
func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) {
v := s.ctxPool.Get()
if v == nil {
- ctx = &RequestCtx{
- s: s,
- }
keepBodyBuffer := !s.ReduceMemoryUsage
+
+ ctx = new(RequestCtx)
ctx.Request.keepBodyBuffer = keepBodyBuffer
ctx.Response.keepBodyBuffer = keepBodyBuffer
+ ctx.s = s
} else {
ctx = v.(*RequestCtx)
}
+ if s.FormValueFunc != nil {
+ ctx.formValueFunc = s.FormValueFunc
+ }
ctx.c = c
- return
+
+ return ctx
}
// Init2 prepares ctx for passing to RequestHandler.
@@ -2642,6 +2733,9 @@ func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) {
// Done returns a channel that's closed when work done on behalf of this
// context should be canceled. Done may return nil if this context can
// never be canceled. Successive calls to Done return the same value.
+//
+// Note: Because creating a new channel for every request is just too expensive, so
+// RequestCtx.s.done is only closed when the server is shutting down
func (ctx *RequestCtx) Done() <-chan struct{} {
return ctx.s.done
}
@@ -2652,6 +2746,9 @@ func (ctx *RequestCtx) Done() <-chan struct{} {
// If Done is closed, Err returns a non-nil error explaining why:
// Canceled if the context was canceled (via server Shutdown)
// or DeadlineExceeded if the context's deadline passed.
+//
+// Note: Because creating a new channel for every request is just too expensive, so
+// RequestCtx.s.done is only closed when the server is shutting down
func (ctx *RequestCtx) Err() error {
select {
case <-ctx.s.done:
@@ -2668,10 +2765,7 @@ func (ctx *RequestCtx) Err() error {
// This method is present to make RequestCtx implement the context interface.
// This method is the same as calling ctx.UserValue(key)
func (ctx *RequestCtx) Value(key interface{}) interface{} {
- if keyString, ok := key.(string); ok {
- return ctx.UserValue(keyString)
- }
- return nil
+ return ctx.UserValue(key)
}
var fakeServer = &Server{
@@ -2709,36 +2803,28 @@ func (s *Server) releaseCtx(ctx *RequestCtx) {
if ctx.timeoutResponse != nil {
panic("BUG: cannot release timed out RequestCtx")
}
- ctx.c = nil
- ctx.remoteAddr = nil
- ctx.fbr.c = nil
- ctx.userValues.Reset()
+
+ ctx.reset()
s.ctxPool.Put(ctx)
}
-func (s *Server) getServerName() []byte {
- v := s.serverName.Load()
- var serverName []byte
- if v == nil {
- serverName = []byte(s.Name)
- if len(serverName) == 0 {
+func (s *Server) getServerName() string {
+ serverName := s.Name
+ if serverName == "" {
+ if !s.NoDefaultServerHeader {
serverName = defaultServerName
}
- s.serverName.Store(serverName)
- } else {
- serverName = v.([]byte)
}
return serverName
}
func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) {
- w.Write(statusLine(statusCode)) //nolint:errcheck
+ w.Write(formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode)))) //nolint:errcheck
- server := ""
- if !s.NoDefaultServerHeader {
- server = fmt.Sprintf("Server: %s\r\n", s.getServerName())
+ server := s.getServerName()
+ if server != "" {
+ server = fmt.Sprintf("Server: %s\r\n", server)
}
-
date := ""
if !s.NoDefaultDate {
serverDateOnce.Do(updateServerDate)
@@ -2765,7 +2851,7 @@ func defaultErrorHandler(ctx *RequestCtx, err error) {
}
}
-func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName []byte, err error) *bufio.Writer {
+func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName string, err error) *bufio.Writer {
errorHandler := defaultErrorHandler
if s.ErrorHandler != nil {
errorHandler = s.ErrorHandler
@@ -2773,18 +2859,55 @@ func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverNam
errorHandler(ctx, err)
- if serverName != nil {
- ctx.Response.Header.SetServerBytes(serverName)
+ if serverName != "" {
+ ctx.Response.Header.SetServer(serverName)
}
ctx.SetConnectionClose()
if bw == nil {
bw = acquireWriter(ctx)
}
+
writeResponse(ctx, bw) //nolint:errcheck
+ ctx.Response.Reset()
bw.Flush()
+
return bw
}
+func (s *Server) trackConn(c net.Conn, state ConnState) {
+ s.idleConnsMu.Lock()
+ switch state {
+ case StateIdle:
+ if s.idleConns == nil {
+ s.idleConns = make(map[net.Conn]time.Time)
+ }
+ s.idleConns[c] = time.Now()
+ case StateNew:
+ if s.idleConns == nil {
+ s.idleConns = make(map[net.Conn]time.Time)
+ }
+ // Count the connection as Idle after 5 seconds.
+ // Same as net/http.Server: https://github.com/golang/go/blob/85d7bab91d9a3ed1f76842e4328973ea75efef54/src/net/http/server.go#L2834-L2836
+ s.idleConns[c] = time.Now().Add(time.Second * 5)
+
+ default:
+ delete(s.idleConns, c)
+ }
+ s.idleConnsMu.Unlock()
+}
+
+func (s *Server) closeIdleConns() {
+ s.idleConnsMu.Lock()
+ now := time.Now()
+ for c, t := range s.idleConns {
+ if now.Sub(t) >= 0 {
+ _ = c.Close()
+ delete(s.idleConns, c)
+ }
+ }
+ s.idleConnsMu.Unlock()
+}
+
// A ConnState represents the state of a client connection to a server.
// It's used by the optional Server.ConnState hook.
type ConnState int
diff --git a/server_example_test.go b/server_example_test.go
index 68321fb..cb26cdc 100644
--- a/server_example_test.go
+++ b/server_example_test.go
@@ -27,7 +27,7 @@ func ExampleListenAndServe() {
//
// ListenAndServe returns only on error, so usually it blocks forever.
if err := fasthttp.ListenAndServe(listenAddr, requestHandler); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
@@ -39,7 +39,7 @@ func ExampleServe() {
// For example, unix socket listener or TLS listener.
ln, err := net.Listen("tcp4", "127.0.0.1:8080")
if err != nil {
- log.Fatalf("error in net.Listen: %s", err)
+ log.Fatalf("error in net.Listen: %v", err)
}
// This function will be called by the server for each incoming request.
@@ -55,7 +55,7 @@ func ExampleServe() {
//
// Serve returns on ln.Close() or error, so usually it blocks forever.
if err := fasthttp.Serve(ln, requestHandler); err != nil {
- log.Fatalf("error in Serve: %s", err)
+ log.Fatalf("error in Serve: %v", err)
}
}
@@ -82,7 +82,7 @@ func ExampleServer() {
//
// ListenAndServe returns only on error, so usually it blocks forever.
if err := s.ListenAndServe("127.0.0.1:80"); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
@@ -94,7 +94,7 @@ func ExampleRequestCtx_Hijack() {
var buf [1]byte
for {
if _, err := c.Read(buf[:]); err != nil {
- log.Printf("error when reading from hijacked connection: %s", err)
+ log.Printf("error when reading from hijacked connection: %v", err)
return
}
fmt.Fprintf(c, "You sent me %q. Waiting for new data\n", buf[:])
@@ -120,7 +120,7 @@ func ExampleRequestCtx_Hijack() {
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
@@ -151,7 +151,7 @@ func ExampleRequestCtx_TimeoutError() {
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
@@ -172,6 +172,6 @@ func ExampleRequestCtx_Logger() {
}
if err := fasthttp.ListenAndServe(":80", requestHandler); err != nil {
- log.Fatalf("error in ListenAndServe: %s", err)
+ log.Fatalf("error in ListenAndServe: %v", err)
}
}
diff --git a/server_test.go b/server_test.go
index 484fdf4..4599208 100644
--- a/server_test.go
+++ b/server_test.go
@@ -5,15 +5,18 @@ import (
"bytes"
"context"
"crypto/tls"
+ "errors"
"fmt"
"io"
- "io/ioutil"
"mime/multipart"
"net"
"os"
"reflect"
+ "regexp"
+ "runtime"
"strings"
"sync"
+ "sync/atomic"
"testing"
"time"
@@ -23,6 +26,15 @@ import (
// Make sure RequestCtx implements context.Context
var _ context.Context = &RequestCtx{}
+type closerWithRequestCtx struct {
+ ctx *RequestCtx
+ closeFunc func(ctx *RequestCtx) error
+}
+
+func (c *closerWithRequestCtx) Close() error {
+ return c.closeFunc(c.ctx)
+}
+
func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
t.Parallel()
@@ -37,13 +49,13 @@ func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
@@ -55,13 +67,13 @@ func TestServerCRNLAfterPost_Pipeline(t *testing.T) {
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -83,13 +95,13 @@ func TestServerCRNLAfterPost(t *testing.T) {
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
defer c.Close()
if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" +
@@ -101,7 +113,7 @@ func TestServerCRNLAfterPost(t *testing.T) {
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -122,13 +134,13 @@ func TestServerPipelineFlush(t *testing.T) {
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
t.Fatal(err)
@@ -152,7 +164,7 @@ func TestServerPipelineFlush(t *testing.T) {
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -160,12 +172,12 @@ func TestServerPipelineFlush(t *testing.T) {
// Since the second request takes 200ms to finish we expect the first one to be flushed earlier.
d := time.Since(start)
- if d > time.Millisecond*100 {
+ if d >= time.Millisecond*200 {
t.Fatalf("had to wait for %v", d)
}
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -188,13 +200,13 @@ func TestServerInvalidHeader(t *testing.T) {
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
c, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil {
t.Fatal(err)
@@ -203,7 +215,7 @@ func TestServerInvalidHeader(t *testing.T) {
br := bufio.NewReader(c)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusBadRequest {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest)
@@ -211,7 +223,7 @@ func TestServerInvalidHeader(t *testing.T) {
c, err = ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil {
t.Fatal(err)
@@ -219,7 +231,7 @@ func TestServerInvalidHeader(t *testing.T) {
br = bufio.NewReader(c)
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusBadRequest {
@@ -227,10 +239,10 @@ func TestServerInvalidHeader(t *testing.T) {
}
if err := c.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -240,7 +252,7 @@ func TestServerConnState(t *testing.T) {
states := make([]string, 0)
s := &Server{
Handler: func(ctx *RequestCtx) {},
- ConnState: func(conn net.Conn, state ConnState) {
+ ConnState: func(_ net.Conn, state ConnState) {
states = append(states, state.String())
},
}
@@ -250,7 +262,7 @@ func TestServerConnState(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -259,24 +271,24 @@ func TestServerConnState(t *testing.T) {
go func() {
c, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
// Send 2 requests on the same connection.
for i := 0; i < 2; i++ {
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
var resp Response
if err := resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
}
}
if err := c.Close(); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
// Give the server a little bit of time to transition the connection to the close state.
time.Sleep(time.Millisecond * 100)
@@ -290,7 +302,7 @@ func TestServerConnState(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -303,7 +315,7 @@ func TestServerConnState(t *testing.T) {
expected := []string{"new", "active", "idle", "active", "idle", "closed"}
if !reflect.DeepEqual(expected, states) {
- t.Fatalf("wrong state, expected %s, got %s", expected, states)
+ t.Fatalf("wrong state, expected %q, got %q", expected, states)
}
}
@@ -337,7 +349,7 @@ func TestSaveMultipartFile(t *testing.T) {
}
defer os.Remove("filea.txt")
- if c, err := ioutil.ReadFile("filea.txt"); err != nil {
+ if c, err := os.ReadFile("filea.txt"); err != nil {
t.Fatal(err)
} else if string(c) != filea {
t.Fatalf("filea changed expected %q got %q", filea, c)
@@ -357,7 +369,7 @@ func TestSaveMultipartFile(t *testing.T) {
}
defer os.Remove("fileb.txt")
- if c, err := ioutil.ReadFile("fileb.txt"); err != nil {
+ if c, err := os.ReadFile("fileb.txt"); err != nil {
t.Fatal(err)
} else if string(c) != fileb {
t.Fatalf("fileb changed expected %q got %q", fileb, c)
@@ -377,20 +389,20 @@ func TestServerName(t *testing.T) {
rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
- resp, err := ioutil.ReadAll(&rw.w)
+ resp, err := io.ReadAll(&rw.w)
if err != nil {
- t.Fatalf("Unexpected error from ReadAll: %s", err)
+ t.Fatalf("Unexpected error from ReadAll: %v", err)
}
return resp
}
resp := getReponse()
- if !bytes.Contains(resp, []byte("\r\nServer: "+string(defaultServerName)+"\r\n")) {
- t.Fatalf("Unexpected response %q expected Server: "+string(defaultServerName), resp)
+ if !bytes.Contains(resp, []byte("\r\nServer: "+defaultServerName+"\r\n")) {
+ t.Fatalf("Unexpected response %q expected Server: "+defaultServerName, resp)
}
// We can't just overwrite s.Name as fasthttp caches the name in an atomic.Value
@@ -478,7 +490,7 @@ func TestServerErrSmallBuffer(t *testing.T) {
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusRequestHeaderFieldsTooLarge {
@@ -534,7 +546,7 @@ func TestRequestCtxRedirectHTTPSSchemeless(t *testing.T) {
s := "GET /foo/bar?baz HTTP/1.1\nHost: aaa.com\n\n"
br := bufio.NewReader(bytes.NewBufferString(s))
if err := ctx.Request.Read(br); err != nil {
- t.Fatalf("cannot read request: %s", err)
+ t.Fatalf("cannot read request: %v", err)
}
ctx.Request.isTLS = true
@@ -560,12 +572,15 @@ func TestRequestCtxRedirect(t *testing.T) {
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd")
- testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html")
- testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html")
- testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz")
testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz")
testRequestCtxRedirect(t, "https://foo.com/bar?aaa", "//google.com/aaa?bb", "https://google.com/aaa?bb")
+
+ if runtime.GOOS != "windows" {
+ testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html")
+ testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html")
+ testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html")
+ }
}
func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) {
@@ -606,7 +621,7 @@ func TestServerResponseServerHeader(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -615,15 +630,15 @@ func TestServerResponseServerHeader(t *testing.T) {
go func() {
c, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusNotFound {
@@ -636,7 +651,7 @@ func TestServerResponseServerHeader(t *testing.T) {
t.Errorf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName)
}
if err = c.Close(); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(clientCh)
}()
@@ -648,7 +663,7 @@ func TestServerResponseServerHeader(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -687,7 +702,7 @@ func TestServerResponseBodyStream(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := Serve(ln, h); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -696,15 +711,15 @@ func TestServerResponseBodyStream(t *testing.T) {
go func() {
c, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if respH.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK)
@@ -713,7 +728,7 @@ func TestServerResponseBodyStream(t *testing.T) {
buf := make([]byte, 1024)
n, err := br.Read(buf)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
b := buf[:n]
if string(b) != "5\r\nfirst\r\n" {
@@ -721,9 +736,9 @@ func TestServerResponseBodyStream(t *testing.T) {
}
close(readyCh)
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" {
t.Errorf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n")
@@ -739,7 +754,7 @@ func TestServerResponseBodyStream(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -764,7 +779,7 @@ func TestServerDisableKeepalive(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -773,15 +788,15 @@ func TestServerDisableKeepalive(t *testing.T) {
go func() {
c, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c)
var resp Response
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -794,9 +809,9 @@ func TestServerDisableKeepalive(t *testing.T) {
}
// make sure the connection is closed
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if len(data) > 0 {
t.Errorf("unexpected data read from the connection: %q. Expecting empty data", data)
@@ -812,7 +827,7 @@ func TestServerDisableKeepalive(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -841,7 +856,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
Listener: ln,
}
if err := s.Serve(fakeLN); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -850,16 +865,16 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
go func() {
c1, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
c2, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusTooManyRequests {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
@@ -867,11 +882,11 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error when writing to the first connection: %s", err)
+ t.Errorf("unexpected error when writing to the first connection: %v", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
@@ -890,7 +905,7 @@ func TestServerMaxConnsPerIPLimit(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -921,7 +936,7 @@ type fakeIPConn struct {
func (conn *fakeIPConn) RemoteAddr() net.Addr {
addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789")
if err != nil {
- panic(fmt.Sprintf("BUG: unexpected error: %s", err))
+ panic(fmt.Sprintf("BUG: unexpected error: %v", err))
}
return addr
}
@@ -942,7 +957,7 @@ func TestServerConcurrencyLimit(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(serverCh)
}()
@@ -951,16 +966,16 @@ func TestServerConcurrencyLimit(t *testing.T) {
go func() {
c1, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
c2, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(c2)
var resp Response
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusServiceUnavailable {
t.Errorf("unexpected status code for the second connection: %d. Expecting %d",
@@ -968,11 +983,11 @@ func TestServerConcurrencyLimit(t *testing.T) {
}
if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
- t.Errorf("unexpected error when writing to the first connection: %s", err)
+ t.Errorf("unexpected error when writing to the first connection: %v", err)
}
br = bufio.NewReader(c1)
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code for the first connection: %d. Expecting %d",
@@ -991,7 +1006,7 @@ func TestServerConcurrencyLimit(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -1014,7 +1029,7 @@ func TestServerWriteFastError(t *testing.T) {
br := bufio.NewReader(&buf)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if resp.StatusCode() != StatusForbidden {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden)
@@ -1130,7 +1145,7 @@ func TestServerTLSReadTimeout(t *testing.T) {
select {
case err = <-r:
- case <-time.After(time.Second):
+ case <-time.After(time.Second * 2):
}
if err == nil {
@@ -1156,15 +1171,14 @@ func TestServerServeTLSEmbed(t *testing.T) {
ctx.Error("expecting tls", StatusBadRequest)
return
}
- scheme := ctx.URI().Scheme()
- if string(scheme) != "https" {
- ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", scheme, "https"), StatusBadRequest)
+ if !ctx.URI().isHttps() {
+ ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", ctx.URI().Scheme(), "https"), StatusBadRequest)
return
}
ctx.WriteString("success") //nolint:errcheck
})
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
@@ -1172,7 +1186,7 @@ func TestServerServeTLSEmbed(t *testing.T) {
// establish connection to the server
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
@@ -1180,7 +1194,7 @@ func TestServerServeTLSEmbed(t *testing.T) {
// send request
if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
// read response
@@ -1205,7 +1219,7 @@ func TestServerServeTLSEmbed(t *testing.T) {
// close the server
if err = ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case <-ch:
@@ -1262,7 +1276,7 @@ Connection: close
case "/upload":
f, err := ctx.MultipartForm()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if len(f.Value) != 1 {
t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
@@ -1284,17 +1298,17 @@ Connection: close
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var resp Response
@@ -1302,7 +1316,7 @@ Connection: close
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
- t.Errorf("error when reading response: %s", err)
+ t.Errorf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusSeeOther {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
@@ -1313,7 +1327,7 @@ Connection: close
}
if err := resp.Read(br); err != nil {
- t.Errorf("error when reading the second response: %s", err)
+ t.Errorf("error when reading the second response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -1332,7 +1346,7 @@ Connection: close
}
if err := ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -1357,12 +1371,12 @@ func TestServerGetWithContent(t *testing.T) {
rw.r.WriteString("GET / HTTP/1.1\r\nHost: mm.com\r\nContent-Length: 5\r\n\r\nabcde")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
resp := rw.w.String()
if !strings.HasSuffix(resp, "success") {
- t.Fatalf("unexpected response %s.", resp)
+ t.Fatalf("unexpected response %q.", resp)
}
}
@@ -1393,14 +1407,14 @@ func TestServerDisableHeaderNamesNormalizing(t *testing.T) {
rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue))
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.Header.DisableNormalizing()
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
hv := resp.Header.Peek(headerName)
@@ -1426,7 +1440,7 @@ func TestServerReduceMemoryUsageSerial(t *testing.T) {
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
@@ -1434,7 +1448,7 @@ func TestServerReduceMemoryUsageSerial(t *testing.T) {
testServerRequests(t, ln)
if err := ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -1457,7 +1471,7 @@ func TestServerReduceMemoryUsageConcurrent(t *testing.T) {
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
@@ -1478,7 +1492,7 @@ func TestServerReduceMemoryUsageConcurrent(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -1491,20 +1505,20 @@ func TestServerReduceMemoryUsageConcurrent(t *testing.T) {
func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) {
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
for i := 0; i < 10; i++ {
if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil {
- t.Fatalf("unexpected error on iteration %d: %s", i, err)
+ t.Fatalf("unexpected error on iteration %d: %v", i, err)
}
respCh := make(chan struct{})
go func() {
if err = resp.Read(br); err != nil {
- t.Errorf("unexpected error when reading response on iteration %d: %s", i, err)
+ t.Errorf("unexpected error when reading response on iteration %d: %v", i, err)
}
close(respCh)
}()
@@ -1516,7 +1530,7 @@ func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) {
}
if err = conn.Close(); err != nil {
- t.Fatalf("error when closing the connection: %s", err)
+ t.Fatalf("error when closing the connection: %v", err)
}
}
@@ -1533,34 +1547,34 @@ func TestServerHTTP10ConnectionKeepAlive(t *testing.T) {
}
})
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
- t.Fatalf("error when writing request: %s", err)
+ t.Fatalf("error when writing request: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n")
if err != nil {
- t.Fatalf("error when writing request: %s", err)
+ t.Fatalf("error when writing request: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if resp.ConnectionClose() {
t.Fatal("response mustn't have 'Connection: close' header")
}
if err = resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("response must have 'Connection: close' header")
@@ -1568,9 +1582,9 @@ func TestServerHTTP10ConnectionKeepAlive(t *testing.T) {
tailCh := make(chan struct{})
go func() {
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Errorf("error when reading tail: %s", err)
+ t.Errorf("error when reading tail: %v", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
@@ -1585,11 +1599,11 @@ func TestServerHTTP10ConnectionKeepAlive(t *testing.T) {
}
if err = conn.Close(); err != nil {
- t.Fatalf("error when closing the connection: %s", err)
+ t.Fatalf("error when closing the connection: %v", err)
}
if err = ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -1617,24 +1631,24 @@ func TestServerHTTP10ConnectionClose(t *testing.T) {
ctx.Response.Header.Set(HeaderConnection, "keep-alive")
})
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
_, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n")
if err != nil {
- t.Fatalf("error when writing request: %s", err)
+ t.Fatalf("error when writing request: %v", err)
}
br := bufio.NewReader(conn)
var resp Response
if err = resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if !resp.ConnectionClose() {
@@ -1643,9 +1657,9 @@ func TestServerHTTP10ConnectionClose(t *testing.T) {
tailCh := make(chan struct{})
go func() {
- tail, err := ioutil.ReadAll(br)
+ tail, err := io.ReadAll(br)
if err != nil {
- t.Errorf("error when reading tail: %s", err)
+ t.Errorf("error when reading tail: %v", err)
}
if len(tail) > 0 {
t.Errorf("unexpected non-zero tail %q", tail)
@@ -1660,11 +1674,11 @@ func TestServerHTTP10ConnectionClose(t *testing.T) {
}
if err = conn.Close(); err != nil {
- t.Fatalf("error when closing the connection: %s", err)
+ t.Fatalf("error when closing the connection: %v", err)
}
if err = ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -1699,6 +1713,20 @@ func TestRequestCtxFormValue(t *testing.T) {
}
}
+func TestSetStandardFormValueFunc(t *testing.T) {
+ t.Parallel()
+ var ctx RequestCtx
+ var req Request
+ req.SetRequestURI("/foo/bar?aaa=bbb")
+ req.SetBodyString("aaa=port")
+ req.Header.SetContentType("application/x-www-form-urlencoded")
+ ctx.Init(&req, nil, nil)
+ ctx.formValueFunc = NetHttpFormValueFunc
+ v := ctx.FormValue("aaa")
+ if string(v) != "port" {
+ t.Fatalf("unexpected value %q. Expecting %q", v, "port")
+ }
+}
func TestRequestCtxUserValue(t *testing.T) {
t.Parallel()
@@ -1724,7 +1752,7 @@ func TestRequestCtxUserValue(t *testing.T) {
vlen := 0
ctx.VisitUserValues(func(key []byte, value interface{}) {
vlen++
- v := ctx.UserValueBytes(key)
+ v := ctx.UserValue(key)
if v != value {
t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value)
}
@@ -1757,14 +1785,14 @@ func TestServerHeadRequest(t *testing.T) {
rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
resp.SkipBody = true
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if resp.Header.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK)
@@ -1779,9 +1807,9 @@ func TestServerHeadRequest(t *testing.T) {
t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb")
}
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
@@ -1814,15 +1842,15 @@ func TestServerExpect100Continue(t *testing.T) {
rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar")
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
@@ -1870,15 +1898,15 @@ func TestServerContinueHandler(t *testing.T) {
sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) {
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse)
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) > 0 {
t.Fatalf("unexpected remaining data %q", data)
@@ -1925,9 +1953,9 @@ func TestCompressHandler(t *testing.T) {
s := ctx.Response.String()
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce := resp.Header.Peek(HeaderContentEncoding)
+ ce := resp.Header.ContentEncoding()
if string(ce) != "" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "")
}
@@ -1945,15 +1973,15 @@ func TestCompressHandler(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err := resp.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
@@ -1968,15 +1996,15 @@ func TestCompressHandler(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "gzip" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip")
}
body, err = resp.BodyGunzip()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
@@ -1991,15 +2019,15 @@ func TestCompressHandler(t *testing.T) {
s = ctx.Response.String()
br = bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
- ce = resp.Header.Peek(HeaderContentEncoding)
+ ce = resp.Header.ContentEncoding()
if string(ce) != "deflate" {
t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate")
}
body, err = resp.BodyInflate()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if string(body) != expectedBody {
t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody)
@@ -2012,14 +2040,14 @@ func TestRequestCtxWriteString(t *testing.T) {
var ctx RequestCtx
n, err := ctx.WriteString("foo")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != 3 {
t.Fatalf("unexpected n %d. Expecting 3", n)
}
n, err = ctx.WriteString("привет")
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != 12 {
t.Fatalf("unexpected n=%d. Expecting 12", n)
@@ -2031,6 +2059,140 @@ func TestRequestCtxWriteString(t *testing.T) {
}
}
+func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) {
+ t.Parallel()
+
+ reqStr := "POST /foo HTTP/1.0\r\nHost: google.com\r\nContent-Type: application/octet-stream\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n"
+ respRegex := regexp.MustCompile("HTTP/1.1 308 Permanent Redirect\r\nServer: fasthttp\r\nDate: (.*)\r\nContent-Length: 0\r\nConnection: keep-alive\r\n\r\n")
+
+ rw := &readWriter{}
+ rw.r.WriteString(reqStr)
+
+ var resultReqStr, resultRespStr string
+
+ ch := make(chan struct{})
+ go func() {
+ err := ServeConn(rw, func(ctx *RequestCtx) {
+ ctx.Response.SetStatusCode(StatusPermanentRedirect)
+
+ ctx.SetUserValue("myKey", &closerWithRequestCtx{
+ ctx: ctx,
+ closeFunc: func(closerCtx *RequestCtx) error {
+ resultReqStr = closerCtx.Request.String()
+ resultRespStr = closerCtx.Response.String()
+
+ return nil
+ }})
+ })
+ if err != nil {
+ t.Errorf("unexpected error in ServeConn: %v", err)
+ }
+ close(ch)
+ }()
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ if resultReqStr != reqStr {
+ t.Errorf("Request == %q, want %q", resultReqStr, reqStr)
+ }
+
+ if !respRegex.MatchString(resultRespStr) {
+ t.Errorf("Response == %q, want regex %q", resultRespStr, respRegex)
+ }
+}
+
+// TestServerErrorHandler tests unexpected cases the for loop will break
+// before request/response reset call. in such cases, call it before
+// release to fix #548.
+func TestServerErrorHandler(t *testing.T) {
+ t.Parallel()
+
+ var resultReqStr, resultRespStr string
+
+ s := &Server{
+ Handler: func(ctx *RequestCtx) {},
+ ErrorHandler: func(ctx *RequestCtx, _ error) {
+ resultReqStr = ctx.Request.String()
+ resultRespStr = ctx.Response.String()
+ },
+ MaxRequestBodySize: 10,
+ }
+
+ reqStrTpl := "POST %s HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nConnection: keep-alive\r\n\r\n"
+ respRegex := regexp.MustCompile("HTTP/1.1 200 OK\r\nDate: (.*)\r\nContent-Length: 0\r\n\r\n")
+
+ rw := &readWriter{}
+
+ for i := 0; i < 100; i++ {
+ body := strings.Repeat("@", s.MaxRequestBodySize+1)
+ path := fmt.Sprintf("/%d", i)
+
+ reqStr := fmt.Sprintf(reqStrTpl, path, len(body))
+ expectedReqStr := fmt.Sprintf(reqStrTpl, path, 0)
+
+ rw.r.WriteString(reqStr)
+ rw.r.WriteString(body)
+
+ ch := make(chan struct{})
+ go func() {
+ err := s.ServeConn(rw)
+ if err != nil && !errors.Is(err, ErrBodyTooLarge) {
+ t.Errorf("unexpected error in ServeConn: %v", err)
+ }
+ close(ch)
+ }()
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ if resultReqStr != expectedReqStr {
+ t.Errorf("[iter: %d] Request == %q, want %s", i, resultReqStr, reqStr)
+ }
+
+ if !respRegex.MatchString(resultRespStr) {
+ t.Errorf("[iter: %d] Response == %q, want regex %q", i, resultRespStr, respRegex)
+ }
+ }
+}
+
+func TestServeConnHijackResetUserValues(t *testing.T) {
+ t.Parallel()
+
+ rw := &readWriter{}
+ rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n")
+ rw.r.WriteString("")
+
+ ch := make(chan struct{})
+ go func() {
+ err := ServeConn(rw, func(ctx *RequestCtx) {
+ ctx.Hijack(func(c net.Conn) {})
+ ctx.SetUserValue("myKey", &closerWithRequestCtx{
+ closeFunc: func(_ *RequestCtx) error {
+ close(ch)
+
+ return nil
+ }},
+ )
+ })
+ if err != nil {
+ t.Errorf("unexpected error in ServeConn: %v", err)
+ }
+ }()
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Errorf("Timeout: UserValues should be reset")
+ }
+}
+
func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
t.Parallel()
@@ -2048,7 +2210,7 @@ func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
ctx.SuccessString("aaa/bbb", "foobar")
})
if err != nil {
- t.Errorf("unexpected error in ServeConn: %s", err)
+ t.Errorf("unexpected error in ServeConn: %v", err)
}
close(ch)
}()
@@ -2065,7 +2227,7 @@ func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
// verify the first response
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "keep-alive" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "keep-alive")
@@ -2076,7 +2238,7 @@ func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
// verify the second response
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if string(resp.Header.Peek(HeaderConnection)) != "close" {
t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "close")
@@ -2085,9 +2247,9 @@ func TestServeConnNonHTTP11KeepAlive(t *testing.T) {
t.Fatal("expecting Connection: close")
}
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after responses %q", data)
@@ -2111,7 +2273,7 @@ func TestRequestCtxSetBodyStreamWriter(t *testing.T) {
ctx.SetBodyStreamWriter(func(w *bufio.Writer) {
fmt.Fprintf(w, "body writer line 1\n")
if err := w.Flush(); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
fmt.Fprintf(w, "body writer line 2\n")
})
@@ -2124,7 +2286,7 @@ func TestRequestCtxSetBodyStreamWriter(t *testing.T) {
br := bufio.NewReader(bytes.NewBufferString(s))
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("Error when reading response: %s", err)
+ t.Fatalf("Error when reading response: %v", err)
}
body := string(resp.Body())
@@ -2174,7 +2336,7 @@ func TestRequestCtxSendFileNotModified(t *testing.T) {
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
@@ -2185,7 +2347,7 @@ func TestRequestCtxSendFileNotModified(t *testing.T) {
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusNotModified {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified)
@@ -2205,7 +2367,7 @@ func TestRequestCtxSendFileModified(t *testing.T) {
filePath := "./server_test.go"
lastModified, err := FileLastModified(filePath)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
lastModified = lastModified.Add(-time.Hour)
ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified)))
@@ -2217,7 +2379,7 @@ func TestRequestCtxSendFileModified(t *testing.T) {
var resp Response
br := bufio.NewReader(bytes.NewBufferString(s))
if err := resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -2225,12 +2387,12 @@ func TestRequestCtxSendFileModified(t *testing.T) {
f, err := os.Open(filePath)
if err != nil {
- t.Fatalf("cannot open file: %s", err)
+ t.Fatalf("cannot open file: %v", err)
}
- body, err := ioutil.ReadAll(f)
+ body, err := io.ReadAll(f)
f.Close()
if err != nil {
- t.Fatalf("error when reading file: %s", err)
+ t.Fatalf("error when reading file: %v", err)
}
if !bytes.Equal(resp.Body(), body) {
@@ -2251,16 +2413,16 @@ func TestRequestCtxSendFile(t *testing.T) {
w := &bytes.Buffer{}
bw := bufio.NewWriter(w)
if err := ctx.Response.Write(bw); err != nil {
- t.Fatalf("error when writing response: %s", err)
+ t.Fatalf("error when writing response: %v", err)
}
if err := bw.Flush(); err != nil {
- t.Fatalf("error when flushing response: %s", err)
+ t.Fatalf("error when flushing response: %v", err)
}
var resp Response
br := bufio.NewReader(w)
if err := resp.Read(br); err != nil {
- t.Fatalf("error when reading response: %s", err)
+ t.Fatalf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -2268,12 +2430,12 @@ func TestRequestCtxSendFile(t *testing.T) {
f, err := os.Open(filePath)
if err != nil {
- t.Fatalf("cannot open file: %s", err)
+ t.Fatalf("cannot open file: %v", err)
}
- body, err := ioutil.ReadAll(f)
+ body, err := io.ReadAll(f)
f.Close()
if err != nil {
- t.Fatalf("error when reading file: %s", err)
+ t.Fatalf("error when reading file: %v", err)
}
if !bytes.Equal(resp.Body(), body) {
@@ -2281,71 +2443,131 @@ func TestRequestCtxSendFile(t *testing.T) {
}
}
-func TestRequestCtxHijack(t *testing.T) {
- t.Parallel()
+func testRequestCtxHijack(t *testing.T, s *Server) {
+ t.Helper()
- hijackStartCh := make(chan struct{})
- hijackStopCh := make(chan struct{})
- s := &Server{
- Handler: func(ctx *RequestCtx) {
- if ctx.Hijacked() {
- t.Error("connection mustn't be hijacked")
- }
- ctx.Hijack(func(c net.Conn) {
- <-hijackStartCh
+ type hijackSignal struct {
+ id int
+ rw *readWriter
+ }
- b := make([]byte, 1)
- // ping-pong echo via hijacked conn
- for {
- n, err := c.Read(b)
- if n != 1 {
- if err == io.EOF {
- close(hijackStopCh)
- return
- }
- if err != nil {
- t.Errorf("unexpected error: %s", err)
- }
- t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
- }
- if _, err = c.Write(b); err != nil {
- t.Errorf("unexpected error when writing data: %s", err)
+ wg := sync.WaitGroup{}
+ totalConns := 100
+ hijackStartCh := make(chan *hijackSignal, totalConns)
+ hijackStopCh := make(chan *hijackSignal, totalConns)
+
+ s.Handler = func(ctx *RequestCtx) {
+ if ctx.Hijacked() {
+ t.Error("connection mustn't be hijacked")
+ }
+
+ ctx.Hijack(func(c net.Conn) {
+ signal := <-hijackStartCh
+ defer func() {
+ hijackStopCh <- signal
+ wg.Done()
+ }()
+
+ b := make([]byte, 1)
+ stop := false
+
+ // ping-pong echo via hijacked conn
+ for !stop {
+ n, err := c.Read(b)
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ stop = true
+
+ continue
}
+
+ t.Errorf("unexpected read error: %v", err)
+ } else if n != 1 {
+ t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
+ }
+
+ if _, err = c.Write(b); err != nil {
+ t.Errorf("unexpected error when writing data: %v", err)
}
- })
- if !ctx.Hijacked() {
- t.Error("connection must be hijacked")
}
- ctx.Success("foo/bar", []byte("hijack it!"))
- },
+ })
+
+ if !ctx.Hijacked() {
+ t.Error("connection must be hijacked")
+ }
+
+ ctx.Success("foo/bar", []byte("hijack it!"))
}
hijackedString := "foobar baz hijacked!!!"
- rw := &readWriter{}
- rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
- rw.r.WriteString(hijackedString)
- if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
- }
+ for i := 0; i < totalConns; i++ {
+ wg.Add(1)
- br := bufio.NewReader(&rw.w)
- verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
+ go func(t *testing.T, id int) {
+ t.Helper()
- close(hijackStartCh)
- select {
- case <-hijackStopCh:
- case <-time.After(100 * time.Millisecond):
- t.Fatal("timeout")
- }
+ rw := new(readWriter)
+ rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
+ rw.r.WriteString(hijackedString)
- data, err := ioutil.ReadAll(br)
- if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ if err := s.ServeConn(rw); err != nil {
+ t.Errorf("[iter: %d] Unexpected error from serveConn: %v", id, err)
+ }
+
+ hijackStartCh <- &hijackSignal{id, rw}
+ }(t, i)
}
- if string(data) != hijackedString {
- t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString)
+
+ wg.Wait()
+
+ count := 0
+ for count != totalConns {
+ select {
+ case signal := <-hijackStopCh:
+ count++
+
+ id := signal.id
+ rw := signal.rw
+
+ br := bufio.NewReader(&rw.w)
+ verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
+
+ data, err := io.ReadAll(br)
+ if err != nil {
+ t.Errorf("[iter: %d] Unexpected error when reading remaining data: %v", id, err)
+
+ return
+ }
+ if string(data) != hijackedString {
+ t.Errorf(
+ "[iter: %d] Unexpected response %q. Expecting %q",
+ id, data, hijackedString,
+ )
+
+ return
+ }
+ case <-time.After(200 * time.Millisecond):
+ t.Errorf("timeout")
+ }
}
+
+ close(hijackStartCh)
+ close(hijackStopCh)
+}
+
+func TestRequestCtxHijack(t *testing.T) {
+ t.Parallel()
+
+ testRequestCtxHijack(t, &Server{})
+}
+
+func TestRequestCtxHijackReduceMemoryUsage(t *testing.T) {
+ t.Parallel()
+
+ testRequestCtxHijack(t, &Server{
+ ReduceMemoryUsage: true,
+ })
}
func TestRequestCtxHijackNoResponse(t *testing.T) {
@@ -2366,13 +2588,13 @@ func TestRequestCtxHijackNoResponse(t *testing.T) {
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
select {
case err := <-hijackDone:
if err != nil {
- t.Fatalf("Unexpected error from hijack: %s", err)
+ t.Fatalf("Unexpected error from hijack: %v", err)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
@@ -2397,7 +2619,7 @@ func TestRequestCtxNoHijackNoResponse(t *testing.T) {
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
bf := bufio.NewReader(
@@ -2444,7 +2666,7 @@ func TestTimeoutHandlerSuccess(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
close(serverCh)
}()
@@ -2455,10 +2677,10 @@ func TestTimeoutHandlerSuccess(t *testing.T) {
go func() {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
@@ -2475,7 +2697,7 @@ func TestTimeoutHandlerSuccess(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -2502,7 +2724,7 @@ func TestTimeoutHandlerTimeout(t *testing.T) {
serverCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
close(serverCh)
}()
@@ -2513,10 +2735,10 @@ func TestTimeoutHandlerTimeout(t *testing.T) {
go func() {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
@@ -2542,7 +2764,7 @@ func TestTimeoutHandlerTimeout(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
@@ -2567,27 +2789,27 @@ func TestTimeoutHandlerTimeoutReuse(t *testing.T) {
}
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
if _, err = conn.Write([]byte("GET /timeout HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!")
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
verifyResponse(t, br, StatusOK, string(defaultContentType), "ok")
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -2619,7 +2841,7 @@ func TestServerGetOnly(t *testing.T) {
t.Fatal("expecting error")
}
if err != ErrGetOnly {
- t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, ErrGetOnly)
+ t.Fatalf("Unexpected error from serveConn: %v. Expecting %v", err, ErrGetOnly)
}
case <-time.After(100 * time.Millisecond):
t.Fatal("timeout")
@@ -2628,7 +2850,7 @@ func TestServerGetOnly(t *testing.T) {
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
statusCode := resp.StatusCode()
if statusCode != StatusBadRequest {
@@ -2667,16 +2889,16 @@ func TestServerTimeoutErrorWithResponse(t *testing.T) {
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 456, "foo/bar", "path=/foo")
verifyResponse(t, br, 456, "foo/bar", "path=/bar")
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
@@ -2701,16 +2923,16 @@ func TestServerTimeoutErrorWithCode(t *testing.T) {
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx")
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
@@ -2735,16 +2957,16 @@ func TestServerTimeoutError(t *testing.T) {
rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx")
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
@@ -2764,22 +2986,22 @@ func TestServerMaxRequestsPerConn(t *testing.T) {
rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("Response must have 'connection: close' header")
}
- verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType))
+ verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType), "")
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
@@ -2800,22 +3022,22 @@ func TestServerConnectionClose(t *testing.T) {
rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !resp.ConnectionClose() {
t.Fatal("expecting Connection: close header")
}
- data, err := ioutil.ReadAll(br)
+ data, err := io.ReadAll(br)
if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ t.Fatalf("Unexpected error when reading remaining data: %v", err)
}
if len(data) != 0 {
t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "")
@@ -2837,7 +3059,7 @@ func TestServerRequestNumAndTime(t *testing.T) {
connT = ctx.ConnTime()
}
if ctx.ConnTime() != connT {
- t.Errorf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT)
+ t.Errorf("unexpected serve conn time: %q. Expecting %q", ctx.ConnTime(), connT)
}
},
}
@@ -2848,7 +3070,7 @@ func TestServerRequestNumAndTime(t *testing.T) {
rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
if n != 3 {
@@ -2872,7 +3094,7 @@ func TestServerEmptyResponse(t *testing.T) {
rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
@@ -2910,7 +3132,7 @@ func TestServerLogger(t *testing.T) {
globalConnID = 0
if err := s.ServeConn(rwx); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
@@ -2950,7 +3172,7 @@ func TestServerRemoteAddr(t *testing.T) {
}
if err := s.ServeConn(rwx); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
@@ -2990,7 +3212,7 @@ func TestServerCustomRemoteAddr(t *testing.T) {
}
if err := s.ServeConn(rwx); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
@@ -3036,13 +3258,13 @@ func TestServerConnError(t *testing.T) {
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
var resp Response
if err := resp.Read(br); err != nil {
- t.Fatalf("Unexpected error when reading response: %s", err)
+ t.Fatalf("Unexpected error when reading response: %v", err)
}
if resp.Header.StatusCode() != 423 {
t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423)
@@ -3072,13 +3294,35 @@ func TestServeConnSingleRequest(t *testing.T) {
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com")
}
+func TestServerSetFormValueFunc(t *testing.T) {
+ t.Parallel()
+ s := &Server{
+ Handler: func(ctx *RequestCtx) {
+ ctx.Success("aaa", ctx.FormValue("aaa"))
+ },
+ FormValueFunc: func(ctx *RequestCtx, s string) []byte {
+ return []byte(s)
+ },
+ }
+
+ rw := &readWriter{}
+ rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n")
+
+ if err := s.ServeConn(rw); err != nil {
+ t.Fatalf("Unexpected error from serveConn: %v", err)
+ }
+
+ br := bufio.NewReader(&rw.w)
+ verifyResponse(t, br, 200, "aaa", "aaa")
+}
+
func TestServeConnMultiRequests(t *testing.T) {
t.Parallel()
@@ -3093,7 +3337,7 @@ func TestServeConnMultiRequests(t *testing.T) {
rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n")
if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
br := bufio.NewReader(&rw.w)
@@ -3114,7 +3358,7 @@ func TestShutdown(t *testing.T) {
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
_, err := ln.Dial()
if err == nil {
@@ -3126,10 +3370,10 @@ func TestShutdown(t *testing.T) {
go func() {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
@@ -3140,7 +3384,7 @@ func TestShutdown(t *testing.T) {
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
shutdownCh <- struct{}{}
}()
@@ -3176,7 +3420,7 @@ func TestCloseOnShutdown(t *testing.T) {
serveCh := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
_, err := ln.Dial()
if err == nil {
@@ -3188,10 +3432,10 @@ func TestCloseOnShutdown(t *testing.T) {
go func() {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
@@ -3202,7 +3446,7 @@ func TestCloseOnShutdown(t *testing.T) {
shutdownCh := make(chan struct{})
go func() {
if err := s.Shutdown(); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
shutdownCh <- struct{}{}
}()
@@ -3237,38 +3481,38 @@ func TestShutdownReuse(t *testing.T) {
}
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
ln = fasthttputil.NewInmemoryListener()
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err = ln.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
if err := s.Shutdown(); err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
}
@@ -3284,21 +3528,21 @@ func TestShutdownDone(t *testing.T) {
}
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
// We can only reach this point and get a valid response
@@ -3323,21 +3567,21 @@ func TestShutdownErr(t *testing.T) {
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
go func() {
// Shutdown won't return if the connection doesn't close,
// which doesn't happen until we read the response.
if err := s.Shutdown(); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
// We can only reach this point and get a valid response
@@ -3346,6 +3590,98 @@ func TestShutdownErr(t *testing.T) {
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
}
+func TestShutdownCloseIdleConns(t *testing.T) {
+ t.Parallel()
+
+ ln := fasthttputil.NewInmemoryListener()
+ s := &Server{
+ Handler: func(ctx *RequestCtx) {
+ ctx.Success("aaa/bbb", []byte("real response"))
+ },
+ }
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexepcted error: %v", err)
+ }
+ }()
+ conn, err := ln.Dial()
+ if err != nil {
+ t.Fatalf("unexepcted error: %v", err)
+ }
+
+ if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ br := bufio.NewReader(conn)
+ verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
+
+ shutdownErr := make(chan error)
+ go func() {
+ shutdownErr <- s.Shutdown()
+ }()
+
+ timer := time.NewTimer(time.Second)
+ select {
+ case <-timer.C:
+ t.Fatal("idle connections not closed on shutdown")
+ case err = <-shutdownErr:
+ if err != nil {
+ t.Errorf("unexepcted error: %v", err)
+ }
+ }
+}
+
+func TestShutdownWithContext(t *testing.T) {
+ t.Parallel()
+
+ ln := fasthttputil.NewInmemoryListener()
+ s := &Server{
+ Handler: func(ctx *RequestCtx) {
+ time.Sleep(5 * time.Second)
+ ctx.Success("aaa/bbb", []byte("real response"))
+ },
+ }
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexepcted error: %v", err)
+ }
+ }()
+ time.Sleep(1 * time.Second)
+ go func() {
+ conn, err := ln.Dial()
+ if err != nil {
+ t.Errorf("unexepcted error: %v", err)
+ }
+
+ if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ br := bufio.NewReader(conn)
+ verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
+ }()
+
+ time.Sleep(1 * time.Second)
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
+ defer cancel()
+ shutdownErr := make(chan error)
+ go func() {
+ shutdownErr <- s.ShutdownWithContext(ctx)
+ }()
+
+ timer := time.NewTimer(time.Second)
+ select {
+ case <-timer.C:
+ t.Fatal("idle connections not closed on shutdown")
+ case err := <-shutdownErr:
+ if err == nil || err != context.DeadlineExceeded {
+ t.Fatalf("unexpected err %v. Expecting %v", err, context.DeadlineExceeded)
+ }
+ }
+ if atomic.LoadInt32(&s.open) != 1 {
+ t.Fatalf("unexpected open connection num: %#v. Expecting %#v", atomic.LoadInt32(&s.open), 1)
+ }
+}
+
func TestMultipleServe(t *testing.T) {
t.Parallel()
@@ -3360,31 +3696,31 @@ func TestMultipleServe(t *testing.T) {
go func() {
if err := s.Serve(ln1); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
go func() {
if err := s.Serve(ln2); err != nil {
- t.Errorf("unexepcted error: %s", err)
+ t.Errorf("unexepcted error: %v", err)
}
}()
conn, err := ln1.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br := bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
conn, err = ln2.Dial()
if err != nil {
- t.Fatalf("unexepcted error: %s", err)
+ t.Fatalf("unexepcted error: %v", err)
}
if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
br = bufio.NewReader(conn)
verifyResponse(t, br, StatusOK, "aaa/bbb", "real response")
@@ -3411,7 +3747,7 @@ func TestMaxBodySizePerRequest(t *testing.T) {
rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1)))
if err := s.ServeConn(rw); err != ErrBodyTooLarge {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
}
@@ -3430,11 +3766,12 @@ func TestStreamRequestBody(t *testing.T) {
checkReader(t, ctx.RequestBodyStream(), part2)
},
StreamRequestBody: true,
+ Logger: &testLogger{},
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
- //write headers and part1 body
+ // write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil {
t.Fatal(err)
}
@@ -3462,8 +3799,8 @@ func TestStreamRequestBody(t *testing.T) {
select {
case err := <-ch:
- if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ if err != nil && err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(500 * time.Millisecond):
t.Fatal("part2 timeout")
@@ -3491,7 +3828,7 @@ func TestStreamRequestBodyExceedMaxSize(t *testing.T) {
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
- //write headers and part1 body
+ // write headers and part1 body
if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
t.Error(err)
}
@@ -3521,7 +3858,7 @@ func TestStreamRequestBodyExceedMaxSize(t *testing.T) {
}
}
-func TestStreamBodyReqestContentLength(t *testing.T) {
+func TestStreamBodyRequestContentLength(t *testing.T) {
t.Parallel()
content := strings.Repeat("1", 1<<15) // 32K
contentLength := len(content)
@@ -3555,7 +3892,7 @@ func TestStreamBodyReqestContentLength(t *testing.T) {
select {
case err := <-ch:
if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
@@ -3565,7 +3902,7 @@ func TestStreamBodyReqestContentLength(t *testing.T) {
func checkReader(t *testing.T, r io.Reader, expected string) {
b := make([]byte, len(expected))
if _, err := io.ReadFull(r, b); err != nil {
- t.Fatalf("Unexpected error from reader: %s", err)
+ t.Fatalf("Unexpected error from reader: %v", err)
}
if string(b) != expected {
t.Fatal("incorrect request body")
@@ -3577,7 +3914,7 @@ func TestMaxReadTimeoutPerRequest(t *testing.T) {
headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024))
s := &Server{
- Handler: func(ctx *RequestCtx) {
+ Handler: func(_ *RequestCtx) {
t.Error("shouldn't reach handler")
},
HeaderReceived: func(header *RequestHeader) RequestConfig {
@@ -3593,12 +3930,12 @@ func TestMaxReadTimeoutPerRequest(t *testing.T) {
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
go func() {
- //write headers
+ // write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
}
- //write body
+ // write body
for i := 0; i < 5*1024; i++ {
time.Sleep(time.Millisecond)
cc.Write([]byte{'a'}) //nolint:errcheck
@@ -3612,7 +3949,7 @@ func TestMaxReadTimeoutPerRequest(t *testing.T) {
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
@@ -3647,7 +3984,7 @@ func TestMaxWriteTimeoutPerRequest(t *testing.T) {
var resp Response
go func() {
- //write headers
+ // write headers
_, err := cc.Write(headers)
if err != nil {
t.Error(err)
@@ -3672,7 +4009,7 @@ func TestMaxWriteTimeoutPerRequest(t *testing.T) {
select {
case err := <-ch:
if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") {
- t.Fatalf("Unexpected error from serveConn: %s", err)
+ t.Fatalf("Unexpected error from serveConn: %v", err)
}
case <-time.After(time.Second):
t.Fatal("test timeout")
@@ -3696,16 +4033,77 @@ func TestIncompleteBodyReturnsUnexpectedEOF(t *testing.T) {
}
}
+func TestServerChunkedResponse(t *testing.T) {
+ t.Parallel()
+
+ trailer := map[string]string{
+ "AtEnd1": "1111",
+ "AtEnd2": "2222",
+ "AtEnd3": "3333",
+ }
+
+ h := func(ctx *RequestCtx) {
+ ctx.Response.Header.DisableNormalizing()
+ ctx.Response.Header.Set("Transfer-Encoding", "chunked")
+ for k := range trailer {
+ err := ctx.Response.Header.AddTrailer(k)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ }
+ ctx.Response.SetBodyStreamWriter(func(w *bufio.Writer) {
+ for i := 0; i < 3; i++ {
+ fmt.Fprintf(w, "message %d", i)
+ if err := w.Flush(); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ time.Sleep(time.Millisecond * 100)
+ }
+ })
+ for k, v := range trailer {
+ ctx.Response.Header.Set(k, v)
+ }
+ }
+ s := &Server{
+ Handler: h,
+ }
+
+ rw := &readWriter{}
+ rw.r.WriteString("GET / HTTP/1.1\r\nHost: test.com\r\n\r\n")
+
+ if err := s.ServeConn(rw); err != nil {
+ t.Fatalf("Unexpected error from serveConn: %v", err)
+ }
+
+ br := bufio.NewReader(&rw.w)
+ var resp Response
+ if err := resp.Read(br); err != nil {
+ t.Fatalf("Unexpected error when reading response: %v", err)
+ }
+ if resp.Header.ContentLength() != -1 {
+ t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), -1)
+ }
+ if !bytes.Equal(resp.Body(), []byte("message 0"+"message 1"+"message 2")) {
+ t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar")
+ }
+ for k, v := range trailer {
+ h := resp.Header.Peek(k)
+ if !bytes.Equal(resp.Header.Peek(k), []byte(v)) {
+ t.Fatalf("Unexpected trailer %q. Expected %q. Got %q", k, v, h)
+ }
+ }
+}
+
func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response {
var resp Response
if err := resp.Read(r); err != nil {
- t.Fatalf("Unexpected error when parsing response: %s", err)
+ t.Fatalf("Unexpected error when parsing response: %v", err)
}
if !bytes.Equal(resp.Body(), []byte(expectedBody)) {
t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody))
}
- verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType)
+ verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType, "")
return &resp
}
diff --git a/server_timing_test.go b/server_timing_test.go
index 6ed1cae..8e1150d 100644
--- a/server_timing_test.go
+++ b/server_timing_test.go
@@ -4,7 +4,6 @@ import (
"bytes"
"fmt"
"io"
- "io/ioutil"
"net"
"net/http"
"runtime"
@@ -317,10 +316,10 @@ func newFakeListener(requestsCount, clientsCount, requestsPerConn int, request s
var (
fakeResponse = []byte("Hello, world!")
getRequest = "GET /foobar?baz HTTP/1.1\r\nHost: google.com\r\nUser-Agent: aaa/bbb/ccc/ddd/eee Firefox Chrome MSIE Opera\r\n" +
- "Referer: http://xxx.com/aaa?bbb=ccc\r\nCookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n"
+ "Referer: http://example.com/aaa?bbb=ccc\r\nCookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n"
postRequest = fmt.Sprintf("POST /foobar?baz HTTP/1.1\r\nHost: google.com\r\nContent-Type: foo/bar\r\nContent-Length: %d\r\n"+
"User-Agent: Opera Chrome MSIE Firefox and other/1.2.34\r\nReferer: http://google.com/aaaa/bbb/ccc\r\n"+
- "Cookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n%s",
+ "Cookie: foo=bar; baz=baraz; aa=aakslsdweriwereowriewroire\r\n\r\n%q",
len(fakeResponse), fakeResponse)
)
@@ -329,7 +328,7 @@ func benchmarkServerGet(b *testing.B, clientsCount, requestsPerConn int) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsGet() {
- b.Fatalf("Unexpected request method: %s", ctx.Method())
+ b.Fatalf("Unexpected request method: %q", ctx.Method())
}
ctx.Success("text/plain", fakeResponse)
if requestsPerConn == 1 {
@@ -348,7 +347,7 @@ func benchmarkNetHTTPServerGet(b *testing.B, clientsCount, requestsPerConn int)
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Method != MethodGet {
- b.Fatalf("Unexpected request method: %s", req.Method)
+ b.Fatalf("Unexpected request method: %q", req.Method)
}
h := w.Header()
h.Set("Content-Type", "text/plain")
@@ -368,7 +367,7 @@ func benchmarkServerPost(b *testing.B, clientsCount, requestsPerConn int) {
s := &Server{
Handler: func(ctx *RequestCtx) {
if !ctx.IsPost() {
- b.Fatalf("Unexpected request method: %s", ctx.Method())
+ b.Fatalf("Unexpected request method: %q", ctx.Method())
}
body := ctx.Request.Body()
if !bytes.Equal(body, fakeResponse) {
@@ -391,11 +390,11 @@ func benchmarkNetHTTPServerPost(b *testing.B, clientsCount, requestsPerConn int)
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.Method != MethodPost {
- b.Fatalf("Unexpected request method: %s", req.Method)
+ b.Fatalf("Unexpected request method: %q", req.Method)
}
- body, err := ioutil.ReadAll(req.Body)
+ body, err := io.ReadAll(req.Body)
if err != nil {
- b.Fatalf("Unexpected error: %s", err)
+ b.Fatalf("Unexpected error: %v", err)
}
req.Body.Close()
if !bytes.Equal(body, fakeResponse) {
diff --git a/stackless/func.go b/stackless/func.go
index 9a49bcc..a50b3eb 100644
--- a/stackless/func.go
+++ b/stackless/func.go
@@ -12,9 +12,9 @@ import (
// The wrapper may save a lot of stack space if the following conditions
// are met:
//
-// - f doesn't contain blocking calls on network, I/O or channels;
-// - f uses a lot of stack space;
-// - the wrapper is called from high number of concurrent goroutines.
+// - f doesn't contain blocking calls on network, I/O or channels;
+// - f uses a lot of stack space;
+// - the wrapper is called from high number of concurrent goroutines.
//
// The stackless wrapper returns false if the call cannot be processed
// at the moment due to high load.
diff --git a/stackless/func_test.go b/stackless/func_test.go
index 02ef2b8..6b2a8d5 100644
--- a/stackless/func_test.go
+++ b/stackless/func_test.go
@@ -66,7 +66,7 @@ func TestNewFuncMulti(t *testing.T) {
select {
case err := <-f1Done:
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
@@ -75,7 +75,7 @@ func TestNewFuncMulti(t *testing.T) {
select {
case err := <-f2Done:
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout")
diff --git a/stackless/writer_test.go b/stackless/writer_test.go
index 9c4748e..fdbe16b 100644
--- a/stackless/writer_test.go
+++ b/stackless/writer_test.go
@@ -6,7 +6,6 @@ import (
"compress/gzip"
"fmt"
"io"
- "io/ioutil"
"testing"
"time"
)
@@ -15,7 +14,7 @@ func TestCompressFlateSerial(t *testing.T) {
t.Parallel()
if err := testCompressFlate(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -23,7 +22,7 @@ func TestCompressFlateConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(testCompressFlate, 10); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -31,7 +30,7 @@ func testCompressFlate() error {
return testWriter(func(w io.Writer) Writer {
zw, err := flate.NewWriter(w, flate.DefaultCompression)
if err != nil {
- panic(fmt.Sprintf("BUG: unexpected error: %s", err))
+ panic(fmt.Sprintf("BUG: unexpected error: %v", err))
}
return zw
}, func(r io.Reader) io.Reader {
@@ -43,7 +42,7 @@ func TestCompressGzipSerial(t *testing.T) {
t.Parallel()
if err := testCompressGzip(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -51,7 +50,7 @@ func TestCompressGzipConcurrent(t *testing.T) {
t.Parallel()
if err := testConcurrent(testCompressGzip, 10); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
}
@@ -61,7 +60,7 @@ func testCompressGzip() error {
}, func(r io.Reader) io.Reader {
zr, err := gzip.NewReader(r)
if err != nil {
- panic(fmt.Sprintf("BUG: cannot create gzip reader: %s", err))
+ panic(fmt.Sprintf("BUG: cannot create gzip reader: %v", err))
}
return zr
})
@@ -73,7 +72,7 @@ func testWriter(newWriter NewWriterFunc, newReader func(io.Reader) io.Reader) er
for i := 0; i < 5; i++ {
if err := testWriterReuse(w, dstW, newReader); err != nil {
- return fmt.Errorf("unexpected error when re-using writer on iteration %d: %s", i, err)
+ return fmt.Errorf("unexpected error when re-using writer on iteration %d: %w", i, err)
}
dstW = &bytes.Buffer{}
w.Reset(dstW)
@@ -89,16 +88,16 @@ func testWriterReuse(w Writer, r io.Reader, newReader func(io.Reader) io.Reader)
fmt.Fprintf(mw, "foobar %d\n", i)
if i%13 == 0 {
if err := w.Flush(); err != nil {
- return fmt.Errorf("error on flush: %s", err)
+ return fmt.Errorf("error on flush: %w", err)
}
}
}
w.Close()
zr := newReader(r)
- data, err := ioutil.ReadAll(zr)
+ data, err := io.ReadAll(zr)
if err != nil {
- return fmt.Errorf("unexpected error: %s, data=%q", err, data)
+ return fmt.Errorf("unexpected error: %w, data=%q", err, data)
}
wantData := wantW.Bytes()
@@ -120,7 +119,7 @@ func testConcurrent(testFunc func() error, concurrency int) error {
select {
case err := <-ch:
if err != nil {
- return fmt.Errorf("unexpected error on goroutine %d: %s", i, err)
+ return fmt.Errorf("unexpected error on goroutine %d: %w", i, err)
}
case <-time.After(time.Second):
return fmt.Errorf("timeout on goroutine %d", i)
diff --git a/status.go b/status.go
index 28d1286..c88ba11 100644
--- a/status.go
+++ b/status.go
@@ -1,7 +1,6 @@
package fasthttp
import (
- "fmt"
"strconv"
)
@@ -81,7 +80,7 @@ const (
)
var (
- statusLines = make([][]byte, statusMessageMax+1)
+ unknownStatusCode = "Unknown Status Code"
statusMessages = []string{
StatusContinue: "Continue",
@@ -155,39 +154,24 @@ var (
// StatusMessage returns HTTP status message for the given status code.
func StatusMessage(statusCode int) string {
if statusCode < statusMessageMin || statusCode > statusMessageMax {
- return "Unknown Status Code"
+ return unknownStatusCode
}
- s := statusMessages[statusCode]
- if s == "" {
- s = "Unknown Status Code"
+ if s := statusMessages[statusCode]; s != "" {
+ return s
}
- return s
+ return unknownStatusCode
}
-func init() {
- // Fill all valid status lines
- for i := 0; i < len(statusLines); i++ {
- statusLines[i] = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", i, StatusMessage(i)))
+func formatStatusLine(dst []byte, protocol []byte, statusCode int, statusText []byte) []byte {
+ dst = append(dst, protocol...)
+ dst = append(dst, ' ')
+ dst = strconv.AppendInt(dst, int64(statusCode), 10)
+ dst = append(dst, ' ')
+ if len(statusText) == 0 {
+ dst = append(dst, s2b(StatusMessage(statusCode))...)
+ } else {
+ dst = append(dst, statusText...)
}
-}
-
-func statusLine(statusCode int) []byte {
- if statusCode < 0 || statusCode > statusMessageMax {
- return invalidStatusLine(statusCode)
- }
-
- return statusLines[statusCode]
-}
-
-func invalidStatusLine(statusCode int) []byte {
- statusText := StatusMessage(statusCode)
- // xxx placeholder of status code
- var line = make([]byte, 0, len("HTTP/1.1 xxx \r\n")+len(statusText))
- line = append(line, "HTTP/1.1 "...)
- line = strconv.AppendInt(line, int64(statusCode), 10)
- line = append(line, ' ')
- line = append(line, statusText...)
- line = append(line, "\r\n"...)
- return line
+ return append(dst, strCRLF...)
}
diff --git a/status_test.go b/status_test.go
index e3bfa8a..ff794a3 100644
--- a/status_test.go
+++ b/status_test.go
@@ -17,8 +17,8 @@ func TestStatusLine(t *testing.T) {
}
func testStatusLine(t *testing.T, statusCode int, expected []byte) {
- line := statusLine(statusCode)
+ line := formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode)))
if !bytes.Equal(expected, line) {
- t.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected))
+ t.Fatalf("unexpected status line %q. Expecting %q", string(line), string(expected))
}
}
diff --git a/status_timing_test.go b/status_timing_test.go
index 42b1c91..e35d8ce 100644
--- a/status_timing_test.go
+++ b/status_timing_test.go
@@ -20,9 +20,9 @@ func BenchmarkStatusLine512(b *testing.B) {
func benchmarkStatusLine(b *testing.B, statusCode int, expected []byte) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
- line := statusLine(statusCode)
+ line := formatStatusLine(nil, strHTTP11, statusCode, s2b(StatusMessage(statusCode)))
if !bytes.Equal(expected, line) {
- b.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected))
+ b.Fatalf("unexpected status line %q. Expecting %q", string(line), string(expected))
}
}
})
diff --git a/stream_test.go b/stream_test.go
index c5ce292..f4dcb5a 100644
--- a/stream_test.go
+++ b/stream_test.go
@@ -4,7 +4,6 @@ import (
"bufio"
"fmt"
"io"
- "io/ioutil"
"testing"
"time"
)
@@ -19,9 +18,9 @@ func TestNewStreamReader(t *testing.T) {
close(ch)
})
- data, err := ioutil.ReadAll(r)
+ data, err := io.ReadAll(r)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
expectedData := "Hello, world\nLine #2\n"
if string(data) != expectedData {
@@ -47,7 +46,7 @@ func TestStreamReaderClose(t *testing.T) {
r := NewStreamReader(func(w *bufio.Writer) {
fmt.Fprintf(w, "%s", firstLine)
if err := w.Flush(); err != nil {
- ch <- fmt.Errorf("unexpected error on first flush: %s", err)
+ ch <- fmt.Errorf("unexpected error on first flush: %w", err)
return
}
@@ -64,7 +63,7 @@ func TestStreamReaderClose(t *testing.T) {
buf := make([]byte, len(firstLine))
n, err := io.ReadFull(r, buf)
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if n != len(buf) {
t.Fatalf("unexpected number of bytes read: %d. Expecting %d", n, len(buf))
@@ -74,13 +73,13 @@ func TestStreamReaderClose(t *testing.T) {
}
if err := r.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
select {
case err := <-ch:
if err != nil {
- t.Fatalf("error returned from stream reader: %s", err)
+ t.Fatalf("error returned from stream reader: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout when waiting for stream reader")
@@ -88,8 +87,8 @@ func TestStreamReaderClose(t *testing.T) {
// read trailing data
go func() {
- if _, err := ioutil.ReadAll(r); err != nil {
- ch <- fmt.Errorf("unexpected error when reading trailing data: %s", err)
+ if _, err := io.ReadAll(r); err != nil {
+ ch <- fmt.Errorf("unexpected error when reading trailing data: %w", err)
return
}
ch <- nil
@@ -98,7 +97,7 @@ func TestStreamReaderClose(t *testing.T) {
select {
case err := <-ch:
if err != nil {
- t.Fatalf("error returned when reading tail data: %s", err)
+ t.Fatalf("error returned when reading tail data: %v", err)
}
case <-time.After(time.Second):
t.Fatalf("timeout when reading tail data")
diff --git a/stream_timing_test.go b/stream_timing_test.go
index facca3a..b7fe03f 100644
--- a/stream_timing_test.go
+++ b/stream_timing_test.go
@@ -52,16 +52,16 @@ func benchmarkStreamReader(b *testing.B, size int) {
if err == io.EOF {
break
}
- b.Fatalf("unexpected error when reading from stream reader: %s", err)
+ b.Fatalf("unexpected error when reading from stream reader: %v", err)
}
}
if err := sr.Close(); err != nil {
- b.Fatalf("unexpected error when closing stream reader: %s", err)
+ b.Fatalf("unexpected error when closing stream reader: %v", err)
}
select {
case err := <-ch:
if err != nil {
- b.Fatalf("unexpected error from stream reader: %s", err)
+ b.Fatalf("unexpected error from stream reader: %v", err)
}
case <-time.After(time.Second):
b.Fatalf("timeout")
diff --git a/streaming.go b/streaming.go
index 1a3d748..fc04916 100644
--- a/streaming.go
+++ b/streaming.go
@@ -10,10 +10,10 @@ import (
)
type requestStream struct {
+ header *RequestHeader
prefetchedBytes *bytes.Reader
reader *bufio.Reader
totalBytesRead int
- contentLength int
chunkLeft int
}
@@ -22,18 +22,18 @@ func (rs *requestStream) Read(p []byte) (int, error) {
n int
err error
)
- if rs.contentLength == -1 {
+ if rs.header.contentLength == -1 {
if rs.chunkLeft == 0 {
chunkSize, err := parseChunkSize(rs.reader)
if err != nil {
return 0, err
}
if chunkSize == 0 {
- err = readCrLf(rs.reader)
- if err == nil {
- err = io.EOF
+ err = rs.header.ReadTrailer(rs.reader)
+ if err != nil && err != io.EOF {
+ return 0, err
}
- return 0, err
+ return 0, io.EOF
}
rs.chunkLeft = chunkSize
}
@@ -52,7 +52,7 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
return n, err
}
- if rs.totalBytesRead == rs.contentLength {
+ if rs.totalBytesRead == rs.header.contentLength {
return 0, io.EOF
}
prefetchedSize := int(rs.prefetchedBytes.Size())
@@ -63,12 +63,12 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
- if n == rs.contentLength {
+ if n == rs.header.contentLength {
return n, io.EOF
}
return n, err
} else {
- left := rs.contentLength - rs.totalBytesRead
+ left := rs.header.contentLength - rs.totalBytesRead
if len(p) > left {
p = p[:left]
}
@@ -79,18 +79,17 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
}
- if rs.totalBytesRead == rs.contentLength {
+ if rs.totalBytesRead == rs.header.contentLength {
err = io.EOF
}
return n, err
}
-func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, contentLength int) *requestStream {
+func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, h *RequestHeader) *requestStream {
rs := requestStreamPool.Get().(*requestStream)
rs.prefetchedBytes = bytes.NewReader(b.B)
rs.reader = r
- rs.contentLength = contentLength
-
+ rs.header = h
return rs
}
@@ -99,6 +98,7 @@ func releaseRequestStream(rs *requestStream) {
rs.totalBytesRead = 0
rs.chunkLeft = 0
rs.reader = nil
+ rs.header = nil
requestStreamPool.Put(rs)
}
diff --git a/streaming_test.go b/streaming_test.go
index 6066c39..084710b 100644
--- a/streaming_test.go
+++ b/streaming_test.go
@@ -3,7 +3,8 @@ package fasthttp
import (
"bufio"
"bytes"
- "io/ioutil"
+ "fmt"
+ "io"
"sync"
"testing"
"time"
@@ -35,7 +36,7 @@ aaaaaaaaaa`
if string(ctx.Path()) == "/one" {
body = string(ctx.PostBody())
} else {
- all, err := ioutil.ReadAll(ctx.RequestBodyStream())
+ all, err := io.ReadAll(ctx.RequestBodyStream())
if err != nil {
t.Error(err)
}
@@ -50,17 +51,17 @@ aaaaaaaaaa`
ch := make(chan struct{})
go func() {
if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
close(ch)
}()
conn, err := ln.Dial()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte(reqS)); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
var resp Response
@@ -68,14 +69,14 @@ aaaaaaaaaa`
respCh := make(chan struct{})
go func() {
if err := resp.Read(br); err != nil {
- t.Errorf("error when reading response: %s", err)
+ t.Errorf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
}
if err := resp.Read(br); err != nil {
- t.Errorf("error when reading response: %s", err)
+ t.Errorf("error when reading response: %v", err)
}
if resp.StatusCode() != StatusOK {
t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
@@ -90,7 +91,7 @@ aaaaaaaaaa`
}
if err := ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
+ t.Fatalf("error when closing listener: %v", err)
}
select {
@@ -102,12 +103,12 @@ aaaaaaaaaa`
func getChunkedTestEnv(t testing.TB) (*fasthttputil.InmemoryListener, []byte) {
body := createFixedBody(128 * 1024)
- chunkedBody := createChunkedBody(body)
+ chunkedBody := createChunkedBody(body, nil, true)
testHandler := func(ctx *RequestCtx) {
- bodyBytes, err := ioutil.ReadAll(ctx.RequestBodyStream())
+ bodyBytes, err := io.ReadAll(ctx.RequestBodyStream())
if err != nil {
- t.Logf("ioutil read returned err=%s", err)
+ t.Logf("io read returned err=%v", err)
t.Error("unexpected error while reading request body stream")
}
@@ -126,7 +127,7 @@ func getChunkedTestEnv(t testing.TB) (*fasthttputil.InmemoryListener, []byte) {
go func() {
err := s.Serve(ln)
if err != nil {
- t.Errorf("could not serve listener: %s", err)
+ t.Errorf("could not serve listener: %v", err)
}
}()
@@ -142,6 +143,70 @@ func getChunkedTestEnv(t testing.TB) (*fasthttputil.InmemoryListener, []byte) {
return ln, formattedRequest
}
+func TestRequestStreamChunkedWithTrailer(t *testing.T) {
+ t.Parallel()
+
+ body := createFixedBody(10)
+ expectedTrailer := map[string]string{
+ "Foo": "footest",
+ "Bar": "bartest",
+ }
+ chunkedBody := createChunkedBody(body, expectedTrailer, true)
+ req := fmt.Sprintf(`POST / HTTP/1.1
+Host: example.com
+Transfer-Encoding: chunked
+Trailer: Foo, Bar
+
+%s
+`, chunkedBody)
+
+ ln := fasthttputil.NewInmemoryListener()
+ s := &Server{
+ StreamRequestBody: true,
+ Handler: func(ctx *RequestCtx) {
+ all, err := io.ReadAll(ctx.RequestBodyStream())
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if !bytes.Equal(all, body) {
+ t.Errorf("unexpected body %q. Expecting %q", all, body)
+ }
+
+ for k, v := range expectedTrailer {
+ r := ctx.Request.Header.Peek(k)
+ if string(r) != v {
+ t.Errorf("unexpected trailer %q. Expecting %q. Got %q", k, v, r)
+ }
+ }
+ },
+ }
+
+ ch := make(chan struct{})
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ close(ch)
+ }()
+
+ conn, err := ln.Dial()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if _, err = conn.Write([]byte(req)); err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if err := ln.Close(); err != nil {
+ t.Fatalf("error when closing listener: %v", err)
+ }
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout when waiting for the server to stop")
+ }
+}
+
func TestRequestStream(t *testing.T) {
t.Parallel()
@@ -149,16 +214,16 @@ func TestRequestStream(t *testing.T) {
c, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error while dialing: %s", err)
+ t.Errorf("unexpected error while dialing: %v", err)
}
if _, err = c.Write(formattedRequest); err != nil {
- t.Errorf("unexpected error while writing request: %s", err)
+ t.Errorf("unexpected error while writing request: %v", err)
}
br := bufio.NewReader(c)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}
@@ -172,16 +237,16 @@ func BenchmarkRequestStreamE2E(b *testing.B) {
for i := 0; i < b.N/4; i++ {
c, err := ln.Dial()
if err != nil {
- b.Errorf("unexpected error while dialing: %s", err)
+ b.Errorf("unexpected error while dialing: %v", err)
}
if _, err = c.Write(formattedRequest); err != nil {
- b.Errorf("unexpected error while writing request: %s", err)
+ b.Errorf("unexpected error while writing request: %v", err)
}
br := bufio.NewReaderSize(c, 128)
var respH ResponseHeader
if err = respH.Read(br); err != nil {
- b.Errorf("unexpected error: %s", err)
+ b.Errorf("unexpected error: %v", err)
}
c.Close()
}
diff --git a/strings.go b/strings.go
index e28eaac..0e201a1 100644
--- a/strings.go
+++ b/strings.go
@@ -1,50 +1,62 @@
package fasthttp
var (
- defaultServerName = []byte("fasthttp")
- defaultUserAgent = []byte("fasthttp")
+ defaultServerName = "fasthttp"
+ defaultUserAgent = "fasthttp"
defaultContentType = []byte("text/plain; charset=utf-8")
)
var (
- strSlash = []byte("/")
- strSlashSlash = []byte("//")
- strSlashDotDot = []byte("/..")
- strSlashDotSlash = []byte("/./")
- strSlashDotDotSlash = []byte("/../")
- strCRLF = []byte("\r\n")
- strHTTP = []byte("http")
- strHTTPS = []byte("https")
- strHTTP10 = []byte("HTTP/1.0")
- strHTTP11 = []byte("HTTP/1.1")
- strColon = []byte(":")
- strColonSlashSlash = []byte("://")
- strColonSpace = []byte(": ")
- strGMT = []byte("GMT")
+ strSlash = []byte("/")
+ strSlashSlash = []byte("//")
+ strSlashDotDot = []byte("/..")
+ strSlashDotSlash = []byte("/./")
+ strSlashDotDotSlash = []byte("/../")
+ strBackSlashDotDot = []byte(`\..`)
+ strBackSlashDotBackSlash = []byte(`\.\`)
+ strSlashDotDotBackSlash = []byte(`/..\`)
+ strBackSlashDotDotBackSlash = []byte(`\..\`)
+ strCRLF = []byte("\r\n")
+ strHTTP = []byte("http")
+ strHTTPS = []byte("https")
+ strHTTP10 = []byte("HTTP/1.0")
+ strHTTP11 = []byte("HTTP/1.1")
+ strColon = []byte(":")
+ strColonSlashSlash = []byte("://")
+ strColonSpace = []byte(": ")
+ strCommaSpace = []byte(", ")
+ strGMT = []byte("GMT")
strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n")
- strExpect = []byte(HeaderExpect)
- strConnection = []byte(HeaderConnection)
- strContentLength = []byte(HeaderContentLength)
- strContentType = []byte(HeaderContentType)
- strDate = []byte(HeaderDate)
- strHost = []byte(HeaderHost)
- strReferer = []byte(HeaderReferer)
- strServer = []byte(HeaderServer)
- strTransferEncoding = []byte(HeaderTransferEncoding)
- strContentEncoding = []byte(HeaderContentEncoding)
- strAcceptEncoding = []byte(HeaderAcceptEncoding)
- strUserAgent = []byte(HeaderUserAgent)
- strCookie = []byte(HeaderCookie)
- strSetCookie = []byte(HeaderSetCookie)
- strLocation = []byte(HeaderLocation)
- strIfModifiedSince = []byte(HeaderIfModifiedSince)
- strLastModified = []byte(HeaderLastModified)
- strAcceptRanges = []byte(HeaderAcceptRanges)
- strRange = []byte(HeaderRange)
- strContentRange = []byte(HeaderContentRange)
- strAuthorization = []byte(HeaderAuthorization)
+ strExpect = []byte(HeaderExpect)
+ strConnection = []byte(HeaderConnection)
+ strContentLength = []byte(HeaderContentLength)
+ strContentType = []byte(HeaderContentType)
+ strDate = []byte(HeaderDate)
+ strHost = []byte(HeaderHost)
+ strReferer = []byte(HeaderReferer)
+ strServer = []byte(HeaderServer)
+ strTransferEncoding = []byte(HeaderTransferEncoding)
+ strContentEncoding = []byte(HeaderContentEncoding)
+ strAcceptEncoding = []byte(HeaderAcceptEncoding)
+ strUserAgent = []byte(HeaderUserAgent)
+ strCookie = []byte(HeaderCookie)
+ strSetCookie = []byte(HeaderSetCookie)
+ strLocation = []byte(HeaderLocation)
+ strIfModifiedSince = []byte(HeaderIfModifiedSince)
+ strLastModified = []byte(HeaderLastModified)
+ strAcceptRanges = []byte(HeaderAcceptRanges)
+ strRange = []byte(HeaderRange)
+ strContentRange = []byte(HeaderContentRange)
+ strAuthorization = []byte(HeaderAuthorization)
+ strTE = []byte(HeaderTE)
+ strTrailer = []byte(HeaderTrailer)
+ strMaxForwards = []byte(HeaderMaxForwards)
+ strProxyConnection = []byte(HeaderProxyConnection)
+ strProxyAuthenticate = []byte(HeaderProxyAuthenticate)
+ strProxyAuthorization = []byte(HeaderProxyAuthorization)
+ strWWWAuthenticate = []byte(HeaderWWWAuthenticate)
strCookieExpires = []byte("expires")
strCookieDomain = []byte("domain")
diff --git a/tcp.go b/tcp.go
new file mode 100644
index 0000000..54d3033
--- /dev/null
+++ b/tcp.go
@@ -0,0 +1,13 @@
+//go:build !windows
+// +build !windows
+
+package fasthttp
+
+import (
+ "errors"
+ "syscall"
+)
+
+func isConnectionReset(err error) bool {
+ return errors.Is(err, syscall.ECONNRESET)
+}
diff --git a/tcp_windows.go b/tcp_windows.go
new file mode 100644
index 0000000..5c33025
--- /dev/null
+++ b/tcp_windows.go
@@ -0,0 +1,13 @@
+//go:build windows
+// +build windows
+
+package fasthttp
+
+import (
+ "errors"
+ "syscall"
+)
+
+func isConnectionReset(err error) bool {
+ return errors.Is(err, syscall.WSAECONNRESET)
+}
diff --git a/tcpdialer.go b/tcpdialer.go
index 261f62e..5f70b77 100644
--- a/tcpdialer.go
+++ b/tcpdialer.go
@@ -14,12 +14,12 @@ import (
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
-// * It returns ErrDialTimeout if connection cannot be established during
+// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
@@ -30,9 +30,9 @@ import (
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func Dial(addr string) (net.Conn, error) {
return defaultDialer.Dial(addr)
}
@@ -41,9 +41,9 @@ func Dial(addr string) (net.Conn, error) {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
@@ -55,9 +55,9 @@ func Dial(addr string) (net.Conn, error) {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialTimeout(addr, timeout)
}
@@ -66,12 +66,12 @@ func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
-// * It returns ErrDialTimeout if connection cannot be established during
+// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
@@ -83,9 +83,9 @@ func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func DialDualStack(addr string) (net.Conn, error) {
return defaultDialer.DialDualStack(addr)
}
@@ -95,9 +95,9 @@ func DialDualStack(addr string) (net.Conn, error) {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
@@ -109,9 +109,9 @@ func DialDualStack(addr string) (net.Conn, error) {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return defaultDialer.DialDualStackTimeout(addr, timeout)
}
@@ -127,7 +127,7 @@ type Resolver interface {
// TCPDialer contains options to control a group of Dial calls.
type TCPDialer struct {
- // Concurrency controls the maximum number of concurrent Dails
+ // Concurrency controls the maximum number of concurrent Dials
// that can be performed using this object.
// Setting this to 0 means unlimited.
//
@@ -167,12 +167,12 @@ type TCPDialer struct {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
-// * It returns ErrDialTimeout if connection cannot be established during
+// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
//
// This dialer is intended for custom code wrapping before passing
@@ -183,9 +183,9 @@ type TCPDialer struct {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
return d.dial(addr, false, DefaultDialTimeout)
}
@@ -194,9 +194,9 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
@@ -208,9 +208,9 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, false, timeout)
}
@@ -219,12 +219,12 @@ func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, e
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
-// * It returns ErrDialTimeout if connection cannot be established during
+// - It returns ErrDialTimeout if connection cannot be established during
// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
// timeout.
//
@@ -236,9 +236,9 @@ func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, e
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
return d.dial(addr, true, DefaultDialTimeout)
}
@@ -248,9 +248,9 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
//
// This function has the following additional features comparing to net.Dial:
//
-// * It reduces load on DNS resolver by caching resolved TCP addressed
+// - It reduces load on DNS resolver by caching resolved TCP addressed
// for DNSCacheDuration.
-// * It dials all the resolved TCP addresses in round-robin manner until
+// - It dials all the resolved TCP addresses in round-robin manner until
// connection is established. This may be useful if certain addresses
// are temporarily unreachable.
//
@@ -262,9 +262,9 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
//
// The addr passed to the function must contain port. Example addr values:
//
-// * foobar.baz:443
-// * foo.bar:80
-// * aaa.com:8080
+// - foobar.baz:443
+// - foo.bar:80
+// - aaa.com:8080
func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
return d.dial(addr, true, timeout)
}
@@ -309,7 +309,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne
}
func (d *TCPDialer) tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) {
- timeout := -time.Since(deadline)
+ timeout := time.Until(deadline)
if timeout <= 0 {
return nil, ErrDialTimeout
}
@@ -358,8 +358,8 @@ type tcpAddrEntry struct {
addrs []net.TCPAddr
addrsIdx uint32
+ pending int32
resolveTime time.Time
- pending bool
}
// DefaultDNSCacheDuration is the duration for caching resolved TCP addresses
@@ -384,9 +384,11 @@ func (d *TCPDialer) tcpAddrsClean() {
func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok := item.(*tcpAddrEntry)
- if exist && ok && e != nil && !e.pending && time.Since(e.resolveTime) > d.DNSCacheDuration {
- e.pending = true
- e = nil
+ if exist && ok && e != nil && time.Since(e.resolveTime) > d.DNSCacheDuration {
+ // Only let one goroutine re-resolve at a time.
+ if atomic.SwapInt32(&e.pending, 1) == 0 {
+ e = nil
+ }
}
if e == nil {
@@ -394,8 +396,9 @@ func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uin
if err != nil {
item, exist := d.tcpAddrsMap.Load(addr)
e, ok = item.(*tcpAddrEntry)
- if exist && ok && e != nil && e.pending {
- e.pending = false
+ if exist && ok && e != nil {
+ // Set pending to 0 so another goroutine can retry.
+ atomic.StoreInt32(&e.pending, 0)
}
return nil, 0, err
}
diff --git a/uri.go b/uri.go
index 2285f45..ab4cc65 100644
--- a/uri.go
+++ b/uri.go
@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
+ "path/filepath"
"strconv"
"sync"
)
@@ -70,14 +71,14 @@ type URI struct {
// CopyTo copies uri contents to dst.
func (u *URI) CopyTo(dst *URI) {
dst.Reset()
- dst.pathOriginal = append(dst.pathOriginal[:0], u.pathOriginal...)
- dst.scheme = append(dst.scheme[:0], u.scheme...)
- dst.path = append(dst.path[:0], u.path...)
- dst.queryString = append(dst.queryString[:0], u.queryString...)
- dst.hash = append(dst.hash[:0], u.hash...)
- dst.host = append(dst.host[:0], u.host...)
- dst.username = append(dst.username[:0], u.username...)
- dst.password = append(dst.password[:0], u.password...)
+ dst.pathOriginal = append(dst.pathOriginal, u.pathOriginal...)
+ dst.scheme = append(dst.scheme, u.scheme...)
+ dst.path = append(dst.path, u.path...)
+ dst.queryString = append(dst.queryString, u.queryString...)
+ dst.hash = append(dst.hash, u.hash...)
+ dst.host = append(dst.host, u.host...)
+ dst.username = append(dst.username, u.username...)
+ dst.password = append(dst.password, u.password...)
u.queryArgs.CopyTo(&dst.queryArgs)
dst.parsedQueryArgs = u.parsedQueryArgs
@@ -216,6 +217,14 @@ func (u *URI) SetSchemeBytes(scheme []byte) {
lowercaseBytes(u.scheme)
}
+func (u *URI) isHttps() bool {
+ return bytes.Equal(u.scheme, strHTTPS)
+}
+
+func (u *URI) isHttp() bool {
+ return len(u.scheme) == 0 || bytes.Equal(u.scheme, strHTTP)
+}
+
// Reset clears uri.
func (u *URI) Reset() {
u.pathOriginal = u.pathOriginal[:0]
@@ -282,14 +291,13 @@ func (u *URI) parse(host, uri []byte, isTLS bool) error {
if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) {
scheme, newHost, newURI := splitHostURI(host, uri)
- u.scheme = append(u.scheme, scheme...)
- lowercaseBytes(u.scheme)
+ u.SetSchemeBytes(scheme)
host = newHost
uri = newURI
}
if isTLS {
- u.scheme = append(u.scheme[:0], strHTTPS...)
+ u.SetSchemeBytes(strHTTPS)
}
if n := bytes.IndexByte(host, '@'); n >= 0 {
@@ -530,15 +538,9 @@ func shouldEscape(c byte, mode encoding) bool {
}
func ishex(c byte) bool {
- switch {
- case '0' <= c && c <= '9':
- return true
- case 'a' <= c && c <= 'f':
- return true
- case 'A' <= c && c <= 'F':
- return true
- }
- return false
+ return ('0' <= c && c <= '9') ||
+ ('a' <= c && c <= 'f') ||
+ ('A' <= c && c <= 'F')
}
func unhex(c byte) byte {
@@ -627,6 +629,60 @@ func normalizePath(dst, src []byte) []byte {
b = b[:nn+1]
}
+ if filepath.Separator == '\\' {
+ // remove \.\ parts
+ b = dst
+ for {
+ n := bytes.Index(b, strBackSlashDotBackSlash)
+ if n < 0 {
+ break
+ }
+ nn := n + len(strSlashDotSlash) - 1
+ copy(b[n:], b[nn:])
+ b = b[:len(b)-nn+n]
+ }
+
+ // remove /foo/..\ parts
+ for {
+ n := bytes.Index(b, strSlashDotDotBackSlash)
+ if n < 0 {
+ break
+ }
+ nn := bytes.LastIndexByte(b[:n], '/')
+ if nn < 0 {
+ nn = 0
+ }
+ n += len(strSlashDotDotBackSlash) - 1
+ copy(b[nn:], b[n:])
+ b = b[:len(b)-n+nn]
+ }
+
+ // remove /foo\..\ parts
+ for {
+ n := bytes.Index(b, strBackSlashDotDotBackSlash)
+ if n < 0 {
+ break
+ }
+ nn := bytes.LastIndexByte(b[:n], '/')
+ if nn < 0 {
+ nn = 0
+ }
+ n += len(strBackSlashDotDotBackSlash) - 1
+ copy(b[nn:], b[n:])
+ b = b[:len(b)-n+nn]
+ }
+
+ // remove trailing \foo\..
+ n := bytes.LastIndex(b, strBackSlashDotDot)
+ if n >= 0 && n+len(strSlashDotDot) == len(b) {
+ nn := bytes.LastIndexByte(b[:n], '/')
+ if nn < 0 {
+ return append(dst[:0], strSlash...)
+ }
+ b = b[:nn+1]
+ }
+ }
+
return b
}
@@ -653,9 +709,9 @@ func (u *URI) RequestURI() []byte {
//
// Examples:
//
-// * For /foo/bar/baz.html path returns baz.html.
-// * For /foo/bar/ returns empty byte slice.
-// * For /foobar.js returns foobar.js.
+// - For /foo/bar/baz.html path returns baz.html.
+// - For /foo/bar/ returns empty byte slice.
+// - For /foobar.js returns foobar.js.
//
// The returned bytes are valid until the next URI method call.
func (u *URI) LastPathSegment() []byte {
@@ -671,14 +727,14 @@ func (u *URI) LastPathSegment() []byte {
//
// The following newURI types are accepted:
//
-// * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
-// uri is replaced by newURI.
-// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
-// the original scheme is preserved.
-// * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
-// of the original uri is replaced.
-// * Relative path, i.e. xx?yy=abc . In this case the original RequestURI
-// is updated according to the new relative path.
+// - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
+// uri is replaced by newURI.
+// - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
+// the original scheme is preserved.
+// - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
+// of the original uri is replaced.
+// - Relative path, i.e. xx?yy=abc . In this case the original RequestURI
+// is updated according to the new relative path.
func (u *URI) Update(newURI string) {
u.UpdateBytes(s2b(newURI))
}
@@ -687,14 +743,14 @@ func (u *URI) Update(newURI string) {
//
// The following newURI types are accepted:
//
-// * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
-// uri is replaced by newURI.
-// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
-// the original scheme is preserved.
-// * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
-// of the original uri is replaced.
-// * Relative path, i.e. xx?yy=abc . In this case the original RequestURI
-// is updated according to the new relative path.
+// - Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original
+// uri is replaced by newURI.
+// - Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case
+// the original scheme is preserved.
+// - Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part
+// of the original uri is replaced.
+// - Relative path, i.e. xx?yy=abc . In this case the original RequestURI
+// is updated according to the new relative path.
func (u *URI) UpdateBytes(newURI []byte) {
u.requestURI = u.updateBytes(newURI, u.requestURI)
}
@@ -746,7 +802,7 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte {
path := u.Path()
n = bytes.LastIndexByte(path, '/')
if n < 0 {
- panic(fmt.Sprintf("BUG: path must contain at least one slash: %s %s", u.Path(), newURI))
+ panic(fmt.Sprintf("BUG: path must contain at least one slash: %q %q", u.Path(), newURI))
}
buf = u.appendSchemeHost(buf[:0])
buf = appendQuotedPath(buf, path[:n+1])
diff --git a/uri_test.go b/uri_test.go
index fc0b2ef..f074392 100644
--- a/uri_test.go
+++ b/uri_test.go
@@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"reflect"
+ "runtime"
"testing"
"time"
)
@@ -107,28 +108,37 @@ func TestURIUpdate(t *testing.T) {
t.Parallel()
// full uri
- testURIUpdate(t, "http://foo.bar/baz?aaa=22#aaa", "https://aa.com/bb", "https://aa.com/bb")
+ testURIUpdate(t, "http://example.net/dir/path1.html?param1=val1#fragment1", "https://example.com/dir/path2.html", "https://example.com/dir/path2.html")
// empty uri
- testURIUpdate(t, "http://aaa.com/aaa.html?234=234#add", "", "http://aaa.com/aaa.html?234=234#add")
+ testURIUpdate(t, "http://example.com/dir/path1.html?param1=val1#fragment1", "", "http://example.com/dir/path1.html?param1=val1#fragment1")
// request uri
- testURIUpdate(t, "ftp://aaa/xxx/yyy?aaa=bb#aa", "/boo/bar?xx", "ftp://aaa/boo/bar?xx")
+ testURIUpdate(t, "http://example.com/dir/path1.html?param1=val1#fragment1", "/dir/path2.html?param2=val2#fragment2", "http://example.com/dir/path2.html?param2=val2#fragment2")
+
+ // schema
+ testURIUpdate(t, "http://example.com/dir/path1.html?param1=val1#fragment1", "https://example.com/dir/path1.html?param1=val1#fragment1", "https://example.com/dir/path1.html?param1=val1#fragment1")
// relative uri
- testURIUpdate(t, "http://foo.bar/baz/xxx.html?aaa=22#aaa", "bb.html?xx=12#pp", "http://foo.bar/baz/bb.html?xx=12#pp")
- testURIUpdate(t, "http://xx/a/b/c/d", "../qwe/p?zx=34", "http://xx/a/b/qwe/p?zx=34")
- testURIUpdate(t, "https://qqq/aaa.html?foo=bar", "?baz=434&aaa#xcv", "https://qqq/aaa.html?baz=434&aaa#xcv")
- testURIUpdate(t, "http://foo.bar/baz", "~a/%20b=c,тест?йцу=ке", "http://foo.bar/~a/%20b=c,%D1%82%D0%B5%D1%81%D1%82?йцу=ке")
- testURIUpdate(t, "http://foo.bar/baz", "/qwe#fragment", "http://foo.bar/qwe#fragment")
- testURIUpdate(t, "http://foobar/baz/xxx", "aaa.html#bb?cc=dd&ee=dfd", "http://foobar/baz/aaa.html#bb?cc=dd&ee=dfd")
+ testURIUpdate(t, "http://example.com/baz/xxx.html?aaa=22#aaa", "bb.html?xx=12#pp", "http://example.com/baz/bb.html?xx=12#pp")
+
+ testURIUpdate(t, "http://example.com/aaa.html?foo=bar", "?baz=434&aaa#xcv", "http://example.com/aaa.html?baz=434&aaa#xcv")
+ testURIUpdate(t, "http://example.com/baz", "~a/%20b=c,тест?йцу=ке", "http://example.com/~a/%20b=c,%D1%82%D0%B5%D1%81%D1%82?йцу=ке")
+ testURIUpdate(t, "http://example.com/baz", "/qwe#fragment", "http://example.com/qwe#fragment")
+ testURIUpdate(t, "http://example.com/baz/xxx", "aaa.html#bb?cc=dd&ee=dfd", "http://example.com/baz/aaa.html#bb?cc=dd&ee=dfd")
+
+ if runtime.GOOS != "windows" {
+ testURIUpdate(t, "http://example.com/a/b/c/d", "../qwe/p?zx=34", "http://example.com/a/b/qwe/p?zx=34")
+ }
// hash
- testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment")
+ testURIUpdate(t, "http://example.com/#fragment1", "#fragment2", "http://example.com/#fragment2")
// uri without scheme
- testURIUpdate(t, "https://foo.bar/baz", "//aaa.bbb/cc?dd", "https://aaa.bbb/cc?dd")
- testURIUpdate(t, "http://foo.bar/baz", "//aaa.bbb/cc?dd", "http://aaa.bbb/cc?dd")
+ testURIUpdate(t, "https://example.net/dir/path1.html", "//example.com/dir/path2.html", "https://example.com/dir/path2.html")
+ testURIUpdate(t, "http://example.net/dir/path1.html", "//example.com/dir/path2.html", "http://example.com/dir/path2.html")
+ // host with port
+ testURIUpdate(t, "http://example.net/", "//example.com:8080/", "http://example.com:8080/")
}
func testURIUpdate(t *testing.T, base, update, result string) {
@@ -142,6 +152,10 @@ func testURIUpdate(t *testing.T, base, update, result string) {
}
func TestURIPathNormalize(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.SkipNow()
+ }
+
t.Parallel()
var u URI
@@ -220,7 +234,7 @@ func TestURICopyTo(t *testing.T) {
t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet
}
- u.UpdateBytes([]byte("https://google.com/foo?bar=baz&baraz#qqqq"))
+ u.UpdateBytes([]byte("https://example.com/foo?bar=baz&baraz#qqqq"))
u.CopyTo(©U)
if !reflect.DeepEqual(u, copyU) { //nolint:govet
t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet
@@ -234,27 +248,27 @@ func TestURIFullURI(t *testing.T) {
var args Args
// empty scheme, path and hash
- testURIFullURI(t, "", "foobar.com", "", "", &args, "http://foobar.com/")
+ testURIFullURI(t, "", "example.com", "", "", &args, "http://example.com/")
// empty scheme and hash
- testURIFullURI(t, "", "aa.com", "/foo/bar", "", &args, "http://aa.com/foo/bar")
+ testURIFullURI(t, "", "example.com", "/foo/bar", "", &args, "http://example.com/foo/bar")
// empty hash
- testURIFullURI(t, "fTP", "XXx.com", "/foo", "", &args, "ftp://xxx.com/foo")
+ testURIFullURI(t, "fTP", "example.com", "/foo", "", &args, "ftp://example.com/foo")
// empty args
- testURIFullURI(t, "https", "xx.com", "/", "aaa", &args, "https://xx.com/#aaa")
+ testURIFullURI(t, "https", "example.com", "/", "aaa", &args, "https://example.com/#aaa")
// non-empty args and non-ASCII path
args.Set("foo", "bar")
args.Set("xxx", "йух")
- testURIFullURI(t, "", "xxx.com", "/тест123", "2er", &args, "http://xxx.com/%D1%82%D0%B5%D1%81%D1%82123?foo=bar&xxx=%D0%B9%D1%83%D1%85#2er")
+ testURIFullURI(t, "", "example.com", "/тест123", "2er", &args, "http://example.com/%D1%82%D0%B5%D1%81%D1%82123?foo=bar&xxx=%D0%B9%D1%83%D1%85#2er")
// test with empty args and non-empty query string
var u URI
- u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) //nolint:errcheck
+ u.Parse([]byte("example.com"), []byte("/foo?bar=baz&baraz#qqqq")) //nolint:errcheck
uri := u.FullURI()
- expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq"
+ expectedURI := "http://example.com/foo?bar=baz&baraz#qqqq"
if string(uri) != expectedURI {
t.Fatalf("Unexpected URI: %q. Expected %q", uri, expectedURI)
}
@@ -278,19 +292,19 @@ func testURIFullURI(t *testing.T, scheme, host, path, hash string, args *Args, e
func TestURIParseNilHost(t *testing.T) {
t.Parallel()
- testURIParseScheme(t, "http://google.com/foo?bar#baz", "http", "google.com", "/foo?bar", "baz")
- testURIParseScheme(t, "HTtP://google.com/", "http", "google.com", "/", "")
- testURIParseScheme(t, "://google.com/xyz", "http", "google.com", "/xyz", "")
- testURIParseScheme(t, "//google.com/foobar", "http", "google.com", "/foobar", "")
- testURIParseScheme(t, "fTP://aaa.com", "ftp", "aaa.com", "/", "")
- testURIParseScheme(t, "httPS://aaa.com", "https", "aaa.com", "/", "")
+ testURIParseScheme(t, "http://example.com/foo?bar#baz", "http", "example.com", "/foo?bar", "baz")
+ testURIParseScheme(t, "HTtP://example.com/", "http", "example.com", "/", "")
+ testURIParseScheme(t, "://example.com/xyz", "http", "example.com", "/xyz", "")
+ testURIParseScheme(t, "//example.com/foobar", "http", "example.com", "/foobar", "")
+ testURIParseScheme(t, "fTP://example.com", "ftp", "example.com", "/", "")
+ testURIParseScheme(t, "httPS://example.com", "https", "example.com", "/", "")
// missing slash after hostname
- testURIParseScheme(t, "http://foobar.com?baz=111", "http", "foobar.com", "/?baz=111", "")
+ testURIParseScheme(t, "http://example.com?baz=111", "http", "example.com", "/?baz=111", "")
// slash in args
- testURIParseScheme(t, "http://foobar.com?baz=111/222/xyz", "http", "foobar.com", "/?baz=111/222/xyz", "")
- testURIParseScheme(t, "http://foobar.com?111/222/xyz", "http", "foobar.com", "/?111/222/xyz", "")
+ testURIParseScheme(t, "http://example.com?baz=111/222/xyz", "http", "example.com", "/?baz=111/222/xyz", "")
+ testURIParseScheme(t, "http://example.com?111/222/xyz", "http", "example.com", "/?111/222/xyz", "")
}
func testURIParseScheme(t *testing.T, uri, expectedScheme, expectedHost, expectedRequestURI, expectedHash string) {
@@ -310,60 +324,83 @@ func testURIParseScheme(t *testing.T, uri, expectedScheme, expectedHost, expecte
}
}
+func TestIsHttp(t *testing.T) {
+ var u URI
+ if !u.isHttp() || u.isHttps() {
+ t.Fatalf("http scheme is assumed by default and not https")
+ }
+ u.SetSchemeBytes([]byte{})
+ if !u.isHttp() || u.isHttps() {
+ t.Fatalf("empty scheme must be threaten as http and not https")
+ }
+ u.SetScheme("http")
+ if !u.isHttp() || u.isHttps() {
+ t.Fatalf("scheme must be threaten as http and not https")
+ }
+ u.SetScheme("https")
+ if !u.isHttps() || u.isHttp() {
+ t.Fatalf("scheme must be threaten as https and not http")
+ }
+ u.SetScheme("dav")
+ if u.isHttps() || u.isHttp() {
+ t.Fatalf("scheme must be threaten as not http and not https")
+ }
+}
+
func TestURIParse(t *testing.T) {
t.Parallel()
var u URI
// no args
- testURIParse(t, &u, "aaa", "sdfdsf",
- "http://aaa/sdfdsf", "aaa", "/sdfdsf", "sdfdsf", "", "")
+ testURIParse(t, &u, "example.com", "sdfdsf",
+ "http://example.com/sdfdsf", "example.com", "/sdfdsf", "sdfdsf", "", "")
// args
- testURIParse(t, &u, "xx", "/aa?ss",
- "http://xx/aa?ss", "xx", "/aa", "/aa", "ss", "")
+ testURIParse(t, &u, "example.com", "/aa?ss",
+ "http://example.com/aa?ss", "example.com", "/aa", "/aa", "ss", "")
// args and hash
- testURIParse(t, &u, "foobar.com", "/a.b.c?def=gkl#mnop",
- "http://foobar.com/a.b.c?def=gkl#mnop", "foobar.com", "/a.b.c", "/a.b.c", "def=gkl", "mnop")
+ testURIParse(t, &u, "example.com", "/a.b.c?def=gkl#mnop",
+ "http://example.com/a.b.c?def=gkl#mnop", "example.com", "/a.b.c", "/a.b.c", "def=gkl", "mnop")
// '?' and '#' in hash
- testURIParse(t, &u, "aaa.com", "/foo#bar?baz=aaa#bbb",
- "http://aaa.com/foo#bar?baz=aaa#bbb", "aaa.com", "/foo", "/foo", "", "bar?baz=aaa#bbb")
+ testURIParse(t, &u, "example.com", "/foo#bar?baz=aaa#bbb",
+ "http://example.com/foo#bar?baz=aaa#bbb", "example.com", "/foo", "/foo", "", "bar?baz=aaa#bbb")
// encoded path
- testURIParse(t, &u, "aa.com", "/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf",
- "http://aa.com/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "aa.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf")
+ testURIParse(t, &u, "example.com", "/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf",
+ "http://example.com/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "example.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf")
// host in uppercase
- testURIParse(t, &u, "FOObar.COM", "/bC?De=F#Gh",
- "http://foobar.com/bC?De=F#Gh", "foobar.com", "/bC", "/bC", "De=F", "Gh")
+ testURIParse(t, &u, "example.com", "/bC?De=F#Gh",
+ "http://example.com/bC?De=F#Gh", "example.com", "/bC", "/bC", "De=F", "Gh")
// uri with hostname
- testURIParse(t, &u, "xxx.com", "http://aaa.com/foo/bar?baz=aaa#ddd",
- "http://aaa.com/foo/bar?baz=aaa#ddd", "aaa.com", "/foo/bar", "/foo/bar", "baz=aaa", "ddd")
- testURIParse(t, &u, "xxx.com", "https://ab.com/f/b%20r?baz=aaa#ddd",
- "https://ab.com/f/b%20r?baz=aaa#ddd", "ab.com", "/f/b r", "/f/b%20r", "baz=aaa", "ddd")
+ testURIParse(t, &u, "example.com", "http://example.com/foo/bar?baz=aaa#ddd",
+ "http://example.com/foo/bar?baz=aaa#ddd", "example.com", "/foo/bar", "/foo/bar", "baz=aaa", "ddd")
+ testURIParse(t, &u, "example.net", "https://example.com/f/b%20r?baz=aaa#ddd",
+ "https://example.com/f/b%20r?baz=aaa#ddd", "example.com", "/f/b r", "/f/b%20r", "baz=aaa", "ddd")
// no slash after hostname in uri
- testURIParse(t, &u, "aaa.com", "http://google.com",
- "http://google.com/", "google.com", "/", "/", "", "")
+ testURIParse(t, &u, "example.com", "http://example.com",
+ "http://example.com/", "example.com", "/", "/", "", "")
// uppercase hostname in uri
- testURIParse(t, &u, "abc.com", "http://GoGLE.com/aaa",
- "http://gogle.com/aaa", "gogle.com", "/aaa", "/aaa", "", "")
+ testURIParse(t, &u, "example.net", "http://EXAMPLE.COM/aaa",
+ "http://example.com/aaa", "example.com", "/aaa", "/aaa", "", "")
// http:// in query params
- testURIParse(t, &u, "aaa.com", "/foo?bar=http://google.com",
- "http://aaa.com/foo?bar=http://google.com", "aaa.com", "/foo", "/foo", "bar=http://google.com", "")
+ testURIParse(t, &u, "example.com", "/foo?bar=http://example.org",
+ "http://example.com/foo?bar=http://example.org", "example.com", "/foo", "/foo", "bar=http://example.org", "")
- testURIParse(t, &u, "aaa.com", "//relative",
- "http://aaa.com/relative", "aaa.com", "/relative", "//relative", "", "")
+ testURIParse(t, &u, "example.com", "//relative",
+ "http://example.com/relative", "example.com", "/relative", "//relative", "", "")
- testURIParse(t, &u, "", "//aaa.com//absolute",
- "http://aaa.com/absolute", "aaa.com", "/absolute", "//absolute", "", "")
+ testURIParse(t, &u, "", "//example.com//absolute",
+ "http://example.com/absolute", "example.com", "/absolute", "//absolute", "", "")
- testURIParse(t, &u, "", "//aaa.com\r\n\r\nGET x",
+ testURIParse(t, &u, "", "//example.com\r\n\r\nGET x",
"http:///", "", "/", "", "", "")
testURIParse(t, &u, "", "http://[fe80::1%25en0]/",
@@ -413,7 +450,7 @@ func TestURIWithQuerystringOverride(t *testing.T) {
uriString := string(u.RequestURI())
if uriString != "/?q1=foo&q2=bar&q4=quux" {
- t.Fatalf("Expected Querystring to be overridden but was %s ", uriString)
+ t.Fatalf("Expected Querystring to be overridden but was %q ", uriString)
}
}
diff --git a/userdata.go b/userdata.go
index 9a7c988..40690f6 100644
--- a/userdata.go
+++ b/userdata.go
@@ -5,18 +5,21 @@ import (
)
type userDataKV struct {
- key []byte
+ key interface{}
value interface{}
}
type userData []userDataKV
-func (d *userData) Set(key string, value interface{}) {
+func (d *userData) Set(key interface{}, value interface{}) {
+ if b, ok := key.([]byte); ok {
+ key = string(b)
+ }
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
- if string(kv.key) == key {
+ if kv.key == key {
kv.value = value
return
}
@@ -30,28 +33,31 @@ func (d *userData) Set(key string, value interface{}) {
if c > n {
args = args[:n+1]
kv := &args[n]
- kv.key = append(kv.key[:0], key...)
+ kv.key = key
kv.value = value
*d = args
return
}
kv := userDataKV{}
- kv.key = append(kv.key[:0], key...)
+ kv.key = key
kv.value = value
*d = append(args, kv)
}
func (d *userData) SetBytes(key []byte, value interface{}) {
- d.Set(b2s(key), value)
+ d.Set(key, value)
}
-func (d *userData) Get(key string) interface{} {
+func (d *userData) Get(key interface{}) interface{} {
+ if b, ok := key.([]byte); ok {
+ key = b2s(b)
+ }
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
- if string(kv.key) == key {
+ if kv.key == key {
return kv.value
}
}
@@ -59,7 +65,7 @@ func (d *userData) Get(key string) interface{} {
}
func (d *userData) GetBytes(key []byte) interface{} {
- return d.Get(b2s(key))
+ return d.Get(key)
}
func (d *userData) Reset() {
@@ -74,14 +80,17 @@ func (d *userData) Reset() {
*d = (*d)[:0]
}
-func (d *userData) Remove(key string) {
+func (d *userData) Remove(key interface{}) {
+ if b, ok := key.([]byte); ok {
+ key = b2s(b)
+ }
args := *d
n := len(args)
for i := 0; i < n; i++ {
kv := &args[i]
- if string(kv.key) == key {
+ if kv.key == key {
n--
- args[i] = args[n]
+ args[i], args[n] = args[n], args[i]
args[n].value = nil
args = args[:n]
*d = args
@@ -91,5 +100,5 @@ func (d *userData) Remove(key string) {
}
func (d *userData) RemoveBytes(key []byte) {
- d.Remove(b2s(key))
+ d.Remove(key)
}
diff --git a/userdata_test.go b/userdata_test.go
index 94f04dd..3a081ae 100644
--- a/userdata_test.go
+++ b/userdata_test.go
@@ -92,7 +92,7 @@ func TestUserDataDelete(t *testing.T) {
k := fmt.Sprintf("key_%d", i)
u.Remove(k)
if val := u.Get(k); val != nil {
- t.Fatalf("unexpected key= %s, value =%v ,Expecting key= %s, value = nil", k, val, k)
+ t.Fatalf("unexpected key= %q, value =%v ,Expecting key= %q, value = nil", k, val, k)
}
kk := fmt.Sprintf("key_%d", i+1)
testUserDataGet(t, &u, []byte(kk), i+1)
@@ -104,3 +104,18 @@ func TestUserDataDelete(t *testing.T) {
}
}
+
+func TestUserDataSetAndRemove(t *testing.T) {
+ var (
+ u userData
+ shortKey = "[]"
+ longKey = "[ ]"
+ )
+
+ u.Set(shortKey, "")
+ u.Set(longKey, "")
+ u.Remove(shortKey)
+ u.Set(shortKey, "")
+ testUserDataGet(t, &u, []byte(shortKey), "")
+ testUserDataGet(t, &u, []byte(longKey), "")
+}
diff --git a/workerpool.go b/workerpool.go
index 9b1987e..50a5c75 100644
--- a/workerpool.go
+++ b/workerpool.go
@@ -1,6 +1,7 @@
package fasthttp
import (
+ "errors"
"net"
"runtime"
"strings"
@@ -226,8 +227,9 @@ func (wp *workerPool) workerFunc(ch *workerChan) {
strings.Contains(errStr, "reset by peer") ||
strings.Contains(errStr, "request headers: small read buffer") ||
strings.Contains(errStr, "unexpected EOF") ||
- strings.Contains(errStr, "i/o timeout")) {
- wp.Logger.Printf("error when serving connection %q<->%q: %s", c.LocalAddr(), c.RemoteAddr(), err)
+ strings.Contains(errStr, "i/o timeout") ||
+ errors.Is(err, ErrBadTrailer)) {
+ wp.Logger.Printf("error when serving connection %q<->%q: %v", c.LocalAddr(), c.RemoteAddr(), err)
}
}
if err == errHijacked {
diff --git a/workerpool_test.go b/workerpool_test.go
index 4a09d97..c7c8ed9 100644
--- a/workerpool_test.go
+++ b/workerpool_test.go
@@ -1,7 +1,7 @@
package fasthttp
import (
- "io/ioutil"
+ "io"
"net"
"testing"
"time"
@@ -86,14 +86,14 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
buf := make([]byte, 100)
n, err := conn.Read(buf)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
buf = buf[:n]
if string(buf) != "foobar" {
t.Errorf("unexpected data read: %q. Expecting %q", buf, "foobar")
}
if _, err = conn.Write([]byte("baz")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
<-ready
@@ -113,20 +113,20 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
go func() {
conn, err := ln.Dial()
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if _, err = conn.Write([]byte("foobar")); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
- data, err := ioutil.ReadAll(conn)
+ data, err := io.ReadAll(conn)
if err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
if string(data) != "baz" {
t.Errorf("unexpected value read: %q. Expecting %q", data, "baz")
}
if err = conn.Close(); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
clientCh <- struct{}{}
}()
@@ -135,7 +135,7 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
for i := 0; i < wp.MaxWorkersCount; i++ {
conn, err := ln.Accept()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
if !wp.Serve(conn) {
t.Fatalf("worker pool must have enough workers to serve the conn")
@@ -144,12 +144,12 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
go func() {
if _, err := ln.Dial(); err != nil {
- t.Errorf("unexpected error: %s", err)
+ t.Errorf("unexpected error: %v", err)
}
}()
conn, err := ln.Accept()
if err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
for i := 0; i < 5; i++ {
if wp.Serve(conn) {
@@ -157,7 +157,7 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
}
}
if err = conn.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
close(ready)
@@ -171,7 +171,7 @@ func testWorkerPoolMaxWorkersCount(t *testing.T) {
}
if err := ln.Close(); err != nil {
- t.Fatalf("unexpected error: %s", err)
+ t.Fatalf("unexpected error: %v", err)
}
wp.Stop()
}
More details
Historical runs
- failed: go1: internal compiler error: in do_get_backend, at go/gofrontend/expressions.cc:13792
- push-failed: Failed to push result branch: Connection closed: Connection closed early The remote server unexpectedly closed the connection.
- run-disappeared: Jenkins job https://jenkins.debian.net/job/janitor-worker/864330/ has disappeared
- run-disappeared: Jenkins job https://jenkins.debian.net/job/janitor-worker/855284/ has disappeared
- failed: go1: internal compiler error: in do_get_backend, at go/gofrontend/expressions.cc:13792
- build-failed-stage-build: FAIL github.com/valyala/fasthttp 661.751s
- run-disappeared: Jenkins job https://jenkins.debian.net/job/janitor-worker/603862/ has disappeared
- build-failed-stage-explain-bd-uninstallable: build failed stage explain-bd-uninstallable