client.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. package sshclient
  2. import (
  3. "bytes"
  4. "fmt"
  5. "net"
  6. "os"
  7. "time"
  8. "golang.org/x/crypto/ssh"
  9. )
  10. // Client SSH客户端
  11. type Client struct {
  12. client *ssh.Client
  13. timeout time.Duration
  14. host string
  15. port int
  16. username string
  17. password string
  18. keyFile string
  19. insecureCiphers bool
  20. }
  21. // Config SSH客户端配置
  22. type Config struct {
  23. Host string
  24. Port int
  25. Username string
  26. Password string
  27. KeyFile string
  28. Timeout time.Duration
  29. InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
  30. }
  31. // NewClient 创建新的SSH客户端
  32. func NewClient(config Config) *Client {
  33. if config.Port == 0 {
  34. config.Port = 22
  35. }
  36. if config.Timeout == 0 {
  37. config.Timeout = 30 * time.Second
  38. }
  39. return &Client{
  40. host: config.Host,
  41. port: config.Port,
  42. username: config.Username,
  43. password: config.Password,
  44. keyFile: config.KeyFile,
  45. timeout: config.Timeout,
  46. insecureCiphers: config.InsecureCiphers,
  47. }
  48. }
  49. // Connect 连接到SSH服务器
  50. func (c *Client) Connect() error {
  51. config := &ssh.ClientConfig{
  52. User: c.username,
  53. Auth: []ssh.AuthMethod{},
  54. HostKeyCallback: ssh.InsecureIgnoreHostKey(),
  55. Timeout: c.timeout,
  56. }
  57. // 添加密码认证
  58. if c.password != "" {
  59. config.Auth = append(config.Auth, ssh.Password(c.password))
  60. }
  61. // 添加密钥认证
  62. if c.keyFile != "" {
  63. key, err := loadPrivateKey(c.keyFile)
  64. if err != nil {
  65. return fmt.Errorf("failed to load private key: %w", err)
  66. }
  67. config.Auth = append(config.Auth, ssh.PublicKeys(key))
  68. }
  69. // 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
  70. if c.insecureCiphers {
  71. config.Ciphers = []string{
  72. "aes128-ctr", "aes192-ctr", "aes256-ctr",
  73. "aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
  74. "chacha20-poly1305@openssh.com",
  75. "aes128-cbc", "aes256-cbc", // 旧版CBC算法
  76. }
  77. config.KeyExchanges = []string{
  78. "curve25519-sha256", "curve25519-sha256@libssh.org",
  79. "ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
  80. "diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
  81. "diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
  82. }
  83. config.MACs = []string{
  84. "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
  85. "hmac-sha2-256", "hmac-sha2-512",
  86. "hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
  87. }
  88. }
  89. // 连接
  90. addr := fmt.Sprintf("%s:%d", c.host, c.port)
  91. client, err := ssh.Dial("tcp", addr, config)
  92. if err != nil {
  93. return fmt.Errorf("failed to connect to %s: %w", addr, err)
  94. }
  95. c.client = client
  96. return nil
  97. }
  98. // Close 关闭SSH连接
  99. func (c *Client) Close() error {
  100. if c.client != nil {
  101. return c.client.Close()
  102. }
  103. return nil
  104. }
  105. // ExecuteCommand 执行命令并返回输出(每个命令使用新会话)
  106. func (c *Client) ExecuteCommand(command string) (string, error) {
  107. if c.client == nil {
  108. return "", fmt.Errorf("not connected")
  109. }
  110. // 创建新会话
  111. session, err := c.client.NewSession()
  112. if err != nil {
  113. return "", fmt.Errorf("failed to create session: %w", err)
  114. }
  115. defer session.Close()
  116. var stdoutBuf bytes.Buffer
  117. session.Stdout = &stdoutBuf
  118. session.Stderr = &stdoutBuf
  119. err = session.Run(command)
  120. if err != nil {
  121. return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err)
  122. }
  123. return stdoutBuf.String(), nil
  124. }
  125. // ExecuteCommands 执行多个命令
  126. func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
  127. results := make([]string, 0, len(commands))
  128. for _, cmd := range commands {
  129. result, err := c.ExecuteCommand(cmd)
  130. if err != nil {
  131. return results, fmt.Errorf("failed to execute command '%s': %w", cmd, err)
  132. }
  133. results = append(results, result)
  134. }
  135. return results, nil
  136. }
  137. // CheckSSH 检查主机是否开启SSH
  138. func CheckSSH(host string, port int, timeout time.Duration) bool {
  139. if port == 0 {
  140. port = 22
  141. }
  142. if timeout == 0 {
  143. timeout = 2 * time.Second
  144. }
  145. addr := fmt.Sprintf("%s:%d", host, port)
  146. conn, err := net.DialTimeout("tcp", addr, timeout)
  147. if err != nil {
  148. return false
  149. }
  150. defer conn.Close()
  151. return true
  152. }
  153. // loadPrivateKey 加载私钥文件
  154. func loadPrivateKey(keyFile string) (ssh.Signer, error) {
  155. keyData, err := os.ReadFile(keyFile)
  156. if err != nil {
  157. return nil, fmt.Errorf("failed to read key file: %w", err)
  158. }
  159. signer, err := ssh.ParsePrivateKey(keyData)
  160. if err != nil {
  161. return nil, fmt.Errorf("failed to parse private key: %w", err)
  162. }
  163. return signer, nil
  164. }
  165. // Ping 检查主机是否可达 (使用ICMP)
  166. func Ping(host string, timeout time.Duration) bool {
  167. // 简单的TCP ping,实际项目可以使用专门的ICMP库
  168. ports := []int{22, 80, 443, 3389}
  169. for _, port := range ports {
  170. addr := fmt.Sprintf("%s:%d", host, port)
  171. conn, err := net.DialTimeout("tcp", addr, timeout)
  172. if err == nil {
  173. conn.Close()
  174. return true
  175. }
  176. }
  177. return false
  178. }