package services import ( "bufio" "bytes" "context" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "os" "os/exec" "path/filepath" "regexp" "strconv" "strings" "sync" "time" "github.com/ansible-deploy/internal/models" "gopkg.in/yaml.v3" ) // AnsibleService Ansible服务 type AnsibleService struct { config *Config hosts map[string]*models.Host groups map[string]*models.HostGroup inventoryPath string tasks map[string]*models.TaskExecution taskLock sync.RWMutex } // NewAnsibleService 创建Ansible服务 func NewAnsibleService(cfg *Config) *AnsibleService { svc := &AnsibleService{ config: cfg, hosts: make(map[string]*models.Host), groups: make(map[string]*models.HostGroup), inventoryPath: filepath.Join(cfg.InventoryDir, "hosts"), tasks: make(map[string]*models.TaskExecution), } // 初始化默认组 svc.groups["all"] = &models.HostGroup{Name: "all", Description: "所有主机"} svc.groups["ungrouped"] = &models.HostGroup{Name: "ungrouped", Description: "未分组主机"} // 加载现有数据 svc.loadHosts() return svc } // generateID 生成唯一ID func (s *AnsibleService) generateID() string { hash := md5.New() hash.Write([]byte(time.Now().String())) return hex.EncodeToString(hash.Sum(nil))[:8] } // loadInventory 加载资产清单 func (s *AnsibleService) loadInventory() { invFile := filepath.Join(s.config.InventoryDir, "hosts") data, err := os.ReadFile(invFile) if err != nil { return } // 解析INI格式的inventory scanner := bufio.NewScanner(bytes.NewReader(data)) var currentGroup string groupVars := make(map[string]map[string]string) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } // 组定义 if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") { currentGroup = strings.Trim(line, "[]") continue } // 变量定义 if strings.Contains(line, "=") { parts := strings.SplitN(line, "=", 2) if len(parts) == 2 { if groupVars[currentGroup] == nil { groupVars[currentGroup] = make(map[string]string) } groupVars[currentGroup][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) } } // 主机定义 if strings.Contains(line, "ansible_host") { re := regexp.MustCompile(`(\S+)\s+ansible_host=(\S+)`) if matches := re.FindStringSubmatch(line); len(matches) == 3 { host := &models.Host{ ID: s.generateID(), Name: matches[1], IP: matches[2], Status: "unknown", } s.hosts[host.ID] = host } } } } // loadHosts 加载主机列表 func (s *AnsibleService) loadHosts() { // 从hosts.json加载详细配置(唯一数据源) hostsFile := filepath.Join(s.config.InventoryDir, "hosts.json") data, err := os.ReadFile(hostsFile) if err != nil { return } var hosts []models.Host if err := json.Unmarshal(data, &hosts); err == nil { for _, h := range hosts { hcopy := h if hcopy.ID == "" { hcopy.ID = s.generateID() } if hcopy.Port == 0 { hcopy.Port = 22 } if hcopy.Username == "" { hcopy.Username = "root" } if hcopy.Status == "" { hcopy.Status = "pending" } s.hosts[hcopy.ID] = &hcopy } // 保存以持久化补全的字段 s.saveHosts() } } // saveHosts 保存主机列表 func (s *AnsibleService) saveHosts() error { hostsFile := filepath.Join(s.config.InventoryDir, "hosts.json") var hosts []models.Host for _, h := range s.hosts { hosts = append(hosts, *h) } data, _ := json.MarshalIndent(hosts, "", " ") if err := os.WriteFile(hostsFile, data, 0644); err != nil { return err } // 更新inventory文件 s.updateInventoryFile() return nil } // updateInventoryFile 更新inventory文件 func (s *AnsibleService) updateInventoryFile() { var lines []string lines = append(lines, "# Ansible Inventory File") lines = append(lines, "# Generated by ansible-deploy") lines = append(lines, "") // 按组分组主机 groupedHosts := make(map[string][]models.Host) for _, h := range s.hosts { if len(h.Groups) == 0 { groupedHosts["ungrouped"] = append(groupedHosts["ungrouped"], *h) } else { for _, g := range h.Groups { groupedHosts[g] = append(groupedHosts[g], *h) } } } // 输出每个组 for group, hosts := range groupedHosts { lines = append(lines, fmt.Sprintf("[%s]", group)) for _, h := range hosts { line := fmt.Sprintf(" %s ansible_host=%s", h.Name, h.IP) if h.Port != 0 && h.Port != 22 { line += fmt.Sprintf(" ansible_port=%d", h.Port) } if h.Username != "" { line += fmt.Sprintf(" ansible_user=%s", h.Username) } if h.AuthType == "sshkey" && h.SSHKey != "" { line += fmt.Sprintf(" ansible_ssh_private_key_file=%s", h.SSHKey) } lines = append(lines, line) } lines = append(lines, "") } invFile := filepath.Join(s.config.InventoryDir, "hosts") os.WriteFile(invFile, []byte(strings.Join(lines, "\n")), 0644) } // ListHosts 获取主机列表 func (s *AnsibleService) ListHosts() []models.Host { var hosts []models.Host for _, h := range s.hosts { hosts = append(hosts, *h) } return hosts } // AddHost 添加主机 func (s *AnsibleService) AddHost(host models.Host) error { host.ID = s.generateID() host.CreatedAt = time.Now() host.UpdatedAt = time.Now() host.Status = "pending" s.hosts[host.ID] = &host return s.saveHosts() } // DeleteHost 删除主机 func (s *AnsibleService) DeleteHost(id string) error { if _, ok := s.hosts[id]; !ok { return fmt.Errorf("主机不存在") } delete(s.hosts, id) return s.saveHosts() } // UpdateHost 更新主机 func (s *AnsibleService) UpdateHost(id string, host models.Host) error { if _, ok := s.hosts[id]; !ok { return fmt.Errorf("主机不存在") } host.UpdatedAt = time.Now() s.hosts[id] = &host return s.saveHosts() } // ListGroups 获取主机组列表 func (s *AnsibleService) ListGroups() []models.HostGroup { var groups []models.HostGroup for _, g := range s.groups { gcopy := *g // 展开组内主机的详细信息 var hostList []models.Host for _, hName := range g.Hosts { for _, h := range s.hosts { if h.Name == hName { hostList = append(hostList, *h) break } } } gcopy.HostList = hostList groups = append(groups, gcopy) } return groups } // CreateGroup 创建主机组 func (s *AnsibleService) CreateGroup(group models.HostGroup) error { if _, ok := s.groups[group.Name]; ok { return fmt.Errorf("组已存在") } s.groups[group.Name] = &group return s.saveGroups() } // DeleteGroup 删除主机组 func (s *AnsibleService) DeleteGroup(name string) error { if name == "all" || name == "ungrouped" { return fmt.Errorf("不能删除系统组") } delete(s.groups, name) return s.saveGroups() } // UpdateGroup 更新主机组 func (s *AnsibleService) UpdateGroup(name string, group models.HostGroup) error { if _, ok := s.groups[name]; !ok { return fmt.Errorf("组不存在") } s.groups[name] = &group return s.saveGroups() } // saveGroups 保存组信息 func (s *AnsibleService) saveGroups() error { groupsFile := filepath.Join(s.config.InventoryDir, "groups.json") data, _ := json.MarshalIndent(s.groups, "", " ") return os.WriteFile(groupsFile, data, 0644) } // TestConnection 测试主机连接 func (s *AnsibleService) TestConnection(hostID string) (*models.CommandResult, error) { host, ok := s.hosts[hostID] if !ok { return nil, fmt.Errorf("主机不存在") } start := time.Now() result := &models.CommandResult{ Host: host.Name, Success: false, } // 构建ansible命令 args := []string{ host.Name, "-i", s.inventoryPath, "-m", "ping", "-u", host.Username, } // 认证方式:SSH Key 或 密码 if host.AuthType == "sshkey" && host.SSHKey != "" { // SSH Key 认证 args = append(args, "--private-key", host.SSHKey) } else if host.Password != "" { // 密码认证 args = append(args, "--extra-vars", fmt.Sprintf("ansible_password=%s", host.Password)) } // 如果端口不是22,通过extra-vars传递 if host.Port != 0 && host.Port != 22 { args = append(args, "--extra-vars", fmt.Sprintf("ansible_port=%d", host.Port)) } cmd := exec.Command(s.config.AnsiblePath, args...) // 通过环境变量禁用SSH主机密钥检查 cmd.Env = append(os.Environ(), "ANSIBLE_HOST_KEY_CHECKING=False") output, err := cmd.CombinedOutput() result.Duration = time.Since(start).Milliseconds() result.Output = string(output) if err != nil { result.Error = err.Error() host.Status = "offline" } else { result.Success = true if strings.Contains(string(output), "SUCCESS") || strings.Contains(string(output), "pong") { host.Status = "online" } else { host.Status = "offline" } } host.LastCheck = time.Now() // 持久化状态 s.saveHosts() return result, nil } // ExecuteCommand 执行单个命令 func (s *AnsibleService) ExecuteCommand(req models.CommandRequest) ([]models.CommandResult, error) { var results []models.CommandResult for _, hostName := range req.Hosts { result := s.runCommand(hostName, req.Command, req.Timeout) results = append(results, result) } return results, nil } // runCommand 在主机上执行命令 func (s *AnsibleService) runCommand(hostName string, command string, timeout int) models.CommandResult { start := time.Now() result := models.CommandResult{ Host: hostName, Success: false, } // 查找主机获取认证信息 var host *models.Host for _, h := range s.hosts { if h.Name == hostName { host = h break } } if host == nil { result.Error = "主机不存在" return result } if timeout == 0 { timeout = s.config.SSHTimeout } ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) defer cancel() args := []string{ host.Name, "-i", s.inventoryPath, "-m", "shell", "-a", command, "-u", host.Username, } // 认证方式:SSH Key 或 密码 if host.AuthType == "sshkey" && host.SSHKey != "" { args = append(args, "--private-key", host.SSHKey) } else if host.Password != "" { args = append(args, "--extra-vars", fmt.Sprintf("ansible_password=%s", host.Password)) } // 如果端口不是22,通过extra-vars传递 if host.Port != 0 && host.Port != 22 { args = append(args, "--extra-vars", fmt.Sprintf("ansible_port=%d", host.Port)) } cmd := exec.CommandContext(ctx, s.config.AnsiblePath, args...) // 通过环境变量禁用SSH主机密钥检查 cmd.Env = append(os.Environ(), "ANSIBLE_HOST_KEY_CHECKING=False") output, err := cmd.CombinedOutput() result.Duration = time.Since(start).Milliseconds() result.Output = string(output) if err != nil { result.Error = err.Error() if exitErr, ok := err.(*exec.ExitError); ok { result.ExitCode = exitErr.ExitCode() } } else { result.Success = true result.ExitCode = 0 } return result } // BatchExecute 批量执行命令 func (s *AnsibleService) BatchExecute(req models.CommandRequest) *models.BatchCommandResult { result := &models.BatchCommandResult{ TaskID: s.generateID(), Total: len(req.Hosts), Results: make([]models.CommandResult, 0), } task := &models.TaskExecution{ ID: result.TaskID, Name: "批量命令执行", Hosts: req.Hosts, Status: "running", StartTime: time.Now(), TotalHosts: len(req.Hosts), } s.taskLock.Lock() s.tasks[result.TaskID] = task s.taskLock.Unlock() // 并行执行 if req.Parallel { var wg sync.WaitGroup results := make(chan models.CommandResult, len(req.Hosts)) parallelism := s.config.MaxParallelism if parallelism <= 0 { parallelism = 10 } semaphore := make(chan struct{}, parallelism) for _, host := range req.Hosts { wg.Add(1) go func(h string) { defer wg.Done() semaphore <- struct{}{} defer func() { <-semaphore }() r := s.runCommand(h, req.Command, req.Timeout) results <- r }(host) } go func() { wg.Wait() close(results) }() for r := range results { result.Results = append(result.Results, r) if r.Success { result.Success++ } else { result.Failed++ } s.updateTaskProgress(result.TaskID, 1) } } else { // 串行执行 for _, host := range req.Hosts { r := s.runCommand(host, req.Command, req.Timeout) result.Results = append(result.Results, r) if r.Success { result.Success++ } else { result.Failed++ } s.updateTaskProgress(result.TaskID, 1) } } task.Status = "completed" task.EndTime = time.Now() return result } // updateTaskProgress 更新任务进度 func (s *AnsibleService) updateTaskProgress(taskID string, increment int) { s.taskLock.Lock() defer s.taskLock.Unlock() if task, ok := s.tasks[taskID]; ok { task.Progress += increment task.SuccessHosts = task.Progress if task.Progress >= task.TotalHosts { task.Status = "completed" task.EndTime = time.Now() } } } // ListTasks 获取任务列表 func (s *AnsibleService) ListTasks() []*models.TaskExecution { s.taskLock.RLock() defer s.taskLock.RUnlock() var tasks []*models.TaskExecution for _, t := range s.tasks { tasks = append(tasks, t) } return tasks } // GetTask 获取单个任务 func (s *AnsibleService) GetTask(id string) *models.TaskExecution { s.taskLock.RLock() defer s.taskLock.RUnlock() return s.tasks[id] } // CancelTask 取消任务 func (s *AnsibleService) CancelTask(id string) error { s.taskLock.Lock() defer s.taskLock.Unlock() if task, ok := s.tasks[id]; ok { if task.Status == "running" { task.Status = "cancelled" task.EndTime = time.Now() return nil } return fmt.Errorf("任务无法取消") } return fmt.Errorf("任务不存在") } // ExecutePlaybook 执行Playbook func (s *AnsibleService) ExecutePlaybook(req models.PlaybookExecutionRequest) (*models.TaskExecution, error) { playbookPath := filepath.Join(s.config.PlaybookDir, req.Name+".yml") if _, err := os.Stat(playbookPath); os.IsNotExist(err) { return nil, fmt.Errorf("Playbook不存在: %s", req.Name) } task := &models.TaskExecution{ ID: s.generateID(), Name: req.Name, Playbook: playbookPath, Hosts: req.Hosts, Status: "running", StartTime: time.Now(), TotalHosts: len(req.Hosts), SuccessHosts: 0, FailedHosts: 0, } s.taskLock.Lock() s.tasks[task.ID] = task s.taskLock.Unlock() // 启动异步执行 go s.runPlaybook(task, playbookPath, req) return task, nil } // runPlaybook 运行Playbook func (s *AnsibleService) runPlaybook(task *models.TaskExecution, playbookPath string, req models.PlaybookExecutionRequest) { var args []string // 添加inventory args = append(args, "-i", s.inventoryPath) // 添加hosts限制 if len(req.Hosts) > 0 { args = append(args, "-l", strings.Join(req.Hosts, ",")) } // 添加extra-vars if len(req.ExtraVars) > 0 { varsJSON, _ := json.Marshal(req.ExtraVars) args = append(args, "-e", string(varsJSON)) } // 添加tags if len(req.Tags) > 0 { args = append(args, "-t", strings.Join(req.Tags, ",")) } // 添加skip-tags if len(req.SkipTags) > 0 { args = append(args, "--skip-tags", strings.Join(req.SkipTags, ",")) } // 添加verbose if req.Verbose != "" { args = append(args, "-"+req.Verbose) } // 显示文件差异 if req.Diff { args = append(args, "-D") } // dry-run模式 if req.Check { args = append(args, "-C") } // 是否提权 if req.Become != nil { if *req.Become { args = append(args, "-b") } else { args = append(args, "--no-become") } } // 并发数 if req.Forks > 0 { args = append(args, "-f", strconv.Itoa(req.Forks)) } // 自定义额外参数 if req.ExtraArgs != "" { extraParts := strings.Fields(req.ExtraArgs) args = append(args, extraParts...) } // playbook路径放最后 args = append(args, playbookPath) // 构建命令 cmd := exec.Command("ansible-playbook", args...) // 设置超时 if req.Timeout > 0 { ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second) defer cancel() cmd = exec.CommandContext(ctx, "ansible-playbook", args...) } var output bytes.Buffer cmd.Stdout = &output cmd.Stderr = &output err := cmd.Run() task.EndTime = time.Now() if err != nil { task.Status = "failed" task.Error = err.Error() } else { task.Status = "success" } task.Output = output.String() } // ListPlaybooks 列出可用Playbooks func (s *AnsibleService) ListPlaybooks() []models.Playbook { var playbooks []models.Playbook files, _ := os.ReadDir(s.config.PlaybookDir) for _, f := range files { if !f.IsDir() && strings.HasSuffix(f.Name(), ".yml") { name := strings.TrimSuffix(f.Name(), ".yml") playbookPath := filepath.Join(s.config.PlaybookDir, f.Name()) playbook := models.Playbook{ Name: name, Path: playbookPath, } // 解析YAML获取描述和变量信息 data, err := os.ReadFile(playbookPath) if err == nil { // 尝试解析为playbook列表 var playEntries []map[string]interface{} if yaml.Unmarshal(data, &playEntries) == nil && len(playEntries) > 0 { first := playEntries[0] // 提取注释中的描述(name字段) if nameVal, ok := first["name"]; ok { playbook.Description = fmt.Sprintf("%v", nameVal) } // 提取vars if varsVal, ok := first["vars"]; ok { if varsMap, ok := varsVal.(map[string]interface{}); ok { playbook.Variables = varsMap } } } } playbooks = append(playbooks, playbook) } } return playbooks } // GetPlaybook 获取Playbook详情 func (s *AnsibleService) GetPlaybook(name string) (*models.Playbook, error) { playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml") data, err := os.ReadFile(playbookPath) if err != nil { return nil, fmt.Errorf("Playbook不存在") } var playbook models.Playbook playbook.Name = name playbook.Path = playbookPath // 简单解析YAML if err := yaml.Unmarshal(data, &playbook); err != nil { return nil, fmt.Errorf("Playbook解析失败") } return &playbook, nil } // WebSocketLogs WebSocket日志流 func (s *AnsibleService) WebSocketLogs(taskID string) (<-chan models.LogEntry, error) { logChan := make(chan models.LogEntry, 100) go func() { defer close(logChan) ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() for { select { case <-ticker.C: s.taskLock.RLock() task, ok := s.tasks[taskID] s.taskLock.RUnlock() if !ok { return } entry := models.LogEntry{ Time: time.Now().Format("15:04:05"), Level: "info", Host: "system", Message: fmt.Sprintf("Progress: %d/%d", task.Progress, task.TotalHosts), } logChan <- entry if task.Status == "completed" || task.Status == "failed" { return } } } }() return logChan, nil } // ParseAnsibleOutput 解析Ansible输出 func (s *AnsibleService) ParseAnsibleOutput(output string) (map[string]interface{}, error) { var result map[string]interface{} if err := json.Unmarshal([]byte(output), &result); err != nil { return nil, err } return result, nil } // GetTaskOutput 获取任务输出 func (s *AnsibleService) GetTaskOutput(taskID string) string { s.taskLock.RLock() defer s.taskLock.RUnlock() if task, ok := s.tasks[taskID]; ok { return task.Output } return "" } // CreatePlaybook 创建Playbook(通过内容) func (s *AnsibleService) CreatePlaybook(name string, content string) error { if name == "" { return fmt.Errorf("Playbook名称不能为空") } // 检查名称是否含非法字符 if strings.Contains(name, "/") || strings.Contains(name, "..") { return fmt.Errorf("Playbook名称包含非法字符") } playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml") if _, err := os.Stat(playbookPath); err == nil { return fmt.Errorf("Playbook已存在: %s", name) } // 验证YAML格式 var dummy interface{} if err := yaml.Unmarshal([]byte(content), &dummy); err != nil { return fmt.Errorf("YAML格式错误: %v", err) } return os.WriteFile(playbookPath, []byte(content), 0644) } // DeletePlaybook 删除Playbook func (s *AnsibleService) DeletePlaybook(name string) error { if strings.Contains(name, "/") || strings.Contains(name, "..") { return fmt.Errorf("Playbook名称包含非法字符") } playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml") if _, err := os.Stat(playbookPath); os.IsNotExist(err) { return fmt.Errorf("Playbook不存在: %s", name) } return os.Remove(playbookPath) } // UpdatePlaybook 更新Playbook内容 func (s *AnsibleService) UpdatePlaybook(name string, content string) error { if strings.Contains(name, "/") || strings.Contains(name, "..") { return fmt.Errorf("Playbook名称包含非法字符") } playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml") if _, err := os.Stat(playbookPath); os.IsNotExist(err) { return fmt.Errorf("Playbook不存在: %s", name) } // 验证YAML格式 var dummy interface{} if err := yaml.Unmarshal([]byte(content), &dummy); err != nil { return fmt.Errorf("YAML格式错误: %v", err) } return os.WriteFile(playbookPath, []byte(content), 0644) } // GetPlaybookContent 获取Playbook原始内容 func (s *AnsibleService) GetPlaybookContent(name string) (string, error) { if strings.Contains(name, "/") || strings.Contains(name, "..") { return "", fmt.Errorf("Playbook名称包含非法字符") } playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml") data, err := os.ReadFile(playbookPath) if err != nil { return "", fmt.Errorf("Playbook不存在") } return string(data), nil } // CheckAnsibleInstalled 检查Ansible是否安装 func (s *AnsibleService) CheckAnsibleInstalled() bool { cmd := exec.Command("ansible", "--version") err := cmd.Run() return err == nil } // GetInventoryPath 获取inventory路径 func (s *AnsibleService) GetInventoryPath() string { return s.inventoryPath }