Files
network-topology-discovery/internal/ssh/client.go
T
Your Name 7e21e60852 Debug: 添加H3C接口输出调试日志
- 在解析器中输出display interface原始数据
- 便于诊断接口为空的问题
2026-04-26 01:54:16 +08:00

305 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package sshclient
import (
"bytes"
"fmt"
"net"
"os"
"strings"
"time"
"regexp"
"golang.org/x/crypto/ssh"
)
// Client SSH客户端
type Client struct {
client *ssh.Client
timeout time.Duration
host string
port int
username string
password string
keyFile string
insecureCiphers bool
}
// Config SSH客户端配置
type Config struct {
Host string
Port int
Username string
Password string
KeyFile string
Timeout time.Duration
InsecureCiphers bool // 启用不安全的加密算法(用于兼容老旧设备)
}
// NewClient 创建新的SSH客户端
func NewClient(config Config) *Client {
if config.Port == 0 {
config.Port = 22
}
if config.Timeout == 0 {
config.Timeout = 10 * time.Second
}
return &Client{
host: config.Host,
port: config.Port,
username: config.Username,
password: config.Password,
keyFile: config.KeyFile,
timeout: config.Timeout,
insecureCiphers: config.InsecureCiphers,
}
}
// Connect 连接到SSH服务器
func (c *Client) Connect() error {
config := &ssh.ClientConfig{
User: c.username,
Auth: []ssh.AuthMethod{},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Timeout: c.timeout,
}
// 添加密码认证
if c.password != "" {
config.Auth = append(config.Auth, ssh.Password(c.password))
}
// 添加密钥认证
if c.keyFile != "" {
key, err := loadPrivateKey(c.keyFile)
if err != nil {
return fmt.Errorf("failed to load private key: %w", err)
}
config.Auth = append(config.Auth, ssh.PublicKeys(key))
}
// 如果启用不安全加密算法,添加旧版算法支持(用于兼容老旧设备)
if c.insecureCiphers {
config.Ciphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com", "aes256-gcm@openssh.com",
"chacha20-poly1305@openssh.com",
"aes128-cbc", "aes256-cbc", // 旧版CBC算法
}
config.KeyExchanges = []string{
"curve25519-sha256", "curve25519-sha256@libssh.org",
"ecdh-sha2-nistp256", "ecdh-sha2-nistp384", "ecdh-sha2-nistp521",
"diffie-hellman-group14-sha256", "diffie-hellman-group16-sha512",
"diffie-hellman-group14-sha1", "diffie-hellman-group1-sha1", // 旧版KEX算法
}
config.MACs = []string{
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com",
"hmac-sha2-256", "hmac-sha2-512",
"hmac-sha1", "hmac-sha1-96", // 旧版MAC算法
}
}
// 连接
addr := fmt.Sprintf("%s:%d", c.host, c.port)
client, err := ssh.Dial("tcp", addr, config)
if err != nil {
return fmt.Errorf("failed to connect to %s: %w", addr, err)
}
c.client = client
return nil
}
// Close 关闭SSH连接
func (c *Client) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// ExecuteCommand 执行命令并返回输出
func (c *Client) ExecuteCommand(command string) (string, error) {
if c.client == nil {
return "", fmt.Errorf("not connected")
}
// 重试机制,处理会话创建失败的情况
var session *ssh.Session
var err error
maxRetries := 3
for i := 0; i < maxRetries; i++ {
session, err = c.client.NewSession()
if err == nil {
break
}
// 如果是会话拒绝错误,等待后重试
if i < maxRetries-1 {
time.Sleep(time.Duration(i+1) * 500 * time.Millisecond)
}
}
if err != nil {
return "", fmt.Errorf("failed to create session after %d retries: %w", maxRetries, err)
}
defer session.Close()
// H3C/华为设备需要先禁用分页
// 使用 Shell 模式而不是 Run 模式
modes := ssh.TerminalModes{
ssh.ECHO: 0, // 禁用回显
ssh.TTY_OP_ISPEED: 14400, // 输入速度
ssh.TTY_OP_OSPEED: 14400, // 输出速度
}
if err := session.RequestPty("dumb", 200, 1000, modes); err != nil {
return "", fmt.Errorf("failed to request pty: %w", err)
}
// 启动 shell
shell, err := session.StdinPipe()
if err != nil {
return "", fmt.Errorf("failed to get stdin pipe: %w", err)
}
var stdoutBuf bytes.Buffer
session.Stdout = &stdoutBuf
session.Stderr = &stdoutBuf
if err := session.Shell(); err != nil {
return "", fmt.Errorf("failed to start shell: %w", err)
}
// 先发送禁用分页命令(H3C/华为)
if _, err := shell.Write([]byte("screen-length disable\n")); err != nil {
return "", fmt.Errorf("failed to send screen-length disable: %w", err)
}
time.Sleep(500 * time.Millisecond) // 等待命令执行
// 发送实际命令
if _, err := shell.Write([]byte(command + "\n")); err != nil {
return "", fmt.Errorf("failed to send command: %w", err)
}
// 等待命令执行完成
time.Sleep(2 * time.Second)
if _, err := shell.Write([]byte("exit\n")); err != nil {
return "", fmt.Errorf("failed to send exit: %w", err)
}
// 等待会话结束
session.Wait()
output := stdoutBuf.String()
// 调试:输出原始输出长度和前100个字符
fmt.Printf("[SSH DEBUG] Raw output length: %d, first 100 chars: %q\n", len(output), output[:min(len(output), 100)])
// 清理输出:移除命令回显和分页提示
lines := strings.Split(output, "\n")
var cleanLines []string
skipNext := false
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
// 跳过空行
if trimmedLine == "" {
continue
}
// 跳过分页提示
if strings.Contains(trimmedLine, "---- More ----") {
continue
}
// 跳过命令本身的回显(精确匹配)
if trimmedLine == strings.TrimSpace(command) || trimmedLine == "screen-length disable" {
skipNext = true
fmt.Printf("[SSH DEBUG] Skipping command echo: %s\n", trimmedLine)
continue
}
// 跳过提示符行(如 <hostname> 或 [hostname]
if regexp.MustCompile(`^[<\[]\S+[>\]]$`).MatchString(trimmedLine) {
fmt.Printf("[SSH DEBUG] Skipping prompt: %s\n", trimmedLine)
continue
}
// 如果是 "screen-length disable" 后的第一行(通常是提示符),跳过
if skipNext {
skipNext = false
fmt.Printf("[SSH DEBUG] Skipping line after command: %s\n", trimmedLine)
continue
}
cleanLines = append(cleanLines, trimmedLine)
}
cleanOutput := strings.Join(cleanLines, "\n")
fmt.Printf("[SSH DEBUG] Clean output length: %d, first 100 chars: %q\n", len(cleanOutput), cleanOutput[:min(len(cleanOutput), 100)])
return cleanOutput, nil
}
// ExecuteCommands 执行多个命令
func (c *Client) ExecuteCommands(commands []string) ([]string, error) {
results := make([]string, 0, len(commands))
for _, cmd := range commands {
result, err := c.ExecuteCommand(cmd)
if err != nil {
return results, fmt.Errorf("failed to execute command '%s': %w", cmd, err)
}
results = append(results, result)
}
return results, nil
}
// CheckSSH 检查主机是否开启SSH
func CheckSSH(host string, port int, timeout time.Duration) bool {
if port == 0 {
port = 22
}
if timeout == 0 {
timeout = 2 * time.Second
}
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := net.DialTimeout("tcp", addr, timeout)
if err != nil {
return false
}
defer conn.Close()
return true
}
// loadPrivateKey 加载私钥文件
func loadPrivateKey(keyFile string) (ssh.Signer, error) {
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("failed to read key file: %w", err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
return signer, nil
}
// Ping 检查主机是否可达 (使用ICMP)
func Ping(host string, timeout time.Duration) bool {
// 简单的TCP ping,实际项目可以使用专门的ICMP库
ports := []int{22, 80, 443, 3389}
for _, port := range ports {
addr := fmt.Sprintf("%s:%d", host, port)
conn, err := net.DialTimeout("tcp", addr, timeout)
if err == nil {
conn.Close()
return true
}
}
return false
}