Files
Your Name 4d3b737e2d Fix: 使用分隔符标记提取单个命令输出
- 在Shell模式中插入分隔符标记命令边界
- 解决累积输出包含之前命令回显的问题
- display interface等待8秒,lldp等待3秒
2026-04-26 03:23:28 +08:00

346 rader
8.8 KiB
Go
Permalänk Blame Historik

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sshclient
import (
"bytes"
"fmt"
"net"
"os"
"regexp"
"strings"
"time"
"golang.org/x/crypto/ssh"
)
// Client SSH客户端
type Client struct {
client *ssh.Client
timeout time.Duration
host string
port int
username string
password string
keyFile string
insecureCiphers bool
}
// Config SSH客户端配置
type Config struct {
Host string
Port int
Username string
Password string
KeyFile string
Timeout time.Duration
InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
}
// NewClient 创建新的SSH客户端
func NewClient(config Config) *Client {
if config.Port == 0 {
config.Port = 22
}
if config.Timeout == 0 {
config.Timeout = 30 * time.Second
}
return &Client{
host: config.Host,
port: config.Port,
username: config.Username,
password: config.Password,
keyFile: config.KeyFile,
timeout: config.Timeout,
insecureCiphers: config.InsecureCiphers,
}
}
// Connect 连接到SSH服务器
func (c *Client) Connect() error {
config := &ssh.ClientConfig{
User: c.username,
Auth: []ssh.AuthMethod{},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: c.timeout,
}
// 添加密码认证
if c.password != "" {
config.Auth = append(config.Auth, ssh.Password(c.password))
}
// 添加密钥认证
if c.keyFile != "" {
key, err := loadPrivateKey(c.keyFile)
if err != nil {
return fmt.Errorf("failed to load private key: %w", err)
}
config.Auth = append(config.Auth, ssh.PublicKeys(key))
}
// 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
if c.insecureCiphers {
config.Ciphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
"chacha20-poly1305@openssh.com",
"aes128-cbc", "aes256-cbc", // 旧版CBC算法
}
config.KeyExchanges = []string{
"curve25519-sha256", "curve25519-sha256@libssh.org",
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
"diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
"diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
}
config.MACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
"hmac-sha2-256", "hmac-sha2-512",
"hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
}
}
// 连接
addr := fmt.Sprintf("%s:%d", c.host, c.port)
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return fmt.Errorf("failed to connect to %s: %w", addr, err)
}
c.client = client
return nil
}
// Close 关闭SSH连接
func (c *Client) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// ExecuteCommand 执行命令并返回输出(每个命令使用新会话)
func (c *Client) ExecuteCommand(command string) (string, error) {
if c.client == nil {
return "", fmt.Errorf("not connected")
}
// 创建新会话
session, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("failed to create session: %w", err)
}
defer session.Close()
var stdoutBuf bytes.Buffer
session.Stdout = &stdoutBuf
session.Stderr = &stdoutBuf
err = session.Run(command)
if err != nil {
return stdoutBuf.String(), fmt.Errorf("command '%s' failed: %w", command, err)
}
return stdoutBuf.String(), nil
}
// ExecuteCommands 执行多个命令(使用Shell模式,在同一个会话中顺序执行)
func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
if c.client == nil {
return nil, fmt.Errorf("not connected")
}
// 创建一个shell会话
session, err := c.client.NewSession()
if err != nil {
return nil, fmt.Errorf("failed to create session: %w", err)
}
defer session.Close()
// 请求PTY
modes := ssh.TerminalModes{
ssh.ECHO: 0, // 禁用回显
ssh.TTY_OP_ISPEED: 14400, // 输入速度
ssh.TTY_OP_OSPEED: 14400, // 输出速度
}
if err := session.RequestPty("dumb", 200, 1000, modes); err != nil {
return nil, fmt.Errorf("failed to request pty: %w", err)
}
// 获取stdin管道
stdin, err := session.StdinPipe()
if err != nil {
return nil, fmt.Errorf("failed to get stdin pipe: %w", err)
}
// 捕获输出
var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
session.Stdout = &stdoutBuf
session.Stderr = &stderrBuf
// 启动shell
if err := session.Shell(); err != nil {
return nil, fmt.Errorf("failed to start shell: %w", err)
}
// 执行命令并收集输出(使用分隔符标记)
results := make([]string, 0, len(commands))
for i, cmd := range commands {
// 等待一段时间防止设备速率限制
if i > 0 {
time.Sleep(2 * time.Second)
}
// 生成分隔符标记命令边界
delimiter := fmt.Sprintf("===CMD_BOUNDARY_%d===", i)
// 发送命令
if _, err := stdin.Write([]byte(cmd + "\n")); err != nil {
return results, fmt.Errorf("failed to send command '%s': %w", cmd, err)
}
// 等待命令执行完成(不同命令需要不同等待时间)
sleepTime := 1 * time.Second
if cmd == "display interface" || strings.Contains(cmd, "display interface") {
sleepTime = 8 * time.Second // 大输出命令需要更多时间
} else if cmd == "display lldp neighbor-information" {
sleepTime = 3 * time.Second
}
time.Sleep(sleepTime)
// 发送分隔符(不执行任何操作,只是标记)
if _, err := stdin.Write([]byte("echo " + delimiter + "\n")); err != nil {
return results, fmt.Errorf("failed to send delimiter: %w", err)
}
// 等待分隔符输出
time.Sleep(500 * time.Millisecond)
// 获取完整输出并清理
rawOutput := stdoutBuf.String()
// 找到分隔符位置,提取当前命令的输出部分
parts := strings.Split(rawOutput, delimiter)
cmdOutput := ""
if len(parts) > 1 {
// 取最后一个分隔符之前的部分(当前命令的输出)
cmdOutput = parts[len(parts)-2] // 倒数第二部分是当前命令输出,最后一部分是分隔符后的空内容
} else {
cmdOutput = rawOutput
}
cleanOutput := cleanCommandOutput(cmdOutput, cmd)
results = append(results, cleanOutput)
fmt.Printf("[SSH] Command '%s' extracted %d bytes from %d bytes raw output\n",
cmd, len(cleanOutput), len(rawOutput))
}
// 退出shell
stdin.Write([]byte("exit\n"))
session.Wait()
return results, nil
}
// cleanCommandOutput 清理命令输出,移除命令回显、分页提示和提示符
func cleanCommandOutput(output, command string) string {
// 清理\r\n为\n
output = strings.ReplaceAll(output, "\r\n", "\n")
lines := strings.Split(output, "\n")
var cleanLines []string
skipCommandEcho := true // 跳过命令本身的回显
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
// 跳过空行(如果是开头)
if trimmedLine == "" && len(cleanLines) == 0 {
continue
}
// 跳过命令回显(第一次出现)
if skipCommandEcho && trimmedLine == strings.TrimSpace(command) {
skipCommandEcho = false
continue
}
// 跳过分页提示
if strings.Contains(trimmedLine, "---- More ----") {
continue
}
// 跳过提示符行(如 <hostname> 或 [hostname]
if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) {
continue
}
// 跳过版权信息(开头)
if strings.HasPrefix(trimmedLine, "*********") {
continue
}
if strings.HasPrefix(trimmedLine, "* Copyright") {
continue
}
if strings.HasPrefix(trimmedLine, "* Without") {
continue
}
if strings.HasPrefix(trimmedLine, "* no decompiling") {
continue
}
cleanLines = append(cleanLines, trimmedLine)
}
return strings.Join(cleanLines, "\n")
}
// CheckSSH 检查主机是否开启SSH
func CheckSSH(host string, port int, timeout time.Duration) bool {
if port == 0 {
port = 22
}
if timeout == 0 {
timeout = 2 * time.Second
}
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return false
}
defer conn.Close()
return true
}
// loadPrivateKey 加载私钥文件
func loadPrivateKey(keyFile string) (ssh.Signer, error) {
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("failed to read key file: %w", err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return signer, nil
}
// Ping 检查主机是否可达 (使用ICMP)
func Ping(host string, timeout time.Duration) bool {
// 简单的TCP ping,实际项目可以使用专门的ICMP库
ports := []int{22, 80, 443, 3389}
for _, port := range ports {
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := net.DialTimeout("tcp", addr, timeout)
if err == nil {
conn.Close()
return true
}
}
return false
}