diff --git a/helpertest/http.go b/helpertest/http.go new file mode 100644 index 000000000..9096e204d --- /dev/null +++ b/helpertest/http.go @@ -0,0 +1,71 @@ +package helpertest + +import ( + "fmt" + "net" + "net/http" + "net/url" + "sync/atomic" + + "github.com/onsi/ginkgo/v2" +) + +type HTTPProxy struct { + Addr net.Addr + requestTarget atomic.Value // string: HTTP Host of latest request +} + +// TestHTTPProxy returns a new HTTPProxy server. +// +// All requests return http.StatusNotImplemented. +func TestHTTPProxy() *HTTPProxy { + proxyListener, err := net.ListenTCP("tcp4", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + if err != nil { + ginkgo.Fail(fmt.Sprintf("could not create HTTP proxy listener: %s", err)) + } + + proxy := &HTTPProxy{ + Addr: proxyListener.Addr(), + } + + proxySrv := http.Server{ //nolint:gosec + Addr: "127.0.0.1:0", + Handler: proxy, + } + + go func() { _ = proxySrv.Serve(proxyListener) }() + ginkgo.DeferCleanup(proxySrv.Close) + + return proxy +} + +// URL returns the proxy's URL for use by clients. +func (p *HTTPProxy) URL() *url.URL { + return &url.URL{ + Scheme: "http", + Host: p.Addr.String(), + } +} + +// Check ReqURL has the right type signature for http.Transport.Proxy +var _ = http.Transport{Proxy: (*HTTPProxy)(nil).ReqURL} + +func (p *HTTPProxy) ReqURL(*http.Request) (*url.URL, error) { + return p.URL(), nil +} + +// RequestTarget returns the target of the last request. +func (p *HTTPProxy) RequestTarget() string { + val := p.requestTarget.Load() + if val == nil { + ginkgo.Fail(fmt.Sprintf("http proxy %s received no requests", p.Addr)) + } + + return val.(string) +} + +func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) { + p.requestTarget.Store(req.Host) + + w.WriteHeader(http.StatusNotImplemented) +} diff --git a/lists/downloader_test.go b/lists/downloader_test.go index 6cb59ed43..7b1e7866a 100644 --- a/lists/downloader_test.go +++ b/lists/downloader_test.go @@ -54,7 +54,7 @@ var _ = Describe("Downloader", func() { Describe("NewDownloader", func() { It("Should use provided parameters", func() { - transport := &http.Transport{} + transport := new(http.Transport) sut = NewDownloader( config.Downloader{ @@ -96,6 +96,7 @@ var _ = Describe("Downloader", func() { server = httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusNotFound) })) + DeferCleanup(server.Close) sutConfig.Attempts = 3 }) @@ -212,5 +213,17 @@ var _ = Describe("Downloader", func() { Expect(loggerHook.LastEntry().Message).Should(ContainSubstring("Name resolution err: ")) }) }) + When("a proxy is configured", func() { + It("should be used", func(ctx context.Context) { + proxy := TestHTTPProxy() + + sut.client.Transport = &http.Transport{Proxy: proxy.ReqURL} + + _, err := sut.DownloadFile(ctx, "http://example.com") + Expect(err).Should(HaveOccurred()) + + Expect(proxy.RequestTarget()).Should(Equal("example.com")) + }) + }) }) }) diff --git a/resolver/bootstrap.go b/resolver/bootstrap.go index ba88e0679..842b5c301 100644 --- a/resolver/bootstrap.go +++ b/resolver/bootstrap.go @@ -156,18 +156,17 @@ func (b *Bootstrap) resolveUpstream(ctx context.Context, r Resolver, host string // NewHTTPTransport returns a new http.Transport that uses b to resolve hostnames func (b *Bootstrap) NewHTTPTransport() *http.Transport { - if b.resolver == nil { - return &http.Transport{ - DialContext: b.dialer.DialContext, - } - } + transport := util.DefaultHTTPTransport() + transport.DialContext = b.dialContext - return &http.Transport{ - DialContext: b.dialContext, - } + return transport } func (b *Bootstrap) dialContext(ctx context.Context, network, addr string) (net.Conn, error) { + if b.resolver == nil { + return b.dialer.DialContext(ctx, network, addr) + } + ctx, logger := b.logWithFields(ctx, logrus.Fields{"network": network, "addr": addr}) host, port, err := net.SplitHostPort(addr) diff --git a/resolver/bootstrap_test.go b/resolver/bootstrap_test.go index 9c2339e54..fe5733cb1 100644 --- a/resolver/bootstrap_test.go +++ b/resolver/bootstrap_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" "sync/atomic" "github.com/0xERR0R/blocky/config" @@ -77,10 +78,15 @@ var _ = Describe("Bootstrap", Label("bootstrap"), func() { }) Describe("HTTP transport", func() { - It("should use the system resolver", func() { + It("should use Go default values", func() { transport := sut.NewHTTPTransport() - Expect(transport).ShouldNot(BeNil()) + + Expect( + reflect.ValueOf(transport.Proxy).Pointer(), + ).Should(Equal( + reflect.ValueOf(http.ProxyFromEnvironment).Pointer(), + )) }) }) diff --git a/resolver/upstream_resolver.go b/resolver/upstream_resolver.go index 51cf8d082..605066964 100644 --- a/resolver/upstream_resolver.go +++ b/resolver/upstream_resolver.go @@ -98,13 +98,13 @@ func createUpstreamClient(cfg upstreamConfig) upstreamClient { switch cfg.Net { case config.NetProtocolHttps: + transport := util.DefaultHTTPTransport() + transport.TLSClientConfig = &tlsConfig + return &httpUpstreamClient{ userAgent: cfg.UserAgent, client: &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tlsConfig, - ForceAttemptHTTP2: true, - }, + Transport: transport, }, host: cfg.Host, } diff --git a/resolver/upstream_resolver_test.go b/resolver/upstream_resolver_test.go index 9856dcda6..f51208340 100644 --- a/resolver/upstream_resolver_test.go +++ b/resolver/upstream_resolver_test.go @@ -2,9 +2,9 @@ package resolver import ( "context" - "crypto/tls" "errors" "fmt" + "net" "net/http" "sync/atomic" "time" @@ -195,7 +195,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { }) }) - Describe("Using Dns over HTTP (DOH) upstream", func() { + Describe("Using DNS over HTTPS (DoH) upstream", func() { var ( respFn func(request *dns.Msg) (response *dns.Msg) modifyHTTPRespFn func(w http.ResponseWriter) @@ -211,18 +211,34 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { } }) + transport := func() *http.Transport { + upstreamClient := sut.upstreamClient.(*httpUpstreamClient) + + return upstreamClient.client.Transport.(*http.Transport) + } + JustBeforeEach(func() { sutConfig.Upstream = newTestDOHUpstream(respFn, modifyHTTPRespFn) sut = newUpstreamResolverUnchecked(sutConfig, nil) - // use insecure certificates for test doh upstream - sut.upstreamClient.(*httpUpstreamClient).client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - } + // use insecure certificates for test DoH upstream + transport().TLSClientConfig.InsecureSkipVerify = true + }) + + When("a proxy is configured", func() { + It("should use it", func() { + proxy := TestHTTPProxy() + + transport().Proxy = proxy.ReqURL + + _, err := sut.Resolve(ctx, newRequest("example.com.", A)) + Expect(err).Should(HaveOccurred()) + + upstreamHostPort := net.JoinHostPort(sutConfig.Upstream.Host, fmt.Sprint(sutConfig.Port)) + Expect(proxy.RequestTarget()).Should(Equal(upstreamHostPort)) + }) }) - When("Configured DOH resolver can resolve query", func() { + When("Configured DoH resolver can resolve query", func() { It("should return answer from DNS upstream", func() { Expect(sut.Resolve(ctx, newRequest("example.com.", A))). Should( @@ -235,7 +251,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { )) }) }) - When("Configured DOH resolver returns wrong http status code", func() { + When("Configured DoH resolver returns wrong http status code", func() { BeforeEach(func() { modifyHTTPRespFn = func(w http.ResponseWriter) { w.WriteHeader(http.StatusInternalServerError) @@ -247,7 +263,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { Expect(err.Error()).Should(ContainSubstring("http return code should be 200, but received 500")) }) }) - When("Configured DOH resolver returns wrong content type", func() { + When("Configured DoH resolver returns wrong content type", func() { BeforeEach(func() { modifyHTTPRespFn = func(w http.ResponseWriter) { w.Header().Set("content-type", "text") @@ -260,7 +276,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { ContainSubstring("http return content type should be 'application/dns-message', but was 'text'")) }) }) - When("Configured DOH resolver returns wrong content", func() { + When("Configured DoH resolver returns wrong content", func() { BeforeEach(func() { modifyHTTPRespFn = func(w http.ResponseWriter) { _, _ = w.Write([]byte("wrongcontent")) @@ -272,7 +288,7 @@ var _ = Describe("UpstreamResolver", Label("upstreamResolver"), func() { Expect(err.Error()).Should(ContainSubstring("can't unpack message")) }) }) - When("Configured DOH resolver does not respond", func() { + When("Configured DoH resolver does not respond", func() { JustBeforeEach(func() { sutConfig.Upstream = config.Upstream{ Net: config.NetProtocolHttps, diff --git a/util/http.go b/util/http.go index 736177239..989e271e0 100644 --- a/util/http.go +++ b/util/http.go @@ -1,10 +1,56 @@ package util import ( + "fmt" "net" "net/http" ) +//nolint:gochecknoglobals +var baseTransport *http.Transport + +//nolint:gochecknoinits +func init() { + base, ok := http.DefaultTransport.(*http.Transport) + if !ok { + panic(fmt.Errorf( + "unsupported Go version: http.DefaultTransport is not of type *http.Transport: it is a %T", + http.DefaultTransport, + )) + } + + baseTransport = base +} + +// DefaultHTTPTransport returns a new Transport with the same defaults as net/http. +func DefaultHTTPTransport() *http.Transport { + return &http.Transport{ + Dial: baseTransport.Dial, //nolint:staticcheck + DialContext: baseTransport.DialContext, + DialTLS: baseTransport.DialTLS, //nolint:staticcheck + DialTLSContext: baseTransport.DialTLSContext, + DisableCompression: baseTransport.DisableCompression, + DisableKeepAlives: baseTransport.DisableKeepAlives, + ExpectContinueTimeout: baseTransport.ExpectContinueTimeout, + ForceAttemptHTTP2: baseTransport.ForceAttemptHTTP2, + GetProxyConnectHeader: baseTransport.GetProxyConnectHeader, + IdleConnTimeout: baseTransport.IdleConnTimeout, + MaxConnsPerHost: baseTransport.MaxConnsPerHost, + MaxIdleConns: baseTransport.MaxIdleConns, + MaxIdleConnsPerHost: baseTransport.MaxConnsPerHost, + MaxResponseHeaderBytes: baseTransport.MaxResponseHeaderBytes, + OnProxyConnectResponse: baseTransport.OnProxyConnectResponse, + Proxy: baseTransport.Proxy, + ProxyConnectHeader: baseTransport.ProxyConnectHeader, + ReadBufferSize: baseTransport.ReadBufferSize, + ResponseHeaderTimeout: baseTransport.ResponseHeaderTimeout, + TLSClientConfig: baseTransport.TLSClientConfig, + TLSHandshakeTimeout: baseTransport.TLSHandshakeTimeout, + TLSNextProto: baseTransport.TLSNextProto, + WriteBufferSize: baseTransport.WriteBufferSize, + } +} + func HTTPClientIP(r *http.Request) net.IP { addr := r.Header.Get("X-FORWARDED-FOR") if addr == "" { diff --git a/util/http_test.go b/util/http_test.go index b1bf60842..52a7197ab 100644 --- a/util/http_test.go +++ b/util/http_test.go @@ -1,14 +1,39 @@ package util import ( + "context" "net" "net/http" + "net/url" + "reflect" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("HTTP Util", func() { + Describe("DefaultHTTPTransport", func() { + It("returns a new transport", func() { + a := DefaultHTTPTransport() + Expect(a).Should(BeIdenticalTo(a)) + + b := DefaultHTTPTransport() + Expect(a).ShouldNot(BeIdenticalTo(b)) + }) + + It("returns a copy of http.DefaultTransport", func() { + Expect(cmp.Diff( + DefaultHTTPTransport(), http.DefaultTransport, + cmpopts.IgnoreUnexported(http.Transport{}), + // Non nil func field comparers + cmp.Comparer(cmpAsPtrs[func(context.Context, string, string) (net.Conn, error)]), + cmp.Comparer(cmpAsPtrs[func(*http.Request) (*url.URL, error)]), + )).Should(BeEmpty()) + }) + }) + Describe("HTTPClientIP", func() { It("extracts the IP from RemoteAddr", func() { r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) @@ -43,3 +68,10 @@ var _ = Describe("HTTP Util", func() { }) }) }) + +// Go and cmp don't define func comparisons, besides with nil. +// In practice we can just compare them as pointers. +// See https://github.com/google/go-cmp/issues/162 +func cmpAsPtrs[T any](x, y T) bool { + return reflect.ValueOf(x).Pointer() == reflect.ValueOf(y).Pointer() +}