Skip to content

Commit

Permalink
fix: use proxy env vars via Go default HTTP Transport values
Browse files Browse the repository at this point in the history
Don't build `http.Transport` instances from scratch, but start from
`http.DefaultTransport` and override what is needed.
  • Loading branch information
ThinkChaos committed Apr 10, 2024
1 parent 5040ed8 commit d5b6ee9
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 28 deletions.
71 changes: 71 additions & 0 deletions helpertest/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package helpertest

import (
"fmt"
"net"
"net/http"
"net/url"
"sync/atomic"

"github.com/onsi/ginkgo/v2"
)

type HTTPProxy struct {
Addr net.Addr
requestTarget atomic.Value // string: HTTP Host of latest request
}

// TestHTTPProxy returns a new HTTPProxy server.
//
// All requests return http.StatusNotImplemented.
func TestHTTPProxy() *HTTPProxy {
proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0})
if err != nil {
ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err))
}

proxy := &HTTPProxy{
Addr: proxyListener.Addr(),
}

proxySrv := http.Server{ //nolint:gosec
Addr: "127.0.0.1:0",
Handler: proxy,
}

go func() { _ = proxySrv.Serve(proxyListener) }()
ginkgo.DeferCleanup(proxySrv.Close)

return proxy
}

// URL returns the proxy's URL for use by clients.
func (p *HTTPProxy) URL() *url.URL {
return &url.URL{
Scheme: "http",
Host: p.Addr.String(),
}
}

// Check ReqURL has the right type signature for http.Transport.Proxy
var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL}

func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) {
return p.URL(), nil
}

// RequestTarget returns the target of the last request.
func (p *HTTPProxy) RequestTarget() string {
val := p.requestTarget.Load()
if val == nil {
ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr))
}

return val.(string)
}

func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
p.requestTarget.Store(req.Host)

w.WriteHeader(http.StatusNotImplemented)
}
15 changes: 14 additions & 1 deletion lists/downloader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Downloader", func() {

Describe("NewDownloader", func() {
It("Should use provided parameters", func() {
transport := &http.Transport{}
transport := new(http.Transport)

sut = NewDownloader(
config.Downloader{
Expand Down Expand Up @@ -96,6 +96,7 @@ var _ = Describe("Downloader", func() {
server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.WriteHeader(http.StatusNotFound)
}))
DeferCleanup(server.Close)

sutConfig.Attempts = 3
})
Expand Down Expand Up @@ -212,5 +213,17 @@ var _ = Describe("Downloader", func() {
Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: "))
})
})
When("a proxy is configured", func() {
It("should be used", func(ctx context.Context) {
proxy := TestHTTPProxy()

sut.client.Transport = &http.Transport{Proxy: proxy.ReqURL}

_, err := sut.DownloadFile(ctx, "http://example.com")
Expect(err).Should(HaveOccurred())

Expect(proxy.RequestTarget()).Should(Equal("example.com"))
})
})
})
})
15 changes: 7 additions & 8 deletions resolver/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,17 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string

// NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames
func (b *Bootstrap) NewHTTPTransport() *http.Transport {
if b.resolver == nil {
return &http.Transport{
DialContext: b.dialer.DialContext,
}
}
transport := util.DefaultHTTPTransport()
transport.DialContext = b.dialContext

return &http.Transport{
DialContext: b.dialContext,
}
return transport
}

func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) {
if b.resolver == nil {
return b.dialer.DialContext(ctx, network, addr)
}

ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr})

host, port, err := net.SplitHostPort(addr)
Expand Down
10 changes: 8 additions & 2 deletions resolver/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"sync/atomic"

"github.com/0xERR0R/blocky/config"
Expand Down Expand Up @@ -77,10 +78,15 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() {
})

Describe("HTTP transport", func() {
It("should use the system resolver", func() {
It("should use Go default values", func() {
transport := sut.NewHTTPTransport()

Expect(transport).ShouldNot(BeNil())

Expect(
reflect.ValueOf(transport.Proxy).Pointer(),
).Should(Equal(
reflect.ValueOf(http.ProxyFromEnvironment).Pointer(),
))
})
})

Expand Down
8 changes: 4 additions & 4 deletions resolver/upstream_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ func createUpstreamClient(cfg upstreamConfig) upstreamClient {

switch cfg.Net {
case config.NetProtocolHttps:
transport := util.DefaultHTTPTransport()
transport.TLSClientConfig = &tlsConfig

return &httpUpstreamClient{
userAgent: cfg.UserAgent,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tlsConfig,
ForceAttemptHTTP2: true,
},
Transport: transport,
},
host: cfg.Host,
}
Expand Down
42 changes: 29 additions & 13 deletions resolver/upstream_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package resolver

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync/atomic"
"time"
Expand Down Expand Up @@ -195,7 +195,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
})
})

Describe("Using Dns over HTTP (DOH) upstream", func() {
Describe("Using DNS over HTTPS (DoH) upstream", func() {
var (
respFn func(request *dns.Msg) (response *dns.Msg)
modifyHTTPRespFn func(w http.ResponseWriter)
Expand All @@ -211,18 +211,34 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
}
})

transport := func() *http.Transport {
upstreamClient := sut.upstreamClient.(*httpUpstreamClient)

return upstreamClient.client.Transport.(*http.Transport)
}

JustBeforeEach(func() {
sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn)
sut = newUpstreamResolverUnchecked(sutConfig, nil)

// use insecure certificates for test doh upstream
sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
// use insecure certificates for test DoH upstream
transport().TLSClientConfig.InsecureSkipVerify = true
})

When("a proxy is configured", func() {
It("should use it", func() {
proxy := TestHTTPProxy()

transport().Proxy = proxy.ReqURL

_, err := sut.Resolve(ctx, newRequest("example.com.", A))
Expect(err).Should(HaveOccurred())

upstreamHostPort := net.JoinHostPort(sutConfig.Upstream.Host, fmt.Sprint(sutConfig.Port))
Expect(proxy.RequestTarget()).Should(Equal(upstreamHostPort))
})
})
When("Configured DOH resolver can resolve query", func() {
When("Configured DoH resolver can resolve query", func() {
It("should return answer from DNS upstream", func() {
Expect(sut.Resolve(ctx, newRequest("example.com.", A))).
Should(
Expand All @@ -235,7 +251,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
))
})
})
When("Configured DOH resolver returns wrong http status code", func() {
When("Configured DoH resolver returns wrong http status code", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
w.WriteHeader(http.StatusInternalServerError)
Expand All @@ -247,7 +263,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500"))
})
})
When("Configured DOH resolver returns wrong content type", func() {
When("Configured DoH resolver returns wrong content type", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
w.Header().Set("content-type", "text")
Expand All @@ -260,7 +276,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
ContainSubstring("http return content type should be 'application/dns-message', but was 'text'"))
})
})
When("Configured DOH resolver returns wrong content", func() {
When("Configured DoH resolver returns wrong content", func() {
BeforeEach(func() {
modifyHTTPRespFn = func(w http.ResponseWriter) {
_, _ = w.Write([]byte("wrongcontent"))
Expand All @@ -272,7 +288,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() {
Expect(err.Error()).Should(ContainSubstring("can't unpack message"))
})
})
When("Configured DOH resolver does not respond", func() {
When("Configured DoH resolver does not respond", func() {
JustBeforeEach(func() {
sutConfig.Upstream = config.Upstream{
Net: config.NetProtocolHttps,
Expand Down
46 changes: 46 additions & 0 deletions util/http.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
package util

import (
"fmt"
"net"
"net/http"
)

//nolint:gochecknoglobals
var baseTransport *http.Transport

//nolint:gochecknoinits
func init() {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
panic(fmt.Errorf(
"unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T",
http.DefaultTransport,
))
}

baseTransport = base
}

// DefaultHTTPTransport returns a new Transport with the same defaults as net/http.
func DefaultHTTPTransport() *http.Transport {
return &http.Transport{
Dial: baseTransport.Dial, //nolint:staticcheck
DialContext: baseTransport.DialContext,
DialTLS: baseTransport.DialTLS, //nolint:staticcheck
DialTLSContext: baseTransport.DialTLSContext,
DisableCompression: baseTransport.DisableCompression,
DisableKeepAlives: baseTransport.DisableKeepAlives,
ExpectContinueTimeout: baseTransport.ExpectContinueTimeout,
ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2,
GetProxyConnectHeader: baseTransport.GetProxyConnectHeader,
IdleConnTimeout: baseTransport.IdleConnTimeout,
MaxConnsPerHost: baseTransport.MaxConnsPerHost,
MaxIdleConns: baseTransport.MaxIdleConns,
MaxIdleConnsPerHost: baseTransport.MaxConnsPerHost,
MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes,
OnProxyConnectResponse: baseTransport.OnProxyConnectResponse,
Proxy: baseTransport.Proxy,
ProxyConnectHeader: baseTransport.ProxyConnectHeader,
ReadBufferSize: baseTransport.ReadBufferSize,
ResponseHeaderTimeout: baseTransport.ResponseHeaderTimeout,
TLSClientConfig: baseTransport.TLSClientConfig,
TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout,
TLSNextProto: baseTransport.TLSNextProto,
WriteBufferSize: baseTransport.WriteBufferSize,
}
}

func HTTPClientIP(r *http.Request) net.IP {
addr := r.Header.Get("X-FORWARDED-FOR")
if addr == "" {
Expand Down
32 changes: 32 additions & 0 deletions util/http_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
package util

import (
"context"
"net"
"net/http"
"net/url"
"reflect"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("HTTP Util", func() {
Describe("DefaultHTTPTransport", func() {
It("returns a new transport", func() {
a := DefaultHTTPTransport()
Expect(a).Should(BeIdenticalTo(a))

b := DefaultHTTPTransport()
Expect(a).ShouldNot(BeIdenticalTo(b))
})

It("returns a copy of http.DefaultTransport", func() {
Expect(cmp.Diff(
DefaultHTTPTransport(), http.DefaultTransport,
cmpopts.IgnoreUnexported(http.Transport{}),
// Non nil func field comparers
cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]),
cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]),
)).Should(BeEmpty())
})
})

Describe("HTTPClientIP", func() {
It("extracts the IP from RemoteAddr", func() {
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
Expand Down Expand Up @@ -43,3 +68,10 @@ var _ = Describe("HTTP Util", func() {
})
})
})

// Go and cmp don't define func comparisons, besides with nil.
// In practice we can just compare them as pointers.
// See https://github.com/google/go-cmp/issues/162
func cmpAsPtrs[T any](x, y T) bool {
return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer()
}

0 comments on commit d5b6ee9

Please sign in to comment.