14 | 14 |
package httpunix
|
15 | 15 |
|
16 | 16 |
import (
|
17 | |
"bufio"
|
|
17 |
"context"
|
18 | 18 |
"errors"
|
19 | 19 |
"net"
|
20 | 20 |
"net/http"
|
|
28 | 28 |
// Transport is a http.RoundTripper that connects to Unix domain
|
29 | 29 |
// sockets.
|
30 | 30 |
type Transport struct {
|
31 | |
DialTimeout time.Duration
|
32 | |
RequestTimeout time.Duration
|
|
31 |
// DialTimeout is deprecated. Use context instead.
|
|
32 |
DialTimeout time.Duration
|
|
33 |
// RequestTimeout is deprecated and has no effect.
|
|
34 |
RequestTimeout time.Duration
|
|
35 |
// ResponseHeaderTimeout is deprecated. Use context instead.
|
33 | 36 |
ResponseHeaderTimeout time.Duration
|
|
37 |
|
|
38 |
onceInit sync.Once
|
|
39 |
transport http.Transport
|
34 | 40 |
|
35 | 41 |
mu sync.Mutex
|
36 | 42 |
// map a URL "hostname" to a UNIX domain socket path
|
37 | 43 |
loc map[string]string
|
|
44 |
}
|
|
45 |
|
|
46 |
func (t *Transport) initTransport() {
|
|
47 |
t.transport.DialContext = t.dialContext
|
|
48 |
t.transport.DialTLS = t.dialTLS
|
|
49 |
t.transport.DisableCompression = true
|
|
50 |
t.transport.ResponseHeaderTimeout = t.ResponseHeaderTimeout
|
|
51 |
}
|
|
52 |
|
|
53 |
func (t *Transport) getTransport() *http.Transport {
|
|
54 |
t.onceInit.Do(t.initTransport)
|
|
55 |
return &t.transport
|
|
56 |
}
|
|
57 |
|
|
58 |
func (t *Transport) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
59 |
if network != "tcp" {
|
|
60 |
return nil, errors.New("httpunix internals are confused: network=" + network)
|
|
61 |
}
|
|
62 |
host, port, err := net.SplitHostPort(addr)
|
|
63 |
if err != nil {
|
|
64 |
return nil, err
|
|
65 |
}
|
|
66 |
if port != "80" {
|
|
67 |
return nil, errors.New("httpunix internals are confused: port=" + port)
|
|
68 |
}
|
|
69 |
t.mu.Lock()
|
|
70 |
path, ok := t.loc[host]
|
|
71 |
t.mu.Unlock()
|
|
72 |
if !ok {
|
|
73 |
return nil, errors.New("unknown location: " + host)
|
|
74 |
}
|
|
75 |
d := net.Dialer{
|
|
76 |
Timeout: t.DialTimeout,
|
|
77 |
}
|
|
78 |
return d.DialContext(ctx, "unix", path)
|
|
79 |
}
|
|
80 |
|
|
81 |
func (t *Transport) dialTLS(network, addr string) (net.Conn, error) {
|
|
82 |
return nil, errors.New("httpunix: TLS over UNIX domain sockets is not supported")
|
38 | 83 |
}
|
39 | 84 |
|
40 | 85 |
// RegisterLocation registers an URL location and maps it to the given
|
|
68 | 113 |
if req.URL.Host == "" {
|
69 | 114 |
return nil, errors.New("http+unix: no Host in request URL")
|
70 | 115 |
}
|
71 | |
t.mu.Lock()
|
72 | |
path, ok := t.loc[req.URL.Host]
|
73 | |
t.mu.Unlock()
|
74 | |
if !ok {
|
75 | |
return nil, errors.New("unknown location: " + req.Host)
|
76 | |
}
|
77 | 116 |
|
78 | |
c, err := net.DialTimeout("unix", path, t.DialTimeout)
|
79 | |
if err != nil {
|
80 | |
return nil, err
|
81 | |
}
|
82 | |
r := bufio.NewReader(c)
|
83 | |
if t.RequestTimeout > 0 {
|
84 | |
c.SetWriteDeadline(time.Now().Add(t.RequestTimeout))
|
85 | |
}
|
86 | |
if err := req.Write(c); err != nil {
|
87 | |
return nil, err
|
88 | |
}
|
89 | |
if t.ResponseHeaderTimeout > 0 {
|
90 | |
c.SetReadDeadline(time.Now().Add(t.ResponseHeaderTimeout))
|
91 | |
}
|
92 | |
resp, err := http.ReadResponse(r, req)
|
93 | |
return resp, err
|
|
117 |
tt := t.getTransport()
|
|
118 |
req = req.Clone(req.Context())
|
|
119 |
// get http.Transport to cooperate
|
|
120 |
req.URL.Scheme = "http"
|
|
121 |
return tt.RoundTrip(req)
|
94 | 122 |
}
|