diff --git a/cmd/scan.go b/cmd/scan.go index 165a8bf..8d370a2 100644 --- a/cmd/scan.go +++ b/cmd/scan.go @@ -46,6 +46,15 @@ For example: log4jScanner scan --cidr "192.168.0.1/24`, cmd.Usage() return } + publicIPAllowed, err := cmd.Flags().GetBool("allow-public-ips") + if err != nil { + pterm.Error.Println("allow-public-ip flag error") + cmd.Usage() + return + } + if publicIPAllowed { + pterm.Warning.Println("Scanning public IPs should be done with care, use at your own risk") + } // TODO: add cancel context cidr, err := cmd.Flags().GetString("cidr") if err != nil { @@ -99,7 +108,7 @@ For example: log4jScanner scan --cidr "192.168.0.1/24`, if !disableServer { StartServer(ctx, serverUrl, serverTimeout) } - ScanCIDR(ctx, cidr, ports, serverUrl) + ScanCIDR(ctx, cidr, ports, serverUrl, publicIPAllowed) }, } @@ -113,6 +122,7 @@ func init() { scanCmd.Flags().String("cidr", "", "IP subnet to scan in CIDR notation (e.g. 192.168.1.0/24)") scanCmd.Flags().Bool("noserver", false, "Do not use the internal TCP server, this overrides the server flag if present") scanCmd.Flags().Bool("nocolor", false, "remove colors from output") + scanCmd.Flags().Bool("allow-public-ips",false,"allowing to scan public IPs") scanCmd.Flags().String("server", "", "Callback server IP and port (e.g. 192.168.1.100:5555)") scanCmd.Flags().String("ports", "top10", "Ports to scan. By default scans top 10 ports; 'top100' will scan the top 100 ports, 'slow' will scan all possible ports") @@ -122,8 +132,8 @@ func init() { createPrivateIPBlocks() } -func ScanCIDR(ctx context.Context, cidr string, portsFlag string, serverUrl string) { - hosts, err := Hosts(cidr) +func ScanCIDR(ctx context.Context, cidr string, portsFlag string, serverUrl string, allowPublicIPs bool) { + hosts, err := Hosts(cidr, allowPublicIPs) //if err is not nil cidr wasn't parse correctly or ip isn't private if err != nil { pterm.Error.Println("Failed to get hosts, what:", err) @@ -233,7 +243,7 @@ func ScanPorts(ip, server string, ports []int, resChan chan string, wg *sync.Wai } -func Hosts(cidr string) ([]string, error) { +func Hosts(cidr string, allowPublicIPs bool) ([]string, error) { ip, ipnet, err := net.ParseCIDR(cidr) if err != nil { return nil, err @@ -241,8 +251,9 @@ func Hosts(cidr string) ([]string, error) { var ips []string for ip := ip.Mask(ipnet.Mask); ipnet.Contains(ip); inc(ip) { - // Only scan for private IP addresses. If IP is not private, skip. - if !isPrivateIP(ip) { + + //if public ip scanning isn't allowed Only scan for private IP addresses. If IP is not private, terminate with error. + if !allowPublicIPs && !isPrivateIP(ip) { badIPStatus := ip.String() + " IP address is not private" pterm.Error.Println(badIPStatus) log.Fatal(badIPStatus)