package sshclient import ( "bytes" "fmt" "net" "os" "regexp" "strings" "time" "golang.org/x/crypto/ssh" ) // Client SSH客户端 type Client struct { client *ssh.Client timeout time.Duration host string port int username string password string keyFile string insecureCiphers bool } // Config SSH客户端配置 type Config struct { Host string Port int Username string Password string KeyFile string Timeout time.Duration InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备) } // NewClient 创建新的SSH客户端 func NewClient(config Config) *Client { if config.Port == 0 { config.Port = 22 } if config.Timeout == 0 { config.Timeout = 30 * time.Second } return &Client{ host: config.Host, port: config.Port, username: config.Username, password: config.Password, keyFile: config.KeyFile, timeout: config.Timeout, insecureCiphers: config.InsecureCiphers, } } // Connect 连接到SSH服务器 func (c *Client) Connect() error { config := &ssh.ClientConfig{ User: c.username, Auth: []ssh.AuthMethod{}, HostKeyCallback: ssh.InsecureIgnoreHostKey(), Timeout: c.timeout, } // 添加密码认证 if c.password != "" { config.Auth = append(config.Auth, ssh.Password(c.password)) } // 添加密钥认证 if c.keyFile != "" { key, err := loadPrivateKey(c.keyFile) if err != nil { return fmt.Errorf("failed to load private key: %w", err) } config.Auth = append(config.Auth, ssh.PublicKeys(key)) } // 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备) if c.insecureCiphers { config.Ciphers = []string{ "aes128-ctr", "aes192-ctr", "aes256-ctr", "aes128-gcm@openssh.com", "aes256-gcm@openssh.com", "chacha20-poly1305@openssh.com", "aes128-cbc", "aes256-cbc", // 旧版CBC算法 } config.KeyExchanges = []string{ "curve25519-sha256", "curve25519-sha256@libssh.org", "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521", "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512", "diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法 } config.MACs = []string{ "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96", // 旧版MAC算法 } } // 连接 addr := fmt.Sprintf("%s:%d", c.host, c.port) client, err := ssh.Dial("tcp", addr, config) if err != nil { return fmt.Errorf("failed to connect to %s: %w", addr, err) } c.client = client return nil } // Close 关闭SSH连接 func (c *Client) Close() error { if c.client != nil { return c.client.Close() } return nil } // ExecuteCommand 执行命令并返回输出(每个命令使用新会话) func (c *Client) ExecuteCommand(command string) (string, error) { if c.client == nil { return "", fmt.Errorf("not connected") } // 创建新会话 session, err := c.client.NewSession() if err != nil { return "", fmt.Errorf("failed to create session: %w", err) } defer session.Close() var stdoutBuf bytes.Buffer session.Stdout = &stdoutBuf session.Stderr = &stdoutBuf err = session.Run(command) if err != nil { return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err) } return stdoutBuf.String(), nil } // ExecuteCommands 执行多个命令(使用Shell模式,在同一个会话中顺序执行) func (c *Client) ExecuteCommands(commands []string) ([]string, error) { if c.client == nil { return nil, fmt.Errorf("not connected") } // 创建一个shell会话 session, err := c.client.NewSession() if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } defer session.Close() // 请求PTY modes := ssh.TerminalModes{ ssh.ECHO: 0, // 禁用回显 ssh.TTY_OP_ISPEED: 14400, // 输入速度 ssh.TTY_OP_OSPEED: 14400, // 输出速度 } if err := session.RequestPty("dumb", 200, 1000, modes); err != nil { return nil, fmt.Errorf("failed to request pty: %w", err) } // 获取stdin管道 stdin, err := session.StdinPipe() if err != nil { return nil, fmt.Errorf("failed to get stdin pipe: %w", err) } // 捕获输出 var stdoutBuf bytes.Buffer var stderrBuf bytes.Buffer session.Stdout = &stdoutBuf session.Stderr = &stderrBuf // 启动shell if err := session.Shell(); err != nil { return nil, fmt.Errorf("failed to start shell: %w", err) } // 执行命令并收集输出(使用分隔符标记) results := make([]string, 0, len(commands)) for i, cmd := range commands { // 等待一段时间防止设备速率限制 if i > 0 { time.Sleep(2 * time.Second) } // 生成分隔符标记命令边界 delimiter := fmt.Sprintf("===CMD_BOUNDARY_%d===", i) // 发送命令 if _, err := stdin.Write([]byte(cmd + "\n")); err != nil { return results, fmt.Errorf("failed to send command '%s': %w", cmd, err) } // 等待命令执行完成(不同命令需要不同等待时间) sleepTime := 1 * time.Second if cmd == "display interface" || strings.Contains(cmd, "display interface") { sleepTime = 8 * time.Second // 大输出命令需要更多时间 } else if cmd == "display lldp neighbor-information" { sleepTime = 3 * time.Second } time.Sleep(sleepTime) // 发送分隔符(不执行任何操作,只是标记) if _, err := stdin.Write([]byte("echo " + delimiter + "\n")); err != nil { return results, fmt.Errorf("failed to send delimiter: %w", err) } // 等待分隔符输出 time.Sleep(500 * time.Millisecond) // 获取完整输出并清理 rawOutput := stdoutBuf.String() // 找到分隔符位置,提取当前命令的输出部分 parts := strings.Split(rawOutput, delimiter) cmdOutput := "" if len(parts) > 1 { // 取最后一个分隔符之前的部分(当前命令的输出) cmdOutput = parts[len(parts)-2] // 倒数第二部分是当前命令输出,最后一部分是分隔符后的空内容 } else { cmdOutput = rawOutput } cleanOutput := cleanCommandOutput(cmdOutput, cmd) results = append(results, cleanOutput) fmt.Printf("[SSH] Command '%s' extracted %d bytes from %d bytes raw output\n", cmd, len(cleanOutput), len(rawOutput)) } // 退出shell stdin.Write([]byte("exit\n")) session.Wait() return results, nil } // cleanCommandOutput 清理命令输出,移除命令回显、分页提示和提示符 func cleanCommandOutput(output, command string) string { // 清理\r\n为\n output = strings.ReplaceAll(output, "\r\n", "\n") lines := strings.Split(output, "\n") var cleanLines []string skipCommandEcho := true // 跳过命令本身的回显 for _, line := range lines { trimmedLine := strings.TrimSpace(line) // 跳过空行(如果是开头) if trimmedLine == "" && len(cleanLines) == 0 { continue } // 跳过命令回显(第一次出现) if skipCommandEcho && trimmedLine == strings.TrimSpace(command) { skipCommandEcho = false continue } // 跳过分页提示 if strings.Contains(trimmedLine, "---- More ----") { continue } // 跳过提示符行(如 或 [hostname]) if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) { continue } // 跳过版权信息(开头) if strings.HasPrefix(trimmedLine, "*********") { continue } if strings.HasPrefix(trimmedLine, "* Copyright") { continue } if strings.HasPrefix(trimmedLine, "* Without") { continue } if strings.HasPrefix(trimmedLine, "* no decompiling") { continue } cleanLines = append(cleanLines, trimmedLine) } return strings.Join(cleanLines, "\n") } // CheckSSH 检查主机是否开启SSH func CheckSSH(host string, port int, timeout time.Duration) bool { if port == 0 { port = 22 } if timeout == 0 { timeout = 2 * time.Second } addr := fmt.Sprintf("%s:%d", host, port) conn, err := net.DialTimeout("tcp", addr, timeout) if err != nil { return false } defer conn.Close() return true } // loadPrivateKey 加载私钥文件 func loadPrivateKey(keyFile string) (ssh.Signer, error) { keyData, err := os.ReadFile(keyFile) if err != nil { return nil, fmt.Errorf("failed to read key file: %w", err) } signer, err := ssh.ParsePrivateKey(keyData) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) } return signer, nil } // Ping 检查主机是否可达 (使用ICMP) func Ping(host string, timeout time.Duration) bool { // 简单的TCP ping,实际项目可以使用专门的ICMP库 ports := []int{22, 80, 443, 3389} for _, port := range ports { addr := fmt.Sprintf("%s:%d", host, port) conn, err := net.DialTimeout("tcp", addr, timeout) if err == nil { conn.Close() return true } } return false }