diff --git a/LICENSE b/LICENSE index 33aec14..9a9852b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2013-2015 Tommi Virtanen. +Copyright (c) 2013-2019 Tommi Virtanen. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..ce469f1 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/tv42/httpunix + +go 1.13 diff --git a/httpunix.go b/httpunix.go index 95f5e95..d6584ac 100644 --- a/httpunix.go +++ b/httpunix.go @@ -15,7 +15,7 @@ package httpunix import ( - "bufio" + "context" "errors" "net" "net/http" @@ -29,13 +29,58 @@ // Transport is a http.RoundTripper that connects to Unix domain // sockets. type Transport struct { - DialTimeout time.Duration - RequestTimeout time.Duration + // DialTimeout is deprecated. Use context instead. + DialTimeout time.Duration + // RequestTimeout is deprecated and has no effect. + RequestTimeout time.Duration + // ResponseHeaderTimeout is deprecated. Use context instead. ResponseHeaderTimeout time.Duration + + onceInit sync.Once + transport http.Transport mu sync.Mutex // map a URL "hostname" to a UNIX domain socket path loc map[string]string +} + +func (t *Transport) initTransport() { + t.transport.DialContext = t.dialContext + t.transport.DialTLS = t.dialTLS + t.transport.DisableCompression = true + t.transport.ResponseHeaderTimeout = t.ResponseHeaderTimeout +} + +func (t *Transport) getTransport() *http.Transport { + t.onceInit.Do(t.initTransport) + return &t.transport +} + +func (t *Transport) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if network != "tcp" { + return nil, errors.New("httpunix internals are confused: network=" + network) + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if port != "80" { + return nil, errors.New("httpunix internals are confused: port=" + port) + } + t.mu.Lock() + path, ok := t.loc[host] + t.mu.Unlock() + if !ok { + return nil, errors.New("unknown location: " + host) + } + d := net.Dialer{ + Timeout: t.DialTimeout, + } + return d.DialContext(ctx, "unix", path) +} + +func (t *Transport) dialTLS(network, addr string) (net.Conn, error) { + return nil, errors.New("httpunix: TLS over UNIX domain sockets is not supported") } // RegisterLocation registers an URL location and maps it to the given @@ -69,27 +114,10 @@ if req.URL.Host == "" { return nil, errors.New("http+unix: no Host in request URL") } - t.mu.Lock() - path, ok := t.loc[req.URL.Host] - t.mu.Unlock() - if !ok { - return nil, errors.New("unknown location: " + req.Host) - } - c, err := net.DialTimeout("unix", path, t.DialTimeout) - if err != nil { - return nil, err - } - r := bufio.NewReader(c) - if t.RequestTimeout > 0 { - c.SetWriteDeadline(time.Now().Add(t.RequestTimeout)) - } - if err := req.Write(c); err != nil { - return nil, err - } - if t.ResponseHeaderTimeout > 0 { - c.SetReadDeadline(time.Now().Add(t.ResponseHeaderTimeout)) - } - resp, err := http.ReadResponse(r, req) - return resp, err + tt := t.getTransport() + req = req.Clone(req.Context()) + // get http.Transport to cooperate + req.URL.Scheme = "http" + return tt.RoundTrip(req) }