| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345 |
- package sshclient
- import (
- "bytes"
- "fmt"
- "net"
- "os"
- "regexp"
- "strings"
- "time"
- "golang.org/x/crypto/ssh"
- )
- // Client SSH客户端
- type Client struct {
- client *ssh.Client
- timeout time.Duration
- host string
- port int
- username string
- password string
- keyFile string
- insecureCiphers bool
- }
- // Config SSH客户端配置
- type Config struct {
- Host string
- Port int
- Username string
- Password string
- KeyFile string
- Timeout time.Duration
- InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
- }
- // NewClient 创建新的SSH客户端
- func NewClient(config Config) *Client {
- if config.Port == 0 {
- config.Port = 22
- }
- if config.Timeout == 0 {
- config.Timeout = 30 * time.Second
- }
- return &Client{
- host: config.Host,
- port: config.Port,
- username: config.Username,
- password: config.Password,
- keyFile: config.KeyFile,
- timeout: config.Timeout,
- insecureCiphers: config.InsecureCiphers,
- }
- }
- // Connect 连接到SSH服务器
- func (c *Client) Connect() error {
- config := &ssh.ClientConfig{
- User: c.username,
- Auth: []ssh.AuthMethod{},
- HostKeyCallback: ssh.InsecureIgnoreHostKey(),
- Timeout: c.timeout,
- }
- // 添加密码认证
- if c.password != "" {
- config.Auth = append(config.Auth, ssh.Password(c.password))
- }
- // 添加密钥认证
- if c.keyFile != "" {
- key, err := loadPrivateKey(c.keyFile)
- if err != nil {
- return fmt.Errorf("failed to load private key: %w", err)
- }
- config.Auth = append(config.Auth, ssh.PublicKeys(key))
- }
- // 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
- if c.insecureCiphers {
- config.Ciphers = []string{
- "aes128-ctr", "aes192-ctr", "aes256-ctr",
- "aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
- "chacha20-poly1305@openssh.com",
- "aes128-cbc", "aes256-cbc", // 旧版CBC算法
- }
- config.KeyExchanges = []string{
- "curve25519-sha256", "curve25519-sha256@libssh.org",
- "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
- "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
- "diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
- }
- config.MACs = []string{
- "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
- "hmac-sha2-256", "hmac-sha2-512",
- "hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
- }
- }
- // 连接
- addr := fmt.Sprintf("%s:%d", c.host, c.port)
- client, err := ssh.Dial("tcp", addr, config)
- if err != nil {
- return fmt.Errorf("failed to connect to %s: %w", addr, err)
- }
- c.client = client
- return nil
- }
- // Close 关闭SSH连接
- func (c *Client) Close() error {
- if c.client != nil {
- return c.client.Close()
- }
- return nil
- }
- // ExecuteCommand 执行命令并返回输出(每个命令使用新会话)
- func (c *Client) ExecuteCommand(command string) (string, error) {
- if c.client == nil {
- return "", fmt.Errorf("not connected")
- }
- // 创建新会话
- session, err := c.client.NewSession()
- if err != nil {
- return "", fmt.Errorf("failed to create session: %w", err)
- }
- defer session.Close()
- var stdoutBuf bytes.Buffer
- session.Stdout = &stdoutBuf
- session.Stderr = &stdoutBuf
- err = session.Run(command)
- if err != nil {
- return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err)
- }
- return stdoutBuf.String(), nil
- }
- // ExecuteCommands 执行多个命令(使用Shell模式,在同一个会话中顺序执行)
- func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
- if c.client == nil {
- return nil, fmt.Errorf("not connected")
- }
- // 创建一个shell会话
- session, err := c.client.NewSession()
- if err != nil {
- return nil, fmt.Errorf("failed to create session: %w", err)
- }
- defer session.Close()
- // 请求PTY
- modes := ssh.TerminalModes{
- ssh.ECHO: 0, // 禁用回显
- ssh.TTY_OP_ISPEED: 14400, // 输入速度
- ssh.TTY_OP_OSPEED: 14400, // 输出速度
- }
- if err := session.RequestPty("dumb", 200, 1000, modes); err != nil {
- return nil, fmt.Errorf("failed to request pty: %w", err)
- }
- // 获取stdin管道
- stdin, err := session.StdinPipe()
- if err != nil {
- return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
- }
- // 捕获输出
- var stdoutBuf bytes.Buffer
- var stderrBuf bytes.Buffer
- session.Stdout = &stdoutBuf
- session.Stderr = &stderrBuf
- // 启动shell
- if err := session.Shell(); err != nil {
- return nil, fmt.Errorf("failed to start shell: %w", err)
- }
- // 执行命令并收集输出(使用分隔符标记)
- results := make([]string, 0, len(commands))
-
- for i, cmd := range commands {
- // 等待一段时间防止设备速率限制
- if i > 0 {
- time.Sleep(2 * time.Second)
- }
- // 生成分隔符标记命令边界
- delimiter := fmt.Sprintf("===CMD_BOUNDARY_%d===", i)
-
- // 发送命令
- if _, err := stdin.Write([]byte(cmd + "\n")); err != nil {
- return results, fmt.Errorf("failed to send command '%s': %w", cmd, err)
- }
- // 等待命令执行完成(不同命令需要不同等待时间)
- sleepTime := 1 * time.Second
- if cmd == "display interface" || strings.Contains(cmd, "display interface") {
- sleepTime = 8 * time.Second // 大输出命令需要更多时间
- } else if cmd == "display lldp neighbor-information" {
- sleepTime = 3 * time.Second
- }
- time.Sleep(sleepTime)
-
- // 发送分隔符(不执行任何操作,只是标记)
- if _, err := stdin.Write([]byte("echo " + delimiter + "\n")); err != nil {
- return results, fmt.Errorf("failed to send delimiter: %w", err)
- }
-
- // 等待分隔符输出
- time.Sleep(500 * time.Millisecond)
- // 获取完整输出并清理
- rawOutput := stdoutBuf.String()
-
- // 找到分隔符位置,提取当前命令的输出部分
- parts := strings.Split(rawOutput, delimiter)
- cmdOutput := ""
- if len(parts) > 1 {
- // 取最后一个分隔符之前的部分(当前命令的输出)
- cmdOutput = parts[len(parts)-2] // 倒数第二部分是当前命令输出,最后一部分是分隔符后的空内容
- } else {
- cmdOutput = rawOutput
- }
-
- cleanOutput := cleanCommandOutput(cmdOutput, cmd)
- results = append(results, cleanOutput)
-
- fmt.Printf("[SSH] Command '%s' extracted %d bytes from %d bytes raw output\n",
- cmd, len(cleanOutput), len(rawOutput))
- }
- // 退出shell
- stdin.Write([]byte("exit\n"))
- session.Wait()
- return results, nil
- }
- // cleanCommandOutput 清理命令输出,移除命令回显、分页提示和提示符
- func cleanCommandOutput(output, command string) string {
- // 清理\r\n为\n
- output = strings.ReplaceAll(output, "\r\n", "\n")
- lines := strings.Split(output, "\n")
- var cleanLines []string
- skipCommandEcho := true // 跳过命令本身的回显
- for _, line := range lines {
- trimmedLine := strings.TrimSpace(line)
- // 跳过空行(如果是开头)
- if trimmedLine == "" && len(cleanLines) == 0 {
- continue
- }
- // 跳过命令回显(第一次出现)
- if skipCommandEcho && trimmedLine == strings.TrimSpace(command) {
- skipCommandEcho = false
- continue
- }
- // 跳过分页提示
- if strings.Contains(trimmedLine, "---- More ----") {
- continue
- }
- // 跳过提示符行(如 <hostname> 或 [hostname])
- if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) {
- continue
- }
- // 跳过版权信息(开头)
- if strings.HasPrefix(trimmedLine, "*********") {
- continue
- }
- if strings.HasPrefix(trimmedLine, "* Copyright") {
- continue
- }
- if strings.HasPrefix(trimmedLine, "* Without") {
- continue
- }
- if strings.HasPrefix(trimmedLine, "* no decompiling") {
- continue
- }
- cleanLines = append(cleanLines, trimmedLine)
- }
- return strings.Join(cleanLines, "\n")
- }
- // CheckSSH 检查主机是否开启SSH
- func CheckSSH(host string, port int, timeout time.Duration) bool {
- if port == 0 {
- port = 22
- }
- if timeout == 0 {
- timeout = 2 * time.Second
- }
- addr := fmt.Sprintf("%s:%d", host, port)
- conn, err := net.DialTimeout("tcp", addr, timeout)
- if err != nil {
- return false
- }
- defer conn.Close()
- return true
- }
- // loadPrivateKey 加载私钥文件
- func loadPrivateKey(keyFile string) (ssh.Signer, error) {
- keyData, err := os.ReadFile(keyFile)
- if err != nil {
- return nil, fmt.Errorf("failed to read key file: %w", err)
- }
- signer, err := ssh.ParsePrivateKey(keyData)
- if err != nil {
- return nil, fmt.Errorf("failed to parse private key: %w", err)
- }
- return signer, nil
- }
- // Ping 检查主机是否可达 (使用ICMP)
- func Ping(host string, timeout time.Duration) bool {
- // 简单的TCP ping,实际项目可以使用专门的ICMP库
- ports := []int{22, 80, 443, 3389}
- for _, port := range ports {
- addr := fmt.Sprintf("%s:%d", host, port)
- conn, err := net.DialTimeout("tcp", addr, timeout)
- if err == nil {
- conn.Close()
- return true
- }
- }
- return false
- }
|