4d3b737e2d
- 在Shell模式中插入分隔符标记命令边界 - 解决累积输出包含之前命令回显的问题 - display interface等待8秒,lldp等待3秒
346 lignes
8.8 KiB
Go
346 lignes
8.8 KiB
Go
package sshclient
|
||
|
||
import (
|
||
"bytes"
|
||
"fmt"
|
||
"net"
|
||
"os"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
|
||
"golang.org/x/crypto/ssh"
|
||
)
|
||
|
||
// Client SSH客户端
|
||
type Client struct {
|
||
client *ssh.Client
|
||
timeout time.Duration
|
||
host string
|
||
port int
|
||
username string
|
||
password string
|
||
keyFile string
|
||
insecureCiphers bool
|
||
}
|
||
|
||
// Config SSH客户端配置
|
||
type Config struct {
|
||
Host string
|
||
Port int
|
||
Username string
|
||
Password string
|
||
KeyFile string
|
||
Timeout time.Duration
|
||
InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
|
||
}
|
||
|
||
// NewClient 创建新的SSH客户端
|
||
func NewClient(config Config) *Client {
|
||
if config.Port == 0 {
|
||
config.Port = 22
|
||
}
|
||
if config.Timeout == 0 {
|
||
config.Timeout = 30 * time.Second
|
||
}
|
||
return &Client{
|
||
host: config.Host,
|
||
port: config.Port,
|
||
username: config.Username,
|
||
password: config.Password,
|
||
keyFile: config.KeyFile,
|
||
timeout: config.Timeout,
|
||
insecureCiphers: config.InsecureCiphers,
|
||
}
|
||
}
|
||
|
||
// Connect 连接到SSH服务器
|
||
func (c *Client) Connect() error {
|
||
config := &ssh.ClientConfig{
|
||
User: c.username,
|
||
Auth: []ssh.AuthMethod{},
|
||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||
Timeout: c.timeout,
|
||
}
|
||
|
||
// 添加密码认证
|
||
if c.password != "" {
|
||
config.Auth = append(config.Auth, ssh.Password(c.password))
|
||
}
|
||
|
||
// 添加密钥认证
|
||
if c.keyFile != "" {
|
||
key, err := loadPrivateKey(c.keyFile)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to load private key: %w", err)
|
||
}
|
||
config.Auth = append(config.Auth, ssh.PublicKeys(key))
|
||
}
|
||
|
||
// 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
|
||
if c.insecureCiphers {
|
||
config.Ciphers = []string{
|
||
"aes128-ctr", "aes192-ctr", "aes256-ctr",
|
||
"aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
|
||
"chacha20-poly1305@openssh.com",
|
||
"aes128-cbc", "aes256-cbc", // 旧版CBC算法
|
||
}
|
||
config.KeyExchanges = []string{
|
||
"curve25519-sha256", "curve25519-sha256@libssh.org",
|
||
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
|
||
"diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
|
||
"diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
|
||
}
|
||
config.MACs = []string{
|
||
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
|
||
"hmac-sha2-256", "hmac-sha2-512",
|
||
"hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
|
||
}
|
||
}
|
||
|
||
// 连接
|
||
addr := fmt.Sprintf("%s:%d", c.host, c.port)
|
||
client, err := ssh.Dial("tcp", addr, config)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to connect to %s: %w", addr, err)
|
||
}
|
||
|
||
c.client = client
|
||
return nil
|
||
}
|
||
|
||
// Close 关闭SSH连接
|
||
func (c *Client) Close() error {
|
||
if c.client != nil {
|
||
return c.client.Close()
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// ExecuteCommand 执行命令并返回输出(每个命令使用新会话)
|
||
func (c *Client) ExecuteCommand(command string) (string, error) {
|
||
if c.client == nil {
|
||
return "", fmt.Errorf("not connected")
|
||
}
|
||
|
||
// 创建新会话
|
||
session, err := c.client.NewSession()
|
||
if err != nil {
|
||
return "", fmt.Errorf("failed to create session: %w", err)
|
||
}
|
||
defer session.Close()
|
||
|
||
var stdoutBuf bytes.Buffer
|
||
session.Stdout = &stdoutBuf
|
||
session.Stderr = &stdoutBuf
|
||
|
||
err = session.Run(command)
|
||
if err != nil {
|
||
return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err)
|
||
}
|
||
|
||
return stdoutBuf.String(), nil
|
||
}
|
||
|
||
// ExecuteCommands 执行多个命令(使用Shell模式,在同一个会话中顺序执行)
|
||
func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
|
||
if c.client == nil {
|
||
return nil, fmt.Errorf("not connected")
|
||
}
|
||
|
||
// 创建一个shell会话
|
||
session, err := c.client.NewSession()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to create session: %w", err)
|
||
}
|
||
defer session.Close()
|
||
|
||
// 请求PTY
|
||
modes := ssh.TerminalModes{
|
||
ssh.ECHO: 0, // 禁用回显
|
||
ssh.TTY_OP_ISPEED: 14400, // 输入速度
|
||
ssh.TTY_OP_OSPEED: 14400, // 输出速度
|
||
}
|
||
if err := session.RequestPty("dumb", 200, 1000, modes); err != nil {
|
||
return nil, fmt.Errorf("failed to request pty: %w", err)
|
||
}
|
||
|
||
// 获取stdin管道
|
||
stdin, err := session.StdinPipe()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
|
||
}
|
||
|
||
// 捕获输出
|
||
var stdoutBuf bytes.Buffer
|
||
var stderrBuf bytes.Buffer
|
||
session.Stdout = &stdoutBuf
|
||
session.Stderr = &stderrBuf
|
||
|
||
// 启动shell
|
||
if err := session.Shell(); err != nil {
|
||
return nil, fmt.Errorf("failed to start shell: %w", err)
|
||
}
|
||
|
||
// 执行命令并收集输出(使用分隔符标记)
|
||
results := make([]string, 0, len(commands))
|
||
|
||
for i, cmd := range commands {
|
||
// 等待一段时间防止设备速率限制
|
||
if i > 0 {
|
||
time.Sleep(2 * time.Second)
|
||
}
|
||
|
||
// 生成分隔符标记命令边界
|
||
delimiter := fmt.Sprintf("===CMD_BOUNDARY_%d===", i)
|
||
|
||
// 发送命令
|
||
if _, err := stdin.Write([]byte(cmd + "\n")); err != nil {
|
||
return results, fmt.Errorf("failed to send command '%s': %w", cmd, err)
|
||
}
|
||
|
||
// 等待命令执行完成(不同命令需要不同等待时间)
|
||
sleepTime := 1 * time.Second
|
||
if cmd == "display interface" || strings.Contains(cmd, "display interface") {
|
||
sleepTime = 8 * time.Second // 大输出命令需要更多时间
|
||
} else if cmd == "display lldp neighbor-information" {
|
||
sleepTime = 3 * time.Second
|
||
}
|
||
time.Sleep(sleepTime)
|
||
|
||
// 发送分隔符(不执行任何操作,只是标记)
|
||
if _, err := stdin.Write([]byte("echo " + delimiter + "\n")); err != nil {
|
||
return results, fmt.Errorf("failed to send delimiter: %w", err)
|
||
}
|
||
|
||
// 等待分隔符输出
|
||
time.Sleep(500 * time.Millisecond)
|
||
|
||
// 获取完整输出并清理
|
||
rawOutput := stdoutBuf.String()
|
||
|
||
// 找到分隔符位置,提取当前命令的输出部分
|
||
parts := strings.Split(rawOutput, delimiter)
|
||
cmdOutput := ""
|
||
if len(parts) > 1 {
|
||
// 取最后一个分隔符之前的部分(当前命令的输出)
|
||
cmdOutput = parts[len(parts)-2] // 倒数第二部分是当前命令输出,最后一部分是分隔符后的空内容
|
||
} else {
|
||
cmdOutput = rawOutput
|
||
}
|
||
|
||
cleanOutput := cleanCommandOutput(cmdOutput, cmd)
|
||
results = append(results, cleanOutput)
|
||
|
||
fmt.Printf("[SSH] Command '%s' extracted %d bytes from %d bytes raw output\n",
|
||
cmd, len(cleanOutput), len(rawOutput))
|
||
}
|
||
|
||
// 退出shell
|
||
stdin.Write([]byte("exit\n"))
|
||
session.Wait()
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// cleanCommandOutput 清理命令输出,移除命令回显、分页提示和提示符
|
||
func cleanCommandOutput(output, command string) string {
|
||
// 清理\r\n为\n
|
||
output = strings.ReplaceAll(output, "\r\n", "\n")
|
||
|
||
lines := strings.Split(output, "\n")
|
||
var cleanLines []string
|
||
skipCommandEcho := true // 跳过命令本身的回显
|
||
|
||
for _, line := range lines {
|
||
trimmedLine := strings.TrimSpace(line)
|
||
|
||
// 跳过空行(如果是开头)
|
||
if trimmedLine == "" && len(cleanLines) == 0 {
|
||
continue
|
||
}
|
||
|
||
// 跳过命令回显(第一次出现)
|
||
if skipCommandEcho && trimmedLine == strings.TrimSpace(command) {
|
||
skipCommandEcho = false
|
||
continue
|
||
}
|
||
|
||
// 跳过分页提示
|
||
if strings.Contains(trimmedLine, "---- More ----") {
|
||
continue
|
||
}
|
||
|
||
// 跳过提示符行(如 <hostname> 或 [hostname])
|
||
if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) {
|
||
continue
|
||
}
|
||
|
||
// 跳过版权信息(开头)
|
||
if strings.HasPrefix(trimmedLine, "*********") {
|
||
continue
|
||
}
|
||
if strings.HasPrefix(trimmedLine, "* Copyright") {
|
||
continue
|
||
}
|
||
if strings.HasPrefix(trimmedLine, "* Without") {
|
||
continue
|
||
}
|
||
if strings.HasPrefix(trimmedLine, "* no decompiling") {
|
||
continue
|
||
}
|
||
|
||
cleanLines = append(cleanLines, trimmedLine)
|
||
}
|
||
|
||
return strings.Join(cleanLines, "\n")
|
||
}
|
||
|
||
// CheckSSH 检查主机是否开启SSH
|
||
func CheckSSH(host string, port int, timeout time.Duration) bool {
|
||
if port == 0 {
|
||
port = 22
|
||
}
|
||
if timeout == 0 {
|
||
timeout = 2 * time.Second
|
||
}
|
||
|
||
addr := fmt.Sprintf("%s:%d", host, port)
|
||
conn, err := net.DialTimeout("tcp", addr, timeout)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
defer conn.Close()
|
||
return true
|
||
}
|
||
|
||
// loadPrivateKey 加载私钥文件
|
||
func loadPrivateKey(keyFile string) (ssh.Signer, error) {
|
||
keyData, err := os.ReadFile(keyFile)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to read key file: %w", err)
|
||
}
|
||
|
||
signer, err := ssh.ParsePrivateKey(keyData)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||
}
|
||
|
||
return signer, nil
|
||
}
|
||
|
||
// Ping 检查主机是否可达 (使用ICMP)
|
||
func Ping(host string, timeout time.Duration) bool {
|
||
// 简单的TCP ping,实际项目可以使用专门的ICMP库
|
||
ports := []int{22, 80, 443, 3389}
|
||
for _, port := range ports {
|
||
addr := fmt.Sprintf("%s:%d", host, port)
|
||
conn, err := net.DialTimeout("tcp", addr, timeout)
|
||
if err == nil {
|
||
conn.Close()
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|