Skip to content

Commit

Permalink
feat: add dns.propagation-rns option (#2284)
Browse files Browse the repository at this point in the history
  • Loading branch information
ldez authored Sep 26, 2024
1 parent d2898e1 commit c704ba5
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 60 deletions.
60 changes: 48 additions & 12 deletions challenge/dns01/precheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"net"
"strings"
"time"

"github.com/miekg/dns"
)
Expand All @@ -23,23 +24,47 @@ func WrapPreCheck(wrap WrapPreCheckFunc) ChallengeOption {
}
}

// DisableCompletePropagationRequirement obsolete.
// Deprecated: use DisableAuthoritativeNssPropagationRequirement instead.
func DisableCompletePropagationRequirement() ChallengeOption {
return DisableAuthoritativeNssPropagationRequirement()
}

func DisableAuthoritativeNssPropagationRequirement() ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.requireAuthoritativeNssPropagation = false
return nil
}
}

func RecursiveNSsPropagationRequirement() ChallengeOption {
return func(chlg *Challenge) error {
chlg.preCheck.requireCompletePropagation = false
chlg.preCheck.requireRecursiveNssPropagation = true
return nil
}
}

func PropagationWaitOnly(wait time.Duration) ChallengeOption {
return WrapPreCheck(func(domain, fqdn, value string, check PreCheckFunc) (bool, error) {
time.Sleep(wait)
return true, nil
})
}

type preCheck struct {
// checks DNS propagation before notifying ACME that the DNS challenge is ready.
checkFunc WrapPreCheckFunc

// require the TXT record to be propagated to all authoritative name servers
requireCompletePropagation bool
requireAuthoritativeNssPropagation bool

// require the TXT record to be propagated to all recursive name servers
requireRecursiveNssPropagation bool
}

func newPreCheck() preCheck {
return preCheck{
requireCompletePropagation: true,
requireAuthoritativeNssPropagation: true,
}
}

Expand All @@ -53,32 +78,43 @@ func (p preCheck) call(domain, fqdn, value string) (bool, error) {

// checkDNSPropagation checks if the expected TXT record has been propagated to all authoritative nameservers.
func (p preCheck) checkDNSPropagation(fqdn, value string) (bool, error) {
// Initial attempt to resolve at the recursive NS
// Initial attempt to resolve at the recursive NS (require to get CNAME)
r, err := dnsQuery(fqdn, dns.TypeTXT, recursiveNameservers, true)
if err != nil {
return false, err
}

if !p.requireCompletePropagation {
return true, nil
}

if r.Rcode == dns.RcodeSuccess {
fqdn = updateDomainWithCName(r, fqdn)
}

if p.requireRecursiveNssPropagation {
_, err = checkNameserversPropagation(fqdn, value, recursiveNameservers, false)
if err != nil {
return false, err
}
}

if !p.requireAuthoritativeNssPropagation {
return true, nil
}

authoritativeNss, err := lookupNameservers(fqdn)
if err != nil {
return false, err
}

return checkAuthoritativeNss(fqdn, value, authoritativeNss)
return checkNameserversPropagation(fqdn, value, authoritativeNss, true)
}

// checkAuthoritativeNss queries each of the given nameservers for the expected TXT record.
func checkAuthoritativeNss(fqdn, value string, nameservers []string) (bool, error) {
// checkNameserversPropagation queries each of the given nameservers for the expected TXT record.
func checkNameserversPropagation(fqdn, value string, nameservers []string, addPort bool) (bool, error) {
for _, ns := range nameservers {
r, err := dnsQuery(fqdn, dns.TypeTXT, []string{net.JoinHostPort(ns, "53")}, false)
if addPort {
ns = net.JoinHostPort(ns, "53")
}

r, err := dnsQuery(fqdn, dns.TypeTXT, []string{ns}, false)
if err != nil {
return false, err
}
Expand Down
4 changes: 2 additions & 2 deletions challenge/dns01/precheck_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestCheckAuthoritativeNss(t *testing.T) {
t.Parallel()
ClearFqdnCache()

ok, _ := checkAuthoritativeNss(test.fqdn, test.value, test.ns)
ok, _ := checkNameserversPropagation(test.fqdn, test.value, test.ns, true)
assert.Equal(t, test.expected, ok, test.fqdn)
})
}
Expand Down Expand Up @@ -106,7 +106,7 @@ func TestCheckAuthoritativeNssErr(t *testing.T) {
t.Parallel()
ClearFqdnCache()

_, err := checkAuthoritativeNss(test.fqdn, test.value, test.ns)
_, err := checkNameserversPropagation(test.fqdn, test.value, test.ns, true)
require.Error(t, err)
assert.Contains(t, err.Error(), test.error)
})
Expand Down
77 changes: 44 additions & 33 deletions cmd/flags.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"fmt"
"time"

"github.com/go-acme/lego/v4/certificate"
Expand All @@ -11,38 +12,40 @@ import (

// Flag names.
const (
flgDomains = "domains"
flgServer = "server"
flgAcceptTOS = "accept-tos"
flgEmail = "email"
flgCSR = "csr"
flgEAB = "eab"
flgKID = "kid"
flgHMAC = "hmac"
flgKeyType = "key-type"
flgFilename = "filename"
flgPath = "path"
flgHTTP = "http"
flgHTTPPort = "http.port"
flgHTTPProxyHeader = "http.proxy-header"
flgHTTPWebroot = "http.webroot"
flgHTTPMemcachedHost = "http.memcached-host"
flgHTTPS3Bucket = "http.s3-bucket"
flgTLS = "tls"
flgTLSPort = "tls.port"
flgDNS = "dns"
flgDNSDisableCP = "dns.disable-cp"
flgDNSPropagationWait = "dns.propagation-wait"
flgDNSResolvers = "dns.resolvers"
flgHTTPTimeout = "http-timeout"
flgDNSTimeout = "dns-timeout"
flgPEM = "pem"
flgPFX = "pfx"
flgPFXPass = "pfx.pass"
flgPFXFormat = "pfx.format"
flgCertTimeout = "cert.timeout"
flgOverallRequestLimit = "overall-request-limit"
flgUserAgent = "user-agent"
flgDomains = "domains"
flgServer = "server"
flgAcceptTOS = "accept-tos"
flgEmail = "email"
flgCSR = "csr"
flgEAB = "eab"
flgKID = "kid"
flgHMAC = "hmac"
flgKeyType = "key-type"
flgFilename = "filename"
flgPath = "path"
flgHTTP = "http"
flgHTTPPort = "http.port"
flgHTTPProxyHeader = "http.proxy-header"
flgHTTPWebroot = "http.webroot"
flgHTTPMemcachedHost = "http.memcached-host"
flgHTTPS3Bucket = "http.s3-bucket"
flgTLS = "tls"
flgTLSPort = "tls.port"
flgDNS = "dns"
flgDNSDisableCP = "dns.disable-cp"
flgDNSPropagationWait = "dns.propagation-wait"
flgDNSPropagationDisableANS = "dns.propagation-disable-ans"
flgDNSPropagationRNS = "dns.propagation-rns"
flgDNSResolvers = "dns.resolvers"
flgHTTPTimeout = "http-timeout"
flgDNSTimeout = "dns-timeout"
flgPEM = "pem"
flgPFX = "pfx"
flgPFXPass = "pfx.pass"
flgPFXFormat = "pfx.format"
flgCertTimeout = "cert.timeout"
flgOverallRequestLimit = "overall-request-limit"
flgUserAgent = "user-agent"
)

func CreateFlags(defaultPath string) []cli.Flag {
Expand Down Expand Up @@ -147,11 +150,19 @@ func CreateFlags(defaultPath string) []cli.Flag {
},
&cli.BoolFlag{
Name: flgDNSDisableCP,
Usage: fmt.Sprintf("(deprecated) use %s instead.", flgDNSPropagationDisableANS),
},
&cli.BoolFlag{
Name: flgDNSPropagationDisableANS,
Usage: "By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers.",
},
&cli.BoolFlag{
Name: flgDNSPropagationRNS,
Usage: "By setting this flag to true, use all the recursive nameservers to check the propagation of the TXT record.",
},
&cli.DurationFlag{
Name: flgDNSPropagationWait,
Usage: "By setting this flag, disables all the propagation checks and uses a wait duration instead.",
Usage: "By setting this flag, disables all the propagation checks of the TXT record and uses a wait duration instead.",
},
&cli.StringSliceFlag{
Name: flgDNSResolvers,
Expand Down
40 changes: 30 additions & 10 deletions cmd/setup_challenges.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ func setupTLSProvider(ctx *cli.Context) challenge.Provider {
}

func setupDNS(ctx *cli.Context, client *lego.Client) error {
if ctx.IsSet(flgDNSDisableCP) && ctx.Bool(flgDNSDisableCP) && ctx.IsSet(flgDNSPropagationWait) {
return fmt.Errorf("'%s' and '%s' are mutually exclusive", flgDNSDisableCP, flgDNSPropagationWait)
err := checkPropagationExclusiveOptions(ctx)
if err != nil {
return err
}

wait := ctx.Duration(flgDNSPropagationWait)
Expand All @@ -138,19 +139,38 @@ func setupDNS(ctx *cli.Context, client *lego.Client) error {
dns01.CondOption(len(servers) > 0,
dns01.AddRecursiveNameservers(dns01.ParseNameservers(ctx.StringSlice(flgDNSResolvers)))),

dns01.CondOption(ctx.Bool(flgDNSDisableCP),
dns01.DisableCompletePropagationRequirement()),
dns01.CondOption(ctx.Bool(flgDNSDisableCP) || ctx.Bool(flgDNSPropagationDisableANS),
dns01.DisableAuthoritativeNssPropagationRequirement()),

dns01.CondOption(ctx.IsSet(flgDNSPropagationWait), dns01.WrapPreCheck(
func(domain, fqdn, value string, check dns01.PreCheckFunc) (bool, error) {
time.Sleep(wait)
return true, nil
},
)),
dns01.CondOption(ctx.Duration(flgDNSPropagationWait) > 0,
dns01.PropagationWaitOnly(wait)),

dns01.CondOption(ctx.Bool(flgDNSPropagationRNS),
dns01.RecursiveNSsPropagationRequirement()),

dns01.CondOption(ctx.IsSet(flgDNSTimeout),
dns01.AddDNSTimeout(time.Duration(ctx.Int(flgDNSTimeout))*time.Second)),
)

return err
}

func checkPropagationExclusiveOptions(ctx *cli.Context) error {
if ctx.IsSet(flgDNSDisableCP) {
log.Println("The flag '%s' is deprecated use '%s' instead.", flgDNSDisableCP, flgDNSPropagationDisableANS)
}

if (isSetBool(ctx, flgDNSDisableCP) || isSetBool(ctx, flgDNSPropagationDisableANS)) && ctx.IsSet(flgDNSPropagationWait) {
return fmt.Errorf("'%s' and '%s' are mutually exclusive", flgDNSPropagationDisableANS, flgDNSPropagationWait)
}

if isSetBool(ctx, flgDNSPropagationRNS) && ctx.IsSet(flgDNSPropagationWait) {
return fmt.Errorf("'%s' and '%s' are mutually exclusive", flgDNSPropagationRNS, flgDNSPropagationWait)
}

return nil
}

func isSetBool(ctx *cli.Context, name string) bool {
return ctx.IsSet(name) && ctx.Bool(name)
}
6 changes: 4 additions & 2 deletions docs/data/zz_cli_help.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ GLOBAL OPTIONS:
--tls Use the TLS-ALPN-01 challenge to solve challenges. Can be mixed with other types of challenges. (default: false)
--tls.port value Set the port and interface to use for TLS-ALPN-01 based challenges to listen on. Supported: interface:port or :port. (default: ":443")
--dns value Solve a DNS-01 challenge using the specified provider. Can be mixed with other types of challenges. Run 'lego dnshelp' for help on usage.
--dns.disable-cp By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers. (default: false)
--dns.propagation-wait value By setting this flag, disables all the propagation checks and uses a wait duration instead. (default: 0s)
--dns.disable-cp (deprecated) use dns.propagation-disable-ans instead. (default: false)
--dns.propagation-disable-ans By setting this flag to true, disables the need to await propagation of the TXT record to all authoritative name servers. (default: false)
--dns.propagation-rns By setting this flag, use all the recursive nameservers to check the propagation of the TXT record. (default: false)
--dns.propagation-wait value By setting this flag, disables all the propagation checks of the TXT record and uses a wait duration instead. (default: 0s)
--dns.resolvers value [ --dns.resolvers value ] Set the resolvers to use for performing (recursive) CNAME resolving and apex domain determination. For DNS-01 challenge verification, the authoritative DNS server is queried directly. Supported: host:port. The default is to use the system resolvers, or Google's DNS resolvers if the system's cannot be determined.
--http-timeout value Set the HTTP timeout value to a specific value in seconds. (default: 0)
--dns-timeout value Set the DNS timeout value to a specific value in seconds. Used only when performing authoritative name server queries. (default: 10)
Expand Down
2 changes: 1 addition & 1 deletion e2e/dnschallenge/dns_challenges_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func TestChallengeDNS_Client_Obtain(t *testing.T) {

err = client.Challenge.SetDNS01Provider(provider,
dns01.AddRecursiveNameservers([]string{":8053"}),
dns01.DisableCompletePropagationRequirement())
dns01.DisableAuthoritativeNssPropagationRequirement())
require.NoError(t, err)

reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true})
Expand Down

0 comments on commit c704ba5

Please sign in to comment.