Skip to content

Commit

Permalink
refactor(server): simplify HTTP router setup
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Aug 30, 2024
1 parent c44bac4 commit 6afa03c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 47 deletions.
22 changes: 4 additions & 18 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
httpsMux *chi.Mux
cert tls.Certificate
}

Expand Down Expand Up @@ -117,19 +116,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}

httpRouter := createHTTPRouter(cfg)
httpsRouter := createHTTPSRouter(cfg)

httpListeners, httpsListeners, err := createHTTPListeners(cfg)
if err != nil {
return nil, err
}

if len(httpListeners) != 0 || len(httpsListeners) != 0 {
metrics.Start(httpRouter, cfg.Prometheus)
metrics.Start(httpsRouter, cfg.Prometheus)
}

metrics.RegisterEventListeners()

bootstrap, err := resolver.NewBootstrap(ctx, cfg)
Expand All @@ -156,25 +147,20 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
httpMux: httpRouter,
httpsMux: httpsRouter,
cert: cert,
}

server.printConfiguration()

server.registerDNSHandlers(ctx)
err = server.registerAPIEndpoints(httpRouter)

openAPIImpl, err := server.createOpenAPIInterfaceImpl()
if err != nil {
return nil, err
}

err = server.registerAPIEndpoints(httpsRouter)

if err != nil {
return nil, err
}
server.httpMux = createHTTPRouter(cfg, openAPIImpl)
server.registerDoHEndpoints(server.httpMux)

return server, err
}
Expand Down Expand Up @@ -518,7 +504,7 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
logger().Infof("https server is up and running on addr/port %s", address)

server := http.Server{
Handler: s.httpsMux,
Handler: s.httpMux,
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
Expand Down
45 changes: 16 additions & 29 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/http"
"time"

"github.com/0xERR0R/blocky/metrics"
"github.com/0xERR0R/blocky/resolver"

"github.com/0xERR0R/blocky/api"
Expand Down Expand Up @@ -37,10 +38,13 @@ const (

func secureHeader(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
if r.TLS != nil {
w.Header().Set("strict-transport-security", "max-age=63072000")
w.Header().Set("x-frame-options", "DENY")
w.Header().Set("x-content-type-options", "nosniff")
w.Header().Set("x-xss-protection", "1; mode=block")
}

next.ServeHTTP(w, r)
})
}
Expand All @@ -64,24 +68,15 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}

func (s *Server) registerAPIEndpoints(router *chi.Mux) error {
func (s *Server) registerDoHEndpoints(router *chi.Mux) {
const pathDohQuery = "/dns-query"

openAPIImpl, err := s.createOpenAPIInterfaceImpl()
if err != nil {
return err
}

api.RegisterOpenAPIEndpoints(router, openAPIImpl)

router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(pathDohQuery, s.dohPostRequestHandler)
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)

return nil
}

func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -177,34 +172,26 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}

func createHTTPSRouter(cfg *config.Config) *chi.Mux {
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux {
router := chi.NewRouter()

configureSecureHeaderHandler(router)

registerHandlers(cfg, router)

return router
}

func createHTTPRouter(cfg *config.Config) *chi.Mux {
router := chi.NewRouter()

registerHandlers(cfg, router)

return router
}

func registerHandlers(cfg *config.Config, router *chi.Mux) {
configureCorsHandler(router)

api.RegisterOpenAPIEndpoints(router, openAPIImpl)

configureDebugHandler(router)

configureDocsHandler(router)

configureStaticAssetsHandler(router)

configureRootHandler(cfg, router)

metrics.Start(router, cfg.Prometheus)

return router
}

func configureDocsHandler(router *chi.Mux) {
Expand Down

0 comments on commit 6afa03c

Please sign in to comment.