|
0 |
From: Marshall Beddoe <mbeddoe@gmail.com>
|
|
1 |
Date: Fri, 9 Apr 2021 20:50:09 -0500
|
|
2 |
Subject: Add support for ReadHeaderTimeout
|
|
3 |
|
|
4 |
Set a read deadline when waiting for the PROXY protocol header.
|
|
5 |
Fix for #65
|
|
6 |
|
|
7 |
(cherry picked from commit cdc63867da24fc609b727231f682670d0d1cd346)
|
|
8 |
|
|
9 |
Closes: #991498, CVE-2021-23409
|
|
10 |
---
|
|
11 |
README.md | 32 ++++++++++++++++++++++++++
|
|
12 |
examples/client/client.go | 48 +++++++++++++++++++++++++++++++++++++++
|
|
13 |
examples/httpserver/httpserver.go | 39 +++++++++++++++++++++++++++++++
|
|
14 |
examples/server/server.go | 36 +++++++++++++++++++++++++++++
|
|
15 |
protocol.go | 11 ++++++---
|
|
16 |
protocol_test.go | 38 +++++++++++++++++++++++++++++++
|
|
17 |
6 files changed, 201 insertions(+), 3 deletions(-)
|
|
18 |
create mode 100644 examples/client/client.go
|
|
19 |
create mode 100644 examples/httpserver/httpserver.go
|
|
20 |
create mode 100644 examples/server/server.go
|
|
21 |
|
|
22 |
diff --git a/README.md b/README.md
|
|
23 |
index 1aedea5..982707c 100644
|
|
24 |
--- a/README.md
|
|
25 |
+++ b/README.md
|
|
26 |
@@ -119,6 +119,38 @@ func main() {
|
|
27 |
}
|
|
28 |
```
|
|
29 |
|
|
30 |
+### HTTP Server
|
|
31 |
+```go
|
|
32 |
+package main
|
|
33 |
+
|
|
34 |
+import (
|
|
35 |
+ "net"
|
|
36 |
+ "net/http"
|
|
37 |
+ "time"
|
|
38 |
+
|
|
39 |
+ "github.com/pires/go-proxyproto"
|
|
40 |
+)
|
|
41 |
+
|
|
42 |
+func main() {
|
|
43 |
+ server := http.Server{
|
|
44 |
+ Addr: ":8080",
|
|
45 |
+ }
|
|
46 |
+
|
|
47 |
+ ln, err := net.Listen("tcp", server.Addr)
|
|
48 |
+ if err != nil {
|
|
49 |
+ panic(err)
|
|
50 |
+ }
|
|
51 |
+
|
|
52 |
+ proxyListener := &proxyproto.Listener{
|
|
53 |
+ Listener: ln,
|
|
54 |
+ ReadHeaderTimeout: 10 * time.Second,
|
|
55 |
+ }
|
|
56 |
+ defer proxyListener.Close()
|
|
57 |
+
|
|
58 |
+ server.Serve(proxyListener)
|
|
59 |
+}
|
|
60 |
+```
|
|
61 |
+
|
|
62 |
## Special notes
|
|
63 |
|
|
64 |
### AWS
|
|
65 |
diff --git a/examples/client/client.go b/examples/client/client.go
|
|
66 |
new file mode 100644
|
|
67 |
index 0000000..7c795fa
|
|
68 |
--- /dev/null
|
|
69 |
+++ b/examples/client/client.go
|
|
70 |
@@ -0,0 +1,48 @@
|
|
71 |
+package main
|
|
72 |
+
|
|
73 |
+import (
|
|
74 |
+ "io"
|
|
75 |
+ "log"
|
|
76 |
+ "net"
|
|
77 |
+
|
|
78 |
+ proxyproto "github.com/pires/go-proxyproto"
|
|
79 |
+)
|
|
80 |
+
|
|
81 |
+func chkErr(err error) {
|
|
82 |
+ if err != nil {
|
|
83 |
+ log.Fatalf("Error: %s", err.Error())
|
|
84 |
+ }
|
|
85 |
+}
|
|
86 |
+
|
|
87 |
+func main() {
|
|
88 |
+ // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto
|
|
89 |
+ target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876")
|
|
90 |
+ chkErr(err)
|
|
91 |
+
|
|
92 |
+ conn, err := net.DialTCP("tcp", nil, target)
|
|
93 |
+ chkErr(err)
|
|
94 |
+
|
|
95 |
+ defer conn.Close()
|
|
96 |
+
|
|
97 |
+ // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you
|
|
98 |
+ // have two conn's
|
|
99 |
+ header := &proxyproto.Header{
|
|
100 |
+ Version: 1,
|
|
101 |
+ Command: proxyproto.PROXY,
|
|
102 |
+ TransportProtocol: proxyproto.TCPv4,
|
|
103 |
+ SourceAddr: &net.TCPAddr{
|
|
104 |
+ IP: net.ParseIP("10.1.1.1"),
|
|
105 |
+ Port: 1000,
|
|
106 |
+ },
|
|
107 |
+ DestinationAddr: &net.TCPAddr{
|
|
108 |
+ IP: net.ParseIP("20.2.2.2"),
|
|
109 |
+ Port: 2000,
|
|
110 |
+ },
|
|
111 |
+ }
|
|
112 |
+ // After the connection was created write the proxy headers first
|
|
113 |
+ _, err = header.WriteTo(conn)
|
|
114 |
+ chkErr(err)
|
|
115 |
+ // Then your data... e.g.:
|
|
116 |
+ _, err = io.WriteString(conn, "HELO")
|
|
117 |
+ chkErr(err)
|
|
118 |
+}
|
|
119 |
diff --git a/examples/httpserver/httpserver.go b/examples/httpserver/httpserver.go
|
|
120 |
new file mode 100644
|
|
121 |
index 0000000..b04f2c7
|
|
122 |
--- /dev/null
|
|
123 |
+++ b/examples/httpserver/httpserver.go
|
|
124 |
@@ -0,0 +1,39 @@
|
|
125 |
+package main
|
|
126 |
+
|
|
127 |
+import (
|
|
128 |
+ "log"
|
|
129 |
+ "net"
|
|
130 |
+ "net/http"
|
|
131 |
+ "time"
|
|
132 |
+
|
|
133 |
+ "github.com/pires/go-proxyproto"
|
|
134 |
+)
|
|
135 |
+
|
|
136 |
+// TODO: add httpclient example
|
|
137 |
+
|
|
138 |
+func main() {
|
|
139 |
+ server := http.Server{
|
|
140 |
+ Addr: ":8080",
|
|
141 |
+ ConnState: func(c net.Conn, s http.ConnState) {
|
|
142 |
+ if s == http.StateNew {
|
|
143 |
+ log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String())
|
|
144 |
+ }
|
|
145 |
+ },
|
|
146 |
+ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
147 |
+ log.Printf("[Handler] remote ip %q", r.RemoteAddr)
|
|
148 |
+ }),
|
|
149 |
+ }
|
|
150 |
+
|
|
151 |
+ ln, err := net.Listen("tcp", server.Addr)
|
|
152 |
+ if err != nil {
|
|
153 |
+ panic(err)
|
|
154 |
+ }
|
|
155 |
+
|
|
156 |
+ proxyListener := &proxyproto.Listener{
|
|
157 |
+ Listener: ln,
|
|
158 |
+ ReadHeaderTimeout: 10 * time.Second,
|
|
159 |
+ }
|
|
160 |
+ defer proxyListener.Close()
|
|
161 |
+
|
|
162 |
+ server.Serve(proxyListener)
|
|
163 |
+}
|
|
164 |
diff --git a/examples/server/server.go b/examples/server/server.go
|
|
165 |
new file mode 100644
|
|
166 |
index 0000000..286dc2c
|
|
167 |
--- /dev/null
|
|
168 |
+++ b/examples/server/server.go
|
|
169 |
@@ -0,0 +1,36 @@
|
|
170 |
+package main
|
|
171 |
+
|
|
172 |
+import (
|
|
173 |
+ "log"
|
|
174 |
+ "net"
|
|
175 |
+
|
|
176 |
+ proxyproto "github.com/pires/go-proxyproto"
|
|
177 |
+)
|
|
178 |
+
|
|
179 |
+func main() {
|
|
180 |
+ // Create a listener
|
|
181 |
+ addr := "localhost:9876"
|
|
182 |
+ list, err := net.Listen("tcp", addr)
|
|
183 |
+ if err != nil {
|
|
184 |
+ log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error())
|
|
185 |
+ }
|
|
186 |
+
|
|
187 |
+ // Wrap listener in a proxyproto listener
|
|
188 |
+ proxyListener := &proxyproto.Listener{Listener: list}
|
|
189 |
+ defer proxyListener.Close()
|
|
190 |
+
|
|
191 |
+ // Wait for a connection and accept it
|
|
192 |
+ conn, err := proxyListener.Accept()
|
|
193 |
+ defer conn.Close()
|
|
194 |
+
|
|
195 |
+ // Print connection details
|
|
196 |
+ if conn.LocalAddr() == nil {
|
|
197 |
+ log.Fatal("couldn't retrieve local address")
|
|
198 |
+ }
|
|
199 |
+ log.Printf("local address: %q", conn.LocalAddr().String())
|
|
200 |
+
|
|
201 |
+ if conn.RemoteAddr() == nil {
|
|
202 |
+ log.Fatal("couldn't retrieve remote address")
|
|
203 |
+ }
|
|
204 |
+ log.Printf("remote address: %q", conn.RemoteAddr().String())
|
|
205 |
+}
|
|
206 |
diff --git a/protocol.go b/protocol.go
|
|
207 |
index 878ca1c..ebad481 100644
|
|
208 |
--- a/protocol.go
|
|
209 |
+++ b/protocol.go
|
|
210 |
@@ -12,9 +12,10 @@ import (
|
|
211 |
// If the connection is using the protocol, the RemoteAddr() will return
|
|
212 |
// the correct client address.
|
|
213 |
type Listener struct {
|
|
214 |
- Listener net.Listener
|
|
215 |
- Policy PolicyFunc
|
|
216 |
- ValidateHeader Validator
|
|
217 |
+ Listener net.Listener
|
|
218 |
+ Policy PolicyFunc
|
|
219 |
+ ValidateHeader Validator
|
|
220 |
+ ReadHeaderTimeout time.Duration
|
|
221 |
}
|
|
222 |
|
|
223 |
// Conn is used to wrap and underlying connection which
|
|
224 |
@@ -51,6 +52,10 @@ func (p *Listener) Accept() (net.Conn, error) {
|
|
225 |
return nil, err
|
|
226 |
}
|
|
227 |
|
|
228 |
+ if d := p.ReadHeaderTimeout; d != 0 {
|
|
229 |
+ conn.SetReadDeadline(time.Now().Add(d))
|
|
230 |
+ }
|
|
231 |
+
|
|
232 |
proxyHeaderPolicy := USE
|
|
233 |
if p.Policy != nil {
|
|
234 |
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
|
|
235 |
diff --git a/protocol_test.go b/protocol_test.go
|
|
236 |
index f5c08be..b3d61ed 100644
|
|
237 |
--- a/protocol_test.go
|
|
238 |
+++ b/protocol_test.go
|
|
239 |
@@ -6,11 +6,13 @@ package proxyproto
|
|
240 |
|
|
241 |
import (
|
|
242 |
"bytes"
|
|
243 |
+ "context"
|
|
244 |
"crypto/tls"
|
|
245 |
"crypto/x509"
|
|
246 |
"fmt"
|
|
247 |
"net"
|
|
248 |
"testing"
|
|
249 |
+ "time"
|
|
250 |
)
|
|
251 |
|
|
252 |
func TestPassthrough(t *testing.T) {
|
|
253 |
@@ -59,6 +61,42 @@ func TestPassthrough(t *testing.T) {
|
|
254 |
}
|
|
255 |
}
|
|
256 |
|
|
257 |
+func TestReadHeaderTimeout(t *testing.T) {
|
|
258 |
+ l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
259 |
+ if err != nil {
|
|
260 |
+ t.Fatalf("err: %v", err)
|
|
261 |
+ }
|
|
262 |
+
|
|
263 |
+ pl := &Listener{
|
|
264 |
+ Listener: l,
|
|
265 |
+ ReadHeaderTimeout: 1 * time.Millisecond,
|
|
266 |
+ }
|
|
267 |
+
|
|
268 |
+ ctx, cancel := context.WithCancel(context.Background())
|
|
269 |
+ defer cancel()
|
|
270 |
+
|
|
271 |
+ go func() {
|
|
272 |
+ conn, err := net.Dial("tcp", pl.Addr().String())
|
|
273 |
+ if err != nil {
|
|
274 |
+ t.Fatalf("err: %v", err)
|
|
275 |
+ }
|
|
276 |
+ defer conn.Close()
|
|
277 |
+
|
|
278 |
+ <-ctx.Done()
|
|
279 |
+ }()
|
|
280 |
+
|
|
281 |
+ conn, err := pl.Accept()
|
|
282 |
+ if err != nil {
|
|
283 |
+ t.Fatalf("err: %v", err)
|
|
284 |
+ }
|
|
285 |
+ defer conn.Close()
|
|
286 |
+
|
|
287 |
+ // Read blocks forever if there is no ReadHeaderTimeout
|
|
288 |
+ recv := make([]byte, 4)
|
|
289 |
+ _, err = conn.Read(recv)
|
|
290 |
+
|
|
291 |
+}
|
|
292 |
+
|
|
293 |
func TestParse_ipv4(t *testing.T) {
|
|
294 |
l, err := net.Listen("tcp", "127.0.0.1:0")
|
|
295 |
if err != nil {
|