From 61bf5f271bae03cadc03de46ed2cb126d4ac0052 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Tue, 29 Oct 2024 22:00:44 +0300 Subject: [PATCH] chore: clean dns code Split from #9596 (without IPv6 stuff). This PR does this things: - Refactored `DNSResolveCacheController`. Most of the logic moved to `dns` package types. Simplify and streamline logic. - Replace most of the goroutine orchestration with suture package. - Support per-item reaction to the dns listeners/servers failing to start. This allows us to ignore IPv6 errors if it's disabled. - Support per-item reaction to the dns listeners/servers failing to stop. Signed-off-by: Dmitriy Matrenichev --- go.mod | 1 + go.sum | 2 + .../controllers/network/dns_resolve_cache.go | 350 ++++++------------ .../pkg/controllers/network/etcfile.go | 11 +- .../pkg/controllers/network/hostdns_config.go | 84 ++--- .../app/machined/pkg/xcontext/xcontext.go | 28 ++ internal/pkg/dns/dns.go | 102 +---- internal/pkg/dns/dns_test.go | 157 +++++--- internal/pkg/dns/manager.go | 318 ++++++++++++++++ internal/pkg/dns/runnner.go | 108 ++++++ 10 files changed, 728 insertions(+), 433 deletions(-) create mode 100644 internal/app/machined/pkg/xcontext/xcontext.go create mode 100644 internal/pkg/dns/manager.go create mode 100644 internal/pkg/dns/runnner.go diff --git a/go.mod b/go.mod index c1698563abb..9abc89e24b5 100644 --- a/go.mod +++ b/go.mod @@ -171,6 +171,7 @@ require ( github.com/spf13/cobra v1.8.1 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.9.0 + github.com/thejerf/suture/v4 v4.0.5 github.com/u-root/u-root v0.14.0 github.com/ulikunitz/xz v0.5.12 github.com/vmware/vmw-guestinfo v0.0.0-20220317130741-510905f0efa3 diff --git a/go.sum b/go.sum index 8be3bebe5c5..0a140339c48 100644 --- a/go.sum +++ b/go.sum @@ -732,6 +732,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635 h1:kdXcSzyDtseVEc4yCz2qF8ZrQvIDBJLl4S1c3GCXmoI= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= +github.com/thejerf/suture/v4 v4.0.5 h1:F1E/4FZwXWqvlWDKEUo6/ndLtxGAUzMmNqkrMknZbAA= +github.com/thejerf/suture/v4 v4.0.5/go.mod h1:gu9Y4dXNUWFrByqRt30Rm9/UZ0wzRSt9AJS6xu/ZGxU= github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1WMluqE= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= diff --git a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go index 98210340ba7..3496673d12a 100644 --- a/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go +++ b/internal/app/machined/pkg/controllers/network/dns_resolve_cache.go @@ -5,24 +5,20 @@ package network import ( + "cmp" "context" - "errors" "fmt" - "net" + "iter" "net/netip" - "slices" - "strings" "sync" - "time" "github.com/coredns/coredns/plugin/pkg/proxy" "github.com/cosi-project/runtime/pkg/controller" "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" - dnssrv "github.com/miekg/dns" "github.com/siderolabs/gen/optional" - "github.com/siderolabs/gen/pair" "github.com/siderolabs/gen/xiter" + "github.com/thejerf/suture/v4" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -36,13 +32,9 @@ type DNSResolveCacheController struct { State state.State Logger *zap.Logger - mx sync.Mutex - handler *dns.Handler - nodeHandler *dns.NodeHandler - rootHandler dnssrv.Handler - runners map[runnerConfig]pair.Pair[func(), <-chan struct{}] - reconcile chan struct{} - originalCtx context.Context //nolint:containedctx + mx sync.Mutex + manager *dns.Manager + reconcile chan struct{} } // Name implements controller.Controller interface. @@ -74,15 +66,21 @@ func (ctrl *DNSResolveCacheController) Outputs() []controller.Output { } // Run implements controller.Controller interface. -// -//nolint:gocyclo,cyclop -func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, logger *zap.Logger) error { +func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Runtime, _ *zap.Logger) error { ctrl.init(ctx) ctrl.mx.Lock() defer ctrl.mx.Unlock() - defer ctrl.stopRunners(ctx, false) + defer func() { + if err := ctrl.manager.ClearAll(ctx.Err() == nil); err != nil { + ctrl.Logger.Error("error stopping dns runners", zap.Error(err)) + } + + if ctx.Err() != nil { + ctrl.Logger.Info("manager finished", zap.Error(<-ctrl.manager.Done())) + } + }() for { select { @@ -90,264 +88,128 @@ func (ctrl *DNSResolveCacheController) Run(ctx context.Context, r controller.Run return nil case <-r.EventCh(): case <-ctrl.reconcile: - for cfg, stop := range ctrl.runners { - select { - default: - continue - case <-stop.F2: - } - - stop.F1() - delete(ctrl.runners, cfg) - } } - cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID) - if err != nil { - if state.IsNotFoundError(err) { - continue - } - - return fmt.Errorf("error getting host dns config: %w", err) - } - - r.StartTrackingOutputs() - - if !cfg.TypedSpec().Enabled { - ctrl.stopRunners(ctx, true) - - if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { - return fmt.Errorf("error cleaning up dns status on disable: %w", err) - } - - continue - } - - ctrl.nodeHandler.SetEnabled(cfg.TypedSpec().ResolveMemberNames) - - touchedRunners := make(map[runnerConfig]struct{}, len(ctrl.runners)) - - for _, addr := range cfg.TypedSpec().ListenAddresses { - for _, netwk := range []string{"udp", "tcp"} { - runnerCfg := runnerConfig{net: netwk, addr: addr} - - if _, ok := ctrl.runners[runnerCfg]; !ok { - runner, rErr := newDNSRunner(runnerCfg, ctrl.rootHandler, ctrl.Logger, cfg.TypedSpec().ServiceHostDNSAddress.IsValid()) - if rErr != nil { - return fmt.Errorf("error creating dns runner: %w", rErr) - } - - ctrl.runners[runnerCfg] = pair.MakePair(runner.Start(ctrl.handleDone(ctx, logger))) - } - - if err = ctrl.writeDNSStatus(ctx, r, runnerCfg); err != nil { - return fmt.Errorf("error writing dns status: %w", err) - } - - touchedRunners[runnerCfg] = struct{}{} - } - } - - for runnerCfg, stop := range ctrl.runners { - if _, ok := touchedRunners[runnerCfg]; !ok { - stop.F1() - delete(ctrl.runners, runnerCfg) - - continue - } - } - - upstreams, err := safe.ReaderListAll[*network.DNSUpstream](ctx, r) - if err != nil { - return fmt.Errorf("error getting resolver status: %w", err) - } - - prxs := xiter.Map( - // We are using iterator here to preserve finalizer on - func(upstream *network.DNSUpstream) *proxy.Proxy { - return upstream.TypedSpec().Value.Conn.Proxy().(*proxy.Proxy) - }, - upstreams.All(), - ) - - if ctrl.handler.SetProxy(prxs) { - ctrl.Logger.Info("updated dns server nameservers", zap.Array("addrs", addrsArr(upstreams))) - } - - if err = safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { - return fmt.Errorf("error cleaning up dns status: %w", err) + if err := ctrl.run(ctx, r); err != nil { + return err } } } -func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r controller.Runtime, config runnerConfig) error { - return safe.WriterModify(ctx, r, network.NewDNSResolveCache(fmt.Sprintf("%s-%s", config.net, config.addr)), func(drc *network.DNSResolveCache) error { - drc.TypedSpec().Status = "running" - - return nil - }) -} +//nolint:gocyclo +func (ctrl *DNSResolveCacheController) run(ctx context.Context, r controller.Runtime) (resErr error) { + r.StartTrackingOutputs() + defer cleanupOutputs(ctx, r, &resErr) -func (ctrl *DNSResolveCacheController) init(ctx context.Context) { - if ctrl.runners != nil { - if ctrl.originalCtx != ctx { - // This should not happen, but if it does, it's a bug. - panic("DNSResolveCacheController is called with a different context") - } + cfg, err := safe.ReaderGetByID[*network.HostDNSConfig](ctx, r, network.HostDNSConfigID) - return + switch { + case state.IsNotFoundError(err): + return nil + case err != nil: + return fmt.Errorf("error getting host dns config: %w", err) } - ctrl.originalCtx = ctx - ctrl.handler = dns.NewHandler(ctrl.Logger) - ctrl.nodeHandler = dns.NewNodeHandler(ctrl.handler, &stateMapper{state: ctrl.State}, ctrl.Logger) - ctrl.rootHandler = dns.NewCache(ctrl.nodeHandler, ctrl.Logger) - ctrl.runners = map[runnerConfig]pair.Pair[func(), <-chan struct{}]{} - ctrl.reconcile = make(chan struct{}, 1) - - // Ensure we stop all runners when the context is canceled, no matter where we are currently. - // For example if we are in Controller runtime sleeping after error and ctx is canceled, we should stop all runners - // but, we will never call Run method again, so we need to ensure this happens regardless of the current state. - context.AfterFunc(ctx, func() { - ctrl.mx.Lock() - defer ctrl.mx.Unlock() - - ctrl.stopRunners(ctx, true) - }) -} - -func (ctrl *DNSResolveCacheController) stopRunners(ctx context.Context, ignoreCtx bool) { - if !ignoreCtx && ctx.Err() == nil { - // context not yet canceled, preserve runners, cache and handler - return - } + ctrl.manager.AllowNodeResolving(cfg.TypedSpec().ResolveMemberNames) - for _, stop := range ctrl.runners { - stop.F1() + if !cfg.TypedSpec().Enabled { + return ctrl.manager.ClearAll(false) } - clear(ctrl.runners) - - ctrl.handler.Stop() -} + pairs := allAddressPairs(cfg.TypedSpec().ListenAddresses) + forwardKubeDNSToHost := cfg.TypedSpec().ServiceHostDNSAddress.IsValid() -func (ctrl *DNSResolveCacheController) handleDone(ctx context.Context, logger *zap.Logger) func(err error) { - return func(err error) { - if ctx.Err() != nil { - if err != nil && !errors.Is(err, net.ErrClosed) { - logger.Error("controller is closing, but error running dns server", zap.Error(err)) - } - - return - } - - if err != nil { - logger.Error("error running dns server", zap.Error(err)) + for runCfg, runErr := range ctrl.manager.RunAll(pairs, forwardKubeDNSToHost) { + switch { + case runErr != nil && (runCfg.Network == "tcp6" || runCfg.Network == "udp6"): + // Ignore ipv6 errors + ctrl.Logger.Warn("ignoring ipv6 dns runner error", zap.Error(runErr)) + case runErr != nil: + return fmt.Errorf("error updating dns runner '%v': %w", runCfg, runErr) + case runCfg.Status == dns.StatusRemoved: + // Removed runned, no reason to update status + continue } - select { - case ctrl.reconcile <- struct{}{}: - default: + if err = ctrl.writeDNSStatus(ctx, r, runCfg.AddressPair); err != nil { + return fmt.Errorf("error writing dns status: %w", err) } } -} - -type runnerConfig struct { - net string - addr netip.AddrPort -} -func newDNSRunner(cfg runnerConfig, rootHandler dnssrv.Handler, logger *zap.Logger, forwardEnabled bool) (*dns.Server, error) { - if cfg.addr.Addr().Is6() { - cfg.net += "6" - } - - logger = logger.With(zap.String("net", cfg.net), zap.Stringer("addr", cfg.addr)) - - var serverOpts dns.ServerOptions - - controlFn, ctrlErr := dns.MakeControl(cfg.net, forwardEnabled) - if ctrlErr != nil { - return nil, fmt.Errorf("error creating %q control function: %w", cfg.net, ctrlErr) + upstreams, err := safe.ReaderListAll[*network.DNSUpstream](ctx, r) + if err != nil { + return fmt.Errorf("error getting resolver status: %w", err) } - switch cfg.net { - case "udp", "udp6": - packetConn, err := dns.NewUDPPacketConn(cfg.net, cfg.addr.String(), controlFn) - if err != nil { - return nil, fmt.Errorf("error creating %q packet conn: %w", cfg.net, err) - } - - serverOpts = dns.ServerOptions{ - PacketConn: packetConn, - Handler: rootHandler, - Logger: logger, - } - - case "tcp", "tcp6": - listener, err := dns.NewTCPListener(cfg.net, cfg.addr.String(), controlFn) - if err != nil { - return nil, fmt.Errorf("error creating %q listener: %w", cfg.net, err) - } + prxs := xiter.Map( + // We are using iterator here to preserve finalizer on + func(upstream *network.DNSUpstream) *proxy.Proxy { + return upstream.TypedSpec().Value.Conn.Proxy().(*proxy.Proxy) + }, + upstreams.All(), + ) - serverOpts = dns.ServerOptions{ - Listener: listener, - Handler: rootHandler, - ReadTimeout: 3 * time.Second, - WriteTimeout: 5 * time.Second, - IdleTimeout: func() time.Duration { return 10 * time.Second }, - MaxTCPQueries: -1, - Logger: logger, - } + if ctrl.manager.SetUpstreams(prxs) { + ctrl.Logger.Info("updated dns server nameservers", zap.Array("addrs", addrsArr(upstreams))) } - return dns.NewServer(serverOpts), nil + return nil } -type stateMapper struct { - state state.State +func cleanupOutputs(ctx context.Context, r controller.Runtime, resErr *error) { + if err := safe.CleanupOutputs[*network.DNSResolveCache](ctx, r); err != nil { + *resErr = cmp.Or(*resErr, fmt.Errorf("error cleaning up dns resolve cache: %w", err)) + } } -func (s *stateMapper) ResolveAddr(ctx context.Context, qType uint16, name string) []netip.Addr { - name = strings.TrimRight(name, ".") +func (ctrl *DNSResolveCacheController) writeDNSStatus(ctx context.Context, r controller.Runtime, config dns.AddressPair) error { + res := network.NewDNSResolveCache(fmt.Sprintf("%s-%s", config.Network, config.Addr)) - list, err := safe.ReaderListAll[*cluster.Member](ctx, s.state) - if err != nil { - return nil - } + return safe.WriterModify(ctx, r, res, func(drc *network.DNSResolveCache) error { + drc.TypedSpec().Status = "running" - elem, ok := list.Find(func(res *cluster.Member) bool { - return fqdnMatch(name, res.TypedSpec().Hostname) || fqdnMatch(name, res.Metadata().ID()) - }) - if !ok { return nil - } - - result := slices.DeleteFunc(slices.Clone(elem.TypedSpec().Addresses), func(addr netip.Addr) bool { - return !((qType == dnssrv.TypeA && addr.Is4()) || (qType == dnssrv.TypeAAAA && addr.Is6())) }) +} - if len(result) == 0 { - return nil +func (ctrl *DNSResolveCacheController) init(ctx context.Context) { + if ctrl.manager == nil { + ctrl.manager = dns.NewManager(&memberReader{st: ctrl.State}, ctrl.eventHook, ctrl.Logger) + + // Ensure we stop all runners when the context is canceled, no matter where we are currently. + // For example if we are in Controller runtime sleeping after error and ctx is canceled, we should stop all runners + // but, we will never call Run method again, so we need to ensure this happens regardless of the current state. + context.AfterFunc(ctx, func() { + ctrl.mx.Lock() + defer ctrl.mx.Unlock() + + if err := ctrl.manager.ClearAll(false); err != nil { + ctrl.Logger.Error("error ctx stopping dns runners", zap.Error(err)) + } + }) } - return result + ctrl.manager.ServeBackground(ctx) } -func fqdnMatch(what, where string) bool { - what = strings.TrimRight(what, ".") - where = strings.TrimRight(where, ".") +func (ctrl *DNSResolveCacheController) eventHook(event suture.Event) { + ctrl.Logger.Info("dns-resolve-cache-runners event", zap.String("event", event.String())) - if what == where { - return true + select { + case ctrl.reconcile <- struct{}{}: + default: } +} + +type memberReader struct{ st state.State } - first, _, found := strings.Cut(where, ".") - if !found { - return false +func (m *memberReader) ReadMembers(ctx context.Context) (iter.Seq[*cluster.Member], error) { + list, err := safe.ReaderListAll[*cluster.Member](ctx, m.st) + if err != nil { + return nil, err } - return what == first + return list.All(), nil } type addrsArr safe.List[*network.DNSUpstream] @@ -361,3 +223,23 @@ func (a addrsArr) MarshalLogArray(encoder zapcore.ArrayEncoder) error { return nil } + +func allAddressPairs(addresses []netip.AddrPort) iter.Seq[dns.AddressPair] { + return func(yield func(dns.AddressPair) bool) { + for _, addr := range addresses { + networks := [...]string{"udp", "tcp"} + if addr.Addr().Is6() { + networks = [...]string{"udp6", "tcp6"} + } + + for _, netwk := range networks { + if !yield(dns.AddressPair{ + Network: netwk, + Addr: addr, + }) { + return + } + } + } + } +} diff --git a/internal/app/machined/pkg/controllers/network/etcfile.go b/internal/app/machined/pkg/controllers/network/etcfile.go index 60c5a863b58..0cb0f245f4b 100644 --- a/internal/app/machined/pkg/controllers/network/etcfile.go +++ b/internal/app/machined/pkg/controllers/network/etcfile.go @@ -20,8 +20,8 @@ import ( "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" "github.com/siderolabs/gen/optional" - "github.com/siderolabs/gen/value" "github.com/siderolabs/gen/xiter" + "github.com/siderolabs/gen/xslices" "go.uber.org/zap" efiles "github.com/siderolabs/talos/internal/app/machined/pkg/controllers/files" @@ -158,10 +158,13 @@ func (ctrl *EtcFileController) Run(ctx context.Context, r controller.Runtime, _ } if resolverStatus != nil && hostDNSCfg != nil { - dnsServers := resolverStatus.TypedSpec().DNSServers + dnsServers := xslices.FilterInPlace( + []netip.Addr{hostDNSCfg.TypedSpec().ServiceHostDNSAddress}, + netip.Addr.IsValid, + ) - if !value.IsZero(hostDNSCfg.TypedSpec().ServiceHostDNSAddress) { - dnsServers = []netip.Addr{hostDNSCfg.TypedSpec().ServiceHostDNSAddress} + if len(dnsServers) == 0 { + dnsServers = resolverStatus.TypedSpec().DNSServers } conf := renderResolvConf(slices.All(dnsServers), hostnameStatusSpec, cfgProvider) diff --git a/internal/app/machined/pkg/controllers/network/hostdns_config.go b/internal/app/machined/pkg/controllers/network/hostdns_config.go index 4fad1d70a1d..3e4a45ef2d9 100644 --- a/internal/app/machined/pkg/controllers/network/hostdns_config.go +++ b/internal/app/machined/pkg/controllers/network/hostdns_config.go @@ -8,13 +8,12 @@ import ( "context" "fmt" "net/netip" + "slices" "github.com/cosi-project/runtime/pkg/controller" - "github.com/cosi-project/runtime/pkg/resource" "github.com/cosi-project/runtime/pkg/safe" "github.com/cosi-project/runtime/pkg/state" "github.com/siderolabs/gen/optional" - "github.com/siderolabs/gen/value" "github.com/siderolabs/go-procfs/procfs" "go.uber.org/zap" @@ -74,6 +73,8 @@ func (ctrl *HostDNSConfigController) Run(ctx context.Context, r controller.Runti var cfgProvider talosconfig.Config + r.StartTrackingOutputs() + cfg, err := safe.ReaderGetByID[*config.MachineConfig](ctx, r, config.V1Alpha1ID) if err != nil { if !state.IsNotFoundError(err) { @@ -83,7 +84,7 @@ func (ctrl *HostDNSConfigController) Run(ctx context.Context, r controller.Runti cfgProvider = cfg.Config() } - var newServiceAddr netip.Addr + newServiceAddrs := make([]netip.Addr, 0, 2) if err := safe.WriterModify(ctx, r, network.NewHostDNSConfig(network.HostDNSConfigID), func(res *network.HostDNSConfig) error { res.TypedSpec().ListenAddresses = []netip.AddrPort{ @@ -101,11 +102,19 @@ func (ctrl *HostDNSConfigController) Run(ctx context.Context, r controller.Runti res.TypedSpec().Enabled = cfgProvider.Machine().Features().HostDNS().Enabled() res.TypedSpec().ResolveMemberNames = cfgProvider.Machine().Features().HostDNS().ResolveMemberNames() - if cfgProvider.Machine().Features().HostDNS().ForwardKubeDNSToHost() { - newServiceAddr = netip.MustParseAddr(constants.HostDNSAddress) + if !cfgProvider.Machine().Features().HostDNS().ForwardKubeDNSToHost() { + return nil + } + + if slices.ContainsFunc( + cfgProvider.Cluster().Network().PodCIDRs(), + func(cidr string) bool { return netip.MustParsePrefix(cidr).Addr().Is4() }, + ) { + parsed := netip.MustParseAddr(constants.HostDNSAddress) + newServiceAddrs = append(newServiceAddrs, parsed) - res.TypedSpec().ListenAddresses = append(res.TypedSpec().ListenAddresses, netip.AddrPortFrom(newServiceAddr, 53)) - res.TypedSpec().ServiceHostDNSAddress = newServiceAddr + res.TypedSpec().ListenAddresses = append(res.TypedSpec().ListenAddresses, netip.AddrPortFrom(parsed, 53)) + res.TypedSpec().ServiceHostDNSAddress = parsed } return nil @@ -113,65 +122,25 @@ func (ctrl *HostDNSConfigController) Run(ctx context.Context, r controller.Runti return fmt.Errorf("error writing host dns config: %w", err) } - var touched *network.AddressSpec - - if !value.IsZero(newServiceAddr) { - touched, err = updateSpec(ctx, r, newServiceAddr, logger) + for _, newServiceAddr := range newServiceAddrs { + err := updateSpec(ctx, r, newServiceAddr, logger) if err != nil { return err } } - if err = ctrl.cleanupAddressSpecs( - ctx, - r, - func(id resource.ID) bool { - if touched == nil { - return false - } - - return id == touched.Metadata().ID() - }, - logger, - ); err != nil { + if err = safe.CleanupOutputs[*network.HostDNSConfig](ctx, r); err != nil { return err } - - r.ResetRestartBackoff() } } -func (ctrl *HostDNSConfigController) cleanupAddressSpecs(ctx context.Context, r controller.Runtime, checkResource func(id resource.ID) bool, logger *zap.Logger) error { - list, err := safe.ReaderList[*network.AddressSpec](ctx, r, network.NewAddressSpec(network.ConfigNamespaceName, "").Metadata()) - if err != nil { - return err - } - - for address := range list.All() { - if address.Metadata().Owner() != ctrl.Name() { - continue - } - - if checkResource(address.Metadata().ID()) { - continue - } - - if err = r.Destroy(ctx, address.Metadata()); err != nil && !state.IsNotFoundError(err) { - return err - } - - logger.Info("destroyed address spec", zap.String("address_id", address.Metadata().ID())) - } - - return nil -} - -func updateSpec(ctx context.Context, r controller.Runtime, newServiceAddr netip.Addr, logger *zap.Logger) (*network.AddressSpec, error) { +func updateSpec(ctx context.Context, r controller.Runtime, newServiceAddr netip.Addr, logger *zap.Logger) error { newDNSAddrPrefix := netip.PrefixFrom(newServiceAddr, newServiceAddr.BitLen()) logger.Debug("creating new host dns address spec", zap.String("address", newServiceAddr.String())) - res, err := safe.WriterModifyWithResult( + err := safe.WriterModify( ctx, r, network.NewAddressSpec( @@ -192,14 +161,19 @@ func updateSpec(ctx context.Context, r controller.Runtime, newServiceAddr netip. spec.Flags = nethelpers.AddressFlags(nethelpers.AddressPermanent) spec.LinkName = "lo" - spec.Scope = nethelpers.ScopeHost + + if newServiceAddr.Is6() && newServiceAddr.IsPrivate() { + spec.Scope = nethelpers.ScopeGlobal + } else { + spec.Scope = nethelpers.ScopeHost + } return nil }, ) if err != nil { - return nil, fmt.Errorf("error modifying address: %w", err) + return fmt.Errorf("error modifying address: %w", err) } - return res, nil + return nil } diff --git a/internal/app/machined/pkg/xcontext/xcontext.go b/internal/app/machined/pkg/xcontext/xcontext.go new file mode 100644 index 00000000000..32dca45de85 --- /dev/null +++ b/internal/app/machined/pkg/xcontext/xcontext.go @@ -0,0 +1,28 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Package xcontext provides a small utils for context package +package xcontext + +import "context" + +// AfterFuncSync is like [context.AfterFunc] but it blocks until the function is executed. +func AfterFuncSync(ctx context.Context, fn func()) func() bool { + stopChan := make(chan struct{}) + + stop := context.AfterFunc(ctx, func() { + defer close(stopChan) + + fn() + }) + + return func() bool { + result := stop() + if !result { + <-stopChan + } + + return result + } +} diff --git a/internal/pkg/dns/dns.go b/internal/pkg/dns/dns.go index 6a8be932fcc..5eaa7619f4d 100644 --- a/internal/pkg/dns/dns.go +++ b/internal/pkg/dns/dns.go @@ -9,12 +9,10 @@ import ( "context" "errors" "fmt" - "io" "iter" "net" "net/netip" "slices" - "strings" "sync" "sync/atomic" "syscall" @@ -207,7 +205,7 @@ func NewNodeHandler(next plugin.Handler, hostMapper HostMapper, logger *zap.Logg // HostMapper is a name to node mapper. type HostMapper interface { - ResolveAddr(ctx context.Context, qType uint16, name string) []netip.Addr + ResolveAddr(ctx context.Context, qType uint16, name string) (iter.Seq[netip.Addr], bool) } // NodeHandler try to resolve dns request to a node. If required node is not found, it will move to the next handler. @@ -238,14 +236,19 @@ func (h *NodeHandler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg req := request.Request{W: wrt, Req: msg} // Check if the request is for a node. - result := h.mapper.ResolveAddr(ctx, req.QType(), req.Name()) - if len(result) == 0 { + result, ok := h.mapper.ResolveAddr(ctx, req.QType(), req.Name()) + if !ok { + return h.next.ServeDNS(ctx, wrt, msg) + } + + answers := mapAnswers(result, req.Name()) + if len(answers) == 0 { return h.next.ServeDNS(ctx, wrt, msg) } resp := new(dns.Msg).SetReply(req.Req) resp.Authoritative = true - resp.Answer = mapAnswers(result, req.Name()) + resp.Answer = answers err := wrt.WriteMsg(resp) if err != nil { @@ -256,10 +259,10 @@ func (h *NodeHandler) ServeDNS(ctx context.Context, wrt dns.ResponseWriter, msg return dns.RcodeSuccess, nil } -func mapAnswers(addrs []netip.Addr, name string) []dns.RR { +func mapAnswers(addrs iter.Seq[netip.Addr], name string) []dns.RR { var result []dns.RR - for _, addr := range addrs { + for addr := range addrs { switch { case addr.Is4(): result = append(result, &dns.A{ @@ -295,89 +298,6 @@ func (h *NodeHandler) SetEnabled(enabled bool) { h.enabled.Store(enabled) } -// ServerOptions is a Server options. -type ServerOptions struct { - Listener net.Listener - PacketConn net.PacketConn - Handler dns.Handler - ReadTimeout time.Duration - WriteTimeout time.Duration - IdleTimeout func() time.Duration - MaxTCPQueries int - Logger *zap.Logger -} - -// NewServer creates a new Server. -func NewServer(opts ServerOptions) *Server { - return &Server{ - srv: &dns.Server{ - Listener: opts.Listener, - PacketConn: opts.PacketConn, - Handler: opts.Handler, - UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small. - ReadTimeout: opts.ReadTimeout, - WriteTimeout: opts.WriteTimeout, - IdleTimeout: opts.IdleTimeout, - MaxTCPQueries: opts.MaxTCPQueries, - }, - logger: opts.Logger, - } -} - -// Server is a dns server. -type Server struct { - srv *dns.Server - logger *zap.Logger -} - -// Start starts the dns server. Returns a function to stop the server. -func (s *Server) Start(onDone func(err error)) (stop func(), stopped <-chan struct{}) { - done := make(chan struct{}) - - fn := sync.OnceFunc(func() { - for { - err := s.srv.Shutdown() - if err != nil { - if strings.Contains(err.Error(), "server not started") { - // There a possible scenario where `go func()` not yet reached `ActivateAndServe` and yielded CPU - // time to another goroutine and then this closure reached `Shutdown`. In that case - // `ActivateAndServe` will actually start after `Shutdown` and this closure will block forever - // because `go func()` will never exit and close `done` channel. - continue - } - - s.logger.Error("error shutting down dns server", zap.Error(err)) - } - - break - } - - closer := io.Closer(s.srv.Listener) - if closer == nil { - closer = s.srv.PacketConn - } - - if closer != nil { - err := closer.Close() - if err != nil && !errors.Is(err, net.ErrClosed) { - s.logger.Error("error closing dns server listener", zap.Error(err)) - } else { - s.logger.Debug("dns server listener closed") - } - } - - <-done - }) - - go func() { - defer close(done) - - onDone(s.srv.ActivateAndServe()) - }() - - return fn, done -} - // NewTCPListener creates a new TCP listener. func NewTCPListener(network, addr string, control ControlFn) (net.Listener, error) { network, ok := networkNames[network] diff --git a/internal/pkg/dns/dns_test.go b/internal/pkg/dns/dns_test.go index e15ac7992e2..c579594c091 100644 --- a/internal/pkg/dns/dns_test.go +++ b/internal/pkg/dns/dns_test.go @@ -6,22 +6,26 @@ package dns_test import ( "context" + "iter" "net" "net/netip" + "runtime" "slices" + "strings" "testing" "time" "github.com/coredns/coredns/plugin/pkg/proxy" dnssrv "github.com/miekg/dns" - "github.com/siderolabs/gen/ensure" - "github.com/siderolabs/gen/xiter" + "github.com/siderolabs/gen/maps" "github.com/siderolabs/gen/xslices" "github.com/siderolabs/gen/xtesting/check" "github.com/stretchr/testify/require" + "github.com/thejerf/suture/v4" "go.uber.org/zap/zaptest" "github.com/siderolabs/talos/internal/pkg/dns" + "github.com/siderolabs/talos/pkg/machinery/resources/cluster" ) func TestDNS(t *testing.T) { @@ -76,47 +80,88 @@ func TestDNS(t *testing.T) { }, } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - stop := newServer(t, test.nameservers...) - t.Cleanup(stop) + for _, dnsAddr := range []string{"127.0.0.1:10700"} { + for _, test := range tests { + t.Run(dnsAddr+"/"+test.name, func(t *testing.T) { + stop := newManager(t, test.nameservers...) + t.Cleanup(stop) - time.Sleep(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) - r, err := dnssrv.Exchange(createQuery(test.hostname), "127.0.0.53:10700") - test.errCheck(t, err) + r, err := dnssrv.Exchange(createQuery(test.hostname), dnsAddr) + test.errCheck(t, err) - if r != nil { - require.Equal(t, test.expectedCode, r.Rcode, r) - } + if r != nil { + require.Equal(t, test.expectedCode, r.Rcode, r) + } - t.Logf("r: %s", r) - }) + t.Logf("r: %s", r) + }) + } } } func TestDNSEmptyDestinations(t *testing.T) { - stop := newServer(t) + stop := newManager(t) defer stop() time.Sleep(10 * time.Millisecond) - r, err := dnssrv.Exchange(createQuery("google.com"), "127.0.0.53:10700") + r, err := dnssrv.Exchange(createQuery("google.com"), "127.0.0.1:10700") require.NoError(t, err) require.Equal(t, dnssrv.RcodeServerFailure, r.Rcode, r) - r, err = dnssrv.Exchange(createQuery("google.com"), "127.0.0.53:10700") + r, err = dnssrv.Exchange(createQuery("google.com"), "127.0.0.1:10700") require.NoError(t, err) require.Equal(t, dnssrv.RcodeServerFailure, r.Rcode, r) stop() } -func newServer(t *testing.T, nameservers ...string) func() { - l := zaptest.NewLogger(t) +func TestGC_NOGC(t *testing.T) { + tests := map[string]bool{ + "ClearAll": false, + "No ClearAll": true, + } + + for name, f := range tests { + t.Run(name, func(t *testing.T) { + m := dns.NewManager(&testReader{}, func(e suture.Event) { t.Log("dns-runners event:", e) }, zaptest.NewLogger(t)) + + m.ServeBackground(context.Background()) + m.ServeBackground(context.Background()) + require.Panics(t, func() { m.ServeBackground(context.TODO()) }) + + for _, err := range m.RunAll(slices.Values([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.1:10700")}, + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.1:10701")}, + }), false) { + require.NoError(t, err) + } + + require.NoError(t, m.ClearAll(f)) + + m = nil + + for range 100 { + runtime.GC() + } + }) + } +} + +func newManager(t *testing.T, nameservers ...string) func() { + m := dns.NewManager(&testReader{}, func(e suture.Event) { + t.Log("dns-runners event:", e) + }, zaptest.NewLogger(t)) - handler := dns.NewHandler(l) - t.Cleanup(handler.Stop) + m.AllowNodeResolving(true) + + t.Cleanup(func() { + if err := m.ClearAll(false); err != nil { + t.Logf("error stopping dns runners: %v", err) + } + }) pxs := xslices.Map(nameservers, func(ns string) *proxy.Proxy { p := proxy.NewProxy(ns, net.JoinHostPort(ns, "53"), "dns") @@ -127,30 +172,42 @@ func newServer(t *testing.T, nameservers ...string) func() { return p }) - handler.SetProxy(xiter.Values(slices.All(pxs))) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) - pc, err := dns.NewUDPPacketConn("udp", "127.0.0.53:10700", ensure.Value(dns.MakeControl("udp", false))) - require.NoError(t, err) + m.SetUpstreams(slices.Values(pxs)) - nodeHandler := dns.NewNodeHandler(handler, &testResolver{}, l) + m.ServeBackground(ctx) + m.ServeBackground(ctx) - nodeHandler.SetEnabled(true) + for _, err := range m.RunAll(slices.Values([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.1:10700")}, + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.1:10701")}, + {Network: "tcp", Addr: netip.MustParseAddrPort("127.0.0.1:10700")}, + }), false) { + if err != nil && strings.Contains(err.Error(), "failed to set TCP_FASTOPEN") { + continue + } - srv := dns.NewServer(dns.ServerOptions{ - PacketConn: pc, - Handler: dns.NewCache(nodeHandler, l), - Logger: l, - }) + require.NoError(t, err) + } - stop, _ := srv.Start(func(err error) { - if err != nil { - t.Errorf("error running dns server: %v", err) + for _, err := range m.RunAll(slices.Values([]dns.AddressPair{ + {Network: "udp", Addr: netip.MustParseAddrPort("127.0.0.1:10700")}, + {Network: "tcp", Addr: netip.MustParseAddrPort("127.0.0.1:10700")}, + }), false) { + if err != nil && strings.Contains(err.Error(), "failed to set TCP_FASTOPEN") { + continue } - t.Logf("dns server stopped") - }) + require.NoError(t, err) + } - return stop + return func() { + if err := m.ClearAll(false); err != nil { + t.Logf("error stopping dns runners: %v", err) + } + } } func createQuery(name string) *dnssrv.Msg { @@ -169,19 +226,21 @@ func createQuery(name string) *dnssrv.Msg { } } -type testResolver struct{} +type testReader struct{} -func (*testResolver) ResolveAddr(_ context.Context, qType uint16, name string) []netip.Addr { - if qType != dnssrv.TypeA { - return nil +func (r *testReader) ReadMembers(context.Context) (iter.Seq[*cluster.Member], error) { + namesToAddresses := map[string][]netip.Addr{ + "talos-default-controlplane-1": {netip.MustParseAddr("172.20.0.2")}, + "talos-default-worker-1": {netip.MustParseAddr("172.20.0.3")}, } - switch name { - case "talos-default-controlplane-1.": - return []netip.Addr{netip.MustParseAddr("172.20.0.2")} - case "talos-default-worker-1.": - return []netip.Addr{netip.MustParseAddr("172.20.0.3")} - default: - return nil - } + result := maps.ToSlice(namesToAddresses, func(k string, v []netip.Addr) *cluster.Member { + result := cluster.NewMember(cluster.NamespaceName, k) + + result.TypedSpec().Addresses = v + + return result + }) + + return slices.Values(result), nil } diff --git a/internal/pkg/dns/manager.go b/internal/pkg/dns/manager.go new file mode 100644 index 00000000000..bd8510ccc01 --- /dev/null +++ b/internal/pkg/dns/manager.go @@ -0,0 +1,318 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dns + +import ( + "context" + "errors" + "fmt" + "iter" + "net/netip" + "runtime" + "slices" + "strings" + "time" + + "github.com/coredns/coredns/plugin/pkg/proxy" + "github.com/hashicorp/go-multierror" + dnssrv "github.com/miekg/dns" + "github.com/siderolabs/gen/xiter" + "github.com/thejerf/suture/v4" + "go.uber.org/zap" + + "github.com/siderolabs/talos/pkg/machinery/resources/cluster" +) + +// ErrCreatingRunner is an error that occurs when creating a runner. +var ErrCreatingRunner = errors.New("error creating runner") + +// Manager manages DNS runners. +type Manager struct { + originalCtx context.Context //nolint:containedctx + handler *Handler + nodeHandler *NodeHandler + rootHandler *Cache + s *suture.Supervisor + supervisorCh <-chan error + logger *zap.Logger + runners map[AddressPair]suture.ServiceToken +} + +// NewManager creates a new manager. +func NewManager(mr MemberReader, hook suture.EventHook, logger *zap.Logger) *Manager { + handler := NewHandler(logger) + nodeHandler := NewNodeHandler(handler, &addrResolver{mr: mr}, logger) + rootHandler := NewCache(nodeHandler, logger) + + m := &Manager{ + handler: handler, + nodeHandler: nodeHandler, + rootHandler: rootHandler, + s: suture.New("dns-resolve-cache-runners", suture.Spec{EventHook: hook}), + logger: logger, + runners: map[AddressPair]suture.ServiceToken{}, + } + + // If we lost ref to the manager. Ensure finalizer is called and all upstreams are collected. + runtime.SetFinalizer(m, (*Manager).finalize) + + return m +} + +// ServeBackground starts the manager in the background. It panics if the manager is not initialized or if it's called +// more than once. +func (m *Manager) ServeBackground(ctx context.Context) { + switch { + case m.originalCtx == nil: + m.originalCtx = ctx + case m.originalCtx != ctx: + panic("Manager.ServeBackground is called with a different context") + case m.originalCtx == ctx: + return + } + + m.supervisorCh = m.s.ServeBackground(ctx) +} + +// AddressPair represents a network and address with port. +type AddressPair struct { + Network string + Addr netip.AddrPort +} + +// String returns a string representation of the address pair. +func (a AddressPair) String() string { return "Network: " + a.Network + ", Addr: " + a.Addr.String() } + +// RunAll updates and run the runners managed by the manager. It returns an iterator which yields the address pairs for +// all running and attempted ro run configurations. It's mandatory to range over the iterator to ensure all runners are updated. +func (m *Manager) RunAll(pairs iter.Seq[AddressPair], forwardEnabled bool) iter.Seq2[RunResult, error] { + return func(yield func(RunResult, error) bool) { + preserve := make(map[AddressPair]struct{}, len(m.runners)) + + for cfg := range pairs { + preserve[cfg] = struct{}{} + + if _, ok := m.runners[cfg]; ok { + if !yield(makeResult(cfg, StatusRunning), nil) { + return + } + + continue + } + + opts, err := newDNSRunnerOpts(cfg, m.rootHandler, forwardEnabled) + if err != nil { + err = fmt.Errorf("%w: %w", ErrCreatingRunner, err) + } else { + m.runners[cfg] = m.s.Add(NewRunner(opts, m.logger)) + } + + if !yield(makeResult(cfg, StatusNew), err) { + return + } + } + + for cfg, token := range m.runners { + if _, ok := preserve[cfg]; ok { + continue + } + + err := m.s.RemoveAndWait(token, 0) + if err != nil { + err = fmt.Errorf("error removing runner: %w", err) + } + + if !yield(makeResult(cfg, StatusRemoved), err) { + return + } + + delete(m.runners, cfg) + } + } +} + +func makeResult(cfg AddressPair, s Status) RunResult { return RunResult{AddressPair: cfg, Status: s} } + +// AllowNodeResolving enables or disables the node resolving feature. +func (m *Manager) AllowNodeResolving(enabled bool) { m.nodeHandler.SetEnabled(enabled) } + +// SetUpstreams sets the upstreams for the DNS handler. It returns true if the upstreams were updated, false otherwise. +func (m *Manager) SetUpstreams(prxs iter.Seq[*proxy.Proxy]) bool { return m.handler.SetProxy(prxs) } + +// ClearAll stops and removes all runners. It returns an iterator which yields the address pairs that were removed +// and/or errors that occurred during the removal process. It's mandatory to range over the iterator to ensure all +// runners are stopped. +func (m *Manager) ClearAll(dry bool) error { + if dry { + return nil + } + + var multiErr *multierror.Error + + for _, err := range m.clearAll() { + if err != nil { + multiErr = multierror.Append(multiErr, err) + } + } + + return multiErr.ErrorOrNil() +} + +func (m *Manager) clearAll() iter.Seq2[AddressPair, error] { + return func(yield func(AddressPair, error) bool) { + if len(m.runners) == 0 { + return + } + + defer m.handler.Stop() + + removeAndWait := m.s.RemoveAndWait + if m.originalCtx.Err() != nil { + // ctx canceled, no reason to remove runners from Supervisor since they are already dropped + removeAndWait = func(id suture.ServiceToken, timeout time.Duration) error { return nil } + } + + for runData, token := range m.runners { + err := removeAndWait(token, 0) + if err != nil { + err = fmt.Errorf("error removing runner: %w", err) + } + + if !yield(runData, err) { + return + } + + delete(m.runners, runData) + } + } +} + +func (m *Manager) finalize() { + for data, err := range m.clearAll() { + if err != nil { + m.logger.Error("error stopping dns runner", zap.Error(err)) + } + + m.logger.Info( + "dns runner stopped from finalizer!", + zap.String("address", data.Addr.String()), + zap.String("network", data.Network), + ) + } +} + +// Done reports if superwisor finished execution. +func (m *Manager) Done() <-chan error { + return m.supervisorCh +} + +type addrResolver struct { + mr MemberReader +} + +func (s *addrResolver) ResolveAddr(ctx context.Context, qType uint16, name string) (iter.Seq[netip.Addr], bool) { + name = strings.TrimRight(name, ".") + + items, err := s.mr.ReadMembers(ctx) + if err != nil { + return nil, false + } + + found, ok := xiter.Find(func(res *cluster.Member) bool { + return fqdnMatch(name, res.TypedSpec().Hostname) || fqdnMatch(name, res.Metadata().ID()) + }, items) + if !ok { + return nil, false + } + + return xiter.Filter( + func(addr netip.Addr) bool { + return (qType == dnssrv.TypeA && addr.Is4()) || (qType == dnssrv.TypeAAAA && addr.Is6()) + }, + slices.Values(found.TypedSpec().Addresses), + ), true +} + +func fqdnMatch(what, where string) bool { + what = strings.TrimRight(what, ".") + where = strings.TrimRight(where, ".") + + if what == where { + return true + } + + first, _, found := strings.Cut(where, ".") + if !found { + return false + } + + return what == first +} + +// MemberReader is an interface to read members. +type MemberReader interface { + ReadMembers(ctx context.Context) (iter.Seq[*cluster.Member], error) +} + +func newDNSRunnerOpts(cfg AddressPair, rootHandler dnssrv.Handler, forwardEnabled bool) (RunnerOptions, error) { + if cfg.Addr.Addr().Is6() && !strings.HasSuffix(cfg.Network, "6") { + cfg.Network += "6" + } + + var serverOpts RunnerOptions + + controlFn, ctrlErr := MakeControl(cfg.Network, forwardEnabled) + if ctrlErr != nil { + return serverOpts, fmt.Errorf("error creating %q control function: %w", cfg.Network, ctrlErr) + } + + switch cfg.Network { + case "udp", "udp6": + packetConn, err := NewUDPPacketConn(cfg.Network, cfg.Addr.String(), controlFn) + if err != nil { + return serverOpts, fmt.Errorf("error creating %q packet conn: %w", cfg.Network, err) + } + + serverOpts = RunnerOptions{ + PacketConn: packetConn, + Handler: rootHandler, + } + + case "tcp", "tcp6": + listener, err := NewTCPListener(cfg.Network, cfg.Addr.String(), controlFn) + if err != nil { + return serverOpts, fmt.Errorf("error creating %q listener: %w", cfg.Network, err) + } + + serverOpts = RunnerOptions{ + Listener: listener, + Handler: rootHandler, + ReadTimeout: 3 * time.Second, + WriteTimeout: 5 * time.Second, + IdleTimeout: func() time.Duration { return 10 * time.Second }, + MaxTCPQueries: -1, + } + } + + return serverOpts, nil +} + +// RunResult represents the result of a RunAll iteration. +type RunResult struct { + AddressPair + Status Status +} + +// Status represents the status of a runner. +type Status int + +const ( + // StatusNew represents a new runner. + StatusNew Status = iota + // StatusRunning represents a already running runner. + StatusRunning + // StatusRemoved represents a removed runner. + StatusRemoved +) diff --git a/internal/pkg/dns/runnner.go b/internal/pkg/dns/runnner.go new file mode 100644 index 00000000000..f0725de97d7 --- /dev/null +++ b/internal/pkg/dns/runnner.go @@ -0,0 +1,108 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +package dns + +import ( + "context" + "errors" + "io" + "net" + "strings" + "time" + + "github.com/miekg/dns" + "go.uber.org/zap" + + "github.com/siderolabs/talos/internal/app/machined/pkg/xcontext" +) + +// RunnerOptions is a [Runner] options. +type RunnerOptions struct { + Listener net.Listener + PacketConn net.PacketConn + Handler dns.Handler + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout func() time.Duration + MaxTCPQueries int +} + +// NewRunner creates a new [Runner]. +func NewRunner(opts RunnerOptions, l *zap.Logger) *Runner { + return &Runner{ + srv: &dns.Server{ + Listener: opts.Listener, + PacketConn: opts.PacketConn, + Handler: opts.Handler, + UDPSize: dns.DefaultMsgSize, // 4096 since default is [dns.MinMsgSize] = 512 bytes, which is too small. + ReadTimeout: opts.ReadTimeout, + WriteTimeout: opts.WriteTimeout, + IdleTimeout: opts.IdleTimeout, + MaxTCPQueries: opts.MaxTCPQueries, + }, + logger: l, + } +} + +// Runner is a DNS server runner. +type Runner struct { + srv *dns.Server + logger *zap.Logger +} + +// Serve starts the DNS server. +func (r *Runner) Serve(ctx context.Context) error { + detach := xcontext.AfterFuncSync(ctx, r.close) + defer func() { + if !detach() { + return + } + + r.close() + }() + + return r.srv.ActivateAndServe() +} + +func (r *Runner) close() { + l := r.logger + + if r.srv.Listener != nil { + l = l.With(zap.String("net", "tcp"), zap.String("local_addr", r.srv.Listener.Addr().String())) + } else if r.srv.PacketConn != nil { + l = l.With(zap.String("net", "udp"), zap.String("local_addr", r.srv.PacketConn.LocalAddr().String())) + } + + for { + err := r.srv.Shutdown() + if err != nil { + if strings.Contains(err.Error(), "server not started") { + // There a possible scenario where `go func()` not yet reached `ActivateAndServe` and yielded CPU + // time to another goroutine and then this closure reached `Shutdown`. In that case + // `dns.Server.ActivateAndServe` will actually start after `Shutdown` and this closure will block forever + // because `go func()` will never exit and close `done` channel. + continue + } + + l.Error("error shutting down dns server", zap.Error(err)) + } + + closer := io.Closer(r.srv.Listener) + if closer == nil { + closer = r.srv.PacketConn + } + + if closer != nil { + err = closer.Close() + if err != nil && !errors.Is(err, net.ErrClosed) { + l.Error("error closing dns server listener", zap.Error(err)) + } else { + l.Debug("dns server listener closed") + } + } + + break + } +}