Codebase list golang-websocket / 80afda1
Update upstream source from tag 'upstream/1.4.0' Update to upstream version '1.4.0' with Debian dir c0565b70787791c81124f03c18c4625afd038cda Anthony Fok 5 years ago
39 changed file(s) with 1600 addition(s) and 500 deletion(s). Raw diff Collapse all Expand all
2121 *.exe
2222
2323 .idea/
24 *.iml
24 *.iml
22
33 matrix:
44 include:
5 - go: 1.4
6 - go: 1.5
7 - go: 1.6
8 - go: 1.7
9 - go: 1.8
5 - go: 1.7.x
6 - go: 1.8.x
7 - go: 1.9.x
8 - go: 1.10.x
9 - go: 1.11.x
1010 - go: tip
1111 allow_failures:
1212 - go: tip
33 # Please keep the list sorted.
44
55 Gary Burd <gary@beagledreams.com>
6 Google LLC (https://opensource.google.com/)
67 Joachim Bauch <mail@joachim-bauch.de>
78
5050 <tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
5151 </table>
5252
53 Notes:
53 Notes:
5454
5555 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
5656 2. The application can get the type of a received data message by implementing
44 package websocket
55
66 import (
7 "bufio"
87 "bytes"
8 "context"
99 "crypto/tls"
10 "encoding/base64"
1110 "errors"
1211 "io"
1312 "io/ioutil"
1413 "net"
1514 "net/http"
15 "net/http/httptrace"
1616 "net/url"
1717 "strings"
1818 "time"
5252 // NetDial is nil, net.Dial is used.
5353 NetDial func(network, addr string) (net.Conn, error)
5454
55 // NetDialContext specifies the dial function for creating TCP connections. If
56 // NetDialContext is nil, net.DialContext is used.
57 NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
58
5559 // Proxy specifies a function to return a proxy for a given
5660 // Request. If the function returns a non-nil error, the
5761 // request is aborted with the provided error.
7074 // do not limit the size of the messages that can be sent or received.
7175 ReadBufferSize, WriteBufferSize int
7276
77 // WriteBufferPool is a pool of buffers for write operations. If the value
78 // is not set, then write buffers are allocated to the connection for the
79 // lifetime of the connection.
80 //
81 // A pool is most useful when the application has a modest volume of writes
82 // across a large number of connections.
83 //
84 // Applications should use a single pool for each unique value of
85 // WriteBufferSize.
86 WriteBufferPool BufferPool
87
7388 // Subprotocols specifies the client's requested subprotocols.
7489 Subprotocols []string
7590
85100 Jar http.CookieJar
86101 }
87102
103 // Dial creates a new client connection by calling DialContext with a background context.
104 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
105 return d.DialContext(context.Background(), urlStr, requestHeader)
106 }
107
88108 var errMalformedURL = errors.New("malformed ws or wss URL")
89
90 // parseURL parses the URL.
91 //
92 // This function is a replacement for the standard library url.Parse function.
93 // In Go 1.4 and earlier, url.Parse loses information from the path.
94 func parseURL(s string) (*url.URL, error) {
95 // From the RFC:
96 //
97 // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
98 // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
99 var u url.URL
100 switch {
101 case strings.HasPrefix(s, "ws://"):
102 u.Scheme = "ws"
103 s = s[len("ws://"):]
104 case strings.HasPrefix(s, "wss://"):
105 u.Scheme = "wss"
106 s = s[len("wss://"):]
107 default:
108 return nil, errMalformedURL
109 }
110
111 if i := strings.Index(s, "?"); i >= 0 {
112 u.RawQuery = s[i+1:]
113 s = s[:i]
114 }
115
116 if i := strings.Index(s, "/"); i >= 0 {
117 u.Opaque = s[i:]
118 s = s[:i]
119 } else {
120 u.Opaque = "/"
121 }
122
123 u.Host = s
124
125 if strings.Contains(u.Host, "@") {
126 // Don't bother parsing user information because user information is
127 // not allowed in websocket URIs.
128 return nil, errMalformedURL
129 }
130
131 return &u, nil
132 }
133109
134110 func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
135111 hostPort = u.Host
149125 return hostPort, hostNoPort
150126 }
151127
152 // DefaultDialer is a dialer with all fields set to the default zero values.
128 // DefaultDialer is a dialer with all fields set to the default values.
153129 var DefaultDialer = &Dialer{
154 Proxy: http.ProxyFromEnvironment,
155 }
156
157 // Dial creates a new client connection. Use requestHeader to specify the
130 Proxy: http.ProxyFromEnvironment,
131 HandshakeTimeout: 45 * time.Second,
132 }
133
134 // nilDialer is dialer to use when receiver is nil.
135 var nilDialer = *DefaultDialer
136
137 // DialContext creates a new client connection. Use requestHeader to specify the
158138 // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
159139 // Use the response.Header to get the selected subprotocol
160140 // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
141 //
142 // The context will be used in the request and in the Dialer
161143 //
162144 // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
163145 // non-nil *http.Response so that callers can handle redirects, authentication,
164146 // etcetera. The response body may not contain the entire response and does not
165147 // need to be closed by the application.
166 func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
167
148 func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
168149 if d == nil {
169 d = &Dialer{
170 Proxy: http.ProxyFromEnvironment,
171 }
150 d = &nilDialer
172151 }
173152
174153 challengeKey, err := generateChallengeKey()
176155 return nil, nil, err
177156 }
178157
179 u, err := parseURL(urlStr)
158 u, err := url.Parse(urlStr)
180159 if err != nil {
181160 return nil, nil, err
182161 }
204183 Header: make(http.Header),
205184 Host: u.Host,
206185 }
186 req = req.WithContext(ctx)
207187
208188 // Set the cookies present in the cookie jar of the dialer
209189 if d.Jar != nil {
236216 k == "Sec-Websocket-Extensions" ||
237217 (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
238218 return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
219 case k == "Sec-Websocket-Protocol":
220 req.Header["Sec-WebSocket-Protocol"] = vs
239221 default:
240222 req.Header[k] = vs
241223 }
242224 }
243225
244226 if d.EnableCompression {
245 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover")
227 req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
228 }
229
230 if d.HandshakeTimeout != 0 {
231 var cancel func()
232 ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
233 defer cancel()
234 }
235
236 // Get network dial function.
237 var netDial func(network, add string) (net.Conn, error)
238
239 if d.NetDialContext != nil {
240 netDial = func(network, addr string) (net.Conn, error) {
241 return d.NetDialContext(ctx, network, addr)
242 }
243 } else if d.NetDial != nil {
244 netDial = d.NetDial
245 } else {
246 netDialer := &net.Dialer{}
247 netDial = func(network, addr string) (net.Conn, error) {
248 return netDialer.DialContext(ctx, network, addr)
249 }
250 }
251
252 // If needed, wrap the dial function to set the connection deadline.
253 if deadline, ok := ctx.Deadline(); ok {
254 forwardDial := netDial
255 netDial = func(network, addr string) (net.Conn, error) {
256 c, err := forwardDial(network, addr)
257 if err != nil {
258 return nil, err
259 }
260 err = c.SetDeadline(deadline)
261 if err != nil {
262 c.Close()
263 return nil, err
264 }
265 return c, nil
266 }
267 }
268
269 // If needed, wrap the dial function to connect through a proxy.
270 if d.Proxy != nil {
271 proxyURL, err := d.Proxy(req)
272 if err != nil {
273 return nil, nil, err
274 }
275 if proxyURL != nil {
276 dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
277 if err != nil {
278 return nil, nil, err
279 }
280 netDial = dialer.Dial
281 }
246282 }
247283
248284 hostPort, hostNoPort := hostPortNoPort(u)
249
250 var proxyURL *url.URL
251 // Check wether the proxy method has been configured
252 if d.Proxy != nil {
253 proxyURL, err = d.Proxy(req)
254 }
255 if err != nil {
256 return nil, nil, err
257 }
258
259 var targetHostPort string
260 if proxyURL != nil {
261 targetHostPort, _ = hostPortNoPort(proxyURL)
262 } else {
263 targetHostPort = hostPort
264 }
265
266 var deadline time.Time
267 if d.HandshakeTimeout != 0 {
268 deadline = time.Now().Add(d.HandshakeTimeout)
269 }
270
271 netDial := d.NetDial
272 if netDial == nil {
273 netDialer := &net.Dialer{Deadline: deadline}
274 netDial = netDialer.Dial
275 }
276
277 netConn, err := netDial("tcp", targetHostPort)
285 trace := httptrace.ContextClientTrace(ctx)
286 if trace != nil && trace.GetConn != nil {
287 trace.GetConn(hostPort)
288 }
289
290 netConn, err := netDial("tcp", hostPort)
291 if trace != nil && trace.GotConn != nil {
292 trace.GotConn(httptrace.GotConnInfo{
293 Conn: netConn,
294 })
295 }
278296 if err != nil {
279297 return nil, nil, err
280298 }
284302 netConn.Close()
285303 }
286304 }()
287
288 if err := netConn.SetDeadline(deadline); err != nil {
289 return nil, nil, err
290 }
291
292 if proxyURL != nil {
293 connectHeader := make(http.Header)
294 if user := proxyURL.User; user != nil {
295 proxyUser := user.Username()
296 if proxyPassword, passwordSet := user.Password(); passwordSet {
297 credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
298 connectHeader.Set("Proxy-Authorization", "Basic "+credential)
299 }
300 }
301 connectReq := &http.Request{
302 Method: "CONNECT",
303 URL: &url.URL{Opaque: hostPort},
304 Host: hostPort,
305 Header: connectHeader,
306 }
307
308 connectReq.Write(netConn)
309
310 // Read response.
311 // Okay to use and discard buffered reader here, because
312 // TLS server will not speak until spoken to.
313 br := bufio.NewReader(netConn)
314 resp, err := http.ReadResponse(br, connectReq)
315 if err != nil {
316 return nil, nil, err
317 }
318 if resp.StatusCode != 200 {
319 f := strings.SplitN(resp.Status, " ", 2)
320 return nil, nil, errors.New(f[1])
321 }
322 }
323305
324306 if u.Scheme == "https" {
325307 cfg := cloneTLSConfig(d.TLSClientConfig)
328310 }
329311 tlsConn := tls.Client(netConn, cfg)
330312 netConn = tlsConn
331 if err := tlsConn.Handshake(); err != nil {
313
314 var err error
315 if trace != nil {
316 err = doHandshakeWithTrace(trace, tlsConn, cfg)
317 } else {
318 err = doHandshake(tlsConn, cfg)
319 }
320
321 if err != nil {
332322 return nil, nil, err
333323 }
334 if !cfg.InsecureSkipVerify {
335 if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
336 return nil, nil, err
337 }
338 }
339 }
340
341 conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
324 }
325
326 conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
342327
343328 if err := req.Write(netConn); err != nil {
344329 return nil, nil, err
330 }
331
332 if trace != nil && trace.GotFirstResponseByte != nil {
333 if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
334 trace.GotFirstResponseByte()
335 }
345336 }
346337
347338 resp, err := http.ReadResponse(conn.br, req)
389380 netConn = nil // to avoid close in defer.
390381 return conn, resp, nil
391382 }
383
384 func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
385 if err := tlsConn.Handshake(); err != nil {
386 return err
387 }
388 if !cfg.InsecureSkipVerify {
389 if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
390 return err
391 }
392 }
393 return nil
394 }
44 package websocket
55
66 import (
7 "bytes"
8 "context"
79 "crypto/tls"
810 "crypto/x509"
911 "encoding/base64"
12 "encoding/binary"
1013 "io"
1114 "io/ioutil"
15 "net"
1216 "net/http"
1317 "net/http/cookiejar"
1418 "net/http/httptest"
19 "net/http/httptrace"
1520 "net/url"
1621 "reflect"
1722 "strings"
3035 }
3136
3237 var cstDialer = Dialer{
38 Subprotocols: []string{"p1", "p2"},
39 ReadBufferSize: 1024,
40 WriteBufferSize: 1024,
41 HandshakeTimeout: 30 * time.Second,
42 }
43
44 var cstDialerWithoutHandshakeTimeout = Dialer{
3345 Subprotocols: []string{"p1", "p2"},
3446 ReadBufferSize: 1024,
3547 WriteBufferSize: 1024,
6779 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6880 if r.URL.Path != cstPath {
6981 t.Logf("path=%v, want %v", r.URL.Path, cstPath)
70 http.Error(w, "bad path", 400)
82 http.Error(w, "bad path", http.StatusBadRequest)
7183 return
7284 }
7385 if r.URL.RawQuery != cstRawQuery {
7486 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
75 http.Error(w, "bad path", 400)
87 http.Error(w, "bad path", http.StatusBadRequest)
7688 return
7789 }
7890 subprotos := Subprotocols(r)
7991 if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
8092 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
81 http.Error(w, "bad protocol", 400)
93 http.Error(w, "bad protocol", http.StatusBadRequest)
8294 return
8395 }
8496 ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
142154 s := newServer(t)
143155 defer s.Close()
144156
145 surl, _ := url.Parse(s.URL)
146
157 surl, _ := url.Parse(s.Server.URL)
158
159 cstDialer := cstDialer // make local copy for modification on next line.
147160 cstDialer.Proxy = http.ProxyURL(surl)
148161
149162 connect := false
154167 func(w http.ResponseWriter, r *http.Request) {
155168 if r.Method == "CONNECT" {
156169 connect = true
157 w.WriteHeader(200)
170 w.WriteHeader(http.StatusOK)
158171 return
159172 }
160173
161174 if !connect {
162 t.Log("connect not recieved")
163 http.Error(w, "connect not recieved", 405)
175 t.Log("connect not received")
176 http.Error(w, "connect not received", http.StatusMethodNotAllowed)
164177 return
165178 }
166179 origHandler.ServeHTTP(w, r)
172185 }
173186 defer ws.Close()
174187 sendRecv(t, ws)
175
176 cstDialer.Proxy = http.ProxyFromEnvironment
177188 }
178189
179190 func TestProxyAuthorizationDial(t *testing.T) {
180191 s := newServer(t)
181192 defer s.Close()
182193
183 surl, _ := url.Parse(s.URL)
194 surl, _ := url.Parse(s.Server.URL)
184195 surl.User = url.UserPassword("username", "password")
196
197 cstDialer := cstDialer // make local copy for modification on next line.
185198 cstDialer.Proxy = http.ProxyURL(surl)
186199
187200 connect := false
194207 expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
195208 if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
196209 connect = true
197 w.WriteHeader(200)
210 w.WriteHeader(http.StatusOK)
198211 return
199212 }
200213
201214 if !connect {
202 t.Log("connect with proxy authorization not recieved")
203 http.Error(w, "connect with proxy authorization not recieved", 405)
215 t.Log("connect with proxy authorization not received")
216 http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
204217 return
205218 }
206219 origHandler.ServeHTTP(w, r)
212225 }
213226 defer ws.Close()
214227 sendRecv(t, ws)
215
216 cstDialer.Proxy = http.ProxyFromEnvironment
217228 }
218229
219230 func TestDial(t *testing.T) {
236247 d := cstDialer
237248 d.Jar = jar
238249
239 u, _ := parseURL(s.URL)
250 u, _ := url.Parse(s.URL)
240251
241252 switch u.Scheme {
242253 case "ws":
245256 u.Scheme = "https"
246257 }
247258
248 cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}}
259 cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
249260 d.Jar.SetCookies(u, cookies)
250261
251262 ws, _, err := d.Dial(s.URL, nil)
338349 ws.Close()
339350 t.Fatalf("Dial: nil")
340351 }
352 }
353
354 // requireDeadlineNetConn fails the current test when Read or Write are called
355 // with no deadline.
356 type requireDeadlineNetConn struct {
357 t *testing.T
358 c net.Conn
359 readDeadlineIsSet bool
360 writeDeadlineIsSet bool
361 }
362
363 func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
364 c.writeDeadlineIsSet = !t.Equal(time.Time{})
365 c.readDeadlineIsSet = c.writeDeadlineIsSet
366 return c.c.SetDeadline(t)
367 }
368
369 func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
370 c.readDeadlineIsSet = !t.Equal(time.Time{})
371 return c.c.SetDeadline(t)
372 }
373
374 func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
375 c.writeDeadlineIsSet = !t.Equal(time.Time{})
376 return c.c.SetDeadline(t)
377 }
378
379 func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
380 if !c.writeDeadlineIsSet {
381 c.t.Fatalf("write with no deadline")
382 }
383 return c.c.Write(p)
384 }
385
386 func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
387 if !c.readDeadlineIsSet {
388 c.t.Fatalf("read with no deadline")
389 }
390 return c.c.Read(p)
391 }
392
393 func (c *requireDeadlineNetConn) Close() error { return c.c.Close() }
394 func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() }
395 func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
396
397 func TestHandshakeTimeout(t *testing.T) {
398 s := newServer(t)
399 defer s.Close()
400
401 d := cstDialer
402 d.NetDial = func(n, a string) (net.Conn, error) {
403 c, err := net.Dial(n, a)
404 return &requireDeadlineNetConn{c: c, t: t}, err
405 }
406 ws, _, err := d.Dial(s.URL, nil)
407 if err != nil {
408 t.Fatal("Dial:", err)
409 }
410 ws.Close()
411 }
412
413 func TestHandshakeTimeoutInContext(t *testing.T) {
414 s := newServer(t)
415 defer s.Close()
416
417 d := cstDialerWithoutHandshakeTimeout
418 d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
419 netDialer := &net.Dialer{}
420 c, err := netDialer.DialContext(ctx, n, a)
421 return &requireDeadlineNetConn{c: c, t: t}, err
422 }
423
424 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
425 defer cancel()
426 ws, _, err := d.DialContext(ctx, s.URL, nil)
427 if err != nil {
428 t.Fatal("Dial:", err)
429 }
430 ws.Close()
341431 }
342432
343433 func TestDialBadScheme(t *testing.T) {
397487 }))
398488 defer s.Close()
399489
400 resp, err := http.PostForm(s.URL, url.Values{})
401 if err != nil {
402 t.Fatalf("PostForm returned error %v", err)
490 req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
491 if err != nil {
492 t.Fatalf("NewRequest returned error %v", err)
493 }
494 req.Header.Set("Connection", "upgrade")
495 req.Header.Set("Upgrade", "websocket")
496 req.Header.Set("Sec-Websocket-Version", "13")
497
498 resp, err := http.DefaultClient.Do(req)
499 if err != nil {
500 t.Fatalf("Do returned error %v", err)
403501 }
404502 resp.Body.Close()
405503 if resp.StatusCode != http.StatusMethodNotAllowed {
509607 defer ws.Close()
510608 sendRecv(t, ws)
511609 }
610
611 func TestSocksProxyDial(t *testing.T) {
612 s := newServer(t)
613 defer s.Close()
614
615 proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
616 if err != nil {
617 t.Fatalf("listen failed: %v", err)
618 }
619 defer proxyListener.Close()
620 go func() {
621 c1, err := proxyListener.Accept()
622 if err != nil {
623 t.Errorf("proxy accept failed: %v", err)
624 return
625 }
626 defer c1.Close()
627
628 c1.SetDeadline(time.Now().Add(30 * time.Second))
629
630 buf := make([]byte, 32)
631 if _, err := io.ReadFull(c1, buf[:3]); err != nil {
632 t.Errorf("read failed: %v", err)
633 return
634 }
635 if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
636 t.Errorf("read %x, want %x", buf[:len(want)], want)
637 }
638 if _, err := c1.Write([]byte{5, 0}); err != nil {
639 t.Errorf("write failed: %v", err)
640 return
641 }
642 if _, err := io.ReadFull(c1, buf[:10]); err != nil {
643 t.Errorf("read failed: %v", err)
644 return
645 }
646 if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
647 t.Errorf("read %x, want %x", buf[:len(want)], want)
648 return
649 }
650 buf[1] = 0
651 if _, err := c1.Write(buf[:10]); err != nil {
652 t.Errorf("write failed: %v", err)
653 return
654 }
655
656 ip := net.IP(buf[4:8])
657 port := binary.BigEndian.Uint16(buf[8:10])
658
659 c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
660 if err != nil {
661 t.Errorf("dial failed; %v", err)
662 return
663 }
664 defer c2.Close()
665 done := make(chan struct{})
666 go func() {
667 io.Copy(c1, c2)
668 close(done)
669 }()
670 io.Copy(c2, c1)
671 <-done
672 }()
673
674 purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
675 if err != nil {
676 t.Fatalf("parse failed: %v", err)
677 }
678
679 cstDialer := cstDialer // make local copy for modification on next line.
680 cstDialer.Proxy = http.ProxyURL(purl)
681
682 ws, _, err := cstDialer.Dial(s.URL, nil)
683 if err != nil {
684 t.Fatalf("Dial: %v", err)
685 }
686 defer ws.Close()
687 sendRecv(t, ws)
688 }
689
690 func TestTracingDialWithContext(t *testing.T) {
691
692 var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
693 trace := &httptrace.ClientTrace{
694 WroteHeaders: func() {
695 headersWrote = true
696 },
697 WroteRequest: func(httptrace.WroteRequestInfo) {
698 requestWrote = true
699 },
700 GetConn: func(hostPort string) {
701 getConn = true
702 },
703 GotConn: func(info httptrace.GotConnInfo) {
704 gotConn = true
705 },
706 ConnectDone: func(network, addr string, err error) {
707 connectDone = true
708 },
709 GotFirstResponseByte: func() {
710 gotFirstResponseByte = true
711 },
712 }
713 ctx := httptrace.WithClientTrace(context.Background(), trace)
714
715 s := newTLSServer(t)
716 defer s.Close()
717
718 certs := x509.NewCertPool()
719 for _, c := range s.TLS.Certificates {
720 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
721 if err != nil {
722 t.Fatalf("error parsing server's root cert: %v", err)
723 }
724 for _, root := range roots {
725 certs.AddCert(root)
726 }
727 }
728
729 d := cstDialer
730 d.TLSClientConfig = &tls.Config{RootCAs: certs}
731
732 ws, _, err := d.DialContext(ctx, s.URL, nil)
733 if err != nil {
734 t.Fatalf("Dial: %v", err)
735 }
736
737 if !headersWrote {
738 t.Fatal("Headers was not written")
739 }
740 if !requestWrote {
741 t.Fatal("Request was not written")
742 }
743 if !getConn {
744 t.Fatal("getConn was not called")
745 }
746 if !gotConn {
747 t.Fatal("gotConn was not called")
748 }
749 if !connectDone {
750 t.Fatal("connectDone was not called")
751 }
752 if !gotFirstResponseByte {
753 t.Fatal("GotFirstResponseByte was not called")
754 }
755
756 defer ws.Close()
757 sendRecv(t, ws)
758 }
759
760 func TestEmptyTracingDialWithContext(t *testing.T) {
761
762 trace := &httptrace.ClientTrace{}
763 ctx := httptrace.WithClientTrace(context.Background(), trace)
764
765 s := newTLSServer(t)
766 defer s.Close()
767
768 certs := x509.NewCertPool()
769 for _, c := range s.TLS.Certificates {
770 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
771 if err != nil {
772 t.Fatalf("error parsing server's root cert: %v", err)
773 }
774 for _, root := range roots {
775 certs.AddCert(root)
776 }
777 }
778
779 d := cstDialer
780 d.TLSClientConfig = &tls.Config{RootCAs: certs}
781
782 ws, _, err := d.DialContext(ctx, s.URL, nil)
783 if err != nil {
784 t.Fatalf("Dial: %v", err)
785 }
786
787 defer ws.Close()
788 sendRecv(t, ws)
789 }
55
66 import (
77 "net/url"
8 "reflect"
98 "testing"
109 )
11
12 var parseURLTests = []struct {
13 s string
14 u *url.URL
15 rui string
16 }{
17 {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
18 {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"},
19 {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"},
20 {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"},
21 {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"},
22 {"ss://example.com/a/b", nil, ""},
23 {"ws://webmaster@example.com/", nil, ""},
24 {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"},
25 {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"},
26 }
27
28 func TestParseURL(t *testing.T) {
29 for _, tt := range parseURLTests {
30 u, err := parseURL(tt.s)
31 if tt.u != nil && err != nil {
32 t.Errorf("parseURL(%q) returned error %v", tt.s, err)
33 continue
34 }
35 if tt.u == nil {
36 if err == nil {
37 t.Errorf("parseURL(%q) did not return error", tt.s)
38 }
39 continue
40 }
41 if !reflect.DeepEqual(u, tt.u) {
42 t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u)
43 continue
44 }
45 if u.RequestURI() != tt.rui {
46 t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui)
47 }
48 }
49 }
5010
5111 var hostPortNoPortTests = []struct {
5212 u *url.URL
4242
4343 func BenchmarkWriteNoCompression(b *testing.B) {
4444 w := ioutil.Discard
45 c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
45 c := newTestConn(nil, w, false)
4646 messages := textMessages(100)
4747 b.ResetTimer()
4848 for i := 0; i < b.N; i++ {
5353
5454 func BenchmarkWriteWithCompression(b *testing.B) {
5555 w := ioutil.Discard
56 c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
56 c := newTestConn(nil, w, false)
5757 messages := textMessages(100)
5858 c.enableWriteCompression = true
5959 c.newCompressionWriter = compressNoContextTakeover
6565 }
6666
6767 func TestValidCompressionLevel(t *testing.T) {
68 c := newConn(fakeNetConn{}, false, 1024, 1024)
68 c := newTestConn(nil, nil, false)
6969 for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
7070 if err := c.SetCompressionLevel(level); err == nil {
7171 t.Errorf("no error for level %d", level)
7575 // is UTF-8 encoded text.
7676 PingMessage = 9
7777
78 // PongMessage denotes a ping control message. The optional message payload
78 // PongMessage denotes a pong control message. The optional message payload
7979 // is UTF-8 encoded text.
8080 PongMessage = 10
8181 )
9999 func (e *netError) Temporary() bool { return e.temporary }
100100 func (e *netError) Timeout() bool { return e.timeout }
101101
102 // CloseError represents close frame.
102 // CloseError represents a close message.
103103 type CloseError struct {
104
105104 // Code is defined in RFC 6455, section 11.7.
106105 Code int
107106
223222 return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
224223 }
225224
225 // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
226 // interface. The type of the value stored in a pool is not specified.
227 type BufferPool interface {
228 // Get gets a value from the pool or returns nil if the pool is empty.
229 Get() interface{}
230 // Put adds a value to the pool.
231 Put(interface{})
232 }
233
234 // writePoolData is the type added to the write buffer pool. This wrapper is
235 // used to prevent applications from peeking at and depending on the values
236 // added to the pool.
237 type writePoolData struct{ buf []byte }
238
226239 // The Conn type represents a WebSocket connection.
227240 type Conn struct {
228241 conn net.Conn
232245 // Write fields
233246 mu chan bool // used as mutex to protect write to conn
234247 writeBuf []byte // frame is constructed in this buffer.
248 writePool BufferPool
249 writeBufSize int
235250 writeDeadline time.Time
236251 writer io.WriteCloser // the current writer returned to the application
237252 isWriting bool // for best-effort concurrent write detection
263278 newDecompressionReader func(io.Reader) io.ReadCloser
264279 }
265280
266 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn {
267 return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
268 }
269
270 type writeHook struct {
271 p []byte
272 }
273
274 func (wh *writeHook) Write(p []byte) (int, error) {
275 wh.p = p
276 return len(p), nil
277 }
278
279 func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
280 mu := make(chan bool, 1)
281 mu <- true
282
283 var br *bufio.Reader
284 if readBufferSize == 0 && brw != nil && brw.Reader != nil {
285 // Reuse the supplied bufio.Reader if the buffer has a useful size.
286 // This code assumes that peek on a reader returns
287 // bufio.Reader.buf[:0].
288 brw.Reader.Reset(conn)
289 if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
290 br = brw.Reader
291 }
292 }
281 func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
282
293283 if br == nil {
294284 if readBufferSize == 0 {
295285 readBufferSize = defaultReadBufferSize
296 }
297 if readBufferSize < maxControlFramePayloadSize {
286 } else if readBufferSize < maxControlFramePayloadSize {
287 // must be large enough for control frame
298288 readBufferSize = maxControlFramePayloadSize
299289 }
300290 br = bufio.NewReaderSize(conn, readBufferSize)
301291 }
302292
303 var writeBuf []byte
304 if writeBufferSize == 0 && brw != nil && brw.Writer != nil {
305 // Use the bufio.Writer's buffer if the buffer has a useful size. This
306 // code assumes that bufio.Writer.buf[:1] is passed to the
307 // bufio.Writer's underlying writer.
308 var wh writeHook
309 brw.Writer.Reset(&wh)
310 brw.Writer.WriteByte(0)
311 brw.Flush()
312 if cap(wh.p) >= maxFrameHeaderSize+256 {
313 writeBuf = wh.p[:cap(wh.p)]
314 }
315 }
316
317 if writeBuf == nil {
318 if writeBufferSize == 0 {
319 writeBufferSize = defaultWriteBufferSize
320 }
321 writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
322 }
323
293 if writeBufferSize <= 0 {
294 writeBufferSize = defaultWriteBufferSize
295 }
296 writeBufferSize += maxFrameHeaderSize
297
298 if writeBuf == nil && writeBufferPool == nil {
299 writeBuf = make([]byte, writeBufferSize)
300 }
301
302 mu := make(chan bool, 1)
303 mu <- true
324304 c := &Conn{
325305 isServer: isServer,
326306 br: br,
328308 mu: mu,
329309 readFinal: true,
330310 writeBuf: writeBuf,
311 writePool: writeBufferPool,
312 writeBufSize: writeBufferSize,
331313 enableWriteCompression: true,
332314 compressionLevel: defaultCompressionLevel,
333315 }
342324 return c.subprotocol
343325 }
344326
345 // Close closes the underlying network connection without sending or waiting for a close frame.
327 // Close closes the underlying network connection without sending or waiting
328 // for a close message.
346329 func (c *Conn) Close() error {
347330 return c.conn.Close()
348331 }
369352 return err
370353 }
371354
372 func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
355 func (c *Conn) read(n int) ([]byte, error) {
356 p, err := c.br.Peek(n)
357 if err == io.EOF {
358 err = errUnexpectedEOF
359 }
360 c.br.Discard(len(p))
361 return p, err
362 }
363
364 func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
373365 <-c.mu
374366 defer func() { c.mu <- true }()
375367
381373 }
382374
383375 c.conn.SetWriteDeadline(deadline)
384 for _, buf := range bufs {
385 if len(buf) > 0 {
386 _, err := c.conn.Write(buf)
387 if err != nil {
388 return c.writeFatal(err)
389 }
390 }
391 }
392
376 if len(buf1) == 0 {
377 _, err = c.conn.Write(buf0)
378 } else {
379 err = c.writeBufs(buf0, buf1)
380 }
381 if err != nil {
382 return c.writeFatal(err)
383 }
393384 if frameType == CloseMessage {
394385 c.writeFatal(ErrCloseSent)
395386 }
475466 c.writeErrMu.Lock()
476467 err := c.writeErr
477468 c.writeErrMu.Unlock()
478 return err
469 if err != nil {
470 return err
471 }
472
473 if c.writeBuf == nil {
474 wpd, ok := c.writePool.Get().(writePoolData)
475 if ok {
476 c.writeBuf = wpd.buf
477 } else {
478 c.writeBuf = make([]byte, c.writeBufSize)
479 }
480 }
481 return nil
479482 }
480483
481484 // NextWriter returns a writer for the next message to send. The writer's Close
483486 //
484487 // There can be at most one open writer on a connection. NextWriter closes the
485488 // previous writer if the application has not already done so.
489 //
490 // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
491 // PongMessage) are supported.
486492 func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
487493 if err := c.prepWrite(messageType); err != nil {
488494 return nil, err
598604
599605 if final {
600606 c.writer = nil
607 if c.writePool != nil {
608 c.writePool.Put(writePoolData{buf: c.writeBuf})
609 c.writeBuf = nil
610 }
601611 return nil
602612 }
603613
763773 // Read methods
764774
765775 func (c *Conn) advanceFrame() (int, error) {
766
767776 // 1. Skip remainder of previous frame.
768777
769778 if c.readRemaining > 0 {
10321041 }
10331042
10341043 // SetReadLimit sets the maximum size for a message read from the peer. If a
1035 // message exceeds the limit, the connection sends a close frame to the peer
1044 // message exceeds the limit, the connection sends a close message to the peer
10361045 // and returns ErrReadLimit to the application.
10371046 func (c *Conn) SetReadLimit(limit int64) {
10381047 c.readLimit = limit
10451054
10461055 // SetCloseHandler sets the handler for close messages received from the peer.
10471056 // The code argument to h is the received close code or CloseNoStatusReceived
1048 // if the close message is empty. The default close handler sends a close frame
1049 // back to the peer.
1057 // if the close message is empty. The default close handler sends a close
1058 // message back to the peer.
10501059 //
1051 // The application must read the connection to process close messages as
1052 // described in the section on Control Frames above.
1060 // The handler function is called from the NextReader, ReadMessage and message
1061 // reader Read methods. The application must read the connection to process
1062 // close messages as described in the section on Control Messages above.
10531063 //
1054 // The connection read methods return a CloseError when a close frame is
1064 // The connection read methods return a CloseError when a close message is
10551065 // received. Most applications should handle close messages as part of their
10561066 // normal error handling. Applications should only set a close handler when the
1057 // application must perform some action before sending a close frame back to
1067 // application must perform some action before sending a close message back to
10581068 // the peer.
10591069 func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
10601070 if h == nil {
10611071 h = func(code int, text string) error {
1062 message := []byte{}
1063 if code != CloseNoStatusReceived {
1064 message = FormatCloseMessage(code, "")
1065 }
1072 message := FormatCloseMessage(code, "")
10661073 c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
10671074 return nil
10681075 }
10761083 }
10771084
10781085 // SetPingHandler sets the handler for ping messages received from the peer.
1079 // The appData argument to h is the PING frame application data. The default
1086 // The appData argument to h is the PING message application data. The default
10801087 // ping handler sends a pong to the peer.
10811088 //
1082 // The application must read the connection to process ping messages as
1083 // described in the section on Control Frames above.
1089 // The handler function is called from the NextReader, ReadMessage and message
1090 // reader Read methods. The application must read the connection to process
1091 // ping messages as described in the section on Control Messages above.
10841092 func (c *Conn) SetPingHandler(h func(appData string) error) {
10851093 if h == nil {
10861094 h = func(message string) error {
11021110 }
11031111
11041112 // SetPongHandler sets the handler for pong messages received from the peer.
1105 // The appData argument to h is the PONG frame application data. The default
1113 // The appData argument to h is the PONG message application data. The default
11061114 // pong handler does nothing.
11071115 //
1108 // The application must read the connection to process ping messages as
1109 // described in the section on Control Frames above.
1116 // The handler function is called from the NextReader, ReadMessage and message
1117 // reader Read methods. The application must read the connection to process
1118 // pong messages as described in the section on Control Messages above.
11101119 func (c *Conn) SetPongHandler(h func(appData string) error) {
11111120 if h == nil {
11121121 h = func(string) error { return nil }
11401149 }
11411150
11421151 // FormatCloseMessage formats closeCode and text as a WebSocket close message.
1152 // An empty message is returned for code CloseNoStatusReceived.
11431153 func FormatCloseMessage(closeCode int, text string) []byte {
1154 if closeCode == CloseNoStatusReceived {
1155 // Return empty message because it's illegal to send
1156 // CloseNoStatusReceived. Return non-nil value in case application
1157 // checks for nil.
1158 return []byte{}
1159 }
11441160 buf := make([]byte, 2+len(text))
11451161 binary.BigEndian.PutUint16(buf, uint16(closeCode))
11461162 copy(buf[2:], text)
00 // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
11 // Use of this source code is governed by a BSD-style
22 // license that can be found in the LICENSE file.
3
4 // +build go1.7
53
64 package websocket
75
6967 conns := make([]*broadcastConn, numConns)
7068
7169 for i := 0; i < numConns; i++ {
72 c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024)
70 c := newTestConn(nil, b.w, true)
7371 if b.compression {
7472 c.enableWriteCompression = true
7573 c.newCompressionWriter = compressNoContextTakeover
+0
-18
conn_read.go less more
0 // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
1 // Use of this source code is governed by a BSD-style
2 // license that can be found in the LICENSE file.
3
4 // +build go1.5
5
6 package websocket
7
8 import "io"
9
10 func (c *Conn) read(n int) ([]byte, error) {
11 p, err := c.br.Peek(n)
12 if err == io.EOF {
13 err = errUnexpectedEOF
14 }
15 c.br.Discard(len(p))
16 return p, err
17 }
+0
-21
conn_read_legacy.go less more
0 // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
1 // Use of this source code is governed by a BSD-style
2 // license that can be found in the LICENSE file.
3
4 // +build !go1.5
5
6 package websocket
7
8 import "io"
9
10 func (c *Conn) read(n int) ([]byte, error) {
11 p, err := c.br.Peek(n)
12 if err == io.EOF {
13 err = errUnexpectedEOF
14 }
15 if len(p) > 0 {
16 // advance over the bytes just read
17 io.ReadFull(c.br, p)
18 }
19 return p, err
20 }
1212 "io/ioutil"
1313 "net"
1414 "reflect"
15 "sync"
1516 "testing"
1617 "testing/iotest"
1718 "time"
4445
4546 func (a fakeAddr) String() string {
4647 return "str"
48 }
49
50 // newTestConn creates a connnection backed by a fake network connection using
51 // default values for buffering.
52 func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn {
53 return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil)
4754 }
4855
4956 func TestFraming(t *testing.T) {
8188 for _, chunker := range readChunkers {
8289
8390 var connBuf bytes.Buffer
84 wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
85 rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
91 wc := newTestConn(nil, &connBuf, isServer)
92 rc := newTestConn(chunker.f(&connBuf), nil, !isServer)
8693 if compress {
8794 wc.newCompressionWriter = compressNoContextTakeover
8895 rc.newDecompressionReader = decompressNoContextTakeover
142149 for _, isWriteControl := range []bool{true, false} {
143150 name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
144151 var connBuf bytes.Buffer
145 wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
146 rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
152 wc := newTestConn(nil, &connBuf, isServer)
153 rc := newTestConn(&connBuf, nil, !isServer)
147154 if isWriteControl {
148155 wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
149156 } else {
172179 }
173180 }
174181
182 // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool.
183 type simpleBufferPool struct {
184 v interface{}
185 }
186
187 func (p *simpleBufferPool) Get() interface{} {
188 v := p.v
189 p.v = nil
190 return v
191 }
192
193 func (p *simpleBufferPool) Put(v interface{}) {
194 p.v = v
195 }
196
197 func TestWriteBufferPool(t *testing.T) {
198 var buf bytes.Buffer
199 var pool simpleBufferPool
200 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
201 rc := newTestConn(&buf, nil, false)
202
203 if wc.writeBuf != nil {
204 t.Fatal("writeBuf not nil after create")
205 }
206
207 // Part 1: test NextWriter/Write/Close
208
209 w, err := wc.NextWriter(TextMessage)
210 if err != nil {
211 t.Fatalf("wc.NextWriter() returned %v", err)
212 }
213
214 if wc.writeBuf == nil {
215 t.Fatal("writeBuf is nil after NextWriter")
216 }
217
218 writeBufAddr := &wc.writeBuf[0]
219
220 const message = "Hello World!"
221
222 if _, err := io.WriteString(w, message); err != nil {
223 t.Fatalf("io.WriteString(w, message) returned %v", err)
224 }
225
226 if err := w.Close(); err != nil {
227 t.Fatalf("w.Close() returned %v", err)
228 }
229
230 if wc.writeBuf != nil {
231 t.Fatal("writeBuf not nil after w.Close()")
232 }
233
234 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
235 t.Fatal("writeBuf not returned to pool")
236 }
237
238 opCode, p, err := rc.ReadMessage()
239 if opCode != TextMessage || err != nil {
240 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
241 }
242
243 if s := string(p); s != message {
244 t.Fatalf("message is %s, want %s", s, message)
245 }
246
247 // Part 2: Test WriteMessage.
248
249 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
250 t.Fatalf("wc.WriteMessage() returned %v", err)
251 }
252
253 if wc.writeBuf != nil {
254 t.Fatal("writeBuf not nil after wc.WriteMessage()")
255 }
256
257 if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr {
258 t.Fatal("writeBuf not returned to pool after WriteMessage")
259 }
260
261 opCode, p, err = rc.ReadMessage()
262 if opCode != TextMessage || err != nil {
263 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
264 }
265
266 if s := string(p); s != message {
267 t.Fatalf("message is %s, want %s", s, message)
268 }
269 }
270
271 func TestWriteBufferPoolSync(t *testing.T) {
272 var buf bytes.Buffer
273 var pool sync.Pool
274 wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil)
275 rc := newTestConn(&buf, nil, false)
276
277 const message = "Hello World!"
278 for i := 0; i < 3; i++ {
279 if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil {
280 t.Fatalf("wc.WriteMessage() returned %v", err)
281 }
282 opCode, p, err := rc.ReadMessage()
283 if opCode != TextMessage || err != nil {
284 t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err)
285 }
286 if s := string(p); s != message {
287 t.Fatalf("message is %s, want %s", s, message)
288 }
289 }
290 }
291
175292 func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
176293 const bufSize = 512
177294
178295 expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
179296
180297 var b1, b2 bytes.Buffer
181 wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
182 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
298 wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
299 rc := newTestConn(&b1, &b2, true)
183300
184301 w, _ := wc.NextWriter(BinaryMessage)
185302 w.Write(make([]byte, bufSize+bufSize/2))
205322
206323 for n := 0; ; n++ {
207324 var b bytes.Buffer
208 wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
209 rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
325 wc := newTestConn(nil, &b, false)
326 rc := newTestConn(&b, nil, true)
210327
211328 w, _ := wc.NextWriter(BinaryMessage)
212329 w.Write(make([]byte, bufSize))
239356 const bufSize = 512
240357
241358 var b1, b2 bytes.Buffer
242 wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
243 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
359 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil)
360 rc := newTestConn(&b1, &b2, true)
244361
245362 w, _ := wc.NextWriter(BinaryMessage)
246363 w.Write(make([]byte, bufSize+bufSize/2))
260377 }
261378
262379 func TestWriteAfterMessageWriterClose(t *testing.T) {
263 wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
380 wc := newTestConn(nil, &bytes.Buffer{}, false)
264381 w, _ := wc.NextWriter(BinaryMessage)
265382 io.WriteString(w, "hello")
266383 if err := w.Close(); err != nil {
291408 message := make([]byte, readLimit+1)
292409
293410 var b1, b2 bytes.Buffer
294 wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
295 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
411 wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil)
412 rc := newTestConn(&b1, &b2, true)
296413 rc.SetReadLimit(readLimit)
297414
298415 // Send message at the limit with interleaved pong.
320437 }
321438
322439 func TestAddrs(t *testing.T) {
323 c := newConn(&fakeNetConn{}, true, 1024, 1024)
440 c := newTestConn(nil, nil, true)
324441 if c.LocalAddr() != localAddr {
325442 t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
326443 }
332449 func TestUnderlyingConn(t *testing.T) {
333450 var b1, b2 bytes.Buffer
334451 fc := fakeNetConn{Reader: &b1, Writer: &b2}
335 c := newConn(fc, true, 1024, 1024)
452 c := newConn(fc, true, 1024, 1024, nil, nil, nil)
336453 ul := c.UnderlyingConn()
337454 if ul != fc {
338455 t.Fatalf("Underlying conn is not what it should be.")
340457 }
341458
342459 func TestBufioReadBytes(t *testing.T) {
343
344460 // Test calling bufio.ReadBytes for value longer than read buffer size.
345461
346462 m := make([]byte, 512)
347463 m[len(m)-1] = '\n'
348464
349465 var b1, b2 bytes.Buffer
350 wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
351 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
466 wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil)
467 rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil)
352468
353469 w, _ := wc.NextWriter(BinaryMessage)
354470 w.Write(m)
365481 t.Fatalf("ReadBytes() returned %v", err)
366482 }
367483 if len(p) != len(m) {
368 t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m))
484 t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
369485 }
370486 }
371487
423539
424540 func TestConcurrentWritePanic(t *testing.T) {
425541 w := blockingWriter{make(chan struct{}), make(chan struct{})}
426 c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
542 c := newTestConn(nil, w, false)
427543 go func() {
428544 c.WriteMessage(TextMessage, []byte{})
429545 }()
449565 }
450566
451567 func TestFailedConnectionReadPanic(t *testing.T) {
452 c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
568 c := newTestConn(failingReader{}, nil, false)
453569
454570 defer func() {
455571 if v := recover(); v != nil {
462578 }
463579 t.Fatal("should not get here")
464580 }
465
466 func TestBufioReuse(t *testing.T) {
467 brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
468 c := newConnBRW(nil, false, 0, 0, brw)
469
470 if c.br != brw.Reader {
471 t.Error("connection did not reuse bufio.Reader")
472 }
473
474 var wh writeHook
475 brw.Writer.Reset(&wh)
476 brw.WriteByte(0)
477 brw.Flush()
478 if &c.writeBuf[0] != &wh.p[0] {
479 t.Error("connection did not reuse bufio.Writer")
480 }
481
482 brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
483 c = newConnBRW(nil, false, 0, 0, brw)
484
485 if c.br == brw.Reader {
486 t.Error("connection used bufio.Reader with small size")
487 }
488
489 brw.Writer.Reset(&wh)
490 brw.WriteByte(0)
491 brw.Flush()
492 if &c.writeBuf[0] != &wh.p[0] {
493 t.Error("connection used bufio.Writer with small size")
494 }
495
496 }
0 // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
1 // Use of this source code is governed by a BSD-style
2 // license that can be found in the LICENSE file.
3
4 // +build go1.8
5
6 package websocket
7
8 import "net"
9
10 func (c *Conn) writeBufs(bufs ...[]byte) error {
11 b := net.Buffers(bufs)
12 _, err := b.WriteTo(c.conn)
13 return err
14 }
0 // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
1 // Use of this source code is governed by a BSD-style
2 // license that can be found in the LICENSE file.
3
4 // +build !go1.8
5
6 package websocket
7
8 func (c *Conn) writeBufs(bufs ...[]byte) error {
9 for _, buf := range bufs {
10 if len(buf) > 0 {
11 if _, err := c.conn.Write(buf); err != nil {
12 return err
13 }
14 }
15 }
16 return nil
17 }
55 //
66 // Overview
77 //
8 // The Conn type represents a WebSocket connection. A server application uses
9 // the Upgrade function from an Upgrader object with a HTTP request handler
10 // to get a pointer to a Conn:
8 // The Conn type represents a WebSocket connection. A server application calls
9 // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
1110 //
1211 // var upgrader = websocket.Upgrader{
1312 // ReadBufferSize: 1024,
3029 // for {
3130 // messageType, p, err := conn.ReadMessage()
3231 // if err != nil {
32 // log.Println(err)
3333 // return
3434 // }
35 // if err = conn.WriteMessage(messageType, p); err != nil {
36 // return err
35 // if err := conn.WriteMessage(messageType, p); err != nil {
36 // log.Println(err)
37 // return
3738 // }
3839 // }
3940 //
8485 // and pong. Call the connection WriteControl, WriteMessage or NextWriter
8586 // methods to send a control message to the peer.
8687 //
87 // Connections handle received close messages by sending a close message to the
88 // peer and returning a *CloseError from the the NextReader, ReadMessage or the
89 // message Read method.
88 // Connections handle received close messages by calling the handler function
89 // set with the SetCloseHandler method and by returning a *CloseError from the
90 // NextReader, ReadMessage or the message Read method. The default close
91 // handler sends a close message to the peer.
9092 //
91 // Connections handle received ping and pong messages by invoking callback
92 // functions set with SetPingHandler and SetPongHandler methods. The callback
93 // functions are called from the NextReader, ReadMessage and the message Read
94 // methods.
93 // Connections handle received ping messages by calling the handler function
94 // set with the SetPingHandler method. The default ping handler sends a pong
95 // message to the peer.
9596 //
96 // The default ping handler sends a pong to the peer. The application's reading
97 // goroutine can block for a short time while the handler writes the pong data
98 // to the connection.
97 // Connections handle received pong messages by calling the handler function
98 // set with the SetPongHandler method. The default pong handler does nothing.
99 // If an application sends ping messages, then the application should set a
100 // pong handler to receive the corresponding pong.
99101 //
100 // The application must read the connection to process ping, pong and close
102 // The control message handler functions are called from the NextReader,
103 // ReadMessage and message reader Read methods. The default close and ping
104 // handlers can block these methods for a short time when the handler writes to
105 // the connection.
106 //
107 // The application must read the connection to process close, ping and pong
101108 // messages sent from the peer. If the application is not otherwise interested
102109 // in messages from the peer, then the application should start a goroutine to
103110 // read and discard messages from the peer. A simple example is:
136143 // method fails the WebSocket handshake with HTTP status 403.
137144 //
138145 // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
139 // the handshake if the Origin request header is present and not equal to the
140 // Host request header.
146 // the handshake if the Origin request header is present and the Origin host is
147 // not equal to the Host request header.
141148 //
142 // An application can allow connections from any origin by specifying a
143 // function that always returns true:
144 //
145 // var upgrader = websocket.Upgrader{
146 // CheckOrigin: func(r *http.Request) bool { return true },
147 // }
148 //
149 // The deprecated Upgrade function does not enforce an origin policy. It's the
150 // application's responsibility to check the Origin header before calling
151 // Upgrade.
149 // The deprecated package-level Upgrade function does not perform origin
150 // checking. The application is responsible for checking the Origin header
151 // before calling the Upgrade function.
152152 //
153153 // Compression EXPERIMENTAL
154154 //
156156
157157 func serveHome(w http.ResponseWriter, r *http.Request) {
158158 if r.URL.Path != "/" {
159 http.Error(w, "Not found.", 404)
159 http.Error(w, "Not found.", http.StatusNotFound)
160160 return
161161 }
162162 if r.Method != "GET" {
163 http.Error(w, "Method not allowed", 405)
163 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
164164 return
165165 }
166166 w.Header().Set("Content-Type", "text/html; charset=utf-8")
00 # Chat Example
11
2 This application shows how to use use the
2 This application shows how to use the
33 [websocket](https://github.com/gorilla/websocket) package to implement a simple
44 web chat application.
55
6363 for {
6464 _, message, err := c.conn.ReadMessage()
6565 if err != nil {
66 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
66 if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
6767 log.Printf("error: %v", err)
6868 }
6969 break
112112 }
113113 case <-ticker.C:
114114 c.conn.SetWriteDeadline(time.Now().Add(writeWait))
115 if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
115 if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
116116 return
117117 }
118118 }
33
44 package main
55
6 // hub maintains the set of active clients and broadcasts messages to the
6 // Hub maintains the set of active clients and broadcasts messages to the
77 // clients.
88 type Hub struct {
99 // Registered clients.
1414 func serveHome(w http.ResponseWriter, r *http.Request) {
1515 log.Println(r.URL)
1616 if r.URL.Path != "/" {
17 http.Error(w, "Not found", 404)
17 http.Error(w, "Not found", http.StatusNotFound)
1818 return
1919 }
2020 if r.Method != "GET" {
21 http.Error(w, "Method not allowed", 405)
21 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
2222 return
2323 }
2424 http.ServeFile(w, r, "home.html")
166166
167167 func serveHome(w http.ResponseWriter, r *http.Request) {
168168 if r.URL.Path != "/" {
169 http.Error(w, "Not found", 404)
169 http.Error(w, "Not found", http.StatusNotFound)
170170 return
171171 }
172172 if r.Method != "GET" {
173 http.Error(w, "Method not allowed", 405)
173 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
174174 return
175175 }
176176 http.ServeFile(w, r, "home.html")
3737 done := make(chan struct{})
3838
3939 go func() {
40 defer c.Close()
4140 defer close(done)
4241 for {
4342 _, message, err := c.ReadMessage()
5453
5554 for {
5655 select {
56 case <-done:
57 return
5758 case t := <-ticker.C:
5859 err := c.WriteMessage(websocket.TextMessage, []byte(t.String()))
5960 if err != nil {
6263 }
6364 case <-interrupt:
6465 log.Println("interrupt")
65 // To cleanly close a connection, a client should send a close
66 // frame and wait for the server to close the connection.
66
67 // Cleanly close the connection by sending a close message and then
68 // waiting (with timeout) for the server to close the connection.
6769 err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
6870 if err != nil {
6971 log.Println("write close:", err)
7375 case <-done:
7476 case <-time.After(time.Second):
7577 }
76 c.Close()
7778 return
7879 }
7980 }
5454
5555 var homeTemplate = template.Must(template.New("").Parse(`
5656 <!DOCTYPE html>
57 <html>
5758 <head>
5859 <meta charset="utf-8">
5960 <script>
129129
130130 func serveHome(w http.ResponseWriter, r *http.Request) {
131131 if r.URL.Path != "/" {
132 http.Error(w, "Not found", 404)
132 http.Error(w, "Not found", http.StatusNotFound)
133133 return
134134 }
135135 if r.Method != "GET" {
136 http.Error(w, "Method not allowed", 405)
136 http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
137137 return
138138 }
139139 w.Header().Set("Content-Type", "text/html; charset=utf-8")
88 "io"
99 )
1010
11 // WriteJSON is deprecated, use c.WriteJSON instead.
11 // WriteJSON writes the JSON encoding of v as a message.
12 //
13 // Deprecated: Use c.WriteJSON instead.
1214 func WriteJSON(c *Conn, v interface{}) error {
1315 return c.WriteJSON(v)
1416 }
1517
16 // WriteJSON writes the JSON encoding of v to the connection.
18 // WriteJSON writes the JSON encoding of v as a message.
1719 //
1820 // See the documentation for encoding/json Marshal for details about the
1921 // conversion of Go values to JSON.
3032 return err2
3133 }
3234
33 // ReadJSON is deprecated, use c.ReadJSON instead.
35 // ReadJSON reads the next JSON-encoded message from the connection and stores
36 // it in the value pointed to by v.
37 //
38 // Deprecated: Use c.ReadJSON instead.
3439 func ReadJSON(c *Conn, v interface{}) error {
3540 return c.ReadJSON(v)
3641 }
1313
1414 func TestJSON(t *testing.T) {
1515 var buf bytes.Buffer
16 c := fakeNetConn{&buf, &buf}
17 wc := newConn(c, true, 1024, 1024)
18 rc := newConn(c, false, 1024, 1024)
16 wc := newTestConn(nil, &buf, true)
17 rc := newTestConn(&buf, nil, false)
1918
2019 var actual, expect struct {
2120 A int
3837 }
3938
4039 func TestPartialJSONRead(t *testing.T) {
41 var buf bytes.Buffer
42 c := fakeNetConn{&buf, &buf}
43 wc := newConn(c, true, 1024, 1024)
44 rc := newConn(c, false, 1024, 1024)
40 var buf0, buf1 bytes.Buffer
41 wc := newTestConn(nil, &buf0, true)
42 rc := newTestConn(&buf0, &buf1, false)
4543
4644 var v struct {
4745 A int
9391
9492 func TestDeprecatedJSON(t *testing.T) {
9593 var buf bytes.Buffer
96 c := fakeNetConn{&buf, &buf}
97 wc := newConn(c, true, 1024, 1024)
98 rc := newConn(c, false, 1024, 1024)
94 wc := newTestConn(nil, &buf, true)
95 rc := newTestConn(&buf, nil, false)
9996
10097 var actual, expect struct {
10198 A int
1010 const wordSize = int(unsafe.Sizeof(uintptr(0)))
1111
1212 func maskBytes(key [4]byte, pos int, b []byte) int {
13
1413 // Mask one byte at a time for small buffers.
1514 if len(b) < 2*wordSize {
1615 for i := range b {
11 // this source code is governed by a BSD-style license that can be found in the
22 // LICENSE file.
33
4 // Require 1.7 for sub-bencmarks
5 // +build go1.7,!appengine
4 // !appengine
65
76 package websocket
87
1818 type PreparedMessage struct {
1919 messageType int
2020 data []byte
21 err error
2221 mu sync.Mutex
2322 frames map[prepareKey]*preparedFrame
2423 }
3535 for _, tt := range preparedMessageTests {
3636 var data = []byte("this is a test")
3737 var buf bytes.Buffer
38 c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024)
38 c := newTestConn(nil, &buf, tt.isServer)
3939 if tt.enableWriteCompression {
4040 c.newCompressionWriter = compressNoContextTakeover
4141 }
0 // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
1 // Use of this source code is governed by a BSD-style
2 // license that can be found in the LICENSE file.
3
4 package websocket
5
6 import (
7 "bufio"
8 "encoding/base64"
9 "errors"
10 "net"
11 "net/http"
12 "net/url"
13 "strings"
14 )
15
16 type netDialerFunc func(network, addr string) (net.Conn, error)
17
18 func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
19 return fn(network, addr)
20 }
21
22 func init() {
23 proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
24 return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
25 })
26 }
27
28 type httpProxyDialer struct {
29 proxyURL *url.URL
30 fowardDial func(network, addr string) (net.Conn, error)
31 }
32
33 func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
34 hostPort, _ := hostPortNoPort(hpd.proxyURL)
35 conn, err := hpd.fowardDial(network, hostPort)
36 if err != nil {
37 return nil, err
38 }
39
40 connectHeader := make(http.Header)
41 if user := hpd.proxyURL.User; user != nil {
42 proxyUser := user.Username()
43 if proxyPassword, passwordSet := user.Password(); passwordSet {
44 credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
45 connectHeader.Set("Proxy-Authorization", "Basic "+credential)
46 }
47 }
48
49 connectReq := &http.Request{
50 Method: "CONNECT",
51 URL: &url.URL{Opaque: addr},
52 Host: addr,
53 Header: connectHeader,
54 }
55
56 if err := connectReq.Write(conn); err != nil {
57 conn.Close()
58 return nil, err
59 }
60
61 // Read response. It's OK to use and discard buffered reader here becaue
62 // the remote server does not speak until spoken to.
63 br := bufio.NewReader(conn)
64 resp, err := http.ReadResponse(br, connectReq)
65 if err != nil {
66 conn.Close()
67 return nil, err
68 }
69
70 if resp.StatusCode != 200 {
71 conn.Close()
72 f := strings.SplitN(resp.Status, " ", 2)
73 return nil, errors.New(f[1])
74 }
75 return conn, nil
76 }
66 import (
77 "bufio"
88 "errors"
9 "net"
9 "io"
1010 "net/http"
1111 "net/url"
1212 "strings"
3232 // or received.
3333 ReadBufferSize, WriteBufferSize int
3434
35 // WriteBufferPool is a pool of buffers for write operations. If the value
36 // is not set, then write buffers are allocated to the connection for the
37 // lifetime of the connection.
38 //
39 // A pool is most useful when the application has a modest volume of writes
40 // across a large number of connections.
41 //
42 // Applications should use a single pool for each unique value of
43 // WriteBufferSize.
44 WriteBufferPool BufferPool
45
3546 // Subprotocols specifies the server's supported protocols in order of
36 // preference. If this field is set, then the Upgrade method negotiates a
47 // preference. If this field is not nil, then the Upgrade method negotiates a
3748 // subprotocol by selecting the first match in this list with a protocol
38 // requested by the client.
49 // requested by the client. If there's no match, then no protocol is
50 // negotiated (the Sec-Websocket-Protocol header is not included in the
51 // handshake response).
3952 Subprotocols []string
4053
4154 // Error specifies the function for generating HTTP error responses. If Error
4356 Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
4457
4558 // CheckOrigin returns true if the request Origin header is acceptable. If
46 // CheckOrigin is nil, the host in the Origin header must not be set or
47 // must match the host of the request.
59 // CheckOrigin is nil, then a safe default is used: return false if the
60 // Origin request header is present and the origin host is not equal to
61 // request Host header.
62 //
63 // A CheckOrigin function should carefully validate the request origin to
64 // prevent cross-site request forgery.
4865 CheckOrigin func(r *http.Request) bool
4966
5067 // EnableCompression specify if the server should attempt to negotiate per
7592 if err != nil {
7693 return false
7794 }
78 return u.Host == r.Host
95 return equalASCIIFold(u.Host, r.Host)
7996 }
8097
8198 func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
98115 //
99116 // The responseHeader is included in the response to the client's upgrade
100117 // request. Use the responseHeader to specify cookies (Set-Cookie) and the
101 // application negotiated subprotocol (Sec-Websocket-Protocol).
118 // application negotiated subprotocol (Sec-WebSocket-Protocol).
102119 //
103120 // If the upgrade fails, then Upgrade replies to the client with an HTTP error
104121 // response.
105122 func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
123 const badHandshake = "websocket: the client is not using the websocket protocol: "
124
125 if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
126 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
127 }
128
129 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
130 return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
131 }
132
106133 if r.Method != "GET" {
107 return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")
108 }
109
110 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
111 return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
112 }
113
114 if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
115 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header")
116 }
117
118 if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
119 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header")
134 return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
120135 }
121136
122137 if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
123138 return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
139 }
140
141 if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
142 return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
124143 }
125144
126145 checkOrigin := u.CheckOrigin
128147 checkOrigin = checkSameOrigin
129148 }
130149 if !checkOrigin(r) {
131 return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed")
150 return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
132151 }
133152
134153 challengeKey := r.Header.Get("Sec-Websocket-Key")
135154 if challengeKey == "" {
136 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank")
155 return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
137156 }
138157
139158 subprotocol := u.selectSubprotocol(r, responseHeader)
150169 }
151170 }
152171
153 var (
154 netConn net.Conn
155 err error
156 )
157
158172 h, ok := w.(http.Hijacker)
159173 if !ok {
160174 return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
161175 }
162176 var brw *bufio.ReadWriter
163 netConn, brw, err = h.Hijack()
177 netConn, brw, err := h.Hijack()
164178 if err != nil {
165179 return u.returnError(w, r, http.StatusInternalServerError, err.Error())
166180 }
170184 return nil, errors.New("websocket: client sent data before handshake is complete")
171185 }
172186
173 c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw)
187 var br *bufio.Reader
188 if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
189 // Reuse hijacked buffered reader as connection reader.
190 br = brw.Reader
191 }
192
193 buf := bufioWriterBuffer(netConn, brw.Writer)
194
195 var writeBuf []byte
196 if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
197 // Reuse hijacked write buffer as connection buffer.
198 writeBuf = buf
199 }
200
201 c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
174202 c.subprotocol = subprotocol
175203
176204 if compress {
178206 c.newDecompressionReader = decompressNoContextTakeover
179207 }
180208
181 p := c.writeBuf[:0]
209 // Use larger of hijacked buffer and connection write buffer for header.
210 p := buf
211 if len(c.writeBuf) > len(p) {
212 p = c.writeBuf
213 }
214 p = p[:0]
215
182216 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
183217 p = append(p, computeAcceptKey(challengeKey)...)
184218 p = append(p, "\r\n"...)
185219 if c.subprotocol != "" {
186 p = append(p, "Sec-Websocket-Protocol: "...)
220 p = append(p, "Sec-WebSocket-Protocol: "...)
187221 p = append(p, c.subprotocol...)
188222 p = append(p, "\r\n"...)
189223 }
190224 if compress {
191 p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
225 p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
192226 }
193227 for k, vs := range responseHeader {
194228 if k == "Sec-Websocket-Protocol" {
229263
230264 // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
231265 //
232 // This function is deprecated, use websocket.Upgrader instead.
233 //
234 // The application is responsible for checking the request origin before
235 // calling Upgrade. An example implementation of the same origin policy is:
266 // Deprecated: Use websocket.Upgrader instead.
267 //
268 // Upgrade does not perform origin checking. The application is responsible for
269 // checking the Origin header before calling Upgrade. An example implementation
270 // of the same origin policy check is:
236271 //
237272 // if req.Header.Get("Origin") != "http://"+req.Host {
238 // http.Error(w, "Origin not allowed", 403)
273 // http.Error(w, "Origin not allowed", http.StatusForbidden)
239274 // return
240275 // }
241276 //
288323 return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
289324 tokenListContainsValue(r.Header, "Upgrade", "websocket")
290325 }
326
327 // bufioReaderSize size returns the size of a bufio.Reader.
328 func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
329 // This code assumes that peek on a reset reader returns
330 // bufio.Reader.buf[:0].
331 // TODO: Use bufio.Reader.Size() after Go 1.10
332 br.Reset(originalReader)
333 if p, err := br.Peek(0); err == nil {
334 return cap(p)
335 }
336 return 0
337 }
338
339 // writeHook is an io.Writer that records the last slice passed to it vio
340 // io.Writer.Write.
341 type writeHook struct {
342 p []byte
343 }
344
345 func (wh *writeHook) Write(p []byte) (int, error) {
346 wh.p = p
347 return len(p), nil
348 }
349
350 // bufioWriterBuffer grabs the buffer from a bufio.Writer.
351 func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
352 // This code assumes that bufio.Writer.buf[:1] is passed to the
353 // bufio.Writer's underlying writer.
354 var wh writeHook
355 bw.Reset(&wh)
356 bw.WriteByte(0)
357 bw.Flush()
358
359 bw.Reset(originalWriter)
360
361 return wh.p[:cap(wh.p)]
362 }
44 package websocket
55
66 import (
7 "bufio"
8 "bytes"
9 "net"
710 "net/http"
811 "reflect"
12 "strings"
913 "testing"
1014 )
1115
4852 }
4953 }
5054 }
55
56 var checkSameOriginTests = []struct {
57 ok bool
58 r *http.Request
59 }{
60 {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}},
61 {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
62 {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}},
63 }
64
65 func TestCheckSameOrigin(t *testing.T) {
66 for _, tt := range checkSameOriginTests {
67 ok := checkSameOrigin(tt.r)
68 if tt.ok != ok {
69 t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok)
70 }
71 }
72 }
73
74 type reuseTestResponseWriter struct {
75 brw *bufio.ReadWriter
76 http.ResponseWriter
77 }
78
79 func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
80 return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil
81 }
82
83 var bufioReuseTests = []struct {
84 n int
85 reuse bool
86 }{
87 {4096, true},
88 {128, false},
89 }
90
91 func TestBufioReuse(t *testing.T) {
92 for i, tt := range bufioReuseTests {
93 br := bufio.NewReaderSize(strings.NewReader(""), tt.n)
94 bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n)
95 resp := &reuseTestResponseWriter{
96 brw: bufio.NewReadWriter(br, bw),
97 }
98 upgrader := Upgrader{}
99 c, err := upgrader.Upgrade(resp, &http.Request{
100 Method: "GET",
101 Header: http.Header{
102 "Upgrade": []string{"websocket"},
103 "Connection": []string{"upgrade"},
104 "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="},
105 "Sec-Websocket-Version": []string{"13"},
106 }}, nil)
107 if err != nil {
108 t.Fatal(err)
109 }
110 if reuse := c.br == br; reuse != tt.reuse {
111 t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse)
112 }
113 writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw)
114 if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse {
115 t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse)
116 }
117 }
118 }
0 // +build go1.8
1
2 package websocket
3
4 import (
5 "crypto/tls"
6 "net/http/httptrace"
7 )
8
9 func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
10 if trace.TLSHandshakeStart != nil {
11 trace.TLSHandshakeStart()
12 }
13 err := doHandshake(tlsConn, cfg)
14 if trace.TLSHandshakeDone != nil {
15 trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
16 }
17 return err
18 }
0 // +build !go1.8
1
2 package websocket
3
4 import (
5 "crypto/tls"
6 "net/http/httptrace"
7 )
8
9 func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
10 return doHandshake(tlsConn, cfg)
11 }
1010 "io"
1111 "net/http"
1212 "strings"
13 "unicode/utf8"
1314 )
1415
1516 var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
110111 case escape:
111112 escape = false
112113 p[j] = b
113 j += 1
114 j++
114115 case b == '\\':
115116 escape = true
116117 case b == '"':
117118 return string(p[:j]), s[i+1:]
118119 default:
119120 p[j] = b
120 j += 1
121 j++
121122 }
122123 }
123124 return "", ""
126127 return "", ""
127128 }
128129
130 // equalASCIIFold returns true if s is equal to t with ASCII case folding.
131 func equalASCIIFold(s, t string) bool {
132 for s != "" && t != "" {
133 sr, size := utf8.DecodeRuneInString(s)
134 s = s[size:]
135 tr, size := utf8.DecodeRuneInString(t)
136 t = t[size:]
137 if sr == tr {
138 continue
139 }
140 if 'A' <= sr && sr <= 'Z' {
141 sr = sr + 'a' - 'A'
142 }
143 if 'A' <= tr && tr <= 'Z' {
144 tr = tr + 'a' - 'A'
145 }
146 if sr != tr {
147 return false
148 }
149 }
150 return s == t
151 }
152
129153 // tokenListContainsValue returns true if the 1#token header with the given
130 // name contains token.
154 // name contains a token equal to value with ASCII case folding.
131155 func tokenListContainsValue(header http.Header, name string, value string) bool {
132156 headers:
133157 for _, s := range header[name] {
141165 if s != "" && s[0] != ',' {
142166 continue headers
143167 }
144 if strings.EqualFold(t, value) {
168 if equalASCIIFold(t, value) {
145169 return true
146170 }
147171 if s == "" {
153177 return false
154178 }
155179
156 // parseExtensiosn parses WebSocket extensions from a header.
180 // parseExtensions parses WebSocket extensions from a header.
157181 func parseExtensions(header http.Header) []map[string]string {
158
159182 // From RFC 6455:
160183 //
161184 // Sec-WebSocket-Extensions = extension-list
88 "reflect"
99 "testing"
1010 )
11
12 var equalASCIIFoldTests = []struct {
13 t, s string
14 eq bool
15 }{
16 {"WebSocket", "websocket", true},
17 {"websocket", "WebSocket", true},
18 {"Öyster", "öyster", false},
19 }
20
21 func TestEqualASCIIFold(t *testing.T) {
22 for _, tt := range equalASCIIFoldTests {
23 eq := equalASCIIFold(tt.s, tt.t)
24 if eq != tt.eq {
25 t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq)
26 }
27 }
28 }
1129
1230 var tokenListContainsValueTests = []struct {
1331 value string
3755 value string
3856 extensions []map[string]string
3957 }{
40 {`foo`, []map[string]string{map[string]string{"": "foo"}}},
58 {`foo`, []map[string]string{{"": "foo"}}},
4159 {`foo, bar; baz=2`, []map[string]string{
42 map[string]string{"": "foo"},
43 map[string]string{"": "bar", "baz": "2"}}},
60 {"": "foo"},
61 {"": "bar", "baz": "2"}}},
4462 {`foo; bar="b,a;z"`, []map[string]string{
45 map[string]string{"": "foo", "bar": "b,a;z"}}},
63 {"": "foo", "bar": "b,a;z"}}},
4664 {`foo , bar; baz = 2`, []map[string]string{
47 map[string]string{"": "foo"},
48 map[string]string{"": "bar", "baz": "2"}}},
65 {"": "foo"},
66 {"": "bar", "baz": "2"}}},
4967 {`foo, bar; baz=2 junk`, []map[string]string{
50 map[string]string{"": "foo"}}},
68 {"": "foo"}}},
5169 {`foo junk, bar; baz=2 junk`, nil},
5270 {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{
53 map[string]string{"": "mux", "max-channels": "4", "flow-control": ""},
54 map[string]string{"": "deflate-stream"}}},
71 {"": "mux", "max-channels": "4", "flow-control": ""},
72 {"": "deflate-stream"}}},
5573 {`permessage-foo; x="10"`, []map[string]string{
56 map[string]string{"": "permessage-foo", "x": "10"}}},
74 {"": "permessage-foo", "x": "10"}}},
5775 {`permessage-foo; use_y, permessage-foo`, []map[string]string{
58 map[string]string{"": "permessage-foo", "use_y": ""},
59 map[string]string{"": "permessage-foo"}}},
76 {"": "permessage-foo", "use_y": ""},
77 {"": "permessage-foo"}}},
6078 {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{
61 map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
62 map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}},
79 {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"},
80 {"": "permessage-deflate", "client_max_window_bits": ""}}},
81 {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{
82 {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"},
83 }},
6384 }
6485
6586 func TestParseExtensions(t *testing.T) {
0 // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
1 //go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
2
3 // Package proxy provides support for a variety of protocols to proxy network
4 // data.
5 //
6
7 package websocket
8
9 import (
10 "errors"
11 "io"
12 "net"
13 "net/url"
14 "os"
15 "strconv"
16 "strings"
17 "sync"
18 )
19
20 type proxy_direct struct{}
21
22 // Direct is a direct proxy: one that makes network connections directly.
23 var proxy_Direct = proxy_direct{}
24
25 func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
26 return net.Dial(network, addr)
27 }
28
29 // A PerHost directs connections to a default Dialer unless the host name
30 // requested matches one of a number of exceptions.
31 type proxy_PerHost struct {
32 def, bypass proxy_Dialer
33
34 bypassNetworks []*net.IPNet
35 bypassIPs []net.IP
36 bypassZones []string
37 bypassHosts []string
38 }
39
40 // NewPerHost returns a PerHost Dialer that directs connections to either
41 // defaultDialer or bypass, depending on whether the connection matches one of
42 // the configured rules.
43 func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
44 return &proxy_PerHost{
45 def: defaultDialer,
46 bypass: bypass,
47 }
48 }
49
50 // Dial connects to the address addr on the given network through either
51 // defaultDialer or bypass.
52 func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
53 host, _, err := net.SplitHostPort(addr)
54 if err != nil {
55 return nil, err
56 }
57
58 return p.dialerForRequest(host).Dial(network, addr)
59 }
60
61 func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
62 if ip := net.ParseIP(host); ip != nil {
63 for _, net := range p.bypassNetworks {
64 if net.Contains(ip) {
65 return p.bypass
66 }
67 }
68 for _, bypassIP := range p.bypassIPs {
69 if bypassIP.Equal(ip) {
70 return p.bypass
71 }
72 }
73 return p.def
74 }
75
76 for _, zone := range p.bypassZones {
77 if strings.HasSuffix(host, zone) {
78 return p.bypass
79 }
80 if host == zone[1:] {
81 // For a zone ".example.com", we match "example.com"
82 // too.
83 return p.bypass
84 }
85 }
86 for _, bypassHost := range p.bypassHosts {
87 if bypassHost == host {
88 return p.bypass
89 }
90 }
91 return p.def
92 }
93
94 // AddFromString parses a string that contains comma-separated values
95 // specifying hosts that should use the bypass proxy. Each value is either an
96 // IP address, a CIDR range, a zone (*.example.com) or a host name
97 // (localhost). A best effort is made to parse the string and errors are
98 // ignored.
99 func (p *proxy_PerHost) AddFromString(s string) {
100 hosts := strings.Split(s, ",")
101 for _, host := range hosts {
102 host = strings.TrimSpace(host)
103 if len(host) == 0 {
104 continue
105 }
106 if strings.Contains(host, "/") {
107 // We assume that it's a CIDR address like 127.0.0.0/8
108 if _, net, err := net.ParseCIDR(host); err == nil {
109 p.AddNetwork(net)
110 }
111 continue
112 }
113 if ip := net.ParseIP(host); ip != nil {
114 p.AddIP(ip)
115 continue
116 }
117 if strings.HasPrefix(host, "*.") {
118 p.AddZone(host[1:])
119 continue
120 }
121 p.AddHost(host)
122 }
123 }
124
125 // AddIP specifies an IP address that will use the bypass proxy. Note that
126 // this will only take effect if a literal IP address is dialed. A connection
127 // to a named host will never match an IP.
128 func (p *proxy_PerHost) AddIP(ip net.IP) {
129 p.bypassIPs = append(p.bypassIPs, ip)
130 }
131
132 // AddNetwork specifies an IP range that will use the bypass proxy. Note that
133 // this will only take effect if a literal IP address is dialed. A connection
134 // to a named host will never match.
135 func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
136 p.bypassNetworks = append(p.bypassNetworks, net)
137 }
138
139 // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
140 // "example.com" matches "example.com" and all of its subdomains.
141 func (p *proxy_PerHost) AddZone(zone string) {
142 if strings.HasSuffix(zone, ".") {
143 zone = zone[:len(zone)-1]
144 }
145 if !strings.HasPrefix(zone, ".") {
146 zone = "." + zone
147 }
148 p.bypassZones = append(p.bypassZones, zone)
149 }
150
151 // AddHost specifies a host name that will use the bypass proxy.
152 func (p *proxy_PerHost) AddHost(host string) {
153 if strings.HasSuffix(host, ".") {
154 host = host[:len(host)-1]
155 }
156 p.bypassHosts = append(p.bypassHosts, host)
157 }
158
159 // A Dialer is a means to establish a connection.
160 type proxy_Dialer interface {
161 // Dial connects to the given address via the proxy.
162 Dial(network, addr string) (c net.Conn, err error)
163 }
164
165 // Auth contains authentication parameters that specific Dialers may require.
166 type proxy_Auth struct {
167 User, Password string
168 }
169
170 // FromEnvironment returns the dialer specified by the proxy related variables in
171 // the environment.
172 func proxy_FromEnvironment() proxy_Dialer {
173 allProxy := proxy_allProxyEnv.Get()
174 if len(allProxy) == 0 {
175 return proxy_Direct
176 }
177
178 proxyURL, err := url.Parse(allProxy)
179 if err != nil {
180 return proxy_Direct
181 }
182 proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
183 if err != nil {
184 return proxy_Direct
185 }
186
187 noProxy := proxy_noProxyEnv.Get()
188 if len(noProxy) == 0 {
189 return proxy
190 }
191
192 perHost := proxy_NewPerHost(proxy, proxy_Direct)
193 perHost.AddFromString(noProxy)
194 return perHost
195 }
196
197 // proxySchemes is a map from URL schemes to a function that creates a Dialer
198 // from a URL with such a scheme.
199 var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
200
201 // RegisterDialerType takes a URL scheme and a function to generate Dialers from
202 // a URL with that scheme and a forwarding Dialer. Registered schemes are used
203 // by FromURL.
204 func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
205 if proxy_proxySchemes == nil {
206 proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
207 }
208 proxy_proxySchemes[scheme] = f
209 }
210
211 // FromURL returns a Dialer given a URL specification and an underlying
212 // Dialer for it to make network requests.
213 func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
214 var auth *proxy_Auth
215 if u.User != nil {
216 auth = new(proxy_Auth)
217 auth.User = u.User.Username()
218 if p, ok := u.User.Password(); ok {
219 auth.Password = p
220 }
221 }
222
223 switch u.Scheme {
224 case "socks5":
225 return proxy_SOCKS5("tcp", u.Host, auth, forward)
226 }
227
228 // If the scheme doesn't match any of the built-in schemes, see if it
229 // was registered by another package.
230 if proxy_proxySchemes != nil {
231 if f, ok := proxy_proxySchemes[u.Scheme]; ok {
232 return f(u, forward)
233 }
234 }
235
236 return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
237 }
238
239 var (
240 proxy_allProxyEnv = &proxy_envOnce{
241 names: []string{"ALL_PROXY", "all_proxy"},
242 }
243 proxy_noProxyEnv = &proxy_envOnce{
244 names: []string{"NO_PROXY", "no_proxy"},
245 }
246 )
247
248 // envOnce looks up an environment variable (optionally by multiple
249 // names) once. It mitigates expensive lookups on some platforms
250 // (e.g. Windows).
251 // (Borrowed from net/http/transport.go)
252 type proxy_envOnce struct {
253 names []string
254 once sync.Once
255 val string
256 }
257
258 func (e *proxy_envOnce) Get() string {
259 e.once.Do(e.init)
260 return e.val
261 }
262
263 func (e *proxy_envOnce) init() {
264 for _, n := range e.names {
265 e.val = os.Getenv(n)
266 if e.val != "" {
267 return
268 }
269 }
270 }
271
272 // SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
273 // with an optional username and password. See RFC 1928 and RFC 1929.
274 func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
275 s := &proxy_socks5{
276 network: network,
277 addr: addr,
278 forward: forward,
279 }
280 if auth != nil {
281 s.user = auth.User
282 s.password = auth.Password
283 }
284
285 return s, nil
286 }
287
288 type proxy_socks5 struct {
289 user, password string
290 network, addr string
291 forward proxy_Dialer
292 }
293
294 const proxy_socks5Version = 5
295
296 const (
297 proxy_socks5AuthNone = 0
298 proxy_socks5AuthPassword = 2
299 )
300
301 const proxy_socks5Connect = 1
302
303 const (
304 proxy_socks5IP4 = 1
305 proxy_socks5Domain = 3
306 proxy_socks5IP6 = 4
307 )
308
309 var proxy_socks5Errors = []string{
310 "",
311 "general failure",
312 "connection forbidden",
313 "network unreachable",
314 "host unreachable",
315 "connection refused",
316 "TTL expired",
317 "command not supported",
318 "address type not supported",
319 }
320
321 // Dial connects to the address addr on the given network via the SOCKS5 proxy.
322 func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
323 switch network {
324 case "tcp", "tcp6", "tcp4":
325 default:
326 return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
327 }
328
329 conn, err := s.forward.Dial(s.network, s.addr)
330 if err != nil {
331 return nil, err
332 }
333 if err := s.connect(conn, addr); err != nil {
334 conn.Close()
335 return nil, err
336 }
337 return conn, nil
338 }
339
340 // connect takes an existing connection to a socks5 proxy server,
341 // and commands the server to extend that connection to target,
342 // which must be a canonical address with a host and port.
343 func (s *proxy_socks5) connect(conn net.Conn, target string) error {
344 host, portStr, err := net.SplitHostPort(target)
345 if err != nil {
346 return err
347 }
348
349 port, err := strconv.Atoi(portStr)
350 if err != nil {
351 return errors.New("proxy: failed to parse port number: " + portStr)
352 }
353 if port < 1 || port > 0xffff {
354 return errors.New("proxy: port number out of range: " + portStr)
355 }
356
357 // the size here is just an estimate
358 buf := make([]byte, 0, 6+len(host))
359
360 buf = append(buf, proxy_socks5Version)
361 if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
362 buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
363 } else {
364 buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
365 }
366
367 if _, err := conn.Write(buf); err != nil {
368 return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
369 }
370
371 if _, err := io.ReadFull(conn, buf[:2]); err != nil {
372 return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
373 }
374 if buf[0] != 5 {
375 return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
376 }
377 if buf[1] == 0xff {
378 return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
379 }
380
381 // See RFC 1929
382 if buf[1] == proxy_socks5AuthPassword {
383 buf = buf[:0]
384 buf = append(buf, 1 /* password protocol version */)
385 buf = append(buf, uint8(len(s.user)))
386 buf = append(buf, s.user...)
387 buf = append(buf, uint8(len(s.password)))
388 buf = append(buf, s.password...)
389
390 if _, err := conn.Write(buf); err != nil {
391 return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
392 }
393
394 if _, err := io.ReadFull(conn, buf[:2]); err != nil {
395 return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
396 }
397
398 if buf[1] != 0 {
399 return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
400 }
401 }
402
403 buf = buf[:0]
404 buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
405
406 if ip := net.ParseIP(host); ip != nil {
407 if ip4 := ip.To4(); ip4 != nil {
408 buf = append(buf, proxy_socks5IP4)
409 ip = ip4
410 } else {
411 buf = append(buf, proxy_socks5IP6)
412 }
413 buf = append(buf, ip...)
414 } else {
415 if len(host) > 255 {
416 return errors.New("proxy: destination host name too long: " + host)
417 }
418 buf = append(buf, proxy_socks5Domain)
419 buf = append(buf, byte(len(host)))
420 buf = append(buf, host...)
421 }
422 buf = append(buf, byte(port>>8), byte(port))
423
424 if _, err := conn.Write(buf); err != nil {
425 return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
426 }
427
428 if _, err := io.ReadFull(conn, buf[:4]); err != nil {
429 return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
430 }
431
432 failure := "unknown error"
433 if int(buf[1]) < len(proxy_socks5Errors) {
434 failure = proxy_socks5Errors[buf[1]]
435 }
436
437 if len(failure) > 0 {
438 return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
439 }
440
441 bytesToDiscard := 0
442 switch buf[3] {
443 case proxy_socks5IP4:
444 bytesToDiscard = net.IPv4len
445 case proxy_socks5IP6:
446 bytesToDiscard = net.IPv6len
447 case proxy_socks5Domain:
448 _, err := io.ReadFull(conn, buf[:1])
449 if err != nil {
450 return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
451 }
452 bytesToDiscard = int(buf[0])
453 default:
454 return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
455 }
456
457 if cap(buf) < bytesToDiscard {
458 buf = make([]byte, bytesToDiscard)
459 } else {
460 buf = buf[:bytesToDiscard]
461 }
462 if _, err := io.ReadFull(conn, buf); err != nil {
463 return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
464 }
465
466 // Also need to discard the port number
467 if _, err := io.ReadFull(conn, buf[:2]); err != nil {
468 return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
469 }
470
471 return nil
472 }