Update upstream source from tag 'upstream/1.4.0'
Update to upstream version '1.4.0'
with Debian dir c0565b70787791c81124f03c18c4625afd038cda
Anthony Fok
5 years ago
2 | 2 | |
3 | 3 | matrix: |
4 | 4 | include: |
5 | - go: 1.4 | |
6 | - go: 1.5 | |
7 | - go: 1.6 | |
8 | - go: 1.7 | |
9 | - go: 1.8 | |
5 | - go: 1.7.x | |
6 | - go: 1.8.x | |
7 | - go: 1.9.x | |
8 | - go: 1.10.x | |
9 | - go: 1.11.x | |
10 | 10 | - go: tip |
11 | 11 | allow_failures: |
12 | 12 | - go: tip |
3 | 3 | # Please keep the list sorted. |
4 | 4 | |
5 | 5 | Gary Burd <gary@beagledreams.com> |
6 | Google LLC (https://opensource.google.com/) | |
6 | 7 | Joachim Bauch <mail@joachim-bauch.de> |
7 | 8 |
50 | 50 | <tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr> |
51 | 51 | </table> |
52 | 52 | |
53 | Notes: | |
53 | Notes: | |
54 | 54 | |
55 | 55 | 1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html). |
56 | 56 | 2. The application can get the type of a received data message by implementing |
4 | 4 | package websocket |
5 | 5 | |
6 | 6 | import ( |
7 | "bufio" | |
8 | 7 | "bytes" |
8 | "context" | |
9 | 9 | "crypto/tls" |
10 | "encoding/base64" | |
11 | 10 | "errors" |
12 | 11 | "io" |
13 | 12 | "io/ioutil" |
14 | 13 | "net" |
15 | 14 | "net/http" |
15 | "net/http/httptrace" | |
16 | 16 | "net/url" |
17 | 17 | "strings" |
18 | 18 | "time" |
52 | 52 | // NetDial is nil, net.Dial is used. |
53 | 53 | NetDial func(network, addr string) (net.Conn, error) |
54 | 54 | |
55 | // NetDialContext specifies the dial function for creating TCP connections. If | |
56 | // NetDialContext is nil, net.DialContext is used. | |
57 | NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error) | |
58 | ||
55 | 59 | // Proxy specifies a function to return a proxy for a given |
56 | 60 | // Request. If the function returns a non-nil error, the |
57 | 61 | // request is aborted with the provided error. |
70 | 74 | // do not limit the size of the messages that can be sent or received. |
71 | 75 | ReadBufferSize, WriteBufferSize int |
72 | 76 | |
77 | // WriteBufferPool is a pool of buffers for write operations. If the value | |
78 | // is not set, then write buffers are allocated to the connection for the | |
79 | // lifetime of the connection. | |
80 | // | |
81 | // A pool is most useful when the application has a modest volume of writes | |
82 | // across a large number of connections. | |
83 | // | |
84 | // Applications should use a single pool for each unique value of | |
85 | // WriteBufferSize. | |
86 | WriteBufferPool BufferPool | |
87 | ||
73 | 88 | // Subprotocols specifies the client's requested subprotocols. |
74 | 89 | Subprotocols []string |
75 | 90 | |
85 | 100 | Jar http.CookieJar |
86 | 101 | } |
87 | 102 | |
103 | // Dial creates a new client connection by calling DialContext with a background context. | |
104 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { | |
105 | return d.DialContext(context.Background(), urlStr, requestHeader) | |
106 | } | |
107 | ||
88 | 108 | var errMalformedURL = errors.New("malformed ws or wss URL") |
89 | ||
90 | // parseURL parses the URL. | |
91 | // | |
92 | // This function is a replacement for the standard library url.Parse function. | |
93 | // In Go 1.4 and earlier, url.Parse loses information from the path. | |
94 | func parseURL(s string) (*url.URL, error) { | |
95 | // From the RFC: | |
96 | // | |
97 | // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] | |
98 | // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] | |
99 | var u url.URL | |
100 | switch { | |
101 | case strings.HasPrefix(s, "ws://"): | |
102 | u.Scheme = "ws" | |
103 | s = s[len("ws://"):] | |
104 | case strings.HasPrefix(s, "wss://"): | |
105 | u.Scheme = "wss" | |
106 | s = s[len("wss://"):] | |
107 | default: | |
108 | return nil, errMalformedURL | |
109 | } | |
110 | ||
111 | if i := strings.Index(s, "?"); i >= 0 { | |
112 | u.RawQuery = s[i+1:] | |
113 | s = s[:i] | |
114 | } | |
115 | ||
116 | if i := strings.Index(s, "/"); i >= 0 { | |
117 | u.Opaque = s[i:] | |
118 | s = s[:i] | |
119 | } else { | |
120 | u.Opaque = "/" | |
121 | } | |
122 | ||
123 | u.Host = s | |
124 | ||
125 | if strings.Contains(u.Host, "@") { | |
126 | // Don't bother parsing user information because user information is | |
127 | // not allowed in websocket URIs. | |
128 | return nil, errMalformedURL | |
129 | } | |
130 | ||
131 | return &u, nil | |
132 | } | |
133 | 109 | |
134 | 110 | func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { |
135 | 111 | hostPort = u.Host |
149 | 125 | return hostPort, hostNoPort |
150 | 126 | } |
151 | 127 | |
152 | // DefaultDialer is a dialer with all fields set to the default zero values. | |
128 | // DefaultDialer is a dialer with all fields set to the default values. | |
153 | 129 | var DefaultDialer = &Dialer{ |
154 | Proxy: http.ProxyFromEnvironment, | |
155 | } | |
156 | ||
157 | // Dial creates a new client connection. Use requestHeader to specify the | |
130 | Proxy: http.ProxyFromEnvironment, | |
131 | HandshakeTimeout: 45 * time.Second, | |
132 | } | |
133 | ||
134 | // nilDialer is dialer to use when receiver is nil. | |
135 | var nilDialer = *DefaultDialer | |
136 | ||
137 | // DialContext creates a new client connection. Use requestHeader to specify the | |
158 | 138 | // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). |
159 | 139 | // Use the response.Header to get the selected subprotocol |
160 | 140 | // (Sec-WebSocket-Protocol) and cookies (Set-Cookie). |
141 | // | |
142 | // The context will be used in the request and in the Dialer | |
161 | 143 | // |
162 | 144 | // If the WebSocket handshake fails, ErrBadHandshake is returned along with a |
163 | 145 | // non-nil *http.Response so that callers can handle redirects, authentication, |
164 | 146 | // etcetera. The response body may not contain the entire response and does not |
165 | 147 | // need to be closed by the application. |
166 | func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { | |
167 | ||
148 | func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { | |
168 | 149 | if d == nil { |
169 | d = &Dialer{ | |
170 | Proxy: http.ProxyFromEnvironment, | |
171 | } | |
150 | d = &nilDialer | |
172 | 151 | } |
173 | 152 | |
174 | 153 | challengeKey, err := generateChallengeKey() |
176 | 155 | return nil, nil, err |
177 | 156 | } |
178 | 157 | |
179 | u, err := parseURL(urlStr) | |
158 | u, err := url.Parse(urlStr) | |
180 | 159 | if err != nil { |
181 | 160 | return nil, nil, err |
182 | 161 | } |
204 | 183 | Header: make(http.Header), |
205 | 184 | Host: u.Host, |
206 | 185 | } |
186 | req = req.WithContext(ctx) | |
207 | 187 | |
208 | 188 | // Set the cookies present in the cookie jar of the dialer |
209 | 189 | if d.Jar != nil { |
236 | 216 | k == "Sec-Websocket-Extensions" || |
237 | 217 | (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): |
238 | 218 | return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) |
219 | case k == "Sec-Websocket-Protocol": | |
220 | req.Header["Sec-WebSocket-Protocol"] = vs | |
239 | 221 | default: |
240 | 222 | req.Header[k] = vs |
241 | 223 | } |
242 | 224 | } |
243 | 225 | |
244 | 226 | if d.EnableCompression { |
245 | req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") | |
227 | req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"} | |
228 | } | |
229 | ||
230 | if d.HandshakeTimeout != 0 { | |
231 | var cancel func() | |
232 | ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout) | |
233 | defer cancel() | |
234 | } | |
235 | ||
236 | // Get network dial function. | |
237 | var netDial func(network, add string) (net.Conn, error) | |
238 | ||
239 | if d.NetDialContext != nil { | |
240 | netDial = func(network, addr string) (net.Conn, error) { | |
241 | return d.NetDialContext(ctx, network, addr) | |
242 | } | |
243 | } else if d.NetDial != nil { | |
244 | netDial = d.NetDial | |
245 | } else { | |
246 | netDialer := &net.Dialer{} | |
247 | netDial = func(network, addr string) (net.Conn, error) { | |
248 | return netDialer.DialContext(ctx, network, addr) | |
249 | } | |
250 | } | |
251 | ||
252 | // If needed, wrap the dial function to set the connection deadline. | |
253 | if deadline, ok := ctx.Deadline(); ok { | |
254 | forwardDial := netDial | |
255 | netDial = func(network, addr string) (net.Conn, error) { | |
256 | c, err := forwardDial(network, addr) | |
257 | if err != nil { | |
258 | return nil, err | |
259 | } | |
260 | err = c.SetDeadline(deadline) | |
261 | if err != nil { | |
262 | c.Close() | |
263 | return nil, err | |
264 | } | |
265 | return c, nil | |
266 | } | |
267 | } | |
268 | ||
269 | // If needed, wrap the dial function to connect through a proxy. | |
270 | if d.Proxy != nil { | |
271 | proxyURL, err := d.Proxy(req) | |
272 | if err != nil { | |
273 | return nil, nil, err | |
274 | } | |
275 | if proxyURL != nil { | |
276 | dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial)) | |
277 | if err != nil { | |
278 | return nil, nil, err | |
279 | } | |
280 | netDial = dialer.Dial | |
281 | } | |
246 | 282 | } |
247 | 283 | |
248 | 284 | hostPort, hostNoPort := hostPortNoPort(u) |
249 | ||
250 | var proxyURL *url.URL | |
251 | // Check wether the proxy method has been configured | |
252 | if d.Proxy != nil { | |
253 | proxyURL, err = d.Proxy(req) | |
254 | } | |
255 | if err != nil { | |
256 | return nil, nil, err | |
257 | } | |
258 | ||
259 | var targetHostPort string | |
260 | if proxyURL != nil { | |
261 | targetHostPort, _ = hostPortNoPort(proxyURL) | |
262 | } else { | |
263 | targetHostPort = hostPort | |
264 | } | |
265 | ||
266 | var deadline time.Time | |
267 | if d.HandshakeTimeout != 0 { | |
268 | deadline = time.Now().Add(d.HandshakeTimeout) | |
269 | } | |
270 | ||
271 | netDial := d.NetDial | |
272 | if netDial == nil { | |
273 | netDialer := &net.Dialer{Deadline: deadline} | |
274 | netDial = netDialer.Dial | |
275 | } | |
276 | ||
277 | netConn, err := netDial("tcp", targetHostPort) | |
285 | trace := httptrace.ContextClientTrace(ctx) | |
286 | if trace != nil && trace.GetConn != nil { | |
287 | trace.GetConn(hostPort) | |
288 | } | |
289 | ||
290 | netConn, err := netDial("tcp", hostPort) | |
291 | if trace != nil && trace.GotConn != nil { | |
292 | trace.GotConn(httptrace.GotConnInfo{ | |
293 | Conn: netConn, | |
294 | }) | |
295 | } | |
278 | 296 | if err != nil { |
279 | 297 | return nil, nil, err |
280 | 298 | } |
284 | 302 | netConn.Close() |
285 | 303 | } |
286 | 304 | }() |
287 | ||
288 | if err := netConn.SetDeadline(deadline); err != nil { | |
289 | return nil, nil, err | |
290 | } | |
291 | ||
292 | if proxyURL != nil { | |
293 | connectHeader := make(http.Header) | |
294 | if user := proxyURL.User; user != nil { | |
295 | proxyUser := user.Username() | |
296 | if proxyPassword, passwordSet := user.Password(); passwordSet { | |
297 | credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) | |
298 | connectHeader.Set("Proxy-Authorization", "Basic "+credential) | |
299 | } | |
300 | } | |
301 | connectReq := &http.Request{ | |
302 | Method: "CONNECT", | |
303 | URL: &url.URL{Opaque: hostPort}, | |
304 | Host: hostPort, | |
305 | Header: connectHeader, | |
306 | } | |
307 | ||
308 | connectReq.Write(netConn) | |
309 | ||
310 | // Read response. | |
311 | // Okay to use and discard buffered reader here, because | |
312 | // TLS server will not speak until spoken to. | |
313 | br := bufio.NewReader(netConn) | |
314 | resp, err := http.ReadResponse(br, connectReq) | |
315 | if err != nil { | |
316 | return nil, nil, err | |
317 | } | |
318 | if resp.StatusCode != 200 { | |
319 | f := strings.SplitN(resp.Status, " ", 2) | |
320 | return nil, nil, errors.New(f[1]) | |
321 | } | |
322 | } | |
323 | 305 | |
324 | 306 | if u.Scheme == "https" { |
325 | 307 | cfg := cloneTLSConfig(d.TLSClientConfig) |
328 | 310 | } |
329 | 311 | tlsConn := tls.Client(netConn, cfg) |
330 | 312 | netConn = tlsConn |
331 | if err := tlsConn.Handshake(); err != nil { | |
313 | ||
314 | var err error | |
315 | if trace != nil { | |
316 | err = doHandshakeWithTrace(trace, tlsConn, cfg) | |
317 | } else { | |
318 | err = doHandshake(tlsConn, cfg) | |
319 | } | |
320 | ||
321 | if err != nil { | |
332 | 322 | return nil, nil, err |
333 | 323 | } |
334 | if !cfg.InsecureSkipVerify { | |
335 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { | |
336 | return nil, nil, err | |
337 | } | |
338 | } | |
339 | } | |
340 | ||
341 | conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) | |
324 | } | |
325 | ||
326 | conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil) | |
342 | 327 | |
343 | 328 | if err := req.Write(netConn); err != nil { |
344 | 329 | return nil, nil, err |
330 | } | |
331 | ||
332 | if trace != nil && trace.GotFirstResponseByte != nil { | |
333 | if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 { | |
334 | trace.GotFirstResponseByte() | |
335 | } | |
345 | 336 | } |
346 | 337 | |
347 | 338 | resp, err := http.ReadResponse(conn.br, req) |
389 | 380 | netConn = nil // to avoid close in defer. |
390 | 381 | return conn, resp, nil |
391 | 382 | } |
383 | ||
384 | func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error { | |
385 | if err := tlsConn.Handshake(); err != nil { | |
386 | return err | |
387 | } | |
388 | if !cfg.InsecureSkipVerify { | |
389 | if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { | |
390 | return err | |
391 | } | |
392 | } | |
393 | return nil | |
394 | } |
4 | 4 | package websocket |
5 | 5 | |
6 | 6 | import ( |
7 | "bytes" | |
8 | "context" | |
7 | 9 | "crypto/tls" |
8 | 10 | "crypto/x509" |
9 | 11 | "encoding/base64" |
12 | "encoding/binary" | |
10 | 13 | "io" |
11 | 14 | "io/ioutil" |
15 | "net" | |
12 | 16 | "net/http" |
13 | 17 | "net/http/cookiejar" |
14 | 18 | "net/http/httptest" |
19 | "net/http/httptrace" | |
15 | 20 | "net/url" |
16 | 21 | "reflect" |
17 | 22 | "strings" |
30 | 35 | } |
31 | 36 | |
32 | 37 | var cstDialer = Dialer{ |
38 | Subprotocols: []string{"p1", "p2"}, | |
39 | ReadBufferSize: 1024, | |
40 | WriteBufferSize: 1024, | |
41 | HandshakeTimeout: 30 * time.Second, | |
42 | } | |
43 | ||
44 | var cstDialerWithoutHandshakeTimeout = Dialer{ | |
33 | 45 | Subprotocols: []string{"p1", "p2"}, |
34 | 46 | ReadBufferSize: 1024, |
35 | 47 | WriteBufferSize: 1024, |
67 | 79 | func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
68 | 80 | if r.URL.Path != cstPath { |
69 | 81 | t.Logf("path=%v, want %v", r.URL.Path, cstPath) |
70 | http.Error(w, "bad path", 400) | |
82 | http.Error(w, "bad path", http.StatusBadRequest) | |
71 | 83 | return |
72 | 84 | } |
73 | 85 | if r.URL.RawQuery != cstRawQuery { |
74 | 86 | t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) |
75 | http.Error(w, "bad path", 400) | |
87 | http.Error(w, "bad path", http.StatusBadRequest) | |
76 | 88 | return |
77 | 89 | } |
78 | 90 | subprotos := Subprotocols(r) |
79 | 91 | if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { |
80 | 92 | t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) |
81 | http.Error(w, "bad protocol", 400) | |
93 | http.Error(w, "bad protocol", http.StatusBadRequest) | |
82 | 94 | return |
83 | 95 | } |
84 | 96 | ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}}) |
142 | 154 | s := newServer(t) |
143 | 155 | defer s.Close() |
144 | 156 | |
145 | surl, _ := url.Parse(s.URL) | |
146 | ||
157 | surl, _ := url.Parse(s.Server.URL) | |
158 | ||
159 | cstDialer := cstDialer // make local copy for modification on next line. | |
147 | 160 | cstDialer.Proxy = http.ProxyURL(surl) |
148 | 161 | |
149 | 162 | connect := false |
154 | 167 | func(w http.ResponseWriter, r *http.Request) { |
155 | 168 | if r.Method == "CONNECT" { |
156 | 169 | connect = true |
157 | w.WriteHeader(200) | |
170 | w.WriteHeader(http.StatusOK) | |
158 | 171 | return |
159 | 172 | } |
160 | 173 | |
161 | 174 | if !connect { |
162 | t.Log("connect not recieved") | |
163 | http.Error(w, "connect not recieved", 405) | |
175 | t.Log("connect not received") | |
176 | http.Error(w, "connect not received", http.StatusMethodNotAllowed) | |
164 | 177 | return |
165 | 178 | } |
166 | 179 | origHandler.ServeHTTP(w, r) |
172 | 185 | } |
173 | 186 | defer ws.Close() |
174 | 187 | sendRecv(t, ws) |
175 | ||
176 | cstDialer.Proxy = http.ProxyFromEnvironment | |
177 | 188 | } |
178 | 189 | |
179 | 190 | func TestProxyAuthorizationDial(t *testing.T) { |
180 | 191 | s := newServer(t) |
181 | 192 | defer s.Close() |
182 | 193 | |
183 | surl, _ := url.Parse(s.URL) | |
194 | surl, _ := url.Parse(s.Server.URL) | |
184 | 195 | surl.User = url.UserPassword("username", "password") |
196 | ||
197 | cstDialer := cstDialer // make local copy for modification on next line. | |
185 | 198 | cstDialer.Proxy = http.ProxyURL(surl) |
186 | 199 | |
187 | 200 | connect := false |
194 | 207 | expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) |
195 | 208 | if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { |
196 | 209 | connect = true |
197 | w.WriteHeader(200) | |
210 | w.WriteHeader(http.StatusOK) | |
198 | 211 | return |
199 | 212 | } |
200 | 213 | |
201 | 214 | if !connect { |
202 | t.Log("connect with proxy authorization not recieved") | |
203 | http.Error(w, "connect with proxy authorization not recieved", 405) | |
215 | t.Log("connect with proxy authorization not received") | |
216 | http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed) | |
204 | 217 | return |
205 | 218 | } |
206 | 219 | origHandler.ServeHTTP(w, r) |
212 | 225 | } |
213 | 226 | defer ws.Close() |
214 | 227 | sendRecv(t, ws) |
215 | ||
216 | cstDialer.Proxy = http.ProxyFromEnvironment | |
217 | 228 | } |
218 | 229 | |
219 | 230 | func TestDial(t *testing.T) { |
236 | 247 | d := cstDialer |
237 | 248 | d.Jar = jar |
238 | 249 | |
239 | u, _ := parseURL(s.URL) | |
250 | u, _ := url.Parse(s.URL) | |
240 | 251 | |
241 | 252 | switch u.Scheme { |
242 | 253 | case "ws": |
245 | 256 | u.Scheme = "https" |
246 | 257 | } |
247 | 258 | |
248 | cookies := []*http.Cookie{&http.Cookie{Name: "gorilla", Value: "ws", Path: "/"}} | |
259 | cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}} | |
249 | 260 | d.Jar.SetCookies(u, cookies) |
250 | 261 | |
251 | 262 | ws, _, err := d.Dial(s.URL, nil) |
338 | 349 | ws.Close() |
339 | 350 | t.Fatalf("Dial: nil") |
340 | 351 | } |
352 | } | |
353 | ||
354 | // requireDeadlineNetConn fails the current test when Read or Write are called | |
355 | // with no deadline. | |
356 | type requireDeadlineNetConn struct { | |
357 | t *testing.T | |
358 | c net.Conn | |
359 | readDeadlineIsSet bool | |
360 | writeDeadlineIsSet bool | |
361 | } | |
362 | ||
363 | func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error { | |
364 | c.writeDeadlineIsSet = !t.Equal(time.Time{}) | |
365 | c.readDeadlineIsSet = c.writeDeadlineIsSet | |
366 | return c.c.SetDeadline(t) | |
367 | } | |
368 | ||
369 | func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error { | |
370 | c.readDeadlineIsSet = !t.Equal(time.Time{}) | |
371 | return c.c.SetDeadline(t) | |
372 | } | |
373 | ||
374 | func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error { | |
375 | c.writeDeadlineIsSet = !t.Equal(time.Time{}) | |
376 | return c.c.SetDeadline(t) | |
377 | } | |
378 | ||
379 | func (c *requireDeadlineNetConn) Write(p []byte) (int, error) { | |
380 | if !c.writeDeadlineIsSet { | |
381 | c.t.Fatalf("write with no deadline") | |
382 | } | |
383 | return c.c.Write(p) | |
384 | } | |
385 | ||
386 | func (c *requireDeadlineNetConn) Read(p []byte) (int, error) { | |
387 | if !c.readDeadlineIsSet { | |
388 | c.t.Fatalf("read with no deadline") | |
389 | } | |
390 | return c.c.Read(p) | |
391 | } | |
392 | ||
393 | func (c *requireDeadlineNetConn) Close() error { return c.c.Close() } | |
394 | func (c *requireDeadlineNetConn) LocalAddr() net.Addr { return c.c.LocalAddr() } | |
395 | func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() } | |
396 | ||
397 | func TestHandshakeTimeout(t *testing.T) { | |
398 | s := newServer(t) | |
399 | defer s.Close() | |
400 | ||
401 | d := cstDialer | |
402 | d.NetDial = func(n, a string) (net.Conn, error) { | |
403 | c, err := net.Dial(n, a) | |
404 | return &requireDeadlineNetConn{c: c, t: t}, err | |
405 | } | |
406 | ws, _, err := d.Dial(s.URL, nil) | |
407 | if err != nil { | |
408 | t.Fatal("Dial:", err) | |
409 | } | |
410 | ws.Close() | |
411 | } | |
412 | ||
413 | func TestHandshakeTimeoutInContext(t *testing.T) { | |
414 | s := newServer(t) | |
415 | defer s.Close() | |
416 | ||
417 | d := cstDialerWithoutHandshakeTimeout | |
418 | d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) { | |
419 | netDialer := &net.Dialer{} | |
420 | c, err := netDialer.DialContext(ctx, n, a) | |
421 | return &requireDeadlineNetConn{c: c, t: t}, err | |
422 | } | |
423 | ||
424 | ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second)) | |
425 | defer cancel() | |
426 | ws, _, err := d.DialContext(ctx, s.URL, nil) | |
427 | if err != nil { | |
428 | t.Fatal("Dial:", err) | |
429 | } | |
430 | ws.Close() | |
341 | 431 | } |
342 | 432 | |
343 | 433 | func TestDialBadScheme(t *testing.T) { |
397 | 487 | })) |
398 | 488 | defer s.Close() |
399 | 489 | |
400 | resp, err := http.PostForm(s.URL, url.Values{}) | |
401 | if err != nil { | |
402 | t.Fatalf("PostForm returned error %v", err) | |
490 | req, err := http.NewRequest("POST", s.URL, strings.NewReader("")) | |
491 | if err != nil { | |
492 | t.Fatalf("NewRequest returned error %v", err) | |
493 | } | |
494 | req.Header.Set("Connection", "upgrade") | |
495 | req.Header.Set("Upgrade", "websocket") | |
496 | req.Header.Set("Sec-Websocket-Version", "13") | |
497 | ||
498 | resp, err := http.DefaultClient.Do(req) | |
499 | if err != nil { | |
500 | t.Fatalf("Do returned error %v", err) | |
403 | 501 | } |
404 | 502 | resp.Body.Close() |
405 | 503 | if resp.StatusCode != http.StatusMethodNotAllowed { |
509 | 607 | defer ws.Close() |
510 | 608 | sendRecv(t, ws) |
511 | 609 | } |
610 | ||
611 | func TestSocksProxyDial(t *testing.T) { | |
612 | s := newServer(t) | |
613 | defer s.Close() | |
614 | ||
615 | proxyListener, err := net.Listen("tcp", "127.0.0.1:0") | |
616 | if err != nil { | |
617 | t.Fatalf("listen failed: %v", err) | |
618 | } | |
619 | defer proxyListener.Close() | |
620 | go func() { | |
621 | c1, err := proxyListener.Accept() | |
622 | if err != nil { | |
623 | t.Errorf("proxy accept failed: %v", err) | |
624 | return | |
625 | } | |
626 | defer c1.Close() | |
627 | ||
628 | c1.SetDeadline(time.Now().Add(30 * time.Second)) | |
629 | ||
630 | buf := make([]byte, 32) | |
631 | if _, err := io.ReadFull(c1, buf[:3]); err != nil { | |
632 | t.Errorf("read failed: %v", err) | |
633 | return | |
634 | } | |
635 | if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) { | |
636 | t.Errorf("read %x, want %x", buf[:len(want)], want) | |
637 | } | |
638 | if _, err := c1.Write([]byte{5, 0}); err != nil { | |
639 | t.Errorf("write failed: %v", err) | |
640 | return | |
641 | } | |
642 | if _, err := io.ReadFull(c1, buf[:10]); err != nil { | |
643 | t.Errorf("read failed: %v", err) | |
644 | return | |
645 | } | |
646 | if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) { | |
647 | t.Errorf("read %x, want %x", buf[:len(want)], want) | |
648 | return | |
649 | } | |
650 | buf[1] = 0 | |
651 | if _, err := c1.Write(buf[:10]); err != nil { | |
652 | t.Errorf("write failed: %v", err) | |
653 | return | |
654 | } | |
655 | ||
656 | ip := net.IP(buf[4:8]) | |
657 | port := binary.BigEndian.Uint16(buf[8:10]) | |
658 | ||
659 | c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)}) | |
660 | if err != nil { | |
661 | t.Errorf("dial failed; %v", err) | |
662 | return | |
663 | } | |
664 | defer c2.Close() | |
665 | done := make(chan struct{}) | |
666 | go func() { | |
667 | io.Copy(c1, c2) | |
668 | close(done) | |
669 | }() | |
670 | io.Copy(c2, c1) | |
671 | <-done | |
672 | }() | |
673 | ||
674 | purl, err := url.Parse("socks5://" + proxyListener.Addr().String()) | |
675 | if err != nil { | |
676 | t.Fatalf("parse failed: %v", err) | |
677 | } | |
678 | ||
679 | cstDialer := cstDialer // make local copy for modification on next line. | |
680 | cstDialer.Proxy = http.ProxyURL(purl) | |
681 | ||
682 | ws, _, err := cstDialer.Dial(s.URL, nil) | |
683 | if err != nil { | |
684 | t.Fatalf("Dial: %v", err) | |
685 | } | |
686 | defer ws.Close() | |
687 | sendRecv(t, ws) | |
688 | } | |
689 | ||
690 | func TestTracingDialWithContext(t *testing.T) { | |
691 | ||
692 | var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool | |
693 | trace := &httptrace.ClientTrace{ | |
694 | WroteHeaders: func() { | |
695 | headersWrote = true | |
696 | }, | |
697 | WroteRequest: func(httptrace.WroteRequestInfo) { | |
698 | requestWrote = true | |
699 | }, | |
700 | GetConn: func(hostPort string) { | |
701 | getConn = true | |
702 | }, | |
703 | GotConn: func(info httptrace.GotConnInfo) { | |
704 | gotConn = true | |
705 | }, | |
706 | ConnectDone: func(network, addr string, err error) { | |
707 | connectDone = true | |
708 | }, | |
709 | GotFirstResponseByte: func() { | |
710 | gotFirstResponseByte = true | |
711 | }, | |
712 | } | |
713 | ctx := httptrace.WithClientTrace(context.Background(), trace) | |
714 | ||
715 | s := newTLSServer(t) | |
716 | defer s.Close() | |
717 | ||
718 | certs := x509.NewCertPool() | |
719 | for _, c := range s.TLS.Certificates { | |
720 | roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) | |
721 | if err != nil { | |
722 | t.Fatalf("error parsing server's root cert: %v", err) | |
723 | } | |
724 | for _, root := range roots { | |
725 | certs.AddCert(root) | |
726 | } | |
727 | } | |
728 | ||
729 | d := cstDialer | |
730 | d.TLSClientConfig = &tls.Config{RootCAs: certs} | |
731 | ||
732 | ws, _, err := d.DialContext(ctx, s.URL, nil) | |
733 | if err != nil { | |
734 | t.Fatalf("Dial: %v", err) | |
735 | } | |
736 | ||
737 | if !headersWrote { | |
738 | t.Fatal("Headers was not written") | |
739 | } | |
740 | if !requestWrote { | |
741 | t.Fatal("Request was not written") | |
742 | } | |
743 | if !getConn { | |
744 | t.Fatal("getConn was not called") | |
745 | } | |
746 | if !gotConn { | |
747 | t.Fatal("gotConn was not called") | |
748 | } | |
749 | if !connectDone { | |
750 | t.Fatal("connectDone was not called") | |
751 | } | |
752 | if !gotFirstResponseByte { | |
753 | t.Fatal("GotFirstResponseByte was not called") | |
754 | } | |
755 | ||
756 | defer ws.Close() | |
757 | sendRecv(t, ws) | |
758 | } | |
759 | ||
760 | func TestEmptyTracingDialWithContext(t *testing.T) { | |
761 | ||
762 | trace := &httptrace.ClientTrace{} | |
763 | ctx := httptrace.WithClientTrace(context.Background(), trace) | |
764 | ||
765 | s := newTLSServer(t) | |
766 | defer s.Close() | |
767 | ||
768 | certs := x509.NewCertPool() | |
769 | for _, c := range s.TLS.Certificates { | |
770 | roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) | |
771 | if err != nil { | |
772 | t.Fatalf("error parsing server's root cert: %v", err) | |
773 | } | |
774 | for _, root := range roots { | |
775 | certs.AddCert(root) | |
776 | } | |
777 | } | |
778 | ||
779 | d := cstDialer | |
780 | d.TLSClientConfig = &tls.Config{RootCAs: certs} | |
781 | ||
782 | ws, _, err := d.DialContext(ctx, s.URL, nil) | |
783 | if err != nil { | |
784 | t.Fatalf("Dial: %v", err) | |
785 | } | |
786 | ||
787 | defer ws.Close() | |
788 | sendRecv(t, ws) | |
789 | } |
5 | 5 | |
6 | 6 | import ( |
7 | 7 | "net/url" |
8 | "reflect" | |
9 | 8 | "testing" |
10 | 9 | ) |
11 | ||
12 | var parseURLTests = []struct { | |
13 | s string | |
14 | u *url.URL | |
15 | rui string | |
16 | }{ | |
17 | {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, | |
18 | {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, | |
19 | {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, | |
20 | {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, | |
21 | {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, | |
22 | {"ss://example.com/a/b", nil, ""}, | |
23 | {"ws://webmaster@example.com/", nil, ""}, | |
24 | {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, | |
25 | {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, | |
26 | } | |
27 | ||
28 | func TestParseURL(t *testing.T) { | |
29 | for _, tt := range parseURLTests { | |
30 | u, err := parseURL(tt.s) | |
31 | if tt.u != nil && err != nil { | |
32 | t.Errorf("parseURL(%q) returned error %v", tt.s, err) | |
33 | continue | |
34 | } | |
35 | if tt.u == nil { | |
36 | if err == nil { | |
37 | t.Errorf("parseURL(%q) did not return error", tt.s) | |
38 | } | |
39 | continue | |
40 | } | |
41 | if !reflect.DeepEqual(u, tt.u) { | |
42 | t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) | |
43 | continue | |
44 | } | |
45 | if u.RequestURI() != tt.rui { | |
46 | t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) | |
47 | } | |
48 | } | |
49 | } | |
50 | 10 | |
51 | 11 | var hostPortNoPortTests = []struct { |
52 | 12 | u *url.URL |
42 | 42 | |
43 | 43 | func BenchmarkWriteNoCompression(b *testing.B) { |
44 | 44 | w := ioutil.Discard |
45 | c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) | |
45 | c := newTestConn(nil, w, false) | |
46 | 46 | messages := textMessages(100) |
47 | 47 | b.ResetTimer() |
48 | 48 | for i := 0; i < b.N; i++ { |
53 | 53 | |
54 | 54 | func BenchmarkWriteWithCompression(b *testing.B) { |
55 | 55 | w := ioutil.Discard |
56 | c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) | |
56 | c := newTestConn(nil, w, false) | |
57 | 57 | messages := textMessages(100) |
58 | 58 | c.enableWriteCompression = true |
59 | 59 | c.newCompressionWriter = compressNoContextTakeover |
65 | 65 | } |
66 | 66 | |
67 | 67 | func TestValidCompressionLevel(t *testing.T) { |
68 | c := newConn(fakeNetConn{}, false, 1024, 1024) | |
68 | c := newTestConn(nil, nil, false) | |
69 | 69 | for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { |
70 | 70 | if err := c.SetCompressionLevel(level); err == nil { |
71 | 71 | t.Errorf("no error for level %d", level) |
75 | 75 | // is UTF-8 encoded text. |
76 | 76 | PingMessage = 9 |
77 | 77 | |
78 | // PongMessage denotes a ping control message. The optional message payload | |
78 | // PongMessage denotes a pong control message. The optional message payload | |
79 | 79 | // is UTF-8 encoded text. |
80 | 80 | PongMessage = 10 |
81 | 81 | ) |
99 | 99 | func (e *netError) Temporary() bool { return e.temporary } |
100 | 100 | func (e *netError) Timeout() bool { return e.timeout } |
101 | 101 | |
102 | // CloseError represents close frame. | |
102 | // CloseError represents a close message. | |
103 | 103 | type CloseError struct { |
104 | ||
105 | 104 | // Code is defined in RFC 6455, section 11.7. |
106 | 105 | Code int |
107 | 106 | |
223 | 222 | return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) |
224 | 223 | } |
225 | 224 | |
225 | // BufferPool represents a pool of buffers. The *sync.Pool type satisfies this | |
226 | // interface. The type of the value stored in a pool is not specified. | |
227 | type BufferPool interface { | |
228 | // Get gets a value from the pool or returns nil if the pool is empty. | |
229 | Get() interface{} | |
230 | // Put adds a value to the pool. | |
231 | Put(interface{}) | |
232 | } | |
233 | ||
234 | // writePoolData is the type added to the write buffer pool. This wrapper is | |
235 | // used to prevent applications from peeking at and depending on the values | |
236 | // added to the pool. | |
237 | type writePoolData struct{ buf []byte } | |
238 | ||
226 | 239 | // The Conn type represents a WebSocket connection. |
227 | 240 | type Conn struct { |
228 | 241 | conn net.Conn |
232 | 245 | // Write fields |
233 | 246 | mu chan bool // used as mutex to protect write to conn |
234 | 247 | writeBuf []byte // frame is constructed in this buffer. |
248 | writePool BufferPool | |
249 | writeBufSize int | |
235 | 250 | writeDeadline time.Time |
236 | 251 | writer io.WriteCloser // the current writer returned to the application |
237 | 252 | isWriting bool // for best-effort concurrent write detection |
263 | 278 | newDecompressionReader func(io.Reader) io.ReadCloser |
264 | 279 | } |
265 | 280 | |
266 | func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { | |
267 | return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil) | |
268 | } | |
269 | ||
270 | type writeHook struct { | |
271 | p []byte | |
272 | } | |
273 | ||
274 | func (wh *writeHook) Write(p []byte) (int, error) { | |
275 | wh.p = p | |
276 | return len(p), nil | |
277 | } | |
278 | ||
279 | func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn { | |
280 | mu := make(chan bool, 1) | |
281 | mu <- true | |
282 | ||
283 | var br *bufio.Reader | |
284 | if readBufferSize == 0 && brw != nil && brw.Reader != nil { | |
285 | // Reuse the supplied bufio.Reader if the buffer has a useful size. | |
286 | // This code assumes that peek on a reader returns | |
287 | // bufio.Reader.buf[:0]. | |
288 | brw.Reader.Reset(conn) | |
289 | if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 { | |
290 | br = brw.Reader | |
291 | } | |
292 | } | |
281 | func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { | |
282 | ||
293 | 283 | if br == nil { |
294 | 284 | if readBufferSize == 0 { |
295 | 285 | readBufferSize = defaultReadBufferSize |
296 | } | |
297 | if readBufferSize < maxControlFramePayloadSize { | |
286 | } else if readBufferSize < maxControlFramePayloadSize { | |
287 | // must be large enough for control frame | |
298 | 288 | readBufferSize = maxControlFramePayloadSize |
299 | 289 | } |
300 | 290 | br = bufio.NewReaderSize(conn, readBufferSize) |
301 | 291 | } |
302 | 292 | |
303 | var writeBuf []byte | |
304 | if writeBufferSize == 0 && brw != nil && brw.Writer != nil { | |
305 | // Use the bufio.Writer's buffer if the buffer has a useful size. This | |
306 | // code assumes that bufio.Writer.buf[:1] is passed to the | |
307 | // bufio.Writer's underlying writer. | |
308 | var wh writeHook | |
309 | brw.Writer.Reset(&wh) | |
310 | brw.Writer.WriteByte(0) | |
311 | brw.Flush() | |
312 | if cap(wh.p) >= maxFrameHeaderSize+256 { | |
313 | writeBuf = wh.p[:cap(wh.p)] | |
314 | } | |
315 | } | |
316 | ||
317 | if writeBuf == nil { | |
318 | if writeBufferSize == 0 { | |
319 | writeBufferSize = defaultWriteBufferSize | |
320 | } | |
321 | writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize) | |
322 | } | |
323 | ||
293 | if writeBufferSize <= 0 { | |
294 | writeBufferSize = defaultWriteBufferSize | |
295 | } | |
296 | writeBufferSize += maxFrameHeaderSize | |
297 | ||
298 | if writeBuf == nil && writeBufferPool == nil { | |
299 | writeBuf = make([]byte, writeBufferSize) | |
300 | } | |
301 | ||
302 | mu := make(chan bool, 1) | |
303 | mu <- true | |
324 | 304 | c := &Conn{ |
325 | 305 | isServer: isServer, |
326 | 306 | br: br, |
328 | 308 | mu: mu, |
329 | 309 | readFinal: true, |
330 | 310 | writeBuf: writeBuf, |
311 | writePool: writeBufferPool, | |
312 | writeBufSize: writeBufferSize, | |
331 | 313 | enableWriteCompression: true, |
332 | 314 | compressionLevel: defaultCompressionLevel, |
333 | 315 | } |
342 | 324 | return c.subprotocol |
343 | 325 | } |
344 | 326 | |
345 | // Close closes the underlying network connection without sending or waiting for a close frame. | |
327 | // Close closes the underlying network connection without sending or waiting | |
328 | // for a close message. | |
346 | 329 | func (c *Conn) Close() error { |
347 | 330 | return c.conn.Close() |
348 | 331 | } |
369 | 352 | return err |
370 | 353 | } |
371 | 354 | |
372 | func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { | |
355 | func (c *Conn) read(n int) ([]byte, error) { | |
356 | p, err := c.br.Peek(n) | |
357 | if err == io.EOF { | |
358 | err = errUnexpectedEOF | |
359 | } | |
360 | c.br.Discard(len(p)) | |
361 | return p, err | |
362 | } | |
363 | ||
364 | func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { | |
373 | 365 | <-c.mu |
374 | 366 | defer func() { c.mu <- true }() |
375 | 367 | |
381 | 373 | } |
382 | 374 | |
383 | 375 | c.conn.SetWriteDeadline(deadline) |
384 | for _, buf := range bufs { | |
385 | if len(buf) > 0 { | |
386 | _, err := c.conn.Write(buf) | |
387 | if err != nil { | |
388 | return c.writeFatal(err) | |
389 | } | |
390 | } | |
391 | } | |
392 | ||
376 | if len(buf1) == 0 { | |
377 | _, err = c.conn.Write(buf0) | |
378 | } else { | |
379 | err = c.writeBufs(buf0, buf1) | |
380 | } | |
381 | if err != nil { | |
382 | return c.writeFatal(err) | |
383 | } | |
393 | 384 | if frameType == CloseMessage { |
394 | 385 | c.writeFatal(ErrCloseSent) |
395 | 386 | } |
475 | 466 | c.writeErrMu.Lock() |
476 | 467 | err := c.writeErr |
477 | 468 | c.writeErrMu.Unlock() |
478 | return err | |
469 | if err != nil { | |
470 | return err | |
471 | } | |
472 | ||
473 | if c.writeBuf == nil { | |
474 | wpd, ok := c.writePool.Get().(writePoolData) | |
475 | if ok { | |
476 | c.writeBuf = wpd.buf | |
477 | } else { | |
478 | c.writeBuf = make([]byte, c.writeBufSize) | |
479 | } | |
480 | } | |
481 | return nil | |
479 | 482 | } |
480 | 483 | |
481 | 484 | // NextWriter returns a writer for the next message to send. The writer's Close |
483 | 486 | // |
484 | 487 | // There can be at most one open writer on a connection. NextWriter closes the |
485 | 488 | // previous writer if the application has not already done so. |
489 | // | |
490 | // All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and | |
491 | // PongMessage) are supported. | |
486 | 492 | func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { |
487 | 493 | if err := c.prepWrite(messageType); err != nil { |
488 | 494 | return nil, err |
598 | 604 | |
599 | 605 | if final { |
600 | 606 | c.writer = nil |
607 | if c.writePool != nil { | |
608 | c.writePool.Put(writePoolData{buf: c.writeBuf}) | |
609 | c.writeBuf = nil | |
610 | } | |
601 | 611 | return nil |
602 | 612 | } |
603 | 613 | |
763 | 773 | // Read methods |
764 | 774 | |
765 | 775 | func (c *Conn) advanceFrame() (int, error) { |
766 | ||
767 | 776 | // 1. Skip remainder of previous frame. |
768 | 777 | |
769 | 778 | if c.readRemaining > 0 { |
1032 | 1041 | } |
1033 | 1042 | |
1034 | 1043 | // SetReadLimit sets the maximum size for a message read from the peer. If a |
1035 | // message exceeds the limit, the connection sends a close frame to the peer | |
1044 | // message exceeds the limit, the connection sends a close message to the peer | |
1036 | 1045 | // and returns ErrReadLimit to the application. |
1037 | 1046 | func (c *Conn) SetReadLimit(limit int64) { |
1038 | 1047 | c.readLimit = limit |
1045 | 1054 | |
1046 | 1055 | // SetCloseHandler sets the handler for close messages received from the peer. |
1047 | 1056 | // The code argument to h is the received close code or CloseNoStatusReceived |
1048 | // if the close message is empty. The default close handler sends a close frame | |
1049 | // back to the peer. | |
1057 | // if the close message is empty. The default close handler sends a close | |
1058 | // message back to the peer. | |
1050 | 1059 | // |
1051 | // The application must read the connection to process close messages as | |
1052 | // described in the section on Control Frames above. | |
1060 | // The handler function is called from the NextReader, ReadMessage and message | |
1061 | // reader Read methods. The application must read the connection to process | |
1062 | // close messages as described in the section on Control Messages above. | |
1053 | 1063 | // |
1054 | // The connection read methods return a CloseError when a close frame is | |
1064 | // The connection read methods return a CloseError when a close message is | |
1055 | 1065 | // received. Most applications should handle close messages as part of their |
1056 | 1066 | // normal error handling. Applications should only set a close handler when the |
1057 | // application must perform some action before sending a close frame back to | |
1067 | // application must perform some action before sending a close message back to | |
1058 | 1068 | // the peer. |
1059 | 1069 | func (c *Conn) SetCloseHandler(h func(code int, text string) error) { |
1060 | 1070 | if h == nil { |
1061 | 1071 | h = func(code int, text string) error { |
1062 | message := []byte{} | |
1063 | if code != CloseNoStatusReceived { | |
1064 | message = FormatCloseMessage(code, "") | |
1065 | } | |
1072 | message := FormatCloseMessage(code, "") | |
1066 | 1073 | c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) |
1067 | 1074 | return nil |
1068 | 1075 | } |
1076 | 1083 | } |
1077 | 1084 | |
1078 | 1085 | // SetPingHandler sets the handler for ping messages received from the peer. |
1079 | // The appData argument to h is the PING frame application data. The default | |
1086 | // The appData argument to h is the PING message application data. The default | |
1080 | 1087 | // ping handler sends a pong to the peer. |
1081 | 1088 | // |
1082 | // The application must read the connection to process ping messages as | |
1083 | // described in the section on Control Frames above. | |
1089 | // The handler function is called from the NextReader, ReadMessage and message | |
1090 | // reader Read methods. The application must read the connection to process | |
1091 | // ping messages as described in the section on Control Messages above. | |
1084 | 1092 | func (c *Conn) SetPingHandler(h func(appData string) error) { |
1085 | 1093 | if h == nil { |
1086 | 1094 | h = func(message string) error { |
1102 | 1110 | } |
1103 | 1111 | |
1104 | 1112 | // SetPongHandler sets the handler for pong messages received from the peer. |
1105 | // The appData argument to h is the PONG frame application data. The default | |
1113 | // The appData argument to h is the PONG message application data. The default | |
1106 | 1114 | // pong handler does nothing. |
1107 | 1115 | // |
1108 | // The application must read the connection to process ping messages as | |
1109 | // described in the section on Control Frames above. | |
1116 | // The handler function is called from the NextReader, ReadMessage and message | |
1117 | // reader Read methods. The application must read the connection to process | |
1118 | // pong messages as described in the section on Control Messages above. | |
1110 | 1119 | func (c *Conn) SetPongHandler(h func(appData string) error) { |
1111 | 1120 | if h == nil { |
1112 | 1121 | h = func(string) error { return nil } |
1140 | 1149 | } |
1141 | 1150 | |
1142 | 1151 | // FormatCloseMessage formats closeCode and text as a WebSocket close message. |
1152 | // An empty message is returned for code CloseNoStatusReceived. | |
1143 | 1153 | func FormatCloseMessage(closeCode int, text string) []byte { |
1154 | if closeCode == CloseNoStatusReceived { | |
1155 | // Return empty message because it's illegal to send | |
1156 | // CloseNoStatusReceived. Return non-nil value in case application | |
1157 | // checks for nil. | |
1158 | return []byte{} | |
1159 | } | |
1144 | 1160 | buf := make([]byte, 2+len(text)) |
1145 | 1161 | binary.BigEndian.PutUint16(buf, uint16(closeCode)) |
1146 | 1162 | copy(buf[2:], text) |
0 | 0 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. |
1 | 1 | // Use of this source code is governed by a BSD-style |
2 | 2 | // license that can be found in the LICENSE file. |
3 | ||
4 | // +build go1.7 | |
5 | 3 | |
6 | 4 | package websocket |
7 | 5 | |
69 | 67 | conns := make([]*broadcastConn, numConns) |
70 | 68 | |
71 | 69 | for i := 0; i < numConns; i++ { |
72 | c := newConn(fakeNetConn{Reader: nil, Writer: b.w}, true, 1024, 1024) | |
70 | c := newTestConn(nil, b.w, true) | |
73 | 71 | if b.compression { |
74 | 72 | c.enableWriteCompression = true |
75 | 73 | c.newCompressionWriter = compressNoContextTakeover |
0 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. | |
1 | // Use of this source code is governed by a BSD-style | |
2 | // license that can be found in the LICENSE file. | |
3 | ||
4 | // +build go1.5 | |
5 | ||
6 | package websocket | |
7 | ||
8 | import "io" | |
9 | ||
10 | func (c *Conn) read(n int) ([]byte, error) { | |
11 | p, err := c.br.Peek(n) | |
12 | if err == io.EOF { | |
13 | err = errUnexpectedEOF | |
14 | } | |
15 | c.br.Discard(len(p)) | |
16 | return p, err | |
17 | } |
0 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. | |
1 | // Use of this source code is governed by a BSD-style | |
2 | // license that can be found in the LICENSE file. | |
3 | ||
4 | // +build !go1.5 | |
5 | ||
6 | package websocket | |
7 | ||
8 | import "io" | |
9 | ||
10 | func (c *Conn) read(n int) ([]byte, error) { | |
11 | p, err := c.br.Peek(n) | |
12 | if err == io.EOF { | |
13 | err = errUnexpectedEOF | |
14 | } | |
15 | if len(p) > 0 { | |
16 | // advance over the bytes just read | |
17 | io.ReadFull(c.br, p) | |
18 | } | |
19 | return p, err | |
20 | } |
12 | 12 | "io/ioutil" |
13 | 13 | "net" |
14 | 14 | "reflect" |
15 | "sync" | |
15 | 16 | "testing" |
16 | 17 | "testing/iotest" |
17 | 18 | "time" |
44 | 45 | |
45 | 46 | func (a fakeAddr) String() string { |
46 | 47 | return "str" |
48 | } | |
49 | ||
50 | // newTestConn creates a connnection backed by a fake network connection using | |
51 | // default values for buffering. | |
52 | func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { | |
53 | return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) | |
47 | 54 | } |
48 | 55 | |
49 | 56 | func TestFraming(t *testing.T) { |
81 | 88 | for _, chunker := range readChunkers { |
82 | 89 | |
83 | 90 | var connBuf bytes.Buffer |
84 | wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) | |
85 | rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024) | |
91 | wc := newTestConn(nil, &connBuf, isServer) | |
92 | rc := newTestConn(chunker.f(&connBuf), nil, !isServer) | |
86 | 93 | if compress { |
87 | 94 | wc.newCompressionWriter = compressNoContextTakeover |
88 | 95 | rc.newDecompressionReader = decompressNoContextTakeover |
142 | 149 | for _, isWriteControl := range []bool{true, false} { |
143 | 150 | name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) |
144 | 151 | var connBuf bytes.Buffer |
145 | wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024) | |
146 | rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024) | |
152 | wc := newTestConn(nil, &connBuf, isServer) | |
153 | rc := newTestConn(&connBuf, nil, !isServer) | |
147 | 154 | if isWriteControl { |
148 | 155 | wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) |
149 | 156 | } else { |
172 | 179 | } |
173 | 180 | } |
174 | 181 | |
182 | // simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. | |
183 | type simpleBufferPool struct { | |
184 | v interface{} | |
185 | } | |
186 | ||
187 | func (p *simpleBufferPool) Get() interface{} { | |
188 | v := p.v | |
189 | p.v = nil | |
190 | return v | |
191 | } | |
192 | ||
193 | func (p *simpleBufferPool) Put(v interface{}) { | |
194 | p.v = v | |
195 | } | |
196 | ||
197 | func TestWriteBufferPool(t *testing.T) { | |
198 | var buf bytes.Buffer | |
199 | var pool simpleBufferPool | |
200 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) | |
201 | rc := newTestConn(&buf, nil, false) | |
202 | ||
203 | if wc.writeBuf != nil { | |
204 | t.Fatal("writeBuf not nil after create") | |
205 | } | |
206 | ||
207 | // Part 1: test NextWriter/Write/Close | |
208 | ||
209 | w, err := wc.NextWriter(TextMessage) | |
210 | if err != nil { | |
211 | t.Fatalf("wc.NextWriter() returned %v", err) | |
212 | } | |
213 | ||
214 | if wc.writeBuf == nil { | |
215 | t.Fatal("writeBuf is nil after NextWriter") | |
216 | } | |
217 | ||
218 | writeBufAddr := &wc.writeBuf[0] | |
219 | ||
220 | const message = "Hello World!" | |
221 | ||
222 | if _, err := io.WriteString(w, message); err != nil { | |
223 | t.Fatalf("io.WriteString(w, message) returned %v", err) | |
224 | } | |
225 | ||
226 | if err := w.Close(); err != nil { | |
227 | t.Fatalf("w.Close() returned %v", err) | |
228 | } | |
229 | ||
230 | if wc.writeBuf != nil { | |
231 | t.Fatal("writeBuf not nil after w.Close()") | |
232 | } | |
233 | ||
234 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { | |
235 | t.Fatal("writeBuf not returned to pool") | |
236 | } | |
237 | ||
238 | opCode, p, err := rc.ReadMessage() | |
239 | if opCode != TextMessage || err != nil { | |
240 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) | |
241 | } | |
242 | ||
243 | if s := string(p); s != message { | |
244 | t.Fatalf("message is %s, want %s", s, message) | |
245 | } | |
246 | ||
247 | // Part 2: Test WriteMessage. | |
248 | ||
249 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { | |
250 | t.Fatalf("wc.WriteMessage() returned %v", err) | |
251 | } | |
252 | ||
253 | if wc.writeBuf != nil { | |
254 | t.Fatal("writeBuf not nil after wc.WriteMessage()") | |
255 | } | |
256 | ||
257 | if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { | |
258 | t.Fatal("writeBuf not returned to pool after WriteMessage") | |
259 | } | |
260 | ||
261 | opCode, p, err = rc.ReadMessage() | |
262 | if opCode != TextMessage || err != nil { | |
263 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) | |
264 | } | |
265 | ||
266 | if s := string(p); s != message { | |
267 | t.Fatalf("message is %s, want %s", s, message) | |
268 | } | |
269 | } | |
270 | ||
271 | func TestWriteBufferPoolSync(t *testing.T) { | |
272 | var buf bytes.Buffer | |
273 | var pool sync.Pool | |
274 | wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) | |
275 | rc := newTestConn(&buf, nil, false) | |
276 | ||
277 | const message = "Hello World!" | |
278 | for i := 0; i < 3; i++ { | |
279 | if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { | |
280 | t.Fatalf("wc.WriteMessage() returned %v", err) | |
281 | } | |
282 | opCode, p, err := rc.ReadMessage() | |
283 | if opCode != TextMessage || err != nil { | |
284 | t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) | |
285 | } | |
286 | if s := string(p); s != message { | |
287 | t.Fatalf("message is %s, want %s", s, message) | |
288 | } | |
289 | } | |
290 | } | |
291 | ||
175 | 292 | func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { |
176 | 293 | const bufSize = 512 |
177 | 294 | |
178 | 295 | expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} |
179 | 296 | |
180 | 297 | var b1, b2 bytes.Buffer |
181 | wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) | |
182 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) | |
298 | wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) | |
299 | rc := newTestConn(&b1, &b2, true) | |
183 | 300 | |
184 | 301 | w, _ := wc.NextWriter(BinaryMessage) |
185 | 302 | w.Write(make([]byte, bufSize+bufSize/2)) |
205 | 322 | |
206 | 323 | for n := 0; ; n++ { |
207 | 324 | var b bytes.Buffer |
208 | wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024) | |
209 | rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024) | |
325 | wc := newTestConn(nil, &b, false) | |
326 | rc := newTestConn(&b, nil, true) | |
210 | 327 | |
211 | 328 | w, _ := wc.NextWriter(BinaryMessage) |
212 | 329 | w.Write(make([]byte, bufSize)) |
239 | 356 | const bufSize = 512 |
240 | 357 | |
241 | 358 | var b1, b2 bytes.Buffer |
242 | wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) | |
243 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) | |
359 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) | |
360 | rc := newTestConn(&b1, &b2, true) | |
244 | 361 | |
245 | 362 | w, _ := wc.NextWriter(BinaryMessage) |
246 | 363 | w.Write(make([]byte, bufSize+bufSize/2)) |
260 | 377 | } |
261 | 378 | |
262 | 379 | func TestWriteAfterMessageWriterClose(t *testing.T) { |
263 | wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024) | |
380 | wc := newTestConn(nil, &bytes.Buffer{}, false) | |
264 | 381 | w, _ := wc.NextWriter(BinaryMessage) |
265 | 382 | io.WriteString(w, "hello") |
266 | 383 | if err := w.Close(); err != nil { |
291 | 408 | message := make([]byte, readLimit+1) |
292 | 409 | |
293 | 410 | var b1, b2 bytes.Buffer |
294 | wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2) | |
295 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) | |
411 | wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) | |
412 | rc := newTestConn(&b1, &b2, true) | |
296 | 413 | rc.SetReadLimit(readLimit) |
297 | 414 | |
298 | 415 | // Send message at the limit with interleaved pong. |
320 | 437 | } |
321 | 438 | |
322 | 439 | func TestAddrs(t *testing.T) { |
323 | c := newConn(&fakeNetConn{}, true, 1024, 1024) | |
440 | c := newTestConn(nil, nil, true) | |
324 | 441 | if c.LocalAddr() != localAddr { |
325 | 442 | t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) |
326 | 443 | } |
332 | 449 | func TestUnderlyingConn(t *testing.T) { |
333 | 450 | var b1, b2 bytes.Buffer |
334 | 451 | fc := fakeNetConn{Reader: &b1, Writer: &b2} |
335 | c := newConn(fc, true, 1024, 1024) | |
452 | c := newConn(fc, true, 1024, 1024, nil, nil, nil) | |
336 | 453 | ul := c.UnderlyingConn() |
337 | 454 | if ul != fc { |
338 | 455 | t.Fatalf("Underlying conn is not what it should be.") |
340 | 457 | } |
341 | 458 | |
342 | 459 | func TestBufioReadBytes(t *testing.T) { |
343 | ||
344 | 460 | // Test calling bufio.ReadBytes for value longer than read buffer size. |
345 | 461 | |
346 | 462 | m := make([]byte, 512) |
347 | 463 | m[len(m)-1] = '\n' |
348 | 464 | |
349 | 465 | var b1, b2 bytes.Buffer |
350 | wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) | |
351 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) | |
466 | wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) | |
467 | rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) | |
352 | 468 | |
353 | 469 | w, _ := wc.NextWriter(BinaryMessage) |
354 | 470 | w.Write(m) |
365 | 481 | t.Fatalf("ReadBytes() returned %v", err) |
366 | 482 | } |
367 | 483 | if len(p) != len(m) { |
368 | t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) | |
484 | t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) | |
369 | 485 | } |
370 | 486 | } |
371 | 487 | |
423 | 539 | |
424 | 540 | func TestConcurrentWritePanic(t *testing.T) { |
425 | 541 | w := blockingWriter{make(chan struct{}), make(chan struct{})} |
426 | c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024) | |
542 | c := newTestConn(nil, w, false) | |
427 | 543 | go func() { |
428 | 544 | c.WriteMessage(TextMessage, []byte{}) |
429 | 545 | }() |
449 | 565 | } |
450 | 566 | |
451 | 567 | func TestFailedConnectionReadPanic(t *testing.T) { |
452 | c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024) | |
568 | c := newTestConn(failingReader{}, nil, false) | |
453 | 569 | |
454 | 570 | defer func() { |
455 | 571 | if v := recover(); v != nil { |
462 | 578 | } |
463 | 579 | t.Fatal("should not get here") |
464 | 580 | } |
465 | ||
466 | func TestBufioReuse(t *testing.T) { | |
467 | brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil)) | |
468 | c := newConnBRW(nil, false, 0, 0, brw) | |
469 | ||
470 | if c.br != brw.Reader { | |
471 | t.Error("connection did not reuse bufio.Reader") | |
472 | } | |
473 | ||
474 | var wh writeHook | |
475 | brw.Writer.Reset(&wh) | |
476 | brw.WriteByte(0) | |
477 | brw.Flush() | |
478 | if &c.writeBuf[0] != &wh.p[0] { | |
479 | t.Error("connection did not reuse bufio.Writer") | |
480 | } | |
481 | ||
482 | brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0)) | |
483 | c = newConnBRW(nil, false, 0, 0, brw) | |
484 | ||
485 | if c.br == brw.Reader { | |
486 | t.Error("connection used bufio.Reader with small size") | |
487 | } | |
488 | ||
489 | brw.Writer.Reset(&wh) | |
490 | brw.WriteByte(0) | |
491 | brw.Flush() | |
492 | if &c.writeBuf[0] != &wh.p[0] { | |
493 | t.Error("connection used bufio.Writer with small size") | |
494 | } | |
495 | ||
496 | } |
0 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. | |
1 | // Use of this source code is governed by a BSD-style | |
2 | // license that can be found in the LICENSE file. | |
3 | ||
4 | // +build go1.8 | |
5 | ||
6 | package websocket | |
7 | ||
8 | import "net" | |
9 | ||
10 | func (c *Conn) writeBufs(bufs ...[]byte) error { | |
11 | b := net.Buffers(bufs) | |
12 | _, err := b.WriteTo(c.conn) | |
13 | return err | |
14 | } |
0 | // Copyright 2016 The Gorilla WebSocket Authors. All rights reserved. | |
1 | // Use of this source code is governed by a BSD-style | |
2 | // license that can be found in the LICENSE file. | |
3 | ||
4 | // +build !go1.8 | |
5 | ||
6 | package websocket | |
7 | ||
8 | func (c *Conn) writeBufs(bufs ...[]byte) error { | |
9 | for _, buf := range bufs { | |
10 | if len(buf) > 0 { | |
11 | if _, err := c.conn.Write(buf); err != nil { | |
12 | return err | |
13 | } | |
14 | } | |
15 | } | |
16 | return nil | |
17 | } |
5 | 5 | // |
6 | 6 | // Overview |
7 | 7 | // |
8 | // The Conn type represents a WebSocket connection. A server application uses | |
9 | // the Upgrade function from an Upgrader object with a HTTP request handler | |
10 | // to get a pointer to a Conn: | |
8 | // The Conn type represents a WebSocket connection. A server application calls | |
9 | // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn: | |
11 | 10 | // |
12 | 11 | // var upgrader = websocket.Upgrader{ |
13 | 12 | // ReadBufferSize: 1024, |
30 | 29 | // for { |
31 | 30 | // messageType, p, err := conn.ReadMessage() |
32 | 31 | // if err != nil { |
32 | // log.Println(err) | |
33 | 33 | // return |
34 | 34 | // } |
35 | // if err = conn.WriteMessage(messageType, p); err != nil { | |
36 | // return err | |
35 | // if err := conn.WriteMessage(messageType, p); err != nil { | |
36 | // log.Println(err) | |
37 | // return | |
37 | 38 | // } |
38 | 39 | // } |
39 | 40 | // |
84 | 85 | // and pong. Call the connection WriteControl, WriteMessage or NextWriter |
85 | 86 | // methods to send a control message to the peer. |
86 | 87 | // |
87 | // Connections handle received close messages by sending a close message to the | |
88 | // peer and returning a *CloseError from the the NextReader, ReadMessage or the | |
89 | // message Read method. | |
88 | // Connections handle received close messages by calling the handler function | |
89 | // set with the SetCloseHandler method and by returning a *CloseError from the | |
90 | // NextReader, ReadMessage or the message Read method. The default close | |
91 | // handler sends a close message to the peer. | |
90 | 92 | // |
91 | // Connections handle received ping and pong messages by invoking callback | |
92 | // functions set with SetPingHandler and SetPongHandler methods. The callback | |
93 | // functions are called from the NextReader, ReadMessage and the message Read | |
94 | // methods. | |
93 | // Connections handle received ping messages by calling the handler function | |
94 | // set with the SetPingHandler method. The default ping handler sends a pong | |
95 | // message to the peer. | |
95 | 96 | // |
96 | // The default ping handler sends a pong to the peer. The application's reading | |
97 | // goroutine can block for a short time while the handler writes the pong data | |
98 | // to the connection. | |
97 | // Connections handle received pong messages by calling the handler function | |
98 | // set with the SetPongHandler method. The default pong handler does nothing. | |
99 | // If an application sends ping messages, then the application should set a | |
100 | // pong handler to receive the corresponding pong. | |
99 | 101 | // |
100 | // The application must read the connection to process ping, pong and close | |
102 | // The control message handler functions are called from the NextReader, | |
103 | // ReadMessage and message reader Read methods. The default close and ping | |
104 | // handlers can block these methods for a short time when the handler writes to | |
105 | // the connection. | |
106 | // | |
107 | // The application must read the connection to process close, ping and pong | |
101 | 108 | // messages sent from the peer. If the application is not otherwise interested |
102 | 109 | // in messages from the peer, then the application should start a goroutine to |
103 | 110 | // read and discard messages from the peer. A simple example is: |
136 | 143 | // method fails the WebSocket handshake with HTTP status 403. |
137 | 144 | // |
138 | 145 | // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail |
139 | // the handshake if the Origin request header is present and not equal to the | |
140 | // Host request header. | |
146 | // the handshake if the Origin request header is present and the Origin host is | |
147 | // not equal to the Host request header. | |
141 | 148 | // |
142 | // An application can allow connections from any origin by specifying a | |
143 | // function that always returns true: | |
144 | // | |
145 | // var upgrader = websocket.Upgrader{ | |
146 | // CheckOrigin: func(r *http.Request) bool { return true }, | |
147 | // } | |
148 | // | |
149 | // The deprecated Upgrade function does not enforce an origin policy. It's the | |
150 | // application's responsibility to check the Origin header before calling | |
151 | // Upgrade. | |
149 | // The deprecated package-level Upgrade function does not perform origin | |
150 | // checking. The application is responsible for checking the Origin header | |
151 | // before calling the Upgrade function. | |
152 | 152 | // |
153 | 153 | // Compression EXPERIMENTAL |
154 | 154 | // |
156 | 156 | |
157 | 157 | func serveHome(w http.ResponseWriter, r *http.Request) { |
158 | 158 | if r.URL.Path != "/" { |
159 | http.Error(w, "Not found.", 404) | |
159 | http.Error(w, "Not found.", http.StatusNotFound) | |
160 | 160 | return |
161 | 161 | } |
162 | 162 | if r.Method != "GET" { |
163 | http.Error(w, "Method not allowed", 405) | |
163 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |
164 | 164 | return |
165 | 165 | } |
166 | 166 | w.Header().Set("Content-Type", "text/html; charset=utf-8") |
0 | 0 | # Chat Example |
1 | 1 | |
2 | This application shows how to use use the | |
2 | This application shows how to use the | |
3 | 3 | [websocket](https://github.com/gorilla/websocket) package to implement a simple |
4 | 4 | web chat application. |
5 | 5 |
63 | 63 | for { |
64 | 64 | _, message, err := c.conn.ReadMessage() |
65 | 65 | if err != nil { |
66 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { | |
66 | if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { | |
67 | 67 | log.Printf("error: %v", err) |
68 | 68 | } |
69 | 69 | break |
112 | 112 | } |
113 | 113 | case <-ticker.C: |
114 | 114 | c.conn.SetWriteDeadline(time.Now().Add(writeWait)) |
115 | if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { | |
115 | if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { | |
116 | 116 | return |
117 | 117 | } |
118 | 118 | } |
3 | 3 | |
4 | 4 | package main |
5 | 5 | |
6 | // hub maintains the set of active clients and broadcasts messages to the | |
6 | // Hub maintains the set of active clients and broadcasts messages to the | |
7 | 7 | // clients. |
8 | 8 | type Hub struct { |
9 | 9 | // Registered clients. |
14 | 14 | func serveHome(w http.ResponseWriter, r *http.Request) { |
15 | 15 | log.Println(r.URL) |
16 | 16 | if r.URL.Path != "/" { |
17 | http.Error(w, "Not found", 404) | |
17 | http.Error(w, "Not found", http.StatusNotFound) | |
18 | 18 | return |
19 | 19 | } |
20 | 20 | if r.Method != "GET" { |
21 | http.Error(w, "Method not allowed", 405) | |
21 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |
22 | 22 | return |
23 | 23 | } |
24 | 24 | http.ServeFile(w, r, "home.html") |
166 | 166 | |
167 | 167 | func serveHome(w http.ResponseWriter, r *http.Request) { |
168 | 168 | if r.URL.Path != "/" { |
169 | http.Error(w, "Not found", 404) | |
169 | http.Error(w, "Not found", http.StatusNotFound) | |
170 | 170 | return |
171 | 171 | } |
172 | 172 | if r.Method != "GET" { |
173 | http.Error(w, "Method not allowed", 405) | |
173 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |
174 | 174 | return |
175 | 175 | } |
176 | 176 | http.ServeFile(w, r, "home.html") |
37 | 37 | done := make(chan struct{}) |
38 | 38 | |
39 | 39 | go func() { |
40 | defer c.Close() | |
41 | 40 | defer close(done) |
42 | 41 | for { |
43 | 42 | _, message, err := c.ReadMessage() |
54 | 53 | |
55 | 54 | for { |
56 | 55 | select { |
56 | case <-done: | |
57 | return | |
57 | 58 | case t := <-ticker.C: |
58 | 59 | err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) |
59 | 60 | if err != nil { |
62 | 63 | } |
63 | 64 | case <-interrupt: |
64 | 65 | log.Println("interrupt") |
65 | // To cleanly close a connection, a client should send a close | |
66 | // frame and wait for the server to close the connection. | |
66 | ||
67 | // Cleanly close the connection by sending a close message and then | |
68 | // waiting (with timeout) for the server to close the connection. | |
67 | 69 | err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) |
68 | 70 | if err != nil { |
69 | 71 | log.Println("write close:", err) |
73 | 75 | case <-done: |
74 | 76 | case <-time.After(time.Second): |
75 | 77 | } |
76 | c.Close() | |
77 | 78 | return |
78 | 79 | } |
79 | 80 | } |
54 | 54 | |
55 | 55 | var homeTemplate = template.Must(template.New("").Parse(` |
56 | 56 | <!DOCTYPE html> |
57 | <html> | |
57 | 58 | <head> |
58 | 59 | <meta charset="utf-8"> |
59 | 60 | <script> |
129 | 129 | |
130 | 130 | func serveHome(w http.ResponseWriter, r *http.Request) { |
131 | 131 | if r.URL.Path != "/" { |
132 | http.Error(w, "Not found", 404) | |
132 | http.Error(w, "Not found", http.StatusNotFound) | |
133 | 133 | return |
134 | 134 | } |
135 | 135 | if r.Method != "GET" { |
136 | http.Error(w, "Method not allowed", 405) | |
136 | http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) | |
137 | 137 | return |
138 | 138 | } |
139 | 139 | w.Header().Set("Content-Type", "text/html; charset=utf-8") |
8 | 8 | "io" |
9 | 9 | ) |
10 | 10 | |
11 | // WriteJSON is deprecated, use c.WriteJSON instead. | |
11 | // WriteJSON writes the JSON encoding of v as a message. | |
12 | // | |
13 | // Deprecated: Use c.WriteJSON instead. | |
12 | 14 | func WriteJSON(c *Conn, v interface{}) error { |
13 | 15 | return c.WriteJSON(v) |
14 | 16 | } |
15 | 17 | |
16 | // WriteJSON writes the JSON encoding of v to the connection. | |
18 | // WriteJSON writes the JSON encoding of v as a message. | |
17 | 19 | // |
18 | 20 | // See the documentation for encoding/json Marshal for details about the |
19 | 21 | // conversion of Go values to JSON. |
30 | 32 | return err2 |
31 | 33 | } |
32 | 34 | |
33 | // ReadJSON is deprecated, use c.ReadJSON instead. | |
35 | // ReadJSON reads the next JSON-encoded message from the connection and stores | |
36 | // it in the value pointed to by v. | |
37 | // | |
38 | // Deprecated: Use c.ReadJSON instead. | |
34 | 39 | func ReadJSON(c *Conn, v interface{}) error { |
35 | 40 | return c.ReadJSON(v) |
36 | 41 | } |
13 | 13 | |
14 | 14 | func TestJSON(t *testing.T) { |
15 | 15 | var buf bytes.Buffer |
16 | c := fakeNetConn{&buf, &buf} | |
17 | wc := newConn(c, true, 1024, 1024) | |
18 | rc := newConn(c, false, 1024, 1024) | |
16 | wc := newTestConn(nil, &buf, true) | |
17 | rc := newTestConn(&buf, nil, false) | |
19 | 18 | |
20 | 19 | var actual, expect struct { |
21 | 20 | A int |
38 | 37 | } |
39 | 38 | |
40 | 39 | func TestPartialJSONRead(t *testing.T) { |
41 | var buf bytes.Buffer | |
42 | c := fakeNetConn{&buf, &buf} | |
43 | wc := newConn(c, true, 1024, 1024) | |
44 | rc := newConn(c, false, 1024, 1024) | |
40 | var buf0, buf1 bytes.Buffer | |
41 | wc := newTestConn(nil, &buf0, true) | |
42 | rc := newTestConn(&buf0, &buf1, false) | |
45 | 43 | |
46 | 44 | var v struct { |
47 | 45 | A int |
93 | 91 | |
94 | 92 | func TestDeprecatedJSON(t *testing.T) { |
95 | 93 | var buf bytes.Buffer |
96 | c := fakeNetConn{&buf, &buf} | |
97 | wc := newConn(c, true, 1024, 1024) | |
98 | rc := newConn(c, false, 1024, 1024) | |
94 | wc := newTestConn(nil, &buf, true) | |
95 | rc := newTestConn(&buf, nil, false) | |
99 | 96 | |
100 | 97 | var actual, expect struct { |
101 | 98 | A int |
10 | 10 | const wordSize = int(unsafe.Sizeof(uintptr(0))) |
11 | 11 | |
12 | 12 | func maskBytes(key [4]byte, pos int, b []byte) int { |
13 | ||
14 | 13 | // Mask one byte at a time for small buffers. |
15 | 14 | if len(b) < 2*wordSize { |
16 | 15 | for i := range b { |
1 | 1 | // this source code is governed by a BSD-style license that can be found in the |
2 | 2 | // LICENSE file. |
3 | 3 | |
4 | // Require 1.7 for sub-bencmarks | |
5 | // +build go1.7,!appengine | |
4 | // !appengine | |
6 | 5 | |
7 | 6 | package websocket |
8 | 7 |
18 | 18 | type PreparedMessage struct { |
19 | 19 | messageType int |
20 | 20 | data []byte |
21 | err error | |
22 | 21 | mu sync.Mutex |
23 | 22 | frames map[prepareKey]*preparedFrame |
24 | 23 | } |
35 | 35 | for _, tt := range preparedMessageTests { |
36 | 36 | var data = []byte("this is a test") |
37 | 37 | var buf bytes.Buffer |
38 | c := newConn(fakeNetConn{Reader: nil, Writer: &buf}, tt.isServer, 1024, 1024) | |
38 | c := newTestConn(nil, &buf, tt.isServer) | |
39 | 39 | if tt.enableWriteCompression { |
40 | 40 | c.newCompressionWriter = compressNoContextTakeover |
41 | 41 | } |
0 | // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. | |
1 | // Use of this source code is governed by a BSD-style | |
2 | // license that can be found in the LICENSE file. | |
3 | ||
4 | package websocket | |
5 | ||
6 | import ( | |
7 | "bufio" | |
8 | "encoding/base64" | |
9 | "errors" | |
10 | "net" | |
11 | "net/http" | |
12 | "net/url" | |
13 | "strings" | |
14 | ) | |
15 | ||
16 | type netDialerFunc func(network, addr string) (net.Conn, error) | |
17 | ||
18 | func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) { | |
19 | return fn(network, addr) | |
20 | } | |
21 | ||
22 | func init() { | |
23 | proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) { | |
24 | return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil | |
25 | }) | |
26 | } | |
27 | ||
28 | type httpProxyDialer struct { | |
29 | proxyURL *url.URL | |
30 | fowardDial func(network, addr string) (net.Conn, error) | |
31 | } | |
32 | ||
33 | func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) { | |
34 | hostPort, _ := hostPortNoPort(hpd.proxyURL) | |
35 | conn, err := hpd.fowardDial(network, hostPort) | |
36 | if err != nil { | |
37 | return nil, err | |
38 | } | |
39 | ||
40 | connectHeader := make(http.Header) | |
41 | if user := hpd.proxyURL.User; user != nil { | |
42 | proxyUser := user.Username() | |
43 | if proxyPassword, passwordSet := user.Password(); passwordSet { | |
44 | credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword)) | |
45 | connectHeader.Set("Proxy-Authorization", "Basic "+credential) | |
46 | } | |
47 | } | |
48 | ||
49 | connectReq := &http.Request{ | |
50 | Method: "CONNECT", | |
51 | URL: &url.URL{Opaque: addr}, | |
52 | Host: addr, | |
53 | Header: connectHeader, | |
54 | } | |
55 | ||
56 | if err := connectReq.Write(conn); err != nil { | |
57 | conn.Close() | |
58 | return nil, err | |
59 | } | |
60 | ||
61 | // Read response. It's OK to use and discard buffered reader here becaue | |
62 | // the remote server does not speak until spoken to. | |
63 | br := bufio.NewReader(conn) | |
64 | resp, err := http.ReadResponse(br, connectReq) | |
65 | if err != nil { | |
66 | conn.Close() | |
67 | return nil, err | |
68 | } | |
69 | ||
70 | if resp.StatusCode != 200 { | |
71 | conn.Close() | |
72 | f := strings.SplitN(resp.Status, " ", 2) | |
73 | return nil, errors.New(f[1]) | |
74 | } | |
75 | return conn, nil | |
76 | } |
6 | 6 | import ( |
7 | 7 | "bufio" |
8 | 8 | "errors" |
9 | "net" | |
9 | "io" | |
10 | 10 | "net/http" |
11 | 11 | "net/url" |
12 | 12 | "strings" |
32 | 32 | // or received. |
33 | 33 | ReadBufferSize, WriteBufferSize int |
34 | 34 | |
35 | // WriteBufferPool is a pool of buffers for write operations. If the value | |
36 | // is not set, then write buffers are allocated to the connection for the | |
37 | // lifetime of the connection. | |
38 | // | |
39 | // A pool is most useful when the application has a modest volume of writes | |
40 | // across a large number of connections. | |
41 | // | |
42 | // Applications should use a single pool for each unique value of | |
43 | // WriteBufferSize. | |
44 | WriteBufferPool BufferPool | |
45 | ||
35 | 46 | // Subprotocols specifies the server's supported protocols in order of |
36 | // preference. If this field is set, then the Upgrade method negotiates a | |
47 | // preference. If this field is not nil, then the Upgrade method negotiates a | |
37 | 48 | // subprotocol by selecting the first match in this list with a protocol |
38 | // requested by the client. | |
49 | // requested by the client. If there's no match, then no protocol is | |
50 | // negotiated (the Sec-Websocket-Protocol header is not included in the | |
51 | // handshake response). | |
39 | 52 | Subprotocols []string |
40 | 53 | |
41 | 54 | // Error specifies the function for generating HTTP error responses. If Error |
43 | 56 | Error func(w http.ResponseWriter, r *http.Request, status int, reason error) |
44 | 57 | |
45 | 58 | // CheckOrigin returns true if the request Origin header is acceptable. If |
46 | // CheckOrigin is nil, the host in the Origin header must not be set or | |
47 | // must match the host of the request. | |
59 | // CheckOrigin is nil, then a safe default is used: return false if the | |
60 | // Origin request header is present and the origin host is not equal to | |
61 | // request Host header. | |
62 | // | |
63 | // A CheckOrigin function should carefully validate the request origin to | |
64 | // prevent cross-site request forgery. | |
48 | 65 | CheckOrigin func(r *http.Request) bool |
49 | 66 | |
50 | 67 | // EnableCompression specify if the server should attempt to negotiate per |
75 | 92 | if err != nil { |
76 | 93 | return false |
77 | 94 | } |
78 | return u.Host == r.Host | |
95 | return equalASCIIFold(u.Host, r.Host) | |
79 | 96 | } |
80 | 97 | |
81 | 98 | func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { |
98 | 115 | // |
99 | 116 | // The responseHeader is included in the response to the client's upgrade |
100 | 117 | // request. Use the responseHeader to specify cookies (Set-Cookie) and the |
101 | // application negotiated subprotocol (Sec-Websocket-Protocol). | |
118 | // application negotiated subprotocol (Sec-WebSocket-Protocol). | |
102 | 119 | // |
103 | 120 | // If the upgrade fails, then Upgrade replies to the client with an HTTP error |
104 | 121 | // response. |
105 | 122 | func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { |
123 | const badHandshake = "websocket: the client is not using the websocket protocol: " | |
124 | ||
125 | if !tokenListContainsValue(r.Header, "Connection", "upgrade") { | |
126 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header") | |
127 | } | |
128 | ||
129 | if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { | |
130 | return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header") | |
131 | } | |
132 | ||
106 | 133 | if r.Method != "GET" { |
107 | return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET") | |
108 | } | |
109 | ||
110 | if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { | |
111 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported") | |
112 | } | |
113 | ||
114 | if !tokenListContainsValue(r.Header, "Connection", "upgrade") { | |
115 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header") | |
116 | } | |
117 | ||
118 | if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { | |
119 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header") | |
134 | return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET") | |
120 | 135 | } |
121 | 136 | |
122 | 137 | if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { |
123 | 138 | return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") |
139 | } | |
140 | ||
141 | if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok { | |
142 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") | |
124 | 143 | } |
125 | 144 | |
126 | 145 | checkOrigin := u.CheckOrigin |
128 | 147 | checkOrigin = checkSameOrigin |
129 | 148 | } |
130 | 149 | if !checkOrigin(r) { |
131 | return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed") | |
150 | return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin") | |
132 | 151 | } |
133 | 152 | |
134 | 153 | challengeKey := r.Header.Get("Sec-Websocket-Key") |
135 | 154 | if challengeKey == "" { |
136 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank") | |
155 | return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") | |
137 | 156 | } |
138 | 157 | |
139 | 158 | subprotocol := u.selectSubprotocol(r, responseHeader) |
150 | 169 | } |
151 | 170 | } |
152 | 171 | |
153 | var ( | |
154 | netConn net.Conn | |
155 | err error | |
156 | ) | |
157 | ||
158 | 172 | h, ok := w.(http.Hijacker) |
159 | 173 | if !ok { |
160 | 174 | return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") |
161 | 175 | } |
162 | 176 | var brw *bufio.ReadWriter |
163 | netConn, brw, err = h.Hijack() | |
177 | netConn, brw, err := h.Hijack() | |
164 | 178 | if err != nil { |
165 | 179 | return u.returnError(w, r, http.StatusInternalServerError, err.Error()) |
166 | 180 | } |
170 | 184 | return nil, errors.New("websocket: client sent data before handshake is complete") |
171 | 185 | } |
172 | 186 | |
173 | c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) | |
187 | var br *bufio.Reader | |
188 | if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 { | |
189 | // Reuse hijacked buffered reader as connection reader. | |
190 | br = brw.Reader | |
191 | } | |
192 | ||
193 | buf := bufioWriterBuffer(netConn, brw.Writer) | |
194 | ||
195 | var writeBuf []byte | |
196 | if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 { | |
197 | // Reuse hijacked write buffer as connection buffer. | |
198 | writeBuf = buf | |
199 | } | |
200 | ||
201 | c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf) | |
174 | 202 | c.subprotocol = subprotocol |
175 | 203 | |
176 | 204 | if compress { |
178 | 206 | c.newDecompressionReader = decompressNoContextTakeover |
179 | 207 | } |
180 | 208 | |
181 | p := c.writeBuf[:0] | |
209 | // Use larger of hijacked buffer and connection write buffer for header. | |
210 | p := buf | |
211 | if len(c.writeBuf) > len(p) { | |
212 | p = c.writeBuf | |
213 | } | |
214 | p = p[:0] | |
215 | ||
182 | 216 | p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) |
183 | 217 | p = append(p, computeAcceptKey(challengeKey)...) |
184 | 218 | p = append(p, "\r\n"...) |
185 | 219 | if c.subprotocol != "" { |
186 | p = append(p, "Sec-Websocket-Protocol: "...) | |
220 | p = append(p, "Sec-WebSocket-Protocol: "...) | |
187 | 221 | p = append(p, c.subprotocol...) |
188 | 222 | p = append(p, "\r\n"...) |
189 | 223 | } |
190 | 224 | if compress { |
191 | p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) | |
225 | p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) | |
192 | 226 | } |
193 | 227 | for k, vs := range responseHeader { |
194 | 228 | if k == "Sec-Websocket-Protocol" { |
229 | 263 | |
230 | 264 | // Upgrade upgrades the HTTP server connection to the WebSocket protocol. |
231 | 265 | // |
232 | // This function is deprecated, use websocket.Upgrader instead. | |
233 | // | |
234 | // The application is responsible for checking the request origin before | |
235 | // calling Upgrade. An example implementation of the same origin policy is: | |
266 | // Deprecated: Use websocket.Upgrader instead. | |
267 | // | |
268 | // Upgrade does not perform origin checking. The application is responsible for | |
269 | // checking the Origin header before calling Upgrade. An example implementation | |
270 | // of the same origin policy check is: | |
236 | 271 | // |
237 | 272 | // if req.Header.Get("Origin") != "http://"+req.Host { |
238 | // http.Error(w, "Origin not allowed", 403) | |
273 | // http.Error(w, "Origin not allowed", http.StatusForbidden) | |
239 | 274 | // return |
240 | 275 | // } |
241 | 276 | // |
288 | 323 | return tokenListContainsValue(r.Header, "Connection", "upgrade") && |
289 | 324 | tokenListContainsValue(r.Header, "Upgrade", "websocket") |
290 | 325 | } |
326 | ||
327 | // bufioReaderSize size returns the size of a bufio.Reader. | |
328 | func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int { | |
329 | // This code assumes that peek on a reset reader returns | |
330 | // bufio.Reader.buf[:0]. | |
331 | // TODO: Use bufio.Reader.Size() after Go 1.10 | |
332 | br.Reset(originalReader) | |
333 | if p, err := br.Peek(0); err == nil { | |
334 | return cap(p) | |
335 | } | |
336 | return 0 | |
337 | } | |
338 | ||
339 | // writeHook is an io.Writer that records the last slice passed to it vio | |
340 | // io.Writer.Write. | |
341 | type writeHook struct { | |
342 | p []byte | |
343 | } | |
344 | ||
345 | func (wh *writeHook) Write(p []byte) (int, error) { | |
346 | wh.p = p | |
347 | return len(p), nil | |
348 | } | |
349 | ||
350 | // bufioWriterBuffer grabs the buffer from a bufio.Writer. | |
351 | func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte { | |
352 | // This code assumes that bufio.Writer.buf[:1] is passed to the | |
353 | // bufio.Writer's underlying writer. | |
354 | var wh writeHook | |
355 | bw.Reset(&wh) | |
356 | bw.WriteByte(0) | |
357 | bw.Flush() | |
358 | ||
359 | bw.Reset(originalWriter) | |
360 | ||
361 | return wh.p[:cap(wh.p)] | |
362 | } |
4 | 4 | package websocket |
5 | 5 | |
6 | 6 | import ( |
7 | "bufio" | |
8 | "bytes" | |
9 | "net" | |
7 | 10 | "net/http" |
8 | 11 | "reflect" |
12 | "strings" | |
9 | 13 | "testing" |
10 | 14 | ) |
11 | 15 | |
48 | 52 | } |
49 | 53 | } |
50 | 54 | } |
55 | ||
56 | var checkSameOriginTests = []struct { | |
57 | ok bool | |
58 | r *http.Request | |
59 | }{ | |
60 | {false, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://other.org"}}}}, | |
61 | {true, &http.Request{Host: "example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, | |
62 | {true, &http.Request{Host: "Example.org", Header: map[string][]string{"Origin": {"https://example.org"}}}}, | |
63 | } | |
64 | ||
65 | func TestCheckSameOrigin(t *testing.T) { | |
66 | for _, tt := range checkSameOriginTests { | |
67 | ok := checkSameOrigin(tt.r) | |
68 | if tt.ok != ok { | |
69 | t.Errorf("checkSameOrigin(%+v) returned %v, want %v", tt.r, ok, tt.ok) | |
70 | } | |
71 | } | |
72 | } | |
73 | ||
74 | type reuseTestResponseWriter struct { | |
75 | brw *bufio.ReadWriter | |
76 | http.ResponseWriter | |
77 | } | |
78 | ||
79 | func (resp *reuseTestResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { | |
80 | return fakeNetConn{strings.NewReader(""), &bytes.Buffer{}}, resp.brw, nil | |
81 | } | |
82 | ||
83 | var bufioReuseTests = []struct { | |
84 | n int | |
85 | reuse bool | |
86 | }{ | |
87 | {4096, true}, | |
88 | {128, false}, | |
89 | } | |
90 | ||
91 | func TestBufioReuse(t *testing.T) { | |
92 | for i, tt := range bufioReuseTests { | |
93 | br := bufio.NewReaderSize(strings.NewReader(""), tt.n) | |
94 | bw := bufio.NewWriterSize(&bytes.Buffer{}, tt.n) | |
95 | resp := &reuseTestResponseWriter{ | |
96 | brw: bufio.NewReadWriter(br, bw), | |
97 | } | |
98 | upgrader := Upgrader{} | |
99 | c, err := upgrader.Upgrade(resp, &http.Request{ | |
100 | Method: "GET", | |
101 | Header: http.Header{ | |
102 | "Upgrade": []string{"websocket"}, | |
103 | "Connection": []string{"upgrade"}, | |
104 | "Sec-Websocket-Key": []string{"dGhlIHNhbXBsZSBub25jZQ=="}, | |
105 | "Sec-Websocket-Version": []string{"13"}, | |
106 | }}, nil) | |
107 | if err != nil { | |
108 | t.Fatal(err) | |
109 | } | |
110 | if reuse := c.br == br; reuse != tt.reuse { | |
111 | t.Errorf("%d: buffered reader reuse=%v, want %v", i, reuse, tt.reuse) | |
112 | } | |
113 | writeBuf := bufioWriterBuffer(c.UnderlyingConn(), bw) | |
114 | if reuse := &c.writeBuf[0] == &writeBuf[0]; reuse != tt.reuse { | |
115 | t.Errorf("%d: write buffer reuse=%v, want %v", i, reuse, tt.reuse) | |
116 | } | |
117 | } | |
118 | } |
0 | // +build go1.8 | |
1 | ||
2 | package websocket | |
3 | ||
4 | import ( | |
5 | "crypto/tls" | |
6 | "net/http/httptrace" | |
7 | ) | |
8 | ||
9 | func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { | |
10 | if trace.TLSHandshakeStart != nil { | |
11 | trace.TLSHandshakeStart() | |
12 | } | |
13 | err := doHandshake(tlsConn, cfg) | |
14 | if trace.TLSHandshakeDone != nil { | |
15 | trace.TLSHandshakeDone(tlsConn.ConnectionState(), err) | |
16 | } | |
17 | return err | |
18 | } |
0 | // +build !go1.8 | |
1 | ||
2 | package websocket | |
3 | ||
4 | import ( | |
5 | "crypto/tls" | |
6 | "net/http/httptrace" | |
7 | ) | |
8 | ||
9 | func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error { | |
10 | return doHandshake(tlsConn, cfg) | |
11 | } |
10 | 10 | "io" |
11 | 11 | "net/http" |
12 | 12 | "strings" |
13 | "unicode/utf8" | |
13 | 14 | ) |
14 | 15 | |
15 | 16 | var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") |
110 | 111 | case escape: |
111 | 112 | escape = false |
112 | 113 | p[j] = b |
113 | j += 1 | |
114 | j++ | |
114 | 115 | case b == '\\': |
115 | 116 | escape = true |
116 | 117 | case b == '"': |
117 | 118 | return string(p[:j]), s[i+1:] |
118 | 119 | default: |
119 | 120 | p[j] = b |
120 | j += 1 | |
121 | j++ | |
121 | 122 | } |
122 | 123 | } |
123 | 124 | return "", "" |
126 | 127 | return "", "" |
127 | 128 | } |
128 | 129 | |
130 | // equalASCIIFold returns true if s is equal to t with ASCII case folding. | |
131 | func equalASCIIFold(s, t string) bool { | |
132 | for s != "" && t != "" { | |
133 | sr, size := utf8.DecodeRuneInString(s) | |
134 | s = s[size:] | |
135 | tr, size := utf8.DecodeRuneInString(t) | |
136 | t = t[size:] | |
137 | if sr == tr { | |
138 | continue | |
139 | } | |
140 | if 'A' <= sr && sr <= 'Z' { | |
141 | sr = sr + 'a' - 'A' | |
142 | } | |
143 | if 'A' <= tr && tr <= 'Z' { | |
144 | tr = tr + 'a' - 'A' | |
145 | } | |
146 | if sr != tr { | |
147 | return false | |
148 | } | |
149 | } | |
150 | return s == t | |
151 | } | |
152 | ||
129 | 153 | // tokenListContainsValue returns true if the 1#token header with the given |
130 | // name contains token. | |
154 | // name contains a token equal to value with ASCII case folding. | |
131 | 155 | func tokenListContainsValue(header http.Header, name string, value string) bool { |
132 | 156 | headers: |
133 | 157 | for _, s := range header[name] { |
141 | 165 | if s != "" && s[0] != ',' { |
142 | 166 | continue headers |
143 | 167 | } |
144 | if strings.EqualFold(t, value) { | |
168 | if equalASCIIFold(t, value) { | |
145 | 169 | return true |
146 | 170 | } |
147 | 171 | if s == "" { |
153 | 177 | return false |
154 | 178 | } |
155 | 179 | |
156 | // parseExtensiosn parses WebSocket extensions from a header. | |
180 | // parseExtensions parses WebSocket extensions from a header. | |
157 | 181 | func parseExtensions(header http.Header) []map[string]string { |
158 | ||
159 | 182 | // From RFC 6455: |
160 | 183 | // |
161 | 184 | // Sec-WebSocket-Extensions = extension-list |
8 | 8 | "reflect" |
9 | 9 | "testing" |
10 | 10 | ) |
11 | ||
12 | var equalASCIIFoldTests = []struct { | |
13 | t, s string | |
14 | eq bool | |
15 | }{ | |
16 | {"WebSocket", "websocket", true}, | |
17 | {"websocket", "WebSocket", true}, | |
18 | {"Öyster", "öyster", false}, | |
19 | } | |
20 | ||
21 | func TestEqualASCIIFold(t *testing.T) { | |
22 | for _, tt := range equalASCIIFoldTests { | |
23 | eq := equalASCIIFold(tt.s, tt.t) | |
24 | if eq != tt.eq { | |
25 | t.Errorf("equalASCIIFold(%q, %q) = %v, want %v", tt.s, tt.t, eq, tt.eq) | |
26 | } | |
27 | } | |
28 | } | |
11 | 29 | |
12 | 30 | var tokenListContainsValueTests = []struct { |
13 | 31 | value string |
37 | 55 | value string |
38 | 56 | extensions []map[string]string |
39 | 57 | }{ |
40 | {`foo`, []map[string]string{map[string]string{"": "foo"}}}, | |
58 | {`foo`, []map[string]string{{"": "foo"}}}, | |
41 | 59 | {`foo, bar; baz=2`, []map[string]string{ |
42 | map[string]string{"": "foo"}, | |
43 | map[string]string{"": "bar", "baz": "2"}}}, | |
60 | {"": "foo"}, | |
61 | {"": "bar", "baz": "2"}}}, | |
44 | 62 | {`foo; bar="b,a;z"`, []map[string]string{ |
45 | map[string]string{"": "foo", "bar": "b,a;z"}}}, | |
63 | {"": "foo", "bar": "b,a;z"}}}, | |
46 | 64 | {`foo , bar; baz = 2`, []map[string]string{ |
47 | map[string]string{"": "foo"}, | |
48 | map[string]string{"": "bar", "baz": "2"}}}, | |
65 | {"": "foo"}, | |
66 | {"": "bar", "baz": "2"}}}, | |
49 | 67 | {`foo, bar; baz=2 junk`, []map[string]string{ |
50 | map[string]string{"": "foo"}}}, | |
68 | {"": "foo"}}}, | |
51 | 69 | {`foo junk, bar; baz=2 junk`, nil}, |
52 | 70 | {`mux; max-channels=4; flow-control, deflate-stream`, []map[string]string{ |
53 | map[string]string{"": "mux", "max-channels": "4", "flow-control": ""}, | |
54 | map[string]string{"": "deflate-stream"}}}, | |
71 | {"": "mux", "max-channels": "4", "flow-control": ""}, | |
72 | {"": "deflate-stream"}}}, | |
55 | 73 | {`permessage-foo; x="10"`, []map[string]string{ |
56 | map[string]string{"": "permessage-foo", "x": "10"}}}, | |
74 | {"": "permessage-foo", "x": "10"}}}, | |
57 | 75 | {`permessage-foo; use_y, permessage-foo`, []map[string]string{ |
58 | map[string]string{"": "permessage-foo", "use_y": ""}, | |
59 | map[string]string{"": "permessage-foo"}}}, | |
76 | {"": "permessage-foo", "use_y": ""}, | |
77 | {"": "permessage-foo"}}}, | |
60 | 78 | {`permessage-deflate; client_max_window_bits; server_max_window_bits=10 , permessage-deflate; client_max_window_bits`, []map[string]string{ |
61 | map[string]string{"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, | |
62 | map[string]string{"": "permessage-deflate", "client_max_window_bits": ""}}}, | |
79 | {"": "permessage-deflate", "client_max_window_bits": "", "server_max_window_bits": "10"}, | |
80 | {"": "permessage-deflate", "client_max_window_bits": ""}}}, | |
81 | {"permessage-deflate; server_no_context_takeover; client_max_window_bits=15", []map[string]string{ | |
82 | {"": "permessage-deflate", "server_no_context_takeover": "", "client_max_window_bits": "15"}, | |
83 | }}, | |
63 | 84 | } |
64 | 85 | |
65 | 86 | func TestParseExtensions(t *testing.T) { |
0 | // Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT. | |
1 | //go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy | |
2 | ||
3 | // Package proxy provides support for a variety of protocols to proxy network | |
4 | // data. | |
5 | // | |
6 | ||
7 | package websocket | |
8 | ||
9 | import ( | |
10 | "errors" | |
11 | "io" | |
12 | "net" | |
13 | "net/url" | |
14 | "os" | |
15 | "strconv" | |
16 | "strings" | |
17 | "sync" | |
18 | ) | |
19 | ||
20 | type proxy_direct struct{} | |
21 | ||
22 | // Direct is a direct proxy: one that makes network connections directly. | |
23 | var proxy_Direct = proxy_direct{} | |
24 | ||
25 | func (proxy_direct) Dial(network, addr string) (net.Conn, error) { | |
26 | return net.Dial(network, addr) | |
27 | } | |
28 | ||
29 | // A PerHost directs connections to a default Dialer unless the host name | |
30 | // requested matches one of a number of exceptions. | |
31 | type proxy_PerHost struct { | |
32 | def, bypass proxy_Dialer | |
33 | ||
34 | bypassNetworks []*net.IPNet | |
35 | bypassIPs []net.IP | |
36 | bypassZones []string | |
37 | bypassHosts []string | |
38 | } | |
39 | ||
40 | // NewPerHost returns a PerHost Dialer that directs connections to either | |
41 | // defaultDialer or bypass, depending on whether the connection matches one of | |
42 | // the configured rules. | |
43 | func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost { | |
44 | return &proxy_PerHost{ | |
45 | def: defaultDialer, | |
46 | bypass: bypass, | |
47 | } | |
48 | } | |
49 | ||
50 | // Dial connects to the address addr on the given network through either | |
51 | // defaultDialer or bypass. | |
52 | func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) { | |
53 | host, _, err := net.SplitHostPort(addr) | |
54 | if err != nil { | |
55 | return nil, err | |
56 | } | |
57 | ||
58 | return p.dialerForRequest(host).Dial(network, addr) | |
59 | } | |
60 | ||
61 | func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer { | |
62 | if ip := net.ParseIP(host); ip != nil { | |
63 | for _, net := range p.bypassNetworks { | |
64 | if net.Contains(ip) { | |
65 | return p.bypass | |
66 | } | |
67 | } | |
68 | for _, bypassIP := range p.bypassIPs { | |
69 | if bypassIP.Equal(ip) { | |
70 | return p.bypass | |
71 | } | |
72 | } | |
73 | return p.def | |
74 | } | |
75 | ||
76 | for _, zone := range p.bypassZones { | |
77 | if strings.HasSuffix(host, zone) { | |
78 | return p.bypass | |
79 | } | |
80 | if host == zone[1:] { | |
81 | // For a zone ".example.com", we match "example.com" | |
82 | // too. | |
83 | return p.bypass | |
84 | } | |
85 | } | |
86 | for _, bypassHost := range p.bypassHosts { | |
87 | if bypassHost == host { | |
88 | return p.bypass | |
89 | } | |
90 | } | |
91 | return p.def | |
92 | } | |
93 | ||
94 | // AddFromString parses a string that contains comma-separated values | |
95 | // specifying hosts that should use the bypass proxy. Each value is either an | |
96 | // IP address, a CIDR range, a zone (*.example.com) or a host name | |
97 | // (localhost). A best effort is made to parse the string and errors are | |
98 | // ignored. | |
99 | func (p *proxy_PerHost) AddFromString(s string) { | |
100 | hosts := strings.Split(s, ",") | |
101 | for _, host := range hosts { | |
102 | host = strings.TrimSpace(host) | |
103 | if len(host) == 0 { | |
104 | continue | |
105 | } | |
106 | if strings.Contains(host, "/") { | |
107 | // We assume that it's a CIDR address like 127.0.0.0/8 | |
108 | if _, net, err := net.ParseCIDR(host); err == nil { | |
109 | p.AddNetwork(net) | |
110 | } | |
111 | continue | |
112 | } | |
113 | if ip := net.ParseIP(host); ip != nil { | |
114 | p.AddIP(ip) | |
115 | continue | |
116 | } | |
117 | if strings.HasPrefix(host, "*.") { | |
118 | p.AddZone(host[1:]) | |
119 | continue | |
120 | } | |
121 | p.AddHost(host) | |
122 | } | |
123 | } | |
124 | ||
125 | // AddIP specifies an IP address that will use the bypass proxy. Note that | |
126 | // this will only take effect if a literal IP address is dialed. A connection | |
127 | // to a named host will never match an IP. | |
128 | func (p *proxy_PerHost) AddIP(ip net.IP) { | |
129 | p.bypassIPs = append(p.bypassIPs, ip) | |
130 | } | |
131 | ||
132 | // AddNetwork specifies an IP range that will use the bypass proxy. Note that | |
133 | // this will only take effect if a literal IP address is dialed. A connection | |
134 | // to a named host will never match. | |
135 | func (p *proxy_PerHost) AddNetwork(net *net.IPNet) { | |
136 | p.bypassNetworks = append(p.bypassNetworks, net) | |
137 | } | |
138 | ||
139 | // AddZone specifies a DNS suffix that will use the bypass proxy. A zone of | |
140 | // "example.com" matches "example.com" and all of its subdomains. | |
141 | func (p *proxy_PerHost) AddZone(zone string) { | |
142 | if strings.HasSuffix(zone, ".") { | |
143 | zone = zone[:len(zone)-1] | |
144 | } | |
145 | if !strings.HasPrefix(zone, ".") { | |
146 | zone = "." + zone | |
147 | } | |
148 | p.bypassZones = append(p.bypassZones, zone) | |
149 | } | |
150 | ||
151 | // AddHost specifies a host name that will use the bypass proxy. | |
152 | func (p *proxy_PerHost) AddHost(host string) { | |
153 | if strings.HasSuffix(host, ".") { | |
154 | host = host[:len(host)-1] | |
155 | } | |
156 | p.bypassHosts = append(p.bypassHosts, host) | |
157 | } | |
158 | ||
159 | // A Dialer is a means to establish a connection. | |
160 | type proxy_Dialer interface { | |
161 | // Dial connects to the given address via the proxy. | |
162 | Dial(network, addr string) (c net.Conn, err error) | |
163 | } | |
164 | ||
165 | // Auth contains authentication parameters that specific Dialers may require. | |
166 | type proxy_Auth struct { | |
167 | User, Password string | |
168 | } | |
169 | ||
170 | // FromEnvironment returns the dialer specified by the proxy related variables in | |
171 | // the environment. | |
172 | func proxy_FromEnvironment() proxy_Dialer { | |
173 | allProxy := proxy_allProxyEnv.Get() | |
174 | if len(allProxy) == 0 { | |
175 | return proxy_Direct | |
176 | } | |
177 | ||
178 | proxyURL, err := url.Parse(allProxy) | |
179 | if err != nil { | |
180 | return proxy_Direct | |
181 | } | |
182 | proxy, err := proxy_FromURL(proxyURL, proxy_Direct) | |
183 | if err != nil { | |
184 | return proxy_Direct | |
185 | } | |
186 | ||
187 | noProxy := proxy_noProxyEnv.Get() | |
188 | if len(noProxy) == 0 { | |
189 | return proxy | |
190 | } | |
191 | ||
192 | perHost := proxy_NewPerHost(proxy, proxy_Direct) | |
193 | perHost.AddFromString(noProxy) | |
194 | return perHost | |
195 | } | |
196 | ||
197 | // proxySchemes is a map from URL schemes to a function that creates a Dialer | |
198 | // from a URL with such a scheme. | |
199 | var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error) | |
200 | ||
201 | // RegisterDialerType takes a URL scheme and a function to generate Dialers from | |
202 | // a URL with that scheme and a forwarding Dialer. Registered schemes are used | |
203 | // by FromURL. | |
204 | func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) { | |
205 | if proxy_proxySchemes == nil { | |
206 | proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) | |
207 | } | |
208 | proxy_proxySchemes[scheme] = f | |
209 | } | |
210 | ||
211 | // FromURL returns a Dialer given a URL specification and an underlying | |
212 | // Dialer for it to make network requests. | |
213 | func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) { | |
214 | var auth *proxy_Auth | |
215 | if u.User != nil { | |
216 | auth = new(proxy_Auth) | |
217 | auth.User = u.User.Username() | |
218 | if p, ok := u.User.Password(); ok { | |
219 | auth.Password = p | |
220 | } | |
221 | } | |
222 | ||
223 | switch u.Scheme { | |
224 | case "socks5": | |
225 | return proxy_SOCKS5("tcp", u.Host, auth, forward) | |
226 | } | |
227 | ||
228 | // If the scheme doesn't match any of the built-in schemes, see if it | |
229 | // was registered by another package. | |
230 | if proxy_proxySchemes != nil { | |
231 | if f, ok := proxy_proxySchemes[u.Scheme]; ok { | |
232 | return f(u, forward) | |
233 | } | |
234 | } | |
235 | ||
236 | return nil, errors.New("proxy: unknown scheme: " + u.Scheme) | |
237 | } | |
238 | ||
239 | var ( | |
240 | proxy_allProxyEnv = &proxy_envOnce{ | |
241 | names: []string{"ALL_PROXY", "all_proxy"}, | |
242 | } | |
243 | proxy_noProxyEnv = &proxy_envOnce{ | |
244 | names: []string{"NO_PROXY", "no_proxy"}, | |
245 | } | |
246 | ) | |
247 | ||
248 | // envOnce looks up an environment variable (optionally by multiple | |
249 | // names) once. It mitigates expensive lookups on some platforms | |
250 | // (e.g. Windows). | |
251 | // (Borrowed from net/http/transport.go) | |
252 | type proxy_envOnce struct { | |
253 | names []string | |
254 | once sync.Once | |
255 | val string | |
256 | } | |
257 | ||
258 | func (e *proxy_envOnce) Get() string { | |
259 | e.once.Do(e.init) | |
260 | return e.val | |
261 | } | |
262 | ||
263 | func (e *proxy_envOnce) init() { | |
264 | for _, n := range e.names { | |
265 | e.val = os.Getenv(n) | |
266 | if e.val != "" { | |
267 | return | |
268 | } | |
269 | } | |
270 | } | |
271 | ||
272 | // SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address | |
273 | // with an optional username and password. See RFC 1928 and RFC 1929. | |
274 | func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) { | |
275 | s := &proxy_socks5{ | |
276 | network: network, | |
277 | addr: addr, | |
278 | forward: forward, | |
279 | } | |
280 | if auth != nil { | |
281 | s.user = auth.User | |
282 | s.password = auth.Password | |
283 | } | |
284 | ||
285 | return s, nil | |
286 | } | |
287 | ||
288 | type proxy_socks5 struct { | |
289 | user, password string | |
290 | network, addr string | |
291 | forward proxy_Dialer | |
292 | } | |
293 | ||
294 | const proxy_socks5Version = 5 | |
295 | ||
296 | const ( | |
297 | proxy_socks5AuthNone = 0 | |
298 | proxy_socks5AuthPassword = 2 | |
299 | ) | |
300 | ||
301 | const proxy_socks5Connect = 1 | |
302 | ||
303 | const ( | |
304 | proxy_socks5IP4 = 1 | |
305 | proxy_socks5Domain = 3 | |
306 | proxy_socks5IP6 = 4 | |
307 | ) | |
308 | ||
309 | var proxy_socks5Errors = []string{ | |
310 | "", | |
311 | "general failure", | |
312 | "connection forbidden", | |
313 | "network unreachable", | |
314 | "host unreachable", | |
315 | "connection refused", | |
316 | "TTL expired", | |
317 | "command not supported", | |
318 | "address type not supported", | |
319 | } | |
320 | ||
321 | // Dial connects to the address addr on the given network via the SOCKS5 proxy. | |
322 | func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) { | |
323 | switch network { | |
324 | case "tcp", "tcp6", "tcp4": | |
325 | default: | |
326 | return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network) | |
327 | } | |
328 | ||
329 | conn, err := s.forward.Dial(s.network, s.addr) | |
330 | if err != nil { | |
331 | return nil, err | |
332 | } | |
333 | if err := s.connect(conn, addr); err != nil { | |
334 | conn.Close() | |
335 | return nil, err | |
336 | } | |
337 | return conn, nil | |
338 | } | |
339 | ||
340 | // connect takes an existing connection to a socks5 proxy server, | |
341 | // and commands the server to extend that connection to target, | |
342 | // which must be a canonical address with a host and port. | |
343 | func (s *proxy_socks5) connect(conn net.Conn, target string) error { | |
344 | host, portStr, err := net.SplitHostPort(target) | |
345 | if err != nil { | |
346 | return err | |
347 | } | |
348 | ||
349 | port, err := strconv.Atoi(portStr) | |
350 | if err != nil { | |
351 | return errors.New("proxy: failed to parse port number: " + portStr) | |
352 | } | |
353 | if port < 1 || port > 0xffff { | |
354 | return errors.New("proxy: port number out of range: " + portStr) | |
355 | } | |
356 | ||
357 | // the size here is just an estimate | |
358 | buf := make([]byte, 0, 6+len(host)) | |
359 | ||
360 | buf = append(buf, proxy_socks5Version) | |
361 | if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 { | |
362 | buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword) | |
363 | } else { | |
364 | buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone) | |
365 | } | |
366 | ||
367 | if _, err := conn.Write(buf); err != nil { | |
368 | return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
369 | } | |
370 | ||
371 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { | |
372 | return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
373 | } | |
374 | if buf[0] != 5 { | |
375 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0]))) | |
376 | } | |
377 | if buf[1] == 0xff { | |
378 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication") | |
379 | } | |
380 | ||
381 | // See RFC 1929 | |
382 | if buf[1] == proxy_socks5AuthPassword { | |
383 | buf = buf[:0] | |
384 | buf = append(buf, 1 /* password protocol version */) | |
385 | buf = append(buf, uint8(len(s.user))) | |
386 | buf = append(buf, s.user...) | |
387 | buf = append(buf, uint8(len(s.password))) | |
388 | buf = append(buf, s.password...) | |
389 | ||
390 | if _, err := conn.Write(buf); err != nil { | |
391 | return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
392 | } | |
393 | ||
394 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { | |
395 | return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
396 | } | |
397 | ||
398 | if buf[1] != 0 { | |
399 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password") | |
400 | } | |
401 | } | |
402 | ||
403 | buf = buf[:0] | |
404 | buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */) | |
405 | ||
406 | if ip := net.ParseIP(host); ip != nil { | |
407 | if ip4 := ip.To4(); ip4 != nil { | |
408 | buf = append(buf, proxy_socks5IP4) | |
409 | ip = ip4 | |
410 | } else { | |
411 | buf = append(buf, proxy_socks5IP6) | |
412 | } | |
413 | buf = append(buf, ip...) | |
414 | } else { | |
415 | if len(host) > 255 { | |
416 | return errors.New("proxy: destination host name too long: " + host) | |
417 | } | |
418 | buf = append(buf, proxy_socks5Domain) | |
419 | buf = append(buf, byte(len(host))) | |
420 | buf = append(buf, host...) | |
421 | } | |
422 | buf = append(buf, byte(port>>8), byte(port)) | |
423 | ||
424 | if _, err := conn.Write(buf); err != nil { | |
425 | return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
426 | } | |
427 | ||
428 | if _, err := io.ReadFull(conn, buf[:4]); err != nil { | |
429 | return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
430 | } | |
431 | ||
432 | failure := "unknown error" | |
433 | if int(buf[1]) < len(proxy_socks5Errors) { | |
434 | failure = proxy_socks5Errors[buf[1]] | |
435 | } | |
436 | ||
437 | if len(failure) > 0 { | |
438 | return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure) | |
439 | } | |
440 | ||
441 | bytesToDiscard := 0 | |
442 | switch buf[3] { | |
443 | case proxy_socks5IP4: | |
444 | bytesToDiscard = net.IPv4len | |
445 | case proxy_socks5IP6: | |
446 | bytesToDiscard = net.IPv6len | |
447 | case proxy_socks5Domain: | |
448 | _, err := io.ReadFull(conn, buf[:1]) | |
449 | if err != nil { | |
450 | return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
451 | } | |
452 | bytesToDiscard = int(buf[0]) | |
453 | default: | |
454 | return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr) | |
455 | } | |
456 | ||
457 | if cap(buf) < bytesToDiscard { | |
458 | buf = make([]byte, bytesToDiscard) | |
459 | } else { | |
460 | buf = buf[:bytesToDiscard] | |
461 | } | |
462 | if _, err := io.ReadFull(conn, buf); err != nil { | |
463 | return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
464 | } | |
465 | ||
466 | // Also need to discard the port number | |
467 | if _, err := io.ReadFull(conn, buf[:2]); err != nil { | |
468 | return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error()) | |
469 | } | |
470 | ||
471 | return nil | |
472 | } |