package scanner import ( "fmt" "net" "sync" "time" sshclient "network-topology-discovery/internal/ssh" ) // Scanner 网络扫描器 type Scanner struct { concurrency int timeout time.Duration } // NewScanner 创建扫描器 func NewScanner(concurrency int, timeout time.Duration) *Scanner { if concurrency <= 0 { concurrency = 10 } if timeout == 0 { timeout = 2 * time.Second } return &Scanner{ concurrency: concurrency, timeout: timeout, } } // ScanRange 扫描IP范围 func (s *Scanner) ScanRange(cidr string) ([]string, error) { _, ipNet, err := net.ParseCIDR(cidr) if err != nil { return nil, fmt.Errorf("invalid CIDR: %w", err) } var ips []string for ip := ipNet.IP.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) { ips = append(ips, ip.String()) } return ips, nil } // CheckHosts 检查主机是否存活 func (s *Scanner) CheckHosts(ips []string) []string { var aliveHosts []string var mu sync.Mutex var wg sync.WaitGroup semaphore := make(chan struct{}, s.concurrency) for _, ip := range ips { wg.Add(1) semaphore <- struct{}{} go func(ip string) { defer wg.Done() defer func() { <-semaphore }() if sshclient.Ping(ip, s.timeout) { mu.Lock() aliveHosts = append(aliveHosts, ip) mu.Unlock() } }(ip) } wg.Wait() return aliveHosts } // CheckSSHHosts 检查哪些主机开启了SSH func (s *Scanner) CheckSSHHosts(ips []string, port int) []string { var sshHosts []string var mu sync.Mutex var wg sync.WaitGroup semaphore := make(chan struct{}, s.concurrency) for _, ip := range ips { wg.Add(1) semaphore <- struct{}{} go func(ip string) { defer wg.Done() defer func() { <-semaphore }() if sshclient.CheckSSH(ip, port, s.timeout) { mu.Lock() sshHosts = append(sshHosts, ip) mu.Unlock() } }(ip) } wg.Wait() return sshHosts } // ScanAndDiscover 扫描并发现设备 func (s *Scanner) ScanAndDiscover(cidr string, sshPort int) ([]string, error) { // 解析IP范围 ips, err := s.ScanRange(cidr) if err != nil { return nil, err } // 检查存活主机 aliveHosts := s.CheckHosts(ips) // 检查SSH sshHosts := s.CheckSSHHosts(aliveHosts, sshPort) return sshHosts, nil } // incIP IP地址递增 func incIP(ip net.IP) { for j := len(ip) - 1; j >= 0; j-- { ip[j]++ if ip[j] > 0 { break } } }