983 lines
23 KiB
Go
983 lines
23 KiB
Go
package services
|
||
|
||
import (
|
||
"bufio"
|
||
"bytes"
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/ansible-deploy/internal/models"
|
||
"gopkg.in/yaml.v3"
|
||
)
|
||
|
||
// AnsibleService Ansible服务
|
||
type AnsibleService struct {
|
||
config *Config
|
||
hosts map[string]*models.Host
|
||
groups map[string]*models.HostGroup
|
||
inventoryPath string
|
||
tasks map[string]*models.TaskExecution
|
||
taskLock sync.RWMutex
|
||
}
|
||
|
||
// NewAnsibleService 创建Ansible服务
|
||
func NewAnsibleService(cfg *Config) *AnsibleService {
|
||
svc := &AnsibleService{
|
||
config: cfg,
|
||
hosts: make(map[string]*models.Host),
|
||
groups: make(map[string]*models.HostGroup),
|
||
inventoryPath: filepath.Join(cfg.InventoryDir, "hosts"),
|
||
tasks: make(map[string]*models.TaskExecution),
|
||
}
|
||
|
||
// 初始化默认组
|
||
svc.groups["all"] = &models.HostGroup{Name: "all", Description: "所有主机"}
|
||
svc.groups["ungrouped"] = &models.HostGroup{Name: "ungrouped", Description: "未分组主机"}
|
||
|
||
// 加载现有数据
|
||
svc.loadHosts()
|
||
svc.loadGroups()
|
||
|
||
return svc
|
||
}
|
||
|
||
// loadGroups 加载主机组列表
|
||
func (s *AnsibleService) loadGroups() {
|
||
groupsFile := filepath.Join(s.config.InventoryDir, "groups.json")
|
||
data, err := os.ReadFile(groupsFile)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
var groups map[string]models.HostGroup
|
||
if err := json.Unmarshal(data, &groups); err == nil {
|
||
for name, g := range groups {
|
||
if name != "all" && name != "ungrouped" {
|
||
gcopy := g
|
||
s.groups[name] = &gcopy
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// generateID 生成唯一ID
|
||
func (s *AnsibleService) generateID() string {
|
||
hash := md5.New()
|
||
hash.Write([]byte(time.Now().String()))
|
||
return hex.EncodeToString(hash.Sum(nil))[:8]
|
||
}
|
||
|
||
// loadInventory 加载资产清单
|
||
func (s *AnsibleService) loadInventory() {
|
||
invFile := filepath.Join(s.config.InventoryDir, "hosts")
|
||
data, err := os.ReadFile(invFile)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 解析INI格式的inventory
|
||
scanner := bufio.NewScanner(bytes.NewReader(data))
|
||
var currentGroup string
|
||
groupVars := make(map[string]map[string]string)
|
||
|
||
for scanner.Scan() {
|
||
line := strings.TrimSpace(scanner.Text())
|
||
if line == "" || strings.HasPrefix(line, "#") {
|
||
continue
|
||
}
|
||
|
||
// 组定义
|
||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
|
||
currentGroup = strings.Trim(line, "[]")
|
||
continue
|
||
}
|
||
|
||
// 变量定义
|
||
if strings.Contains(line, "=") {
|
||
parts := strings.SplitN(line, "=", 2)
|
||
if len(parts) == 2 {
|
||
if groupVars[currentGroup] == nil {
|
||
groupVars[currentGroup] = make(map[string]string)
|
||
}
|
||
groupVars[currentGroup][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||
}
|
||
}
|
||
|
||
// 主机定义
|
||
if strings.Contains(line, "ansible_host") {
|
||
re := regexp.MustCompile(`(\S+)\s+ansible_host=(\S+)`)
|
||
if matches := re.FindStringSubmatch(line); len(matches) == 3 {
|
||
host := &models.Host{
|
||
ID: s.generateID(),
|
||
Name: matches[1],
|
||
IP: matches[2],
|
||
Status: "unknown",
|
||
}
|
||
s.hosts[host.ID] = host
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// loadHosts 加载主机列表
|
||
func (s *AnsibleService) loadHosts() {
|
||
// 从hosts.json加载详细配置(唯一数据源)
|
||
hostsFile := filepath.Join(s.config.InventoryDir, "hosts.json")
|
||
data, err := os.ReadFile(hostsFile)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
var hosts []models.Host
|
||
if err := json.Unmarshal(data, &hosts); err == nil {
|
||
for _, h := range hosts {
|
||
host := h // 避免循环变量指针问题
|
||
if host.ID == "" {
|
||
host.ID = s.generateID()
|
||
}
|
||
if host.Port == 0 {
|
||
host.Port = 22
|
||
}
|
||
if host.Username == "" {
|
||
host.Username = "root"
|
||
}
|
||
if host.Status == "" {
|
||
host.Status = "pending"
|
||
}
|
||
s.hosts[host.ID] = &host
|
||
}
|
||
// 保存以持久化补全的字段
|
||
s.saveHosts()
|
||
}
|
||
}
|
||
|
||
// saveHosts 保存主机列表
|
||
func (s *AnsibleService) saveHosts() error {
|
||
hostsFile := filepath.Join(s.config.InventoryDir, "hosts.json")
|
||
var hosts []models.Host
|
||
for _, h := range s.hosts {
|
||
hcopy := *h
|
||
// 确保每个主机都有ID,并更新map中的指针
|
||
if hcopy.ID == "" {
|
||
hcopy.ID = s.generateID()
|
||
h.ID = hcopy.ID // 更新map中的指针
|
||
}
|
||
hosts = append(hosts, hcopy)
|
||
}
|
||
|
||
data, _ := json.MarshalIndent(hosts, "", " ")
|
||
if err := os.WriteFile(hostsFile, data, 0644); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 更新inventory文件
|
||
s.updateInventoryFile()
|
||
return nil
|
||
}
|
||
|
||
// updateInventoryFile 更新inventory文件
|
||
func (s *AnsibleService) updateInventoryFile() {
|
||
var lines []string
|
||
lines = append(lines, "# Ansible Inventory File")
|
||
lines = append(lines, "# Generated by ansible-deploy")
|
||
lines = append(lines, "")
|
||
|
||
// 按组分组主机
|
||
groupedHosts := make(map[string][]models.Host)
|
||
for _, h := range s.hosts {
|
||
if len(h.Groups) == 0 {
|
||
groupedHosts["ungrouped"] = append(groupedHosts["ungrouped"], *h)
|
||
} else {
|
||
for _, g := range h.Groups {
|
||
groupedHosts[g] = append(groupedHosts[g], *h)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 输出每个组
|
||
for group, hosts := range groupedHosts {
|
||
lines = append(lines, fmt.Sprintf("[%s]", group))
|
||
for _, h := range hosts {
|
||
line := fmt.Sprintf(" %s ansible_host=%s", h.Name, h.IP)
|
||
if h.Port != 0 && h.Port != 22 {
|
||
line += fmt.Sprintf(" ansible_port=%d", h.Port)
|
||
}
|
||
if h.Username != "" {
|
||
line += fmt.Sprintf(" ansible_user=%s", h.Username)
|
||
}
|
||
if h.AuthType == "sshkey" && h.SSHKey != "" {
|
||
line += fmt.Sprintf(" ansible_ssh_private_key_file=%s", h.SSHKey)
|
||
}
|
||
lines = append(lines, line)
|
||
}
|
||
lines = append(lines, "")
|
||
}
|
||
|
||
invFile := filepath.Join(s.config.InventoryDir, "hosts")
|
||
os.WriteFile(invFile, []byte(strings.Join(lines, "\n")), 0644)
|
||
}
|
||
|
||
// ListHosts 获取主机列表
|
||
func (s *AnsibleService) ListHosts() []models.Host {
|
||
var hosts []models.Host
|
||
for _, h := range s.hosts {
|
||
hosts = append(hosts, *h)
|
||
}
|
||
return hosts
|
||
}
|
||
|
||
// AddHost 添加主机
|
||
func (s *AnsibleService) AddHost(host models.Host) error {
|
||
host.ID = s.generateID()
|
||
host.CreatedAt = time.Now()
|
||
host.UpdatedAt = time.Now()
|
||
host.Status = "pending"
|
||
|
||
s.hosts[host.ID] = &host
|
||
return s.saveHosts()
|
||
}
|
||
|
||
// DeleteHost 删除主机
|
||
func (s *AnsibleService) DeleteHost(id string) error {
|
||
if _, ok := s.hosts[id]; !ok {
|
||
return fmt.Errorf("主机不存在")
|
||
}
|
||
delete(s.hosts, id)
|
||
return s.saveHosts()
|
||
}
|
||
|
||
// UpdateHost 更新主机
|
||
func (s *AnsibleService) UpdateHost(id string, host models.Host) error {
|
||
if _, ok := s.hosts[id]; !ok {
|
||
return fmt.Errorf("主机不存在")
|
||
}
|
||
host.UpdatedAt = time.Now()
|
||
s.hosts[id] = &host
|
||
return s.saveHosts()
|
||
}
|
||
|
||
// ListGroups 获取主机组列表
|
||
func (s *AnsibleService) ListGroups() []models.HostGroup {
|
||
var groups []models.HostGroup
|
||
for _, g := range s.groups {
|
||
gcopy := *g
|
||
// 动态展开组内主机(通过 host.Groups 字段关联,而非 group.Hosts)
|
||
var hostList []models.Host
|
||
for _, h := range s.hosts {
|
||
if gcopy.Name == "all" {
|
||
// all 组包含所有主机
|
||
hcopy := *h
|
||
hostList = append(hostList, hcopy)
|
||
continue
|
||
}
|
||
for _, hGroup := range h.Groups {
|
||
if hGroup == gcopy.Name {
|
||
hcopy := *h
|
||
hostList = append(hostList, hcopy)
|
||
break
|
||
}
|
||
}
|
||
// 也检查主机的默认组(ungrouped)
|
||
if len(h.Groups) == 0 && gcopy.Name == "ungrouped" {
|
||
hcopy := *h
|
||
hostList = append(hostList, hcopy)
|
||
}
|
||
}
|
||
gcopy.HostList = hostList
|
||
groups = append(groups, gcopy)
|
||
}
|
||
return groups
|
||
}
|
||
|
||
// CreateGroup 创建主机组
|
||
func (s *AnsibleService) CreateGroup(group models.HostGroup) error {
|
||
if _, ok := s.groups[group.Name]; ok {
|
||
return fmt.Errorf("组已存在")
|
||
}
|
||
s.groups[group.Name] = &group
|
||
return s.saveGroups()
|
||
}
|
||
|
||
// DeleteGroup 删除主机组
|
||
func (s *AnsibleService) DeleteGroup(name string) error {
|
||
if name == "all" || name == "ungrouped" {
|
||
return fmt.Errorf("不能删除系统组")
|
||
}
|
||
delete(s.groups, name)
|
||
return s.saveGroups()
|
||
}
|
||
|
||
// UpdateGroup 更新主机组
|
||
func (s *AnsibleService) UpdateGroup(name string, group models.HostGroup) error {
|
||
if _, ok := s.groups[name]; !ok {
|
||
return fmt.Errorf("组不存在")
|
||
}
|
||
s.groups[name] = &group
|
||
return s.saveGroups()
|
||
}
|
||
|
||
// saveGroups 保存组信息
|
||
func (s *AnsibleService) saveGroups() error {
|
||
groupsFile := filepath.Join(s.config.InventoryDir, "groups.json")
|
||
data, _ := json.MarshalIndent(s.groups, "", " ")
|
||
return os.WriteFile(groupsFile, data, 0644)
|
||
}
|
||
|
||
// TestConnection 测试主机连接
|
||
func (s *AnsibleService) TestConnection(hostID string) (*models.CommandResult, error) {
|
||
host, ok := s.hosts[hostID]
|
||
if !ok {
|
||
return nil, fmt.Errorf("主机不存在")
|
||
}
|
||
|
||
start := time.Now()
|
||
result := &models.CommandResult{
|
||
Host: host.Name,
|
||
Success: false,
|
||
}
|
||
|
||
// 构建ansible命令
|
||
args := []string{
|
||
host.Name,
|
||
"-i", s.inventoryPath,
|
||
"-m", "ping",
|
||
"-u", host.Username,
|
||
}
|
||
|
||
// 认证方式:SSH Key 或 密码
|
||
if host.AuthType == "sshkey" && host.SSHKey != "" {
|
||
// SSH Key 认证
|
||
args = append(args, "--private-key", host.SSHKey)
|
||
} else if host.Password != "" {
|
||
// 密码认证
|
||
args = append(args, "--extra-vars", fmt.Sprintf("ansible_password=%s", host.Password))
|
||
}
|
||
|
||
// 如果端口不是22,通过extra-vars传递
|
||
if host.Port != 0 && host.Port != 22 {
|
||
args = append(args, "--extra-vars", fmt.Sprintf("ansible_port=%d", host.Port))
|
||
}
|
||
|
||
cmd := exec.Command(s.config.AnsiblePath, args...)
|
||
// 通过环境变量禁用SSH主机密钥检查
|
||
cmd.Env = append(os.Environ(), "ANSIBLE_HOST_KEY_CHECKING=False")
|
||
output, err := cmd.CombinedOutput()
|
||
result.Duration = time.Since(start).Milliseconds()
|
||
result.Output = string(output)
|
||
|
||
if err != nil {
|
||
result.Error = err.Error()
|
||
host.Status = "offline"
|
||
} else {
|
||
result.Success = true
|
||
if strings.Contains(string(output), "SUCCESS") || strings.Contains(string(output), "pong") {
|
||
host.Status = "online"
|
||
} else {
|
||
host.Status = "offline"
|
||
}
|
||
}
|
||
host.LastCheck = time.Now()
|
||
|
||
// 持久化状态
|
||
s.saveHosts()
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// ExecuteCommand 执行单个命令
|
||
func (s *AnsibleService) ExecuteCommand(req models.CommandRequest) ([]models.CommandResult, error) {
|
||
var results []models.CommandResult
|
||
|
||
for _, hostName := range req.Hosts {
|
||
result := s.runCommand(hostName, req.Command, req.Timeout)
|
||
results = append(results, result)
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// runCommand 在主机上执行命令
|
||
func (s *AnsibleService) runCommand(hostName string, command string, timeout int) models.CommandResult {
|
||
start := time.Now()
|
||
result := models.CommandResult{
|
||
Host: hostName,
|
||
Success: false,
|
||
}
|
||
|
||
// 查找主机获取认证信息
|
||
var host *models.Host
|
||
for _, h := range s.hosts {
|
||
if h.Name == hostName {
|
||
host = h
|
||
break
|
||
}
|
||
}
|
||
if host == nil {
|
||
result.Error = "主机不存在"
|
||
return result
|
||
}
|
||
|
||
if timeout == 0 {
|
||
timeout = s.config.SSHTimeout
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||
defer cancel()
|
||
|
||
args := []string{
|
||
host.Name,
|
||
"-i", s.inventoryPath,
|
||
"-m", "shell",
|
||
"-a", command,
|
||
"-u", host.Username,
|
||
}
|
||
|
||
// 认证方式:SSH Key 或 密码
|
||
if host.AuthType == "sshkey" && host.SSHKey != "" {
|
||
args = append(args, "--private-key", host.SSHKey)
|
||
} else if host.Password != "" {
|
||
args = append(args, "--extra-vars", fmt.Sprintf("ansible_password=%s", host.Password))
|
||
}
|
||
|
||
// 如果端口不是22,通过extra-vars传递
|
||
if host.Port != 0 && host.Port != 22 {
|
||
args = append(args, "--extra-vars", fmt.Sprintf("ansible_port=%d", host.Port))
|
||
}
|
||
|
||
cmd := exec.CommandContext(ctx, s.config.AnsiblePath, args...)
|
||
// 通过环境变量禁用SSH主机密钥检查
|
||
cmd.Env = append(os.Environ(), "ANSIBLE_HOST_KEY_CHECKING=False")
|
||
output, err := cmd.CombinedOutput()
|
||
result.Duration = time.Since(start).Milliseconds()
|
||
result.Output = string(output)
|
||
|
||
if err != nil {
|
||
result.Error = err.Error()
|
||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||
result.ExitCode = exitErr.ExitCode()
|
||
}
|
||
} else {
|
||
result.Success = true
|
||
result.ExitCode = 0
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// BatchExecute 批量执行命令
|
||
func (s *AnsibleService) BatchExecute(req models.CommandRequest) *models.BatchCommandResult {
|
||
result := &models.BatchCommandResult{
|
||
TaskID: s.generateID(),
|
||
Total: len(req.Hosts),
|
||
Results: make([]models.CommandResult, 0),
|
||
}
|
||
|
||
task := &models.TaskExecution{
|
||
ID: result.TaskID,
|
||
Name: "批量命令执行",
|
||
Hosts: req.Hosts,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
TotalHosts: len(req.Hosts),
|
||
}
|
||
|
||
s.taskLock.Lock()
|
||
s.tasks[result.TaskID] = task
|
||
s.taskLock.Unlock()
|
||
|
||
// 并行执行
|
||
if req.Parallel {
|
||
var wg sync.WaitGroup
|
||
results := make(chan models.CommandResult, len(req.Hosts))
|
||
|
||
parallelism := s.config.MaxParallelism
|
||
if parallelism <= 0 {
|
||
parallelism = 10
|
||
}
|
||
semaphore := make(chan struct{}, parallelism)
|
||
|
||
for _, host := range req.Hosts {
|
||
wg.Add(1)
|
||
go func(h string) {
|
||
defer wg.Done()
|
||
semaphore <- struct{}{}
|
||
defer func() { <-semaphore }()
|
||
|
||
r := s.runCommand(h, req.Command, req.Timeout)
|
||
results <- r
|
||
}(host)
|
||
}
|
||
|
||
go func() {
|
||
wg.Wait()
|
||
close(results)
|
||
}()
|
||
|
||
for r := range results {
|
||
result.Results = append(result.Results, r)
|
||
if r.Success {
|
||
result.Success++
|
||
} else {
|
||
result.Failed++
|
||
}
|
||
s.updateTaskProgress(result.TaskID, 1)
|
||
}
|
||
} else {
|
||
// 串行执行
|
||
for _, host := range req.Hosts {
|
||
r := s.runCommand(host, req.Command, req.Timeout)
|
||
result.Results = append(result.Results, r)
|
||
if r.Success {
|
||
result.Success++
|
||
} else {
|
||
result.Failed++
|
||
}
|
||
s.updateTaskProgress(result.TaskID, 1)
|
||
}
|
||
}
|
||
|
||
task.Status = "completed"
|
||
task.EndTime = time.Now()
|
||
|
||
return result
|
||
}
|
||
|
||
// updateTaskProgress 更新任务进度
|
||
func (s *AnsibleService) updateTaskProgress(taskID string, increment int) {
|
||
s.taskLock.Lock()
|
||
defer s.taskLock.Unlock()
|
||
|
||
if task, ok := s.tasks[taskID]; ok {
|
||
task.Progress += increment
|
||
task.SuccessHosts = task.Progress
|
||
if task.Progress >= task.TotalHosts {
|
||
task.Status = "completed"
|
||
task.EndTime = time.Now()
|
||
}
|
||
}
|
||
}
|
||
|
||
// ListTasks 获取任务列表
|
||
func (s *AnsibleService) ListTasks() []*models.TaskExecution {
|
||
s.taskLock.RLock()
|
||
defer s.taskLock.RUnlock()
|
||
|
||
var tasks []*models.TaskExecution
|
||
for _, t := range s.tasks {
|
||
tasks = append(tasks, t)
|
||
}
|
||
return tasks
|
||
}
|
||
|
||
// GetTask 获取单个任务
|
||
func (s *AnsibleService) GetTask(id string) *models.TaskExecution {
|
||
s.taskLock.RLock()
|
||
defer s.taskLock.RUnlock()
|
||
return s.tasks[id]
|
||
}
|
||
|
||
// CancelTask 取消任务
|
||
func (s *AnsibleService) CancelTask(id string) error {
|
||
s.taskLock.Lock()
|
||
defer s.taskLock.Unlock()
|
||
|
||
if task, ok := s.tasks[id]; ok {
|
||
if task.Status == "running" {
|
||
task.Status = "cancelled"
|
||
task.EndTime = time.Now()
|
||
return nil
|
||
}
|
||
return fmt.Errorf("任务无法取消")
|
||
}
|
||
return fmt.Errorf("任务不存在")
|
||
}
|
||
|
||
// ExecutePlaybook 执行Playbook
|
||
func (s *AnsibleService) ExecutePlaybook(req models.PlaybookExecutionRequest) (*models.TaskExecution, error) {
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, req.Name+".yml")
|
||
|
||
if _, err := os.Stat(playbookPath); os.IsNotExist(err) {
|
||
return nil, fmt.Errorf("Playbook不存在: %s", req.Name)
|
||
}
|
||
|
||
task := &models.TaskExecution{
|
||
ID: s.generateID(),
|
||
Name: req.Name,
|
||
Playbook: playbookPath,
|
||
Hosts: req.Hosts,
|
||
Status: "running",
|
||
StartTime: time.Now(),
|
||
TotalHosts: len(req.Hosts),
|
||
SuccessHosts: 0,
|
||
FailedHosts: 0,
|
||
}
|
||
|
||
s.taskLock.Lock()
|
||
s.tasks[task.ID] = task
|
||
s.taskLock.Unlock()
|
||
|
||
// 启动异步执行
|
||
go s.runPlaybook(task, playbookPath, req)
|
||
|
||
return task, nil
|
||
}
|
||
|
||
// runPlaybook 运行Playbook
|
||
func (s *AnsibleService) runPlaybook(task *models.TaskExecution, playbookPath string, req models.PlaybookExecutionRequest) {
|
||
var args []string
|
||
|
||
// 添加inventory
|
||
args = append(args, "-i", s.inventoryPath)
|
||
|
||
// 添加hosts限制
|
||
if len(req.Hosts) > 0 {
|
||
args = append(args, "-l", strings.Join(req.Hosts, ","))
|
||
}
|
||
|
||
// 添加extra-vars
|
||
if len(req.ExtraVars) > 0 {
|
||
varsJSON, _ := json.Marshal(req.ExtraVars)
|
||
args = append(args, "-e", string(varsJSON))
|
||
}
|
||
|
||
// 添加tags
|
||
if len(req.Tags) > 0 {
|
||
args = append(args, "-t", strings.Join(req.Tags, ","))
|
||
}
|
||
|
||
// 添加skip-tags
|
||
if len(req.SkipTags) > 0 {
|
||
args = append(args, "--skip-tags", strings.Join(req.SkipTags, ","))
|
||
}
|
||
|
||
// 添加verbose
|
||
if req.Verbose != "" {
|
||
args = append(args, "-"+req.Verbose)
|
||
}
|
||
|
||
// 显示文件差异
|
||
if req.Diff {
|
||
args = append(args, "-D")
|
||
}
|
||
|
||
// dry-run模式
|
||
if req.Check {
|
||
args = append(args, "-C")
|
||
}
|
||
|
||
// 是否提权
|
||
if req.Become != nil {
|
||
if *req.Become {
|
||
args = append(args, "-b")
|
||
} else {
|
||
args = append(args, "--no-become")
|
||
}
|
||
}
|
||
|
||
// 并发数
|
||
if req.Forks > 0 {
|
||
args = append(args, "-f", strconv.Itoa(req.Forks))
|
||
}
|
||
|
||
// 自定义额外参数
|
||
if req.ExtraArgs != "" {
|
||
extraParts := strings.Fields(req.ExtraArgs)
|
||
args = append(args, extraParts...)
|
||
}
|
||
|
||
// playbook路径放最后
|
||
args = append(args, playbookPath)
|
||
|
||
// 构建命令
|
||
cmd := exec.Command("ansible-playbook", args...)
|
||
|
||
// 设置超时
|
||
if req.Timeout > 0 {
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(req.Timeout)*time.Second)
|
||
defer cancel()
|
||
cmd = exec.CommandContext(ctx, "ansible-playbook", args...)
|
||
}
|
||
|
||
// 实时写入日志的 Writer
|
||
sw := &syncWriter{buf: bytes.NewBuffer(nil)}
|
||
cmd.Stdout = sw
|
||
cmd.Stderr = sw
|
||
|
||
// 启动 goroutine 实时搬运日志到 task.Output
|
||
done := make(chan struct{})
|
||
go func() {
|
||
ticker := time.NewTicker(200 * time.Millisecond)
|
||
defer ticker.Stop()
|
||
var lastLen int
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
s.taskLock.Lock()
|
||
sw.mu.Lock()
|
||
currentLen := sw.buf.Len()
|
||
if currentLen > lastLen {
|
||
task.Output = sw.buf.String()
|
||
lastLen = currentLen
|
||
}
|
||
sw.mu.Unlock()
|
||
s.taskLock.Unlock()
|
||
case <-done:
|
||
return
|
||
}
|
||
}
|
||
}()
|
||
|
||
err := cmd.Run()
|
||
close(done) // 通知 goroutine 退出
|
||
|
||
// 最终同步一次完整日志
|
||
sw.mu.Lock()
|
||
finalOutput := sw.buf.String()
|
||
sw.mu.Unlock()
|
||
|
||
s.taskLock.Lock()
|
||
task.Output = finalOutput
|
||
task.EndTime = time.Now()
|
||
if err != nil {
|
||
task.Status = "failed"
|
||
task.Error = err.Error()
|
||
} else {
|
||
task.Status = "success"
|
||
}
|
||
s.taskLock.Unlock()
|
||
}
|
||
|
||
// syncWriter 线程安全的 Writer
|
||
type syncWriter struct {
|
||
buf *bytes.Buffer
|
||
mu sync.Mutex
|
||
}
|
||
|
||
func (w *syncWriter) Write(p []byte) (n int, err error) {
|
||
w.mu.Lock()
|
||
defer w.mu.Unlock()
|
||
return w.buf.Write(p)
|
||
}
|
||
|
||
func (w *syncWriter) String() string {
|
||
w.mu.Lock()
|
||
defer w.mu.Unlock()
|
||
return w.buf.String()
|
||
}
|
||
|
||
// ListPlaybooks 列出可用Playbooks
|
||
func (s *AnsibleService) ListPlaybooks() []models.Playbook {
|
||
var playbooks []models.Playbook
|
||
|
||
files, _ := os.ReadDir(s.config.PlaybookDir)
|
||
for _, f := range files {
|
||
if !f.IsDir() && strings.HasSuffix(f.Name(), ".yml") {
|
||
name := strings.TrimSuffix(f.Name(), ".yml")
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, f.Name())
|
||
playbook := models.Playbook{
|
||
Name: name,
|
||
Path: playbookPath,
|
||
}
|
||
|
||
// 解析YAML获取描述和变量信息
|
||
data, err := os.ReadFile(playbookPath)
|
||
if err == nil {
|
||
// 尝试解析为playbook列表
|
||
var playEntries []map[string]interface{}
|
||
if yaml.Unmarshal(data, &playEntries) == nil && len(playEntries) > 0 {
|
||
first := playEntries[0]
|
||
// 提取注释中的描述(name字段)
|
||
if nameVal, ok := first["name"]; ok {
|
||
playbook.Description = fmt.Sprintf("%v", nameVal)
|
||
}
|
||
// 提取vars
|
||
if varsVal, ok := first["vars"]; ok {
|
||
if varsMap, ok := varsVal.(map[string]interface{}); ok {
|
||
playbook.Variables = varsMap
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
playbooks = append(playbooks, playbook)
|
||
}
|
||
}
|
||
|
||
return playbooks
|
||
}
|
||
|
||
// GetPlaybook 获取Playbook详情
|
||
func (s *AnsibleService) GetPlaybook(name string) (*models.Playbook, error) {
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml")
|
||
|
||
data, err := os.ReadFile(playbookPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("Playbook不存在")
|
||
}
|
||
|
||
var playbook models.Playbook
|
||
playbook.Name = name
|
||
playbook.Path = playbookPath
|
||
|
||
// 简单解析YAML
|
||
if err := yaml.Unmarshal(data, &playbook); err != nil {
|
||
return nil, fmt.Errorf("Playbook解析失败")
|
||
}
|
||
|
||
return &playbook, nil
|
||
}
|
||
|
||
// WebSocketLogs WebSocket日志流
|
||
func (s *AnsibleService) WebSocketLogs(taskID string) (<-chan models.LogEntry, error) {
|
||
logChan := make(chan models.LogEntry, 100)
|
||
|
||
go func() {
|
||
defer close(logChan)
|
||
|
||
ticker := time.NewTicker(500 * time.Millisecond)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
s.taskLock.RLock()
|
||
task, ok := s.tasks[taskID]
|
||
s.taskLock.RUnlock()
|
||
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
entry := models.LogEntry{
|
||
Time: time.Now().Format("15:04:05"),
|
||
Level: "info",
|
||
Host: "system",
|
||
Message: fmt.Sprintf("Progress: %d/%d", task.Progress, task.TotalHosts),
|
||
}
|
||
logChan <- entry
|
||
|
||
if task.Status == "completed" || task.Status == "failed" {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
}()
|
||
|
||
return logChan, nil
|
||
}
|
||
|
||
// ParseAnsibleOutput 解析Ansible输出
|
||
func (s *AnsibleService) ParseAnsibleOutput(output string) (map[string]interface{}, error) {
|
||
var result map[string]interface{}
|
||
if err := json.Unmarshal([]byte(output), &result); err != nil {
|
||
return nil, err
|
||
}
|
||
return result, nil
|
||
}
|
||
|
||
// GetTaskOutput 获取任务输出
|
||
func (s *AnsibleService) GetTaskOutput(taskID string) string {
|
||
s.taskLock.RLock()
|
||
defer s.taskLock.RUnlock()
|
||
|
||
if task, ok := s.tasks[taskID]; ok {
|
||
return task.Output
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// CreatePlaybook 创建Playbook(通过内容)
|
||
func (s *AnsibleService) CreatePlaybook(name string, content string) error {
|
||
if name == "" {
|
||
return fmt.Errorf("Playbook名称不能为空")
|
||
}
|
||
// 检查名称是否含非法字符
|
||
if strings.Contains(name, "/") || strings.Contains(name, "..") {
|
||
return fmt.Errorf("Playbook名称包含非法字符")
|
||
}
|
||
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml")
|
||
if _, err := os.Stat(playbookPath); err == nil {
|
||
return fmt.Errorf("Playbook已存在: %s", name)
|
||
}
|
||
|
||
// 验证YAML格式
|
||
var dummy interface{}
|
||
if err := yaml.Unmarshal([]byte(content), &dummy); err != nil {
|
||
return fmt.Errorf("YAML格式错误: %v", err)
|
||
}
|
||
|
||
return os.WriteFile(playbookPath, []byte(content), 0644)
|
||
}
|
||
|
||
// DeletePlaybook 删除Playbook
|
||
func (s *AnsibleService) DeletePlaybook(name string) error {
|
||
if strings.Contains(name, "/") || strings.Contains(name, "..") {
|
||
return fmt.Errorf("Playbook名称包含非法字符")
|
||
}
|
||
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml")
|
||
if _, err := os.Stat(playbookPath); os.IsNotExist(err) {
|
||
return fmt.Errorf("Playbook不存在: %s", name)
|
||
}
|
||
|
||
return os.Remove(playbookPath)
|
||
}
|
||
|
||
// UpdatePlaybook 更新Playbook内容
|
||
func (s *AnsibleService) UpdatePlaybook(name string, content string) error {
|
||
if strings.Contains(name, "/") || strings.Contains(name, "..") {
|
||
return fmt.Errorf("Playbook名称包含非法字符")
|
||
}
|
||
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml")
|
||
if _, err := os.Stat(playbookPath); os.IsNotExist(err) {
|
||
return fmt.Errorf("Playbook不存在: %s", name)
|
||
}
|
||
|
||
// 验证YAML格式
|
||
var dummy interface{}
|
||
if err := yaml.Unmarshal([]byte(content), &dummy); err != nil {
|
||
return fmt.Errorf("YAML格式错误: %v", err)
|
||
}
|
||
|
||
return os.WriteFile(playbookPath, []byte(content), 0644)
|
||
}
|
||
|
||
// GetPlaybookContent 获取Playbook原始内容
|
||
func (s *AnsibleService) GetPlaybookContent(name string) (string, error) {
|
||
if strings.Contains(name, "/") || strings.Contains(name, "..") {
|
||
return "", fmt.Errorf("Playbook名称包含非法字符")
|
||
}
|
||
|
||
playbookPath := filepath.Join(s.config.PlaybookDir, name+".yml")
|
||
data, err := os.ReadFile(playbookPath)
|
||
if err != nil {
|
||
return "", fmt.Errorf("Playbook不存在")
|
||
}
|
||
return string(data), nil
|
||
}
|
||
|
||
// CheckAnsibleInstalled 检查Ansible是否安装
|
||
func (s *AnsibleService) CheckAnsibleInstalled() bool {
|
||
cmd := exec.Command("ansible", "--version")
|
||
err := cmd.Run()
|
||
return err == nil
|
||
}
|
||
|
||
// GetInventoryPath 获取inventory路径
|
||
func (s *AnsibleService) GetInventoryPath() string {
|
||
return s.inventoryPath
|
||
}
|