diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 15cf5d7..4ec429a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -17,7 +17,7 @@ jobs: with: fetch-depth: 0 - name: Set up Go - uses: actions/setup-go@v4 + uses: actions/setup-go@v5 with: go-version: "1.22" - name: Log in to the Container registry diff --git a/cmd/api/handlers.go b/cmd/api/handlers.go index 3529a27..956b470 100644 --- a/cmd/api/handlers.go +++ b/cmd/api/handlers.go @@ -67,13 +67,14 @@ func handleLookup(w http.ResponseWriter, r *http.Request) { } // Load Nameservers. - err = app.LoadNameservers() - if err != nil { + if err := app.LoadNameservers(); err != nil { app.Logger.WithError(err).Error("error loading nameservers") sendErrorResponse(w, fmt.Sprintf("Error looking up for records."), http.StatusInternalServerError, nil) return } + app.Logger.WithField("nameservers", app.Nameservers).Debug("Loaded nameservers") + // Load Resolvers. rslvrs, err := resolvers.LoadResolvers(resolvers.Options{ Nameservers: app.Nameservers, @@ -91,6 +92,8 @@ func handleLookup(w http.ResponseWriter, r *http.Request) { } app.Resolvers = rslvrs + app.Logger.WithField("resolvers", app.Resolvers).Debug("Loaded resolvers") + var responses []resolvers.Response for _, q := range app.Questions { for _, rslv := range app.Resolvers { diff --git a/cmd/doggo/cli.go b/cmd/doggo/cli.go index 59fddeb..5c3f2c3 100644 --- a/cmd/doggo/cli.go +++ b/cmd/doggo/cli.go @@ -16,7 +16,6 @@ import ( ) var ( - // Version and date of the build. This is injected at build-time. buildVersion = "unknown" buildDate = "unknown" logger = utils.InitLogger() @@ -24,105 +23,40 @@ var ( ) func main() { - // Initialize app. app := app.New(logger, buildVersion) + f := setupFlags() - // Configure Flags. - f := flag.NewFlagSet("config", flag.ContinueOnError) - - // Custom Help Text. - f.Usage = renderCustomHelp - - // Query Options. - f.StringSliceP("query", "q", []string{}, "Domain name to query") - f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)") - f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)") - f.StringSliceP("nameserver", "n", []string{}, "Address of the nameserver to send packets to") - f.BoolP("reverse", "x", false, "Performs a DNS Lookup for an IPv4 or IPv6 address. Sets the query type and class to PTR and IN respectively.") - - // Resolver Options - f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.") - f.Bool("search", true, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overridden by `ndots` flag.") - f.Int("ndots", -1, "Specify the ndots parameter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf") - f.BoolP("ipv4", "4", false, "Use IPv4 only") - f.BoolP("ipv6", "6", false, "Use IPv6 only") - f.String("strategy", "all", "Strategy to query nameservers in resolv.conf file (`all`, `random`, `first`)") - f.String("tls-hostname", "", "Provide a hostname for doing verification of the certificate if the provided DoT nameserver is an IP") - f.Bool("skip-hostname-verification", false, "Skip TLS Hostname Verification") - - // Output Options - f.BoolP("json", "J", false, "Set the output format as JSON") - f.Bool("short", false, "Short output format") - f.Bool("time", false, "Display how long it took for the response to arrive") - f.Bool("color", true, "Show colored output") - f.Bool("debug", false, "Enable debug mode") - - f.Bool("version", false, "Show version of doggo") - - // Parse and Load Flags. - err := f.Parse(os.Args[1:]) - if err != nil { - app.Logger.WithError(err).Error("error parsing flags") - app.Logger.Exit(2) - } - if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil { - app.Logger.WithError(err).Error("error loading flags") - f.Usage() + if err := parseAndLoadFlags(f); err != nil { + app.Logger.WithError(err).Error("Error parsing or loading flags") app.Logger.Exit(2) } - // If version flag is set, output version and quit. if k.Bool("version") { fmt.Printf("%s - %s\n", buildVersion, buildDate) app.Logger.Exit(0) } - // Set log level. - if k.Bool("debug") { - // Set logger level - app.Logger.SetLevel(logrus.DebugLevel) - } else { - app.Logger.SetLevel(logrus.InfoLevel) - } + setupLogging(&app) - // Unmarshall flags to the app. - err = k.Unmarshal("", &app.QueryFlags) - if err != nil { - app.Logger.WithError(err).Error("error loading args") + if err := k.Unmarshal("", &app.QueryFlags); err != nil { + app.Logger.WithError(err).Error("Error loading args") app.Logger.Exit(2) } - // Load all `non-flag` arguments - // which will be parsed separately. - nsvrs, qt, qc, qn := loadUnparsedArgs(f.Args()) - app.QueryFlags.Nameservers = append(app.QueryFlags.Nameservers, nsvrs...) - app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, qt...) - app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, qc...) - app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...) - - // Check if reverse flag is passed. If it is, then set - // query type as PTR and query class as IN. - // Modify query name like 94.2.0.192.in-addr.arpa if it's an IPv4 address. - // Use IP6.ARPA nibble format otherwise. + loadNameservers(&app, f.Args()) if app.QueryFlags.ReverseLookup { app.ReverseLookup() } - // Load fallbacks. app.LoadFallbacks() - - // Load Questions. app.PrepareQuestions() - // Load Nameservers. - err = app.LoadNameservers() - if err != nil { - app.Logger.WithError(err).Error("error loading nameservers") + if err := app.LoadNameservers(); err != nil { + app.Logger.WithError(err).Error("Error loading nameservers") app.Logger.Exit(2) } - // Load Resolvers. rslvrs, err := resolvers.LoadResolvers(resolvers.Options{ Nameservers: app.Nameservers, UseIPv4: app.QueryFlags.UseIPv4, @@ -136,23 +70,96 @@ func main() { TLSHostname: app.QueryFlags.TLSHostname, }) if err != nil { - app.Logger.WithError(err).Error("error loading resolver") + app.Logger.WithError(err).Error("Error loading resolver") app.Logger.Exit(2) } app.Resolvers = rslvrs - // Run the app. app.Logger.Debug("Starting doggo 🐶") if len(app.QueryFlags.QNames) == 0 { f.Usage() app.Logger.Exit(0) } - // Resolve Queries. - var ( - responses []resolvers.Response - responseErrors []error - ) + responses, responseErrors := resolveQueries(&app) + + outputResults(&app, responses, responseErrors) + + app.Logger.Exit(0) +} + +func setupFlags() *flag.FlagSet { + f := flag.NewFlagSet("config", flag.ContinueOnError) + f.Usage = renderCustomHelp + + f.StringSliceP("query", "q", []string{}, "Domain name to query") + f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)") + f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)") + f.StringSliceP("nameserver", "n", []string{}, "Address of the nameserver to send packets to") + f.BoolP("reverse", "x", false, "Performs a DNS Lookup for an IPv4 or IPv6 address") + + f.Int("timeout", 5, "Sets the timeout for a query to T seconds") + f.Bool("search", true, "Use the search list provided in resolv.conf") + f.Int("ndots", -1, "Specify the ndots parameter") + f.BoolP("ipv4", "4", false, "Use IPv4 only") + f.BoolP("ipv6", "6", false, "Use IPv6 only") + f.String("strategy", "all", "Strategy to query nameservers in resolv.conf file") + f.String("tls-hostname", "", "Hostname for certificate verification") + f.Bool("skip-hostname-verification", false, "Skip TLS Hostname Verification") + + f.BoolP("json", "J", false, "Set the output format as JSON") + f.Bool("short", false, "Short output format") + f.Bool("time", false, "Display how long the response took") + f.Bool("color", true, "Show colored output") + f.Bool("debug", false, "Enable debug mode") + + f.Bool("version", false, "Show version of doggo") + + return f +} + +func parseAndLoadFlags(f *flag.FlagSet) error { + if err := f.Parse(os.Args[1:]); err != nil { + return fmt.Errorf("error parsing flags: %w", err) + } + if err := k.Load(posflag.Provider(f, ".", k), nil); err != nil { + return fmt.Errorf("error loading flags: %w", err) + } + return nil +} + +func setupLogging(app *app.App) { + if k.Bool("debug") { + app.Logger.SetLevel(logrus.DebugLevel) + } else { + app.Logger.SetLevel(logrus.InfoLevel) + } +} + +func loadNameservers(app *app.App, args []string) { + flagNameservers := k.Strings("nameserver") + app.Logger.WithField("flagNameservers", flagNameservers).Debug("Nameservers from -n flag") + + unparsedNameservers, qt, qc, qn := loadUnparsedArgs(args) + app.Logger.WithField("unparsedNameservers", unparsedNameservers).Debug("Nameservers from unparsed arguments") + + if len(flagNameservers) > 0 { + app.QueryFlags.Nameservers = flagNameservers + } else { + app.QueryFlags.Nameservers = unparsedNameservers + } + + app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, qt...) + app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, qc...) + app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...) + + app.Logger.WithField("finalNameservers", app.QueryFlags.Nameservers).Debug("Final nameservers") +} + +func resolveQueries(app *app.App) ([]resolvers.Response, []error) { + var responses []resolvers.Response + var responseErrors []error + for _, q := range app.Questions { for _, rslv := range app.Resolvers { resp, err := rslv.Lookup(q) @@ -163,25 +170,12 @@ func main() { } } - // Output results - if app.QueryFlags.ShowJSON { - jsonOutput := struct { - Responses []resolvers.Response `json:"responses,omitempty"` - Error string `json:"error,omitempty"` - }{ - Responses: responses, - } - - if len(responseErrors) > 0 { - jsonOutput.Error = responseErrors[0].Error() - } + return responses, responseErrors +} - jsonData, err := json.MarshalIndent(jsonOutput, "", " ") - if err != nil { - app.Logger.WithError(err).Error("Error marshaling JSON") - app.Logger.Exit(1) - } - fmt.Println(string(jsonData)) +func outputResults(app *app.App, responses []resolvers.Response, responseErrors []error) { + if app.QueryFlags.ShowJSON { + outputJSON(responses, responseErrors) } else { if len(responseErrors) > 0 { app.Logger.WithError(responseErrors[0]).Error("Error looking up DNS records") @@ -189,7 +183,24 @@ func main() { } app.Output(responses) } +} - // Quitting. - app.Logger.Exit(0) +func outputJSON(responses []resolvers.Response, responseErrors []error) { + jsonOutput := struct { + Responses []resolvers.Response `json:"responses,omitempty"` + Error string `json:"error,omitempty"` + }{ + Responses: responses, + } + + if len(responseErrors) > 0 { + jsonOutput.Error = responseErrors[0].Error() + } + + jsonData, err := json.MarshalIndent(jsonOutput, "", " ") + if err != nil { + logger.WithError(err).Error("Error marshaling JSON") + logger.Exit(1) + } + fmt.Println(string(jsonData)) } diff --git a/cmd/doggo/parse.go b/cmd/doggo/parse.go index 52ac09a..6dd67f8 100644 --- a/cmd/doggo/parse.go +++ b/cmd/doggo/parse.go @@ -18,18 +18,17 @@ import ( // where `@1.1.1.1` and `AAAA` are "unparsed" args. // Returns a list of nameserver, queryTypes, queryClasses, queryNames. func loadUnparsedArgs(args []string) ([]string, []string, []string, []string) { - var ns, qt, qc, qn []string + var nameservers, queryTypes, queryClasses, queryNames []string for _, arg := range args { if strings.HasPrefix(arg, "@") { - ns = append(ns, strings.Trim(arg, "@")) - } else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok { - qt = append(qt, arg) - } else if _, ok := dns.StringToClass[strings.ToUpper(arg)]; ok { - qc = append(qc, arg) + nameservers = append(nameservers, strings.TrimPrefix(arg, "@")) + } else if qt, ok := dns.StringToType[strings.ToUpper(arg)]; ok { + queryTypes = append(queryTypes, dns.TypeToString[qt]) + } else if qc, ok := dns.StringToClass[strings.ToUpper(arg)]; ok { + queryClasses = append(queryClasses, dns.ClassToString[qc]) } else { - // if nothing matches, consider it's a query name. - qn = append(qn, arg) + queryNames = append(queryNames, arg) } } - return ns, qt, qc, qn + return nameservers, queryTypes, queryClasses, queryNames } diff --git a/internal/app/nameservers.go b/internal/app/nameservers.go index ff419ef..9942d2b 100644 --- a/internal/app/nameservers.go +++ b/internal/app/nameservers.go @@ -5,6 +5,7 @@ import ( "math/rand" "net" "net/url" + "strings" "time" "github.com/ameshkov/dnsstamps" @@ -12,117 +13,128 @@ import ( "github.com/mr-karan/doggo/pkg/models" ) -// LoadNameservers reads all the user given -// nameservers and loads to App. func (app *App) LoadNameservers() error { - for _, srv := range app.QueryFlags.Nameservers { - ns, err := initNameserver(srv) - if err != nil { - return fmt.Errorf("error parsing nameserver: %s", srv) - } - // check if properly initialised. - if ns.Address != "" && ns.Type != "" { - app.Nameservers = append(app.Nameservers, ns) + app.Logger.WithField("nameservers", app.QueryFlags.Nameservers).Debug("LoadNameservers: Initial nameservers") + + app.Nameservers = []models.Nameserver{} // Clear existing nameservers + + if len(app.QueryFlags.Nameservers) > 0 { + for _, srv := range app.QueryFlags.Nameservers { + ns, err := initNameserver(srv) + if err != nil { + app.Logger.WithError(err).Error("error parsing nameserver") + return fmt.Errorf("error parsing nameserver: %s", srv) + } + if ns.Address != "" && ns.Type != "" { + app.Nameservers = append(app.Nameservers, ns) + app.Logger.WithField("nameserver", ns).Debug("Added nameserver") + } } } - // Set `ndots` to the user specified value. - app.ResolverOpts.Ndots = app.QueryFlags.Ndots - // fallback to system nameserver - // in case no nameserver is specified by user. + // If no nameservers were successfully loaded, fall back to system nameservers if len(app.Nameservers) == 0 { - ns, ndots, search, err := getDefaultServers(app.QueryFlags.Strategy) - if err != nil { - return fmt.Errorf("error fetching system default nameserver") - } - // `-1` indicates the flag is not set. - // use from config if user hasn't specified any value. - if app.ResolverOpts.Ndots == -1 { - app.ResolverOpts.Ndots = ndots - } - if len(search) > 0 && app.QueryFlags.UseSearchList { - app.ResolverOpts.SearchList = search - } - app.Nameservers = append(app.Nameservers, ns...) + return app.loadSystemNameservers() } - // if the user hasn't given any override of `ndots` AND has - // given a custom nameserver. Set `ndots` to 1 as the fallback value + + app.Logger.WithField("nameservers", app.Nameservers).Debug("LoadNameservers: Final nameservers") + return nil +} + +func (app *App) loadSystemNameservers() error { + app.Logger.Debug("No user specified nameservers, falling back to system nameservers") + ns, ndots, search, err := getDefaultServers(app.QueryFlags.Strategy) + if err != nil { + app.Logger.WithError(err).Error("error fetching system default nameserver") + return fmt.Errorf("error fetching system default nameserver: %v", err) + } + if app.ResolverOpts.Ndots == -1 { - app.ResolverOpts.Ndots = 0 + app.ResolverOpts.Ndots = ndots + } + + if len(search) > 0 && app.QueryFlags.UseSearchList { + app.ResolverOpts.SearchList = search } + + app.Nameservers = append(app.Nameservers, ns...) + app.Logger.WithField("nameservers", app.Nameservers).Debug("Loaded system nameservers") return nil } func initNameserver(n string) (models.Nameserver, error) { - // Instantiate a UDP resolver with default port as a fallback. - ns := models.Nameserver{ - Type: models.UDPResolver, - Address: net.JoinHostPort(n, models.DefaultUDPPort), + // If the nameserver doesn't have a protocol, assume it's UDP + if !strings.Contains(n, "://") { + n = "udp://" + n } + u, err := url.Parse(n) if err != nil { - ip := net.ParseIP(n) - if ip == nil { - return ns, err - } - return ns, nil + return models.Nameserver{}, err } + + ns := models.Nameserver{ + Type: models.UDPResolver, + Address: getAddressWithDefaultPort(u, models.DefaultUDPPort), + } + switch u.Scheme { case "sdns": - stamp, err := dnsstamps.NewServerStampFromString(n) - if err != nil { - return ns, err - } - switch stamp.Proto { - case dnsstamps.StampProtoTypeDoH: - ns.Type = models.DOHResolver - address := url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path} - ns.Address = address.String() - case dnsstamps.StampProtoTypeDNSCrypt: - ns.Type = models.DNSCryptResolver - ns.Address = n - default: - return ns, fmt.Errorf("unsupported protocol: %v", stamp.Proto.String()) - } - + return handleSDNS(n) case "https": ns.Type = models.DOHResolver ns.Address = u.String() - case "tls": ns.Type = models.DOTResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), models.DefaultTLSPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } - + ns.Address = getAddressWithDefaultPort(u, models.DefaultTLSPort) case "tcp": ns.Type = models.TCPResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), models.DefaultTCPPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } - + ns.Address = getAddressWithDefaultPort(u, models.DefaultTCPPort) case "udp": ns.Type = models.UDPResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), models.DefaultUDPPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } + ns.Address = getAddressWithDefaultPort(u, models.DefaultUDPPort) case "quic": ns.Type = models.DOQResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), models.DefaultDOQPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } + ns.Address = getAddressWithDefaultPort(u, models.DefaultDOQPort) + default: + return ns, fmt.Errorf("unsupported protocol: %s", u.Scheme) } + return ns, nil } +func getAddressWithDefaultPort(u *url.URL, defaultPort string) string { + host := u.Hostname() + port := u.Port() + if port == "" { + port = defaultPort + } + return net.JoinHostPort(host, port) +} + +func handleSDNS(n string) (models.Nameserver, error) { + stamp, err := dnsstamps.NewServerStampFromString(n) + if err != nil { + return models.Nameserver{}, err + } + + switch stamp.Proto { + case dnsstamps.StampProtoTypeDoH: + address := url.URL{Scheme: "https", Host: stamp.ProviderName, Path: stamp.Path} + return models.Nameserver{ + Type: models.DOHResolver, + Address: address.String(), + }, nil + case dnsstamps.StampProtoTypeDNSCrypt: + return models.Nameserver{ + Type: models.DNSCryptResolver, + Address: n, + }, nil + default: + return models.Nameserver{}, fmt.Errorf("unsupported protocol: %v", stamp.Proto.String()) + } +} + func getDefaultServers(strategy string) ([]models.Nameserver, int, []string, error) { // Load nameservers from `/etc/resolv.conf`. dnsServers, ndots, search, err := config.GetDefaultServers() @@ -134,7 +146,7 @@ func getDefaultServers(strategy string) ([]models.Nameserver, int, []string, err switch strategy { case "random": // Choose a random server from the list. - rand.Seed(time.Now().Unix()) + rand.Seed(time.Now().UnixNano()) srv := dnsServers[rand.Intn(len(dnsServers))] ns := models.Nameserver{ Type: models.UDPResolver,