client.go 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. package sshclient
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "time"
  10. "golang.org/x/crypto/ssh"
  11. )
  12. // Client SSH客户端
  13. type Client struct {
  14. client *ssh.Client
  15. timeout time.Duration
  16. host string
  17. port int
  18. username string
  19. password string
  20. keyFile string
  21. insecureCiphers bool
  22. }
  23. // Config SSH客户端配置
  24. type Config struct {
  25. Host string
  26. Port int
  27. Username string
  28. Password string
  29. KeyFile string
  30. Timeout time.Duration
  31. InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
  32. }
  33. // NewClient 创建新的SSH客户端
  34. func NewClient(config Config) *Client {
  35. if config.Port == 0 {
  36. config.Port = 22
  37. }
  38. if config.Timeout == 0 {
  39. config.Timeout = 30 * time.Second
  40. }
  41. return &Client{
  42. host: config.Host,
  43. port: config.Port,
  44. username: config.Username,
  45. password: config.Password,
  46. keyFile: config.KeyFile,
  47. timeout: config.Timeout,
  48. insecureCiphers: config.InsecureCiphers,
  49. }
  50. }
  51. // Connect 连接到SSH服务器
  52. func (c *Client) Connect() error {
  53. config := &ssh.ClientConfig{
  54. User: c.username,
  55. Auth: []ssh.AuthMethod{},
  56. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  57. Timeout: c.timeout,
  58. }
  59. // 添加密码认证
  60. if c.password != "" {
  61. config.Auth = append(config.Auth, ssh.Password(c.password))
  62. }
  63. // 添加密钥认证
  64. if c.keyFile != "" {
  65. key, err := loadPrivateKey(c.keyFile)
  66. if err != nil {
  67. return fmt.Errorf("failed to load private key: %w", err)
  68. }
  69. config.Auth = append(config.Auth, ssh.PublicKeys(key))
  70. }
  71. // 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
  72. if c.insecureCiphers {
  73. config.Ciphers = []string{
  74. "aes128-ctr", "aes192-ctr", "aes256-ctr",
  75. "aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
  76. "chacha20-poly1305@openssh.com",
  77. "aes128-cbc", "aes256-cbc", // 旧版CBC算法
  78. }
  79. config.KeyExchanges = []string{
  80. "curve25519-sha256", "curve25519-sha256@libssh.org",
  81. "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
  82. "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
  83. "diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
  84. }
  85. config.MACs = []string{
  86. "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
  87. "hmac-sha2-256", "hmac-sha2-512",
  88. "hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
  89. }
  90. }
  91. // 连接
  92. addr := fmt.Sprintf("%s:%d", c.host, c.port)
  93. client, err := ssh.Dial("tcp", addr, config)
  94. if err != nil {
  95. return fmt.Errorf("failed to connect to %s: %w", addr, err)
  96. }
  97. c.client = client
  98. return nil
  99. }
  100. // Close 关闭SSH连接
  101. func (c *Client) Close() error {
  102. if c.client != nil {
  103. return c.client.Close()
  104. }
  105. return nil
  106. }
  107. // ExecuteCommand 执行命令并返回输出(每个命令使用新会话)
  108. func (c *Client) ExecuteCommand(command string) (string, error) {
  109. if c.client == nil {
  110. return "", fmt.Errorf("not connected")
  111. }
  112. // 创建新会话
  113. session, err := c.client.NewSession()
  114. if err != nil {
  115. return "", fmt.Errorf("failed to create session: %w", err)
  116. }
  117. defer session.Close()
  118. var stdoutBuf bytes.Buffer
  119. session.Stdout = &stdoutBuf
  120. session.Stderr = &stdoutBuf
  121. err = session.Run(command)
  122. if err != nil {
  123. return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err)
  124. }
  125. return stdoutBuf.String(), nil
  126. }
  127. // ExecuteCommands 执行多个命令(使用Shell模式,在同一个会话中顺序执行)
  128. func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
  129. if c.client == nil {
  130. return nil, fmt.Errorf("not connected")
  131. }
  132. // 创建一个shell会话
  133. session, err := c.client.NewSession()
  134. if err != nil {
  135. return nil, fmt.Errorf("failed to create session: %w", err)
  136. }
  137. defer session.Close()
  138. // 请求PTY
  139. modes := ssh.TerminalModes{
  140. ssh.ECHO: 0, // 禁用回显
  141. ssh.TTY_OP_ISPEED: 14400, // 输入速度
  142. ssh.TTY_OP_OSPEED: 14400, // 输出速度
  143. }
  144. if err := session.RequestPty("dumb", 200, 1000, modes); err != nil {
  145. return nil, fmt.Errorf("failed to request pty: %w", err)
  146. }
  147. // 获取stdin管道
  148. stdin, err := session.StdinPipe()
  149. if err != nil {
  150. return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
  151. }
  152. // 捕获输出
  153. var stdoutBuf bytes.Buffer
  154. var stderrBuf bytes.Buffer
  155. session.Stdout = &stdoutBuf
  156. session.Stderr = &stderrBuf
  157. // 启动shell
  158. if err := session.Shell(); err != nil {
  159. return nil, fmt.Errorf("failed to start shell: %w", err)
  160. }
  161. // 执行命令并收集输出(使用分隔符标记)
  162. results := make([]string, 0, len(commands))
  163. for i, cmd := range commands {
  164. // 等待一段时间防止设备速率限制
  165. if i > 0 {
  166. time.Sleep(2 * time.Second)
  167. }
  168. // 生成分隔符标记命令边界
  169. delimiter := fmt.Sprintf("===CMD_BOUNDARY_%d===", i)
  170. // 发送命令
  171. if _, err := stdin.Write([]byte(cmd + "\n")); err != nil {
  172. return results, fmt.Errorf("failed to send command '%s': %w", cmd, err)
  173. }
  174. // 等待命令执行完成(不同命令需要不同等待时间)
  175. sleepTime := 1 * time.Second
  176. if cmd == "display interface" || strings.Contains(cmd, "display interface") {
  177. sleepTime = 8 * time.Second // 大输出命令需要更多时间
  178. } else if cmd == "display lldp neighbor-information" {
  179. sleepTime = 3 * time.Second
  180. }
  181. time.Sleep(sleepTime)
  182. // 发送分隔符(不执行任何操作,只是标记)
  183. if _, err := stdin.Write([]byte("echo " + delimiter + "\n")); err != nil {
  184. return results, fmt.Errorf("failed to send delimiter: %w", err)
  185. }
  186. // 等待分隔符输出
  187. time.Sleep(500 * time.Millisecond)
  188. // 获取完整输出并清理
  189. rawOutput := stdoutBuf.String()
  190. // 找到分隔符位置,提取当前命令的输出部分
  191. parts := strings.Split(rawOutput, delimiter)
  192. cmdOutput := ""
  193. if len(parts) > 1 {
  194. // 取最后一个分隔符之前的部分(当前命令的输出)
  195. cmdOutput = parts[len(parts)-2] // 倒数第二部分是当前命令输出,最后一部分是分隔符后的空内容
  196. } else {
  197. cmdOutput = rawOutput
  198. }
  199. cleanOutput := cleanCommandOutput(cmdOutput, cmd)
  200. results = append(results, cleanOutput)
  201. fmt.Printf("[SSH] Command '%s' extracted %d bytes from %d bytes raw output\n",
  202. cmd, len(cleanOutput), len(rawOutput))
  203. }
  204. // 退出shell
  205. stdin.Write([]byte("exit\n"))
  206. session.Wait()
  207. return results, nil
  208. }
  209. // cleanCommandOutput 清理命令输出,移除命令回显、分页提示和提示符
  210. func cleanCommandOutput(output, command string) string {
  211. // 清理\r\n为\n
  212. output = strings.ReplaceAll(output, "\r\n", "\n")
  213. lines := strings.Split(output, "\n")
  214. var cleanLines []string
  215. skipCommandEcho := true // 跳过命令本身的回显
  216. for _, line := range lines {
  217. trimmedLine := strings.TrimSpace(line)
  218. // 跳过空行(如果是开头)
  219. if trimmedLine == "" && len(cleanLines) == 0 {
  220. continue
  221. }
  222. // 跳过命令回显(第一次出现)
  223. if skipCommandEcho && trimmedLine == strings.TrimSpace(command) {
  224. skipCommandEcho = false
  225. continue
  226. }
  227. // 跳过分页提示
  228. if strings.Contains(trimmedLine, "---- More ----") {
  229. continue
  230. }
  231. // 跳过提示符行(如 <hostname> 或 [hostname])
  232. if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) {
  233. continue
  234. }
  235. // 跳过版权信息(开头)
  236. if strings.HasPrefix(trimmedLine, "*********") {
  237. continue
  238. }
  239. if strings.HasPrefix(trimmedLine, "* Copyright") {
  240. continue
  241. }
  242. if strings.HasPrefix(trimmedLine, "* Without") {
  243. continue
  244. }
  245. if strings.HasPrefix(trimmedLine, "* no decompiling") {
  246. continue
  247. }
  248. cleanLines = append(cleanLines, trimmedLine)
  249. }
  250. return strings.Join(cleanLines, "\n")
  251. }
  252. // CheckSSH 检查主机是否开启SSH
  253. func CheckSSH(host string, port int, timeout time.Duration) bool {
  254. if port == 0 {
  255. port = 22
  256. }
  257. if timeout == 0 {
  258. timeout = 2 * time.Second
  259. }
  260. addr := fmt.Sprintf("%s:%d", host, port)
  261. conn, err := net.DialTimeout("tcp", addr, timeout)
  262. if err != nil {
  263. return false
  264. }
  265. defer conn.Close()
  266. return true
  267. }
  268. // loadPrivateKey 加载私钥文件
  269. func loadPrivateKey(keyFile string) (ssh.Signer, error) {
  270. keyData, err := os.ReadFile(keyFile)
  271. if err != nil {
  272. return nil, fmt.Errorf("failed to read key file: %w", err)
  273. }
  274. signer, err := ssh.ParsePrivateKey(keyData)
  275. if err != nil {
  276. return nil, fmt.Errorf("failed to parse private key: %w", err)
  277. }
  278. return signer, nil
  279. }
  280. // Ping 检查主机是否可达 (使用ICMP)
  281. func Ping(host string, timeout time.Duration) bool {
  282. // 简单的TCP ping,实际项目可以使用专门的ICMP库
  283. ports := []int{22, 80, 443, 3389}
  284. for _, port := range ports {
  285. addr := fmt.Sprintf("%s:%d", host, port)
  286. conn, err := net.DialTimeout("tcp", addr, timeout)
  287. if err == nil {
  288. conn.Close()
  289. return true
  290. }
  291. }
  292. return false
  293. }