"""主机注册表 - 管理多台 libvirt 宿主机""" import json import os import time import uuid import logging from typing import Optional from pydantic import BaseModel, Field from app.config import settings logger = logging.getLogger(__name__) # 数据存储目录 DATA_DIR = os.environ.get("KVM_DATA_DIR", "/var/lib/kvm-manager") HOSTS_FILE = os.path.join(DATA_DIR, "hosts.json") class HostInfo(BaseModel): """主机信息模型""" id: str = Field(..., description="主机唯一ID") name: str = Field(..., description="主机名称") uri: str = Field(..., description="libvirt 连接 URI") type: str = Field("local", description="连接类型: local/tcp/ssh") ssh_key_path: Optional[str] = Field(None, description="SSH 私钥路径(ssh 模式)") status: str = "unknown" created_at: float = Field(default_factory=time.time) last_seen: Optional[float] = None class HostCreate(BaseModel): """创建主机请求""" name: str = Field(..., description="主机名称") uri: str = Field(..., description="libvirt 连接 URI,如 qemu+tcp://192.168.1.2/system") ssh_key_path: Optional[str] = Field(None, description="SSH 私钥路径(ssh 模式)") def _detect_type(uri: str) -> str: """根据 URI 判断连接类型""" if uri.startswith("qemu+ssh://"): return "ssh" elif uri.startswith("qemu+tcp://"): return "tcp" return "local" def _ensure_data_dir(): """确保数据目录存在""" os.makedirs(DATA_DIR, exist_ok=True) def _load_hosts() -> dict: """从文件加载主机列表""" if not os.path.exists(HOSTS_FILE): return {} try: with open(HOSTS_FILE, "r") as f: return json.load(f) except (json.JSONDecodeError, IOError): return {} def _save_hosts(data: dict): """保存主机列表到文件""" _ensure_data_dir() with open(HOSTS_FILE, "w") as f: json.dump(data, f, indent=2, ensure_ascii=False) def _init_local_host() -> dict: """初始化本机默认主机""" return HostInfo( id="local", name="本机", uri=settings.LIBVIRT_URI, type="local", status="unknown", ).model_dump() def list_hosts() -> list[HostInfo]: """列出所有已注册主机""" data = _load_hosts() if not data: # 首次运行,初始化本机 local = _init_local_host() data["local"] = local _save_hosts(data) return [HostInfo(**h) for h in data.values()] def get_host(host_id: str) -> Optional[HostInfo]: """获取单个主机信息""" data = _load_hosts() if host_id not in data: return None return HostInfo(**data[host_id]) def add_host(req: HostCreate) -> HostInfo: """添加新主机""" data = _load_hosts() if not data: data["local"] = _init_local_host() host_id = req.name.lower().replace(" ", "-").replace(".", "-") # 确保ID唯一 if host_id in data: host_id = f"{host_id}-{uuid.uuid4().hex[:6]}" host_type = _detect_type(req.uri) # 构建 SSH URI uri = req.uri if host_type == "ssh" and req.ssh_key_path: # 在 URI 中嵌入 key 提示,实际连接时由 libvirt ssh driver 使用 pass host = HostInfo( id=host_id, name=req.name, uri=uri, type=host_type, ssh_key_path=req.ssh_key_path, status="unknown", created_at=time.time(), ) data[host_id] = host.model_dump() _save_hosts(data) return host def remove_host(host_id: str) -> bool: """删除主机(local 不可删)""" if host_id == "local": return False data = _load_hosts() if host_id not in data: return False del data[host_id] _save_hosts(data) return True def update_host_status(host_id: str, status: str): """更新主机在线状态""" data = _load_hosts() if host_id in data: data[host_id]["status"] = status data[host_id]["last_seen"] = time.time() _save_hosts(data) def test_connection(uri: str) -> dict: """测试 libvirt 连接是否可用""" import libvirt try: conn = libvirt.openReadOnly(uri) if conn: info = conn.getInfo() result = { "success": True, "hostname": conn.getHostname(), "hypervisor": conn.getType(), "cpu_cores": info[2], "memory_mb": info[1], # getInfo()[1] 已经是 MiB 单位 "libvirt_version": conn.getLibVersion(), } conn.close() return result else: return {"success": False, "error": "无法建立连接"} except libvirt.libvirtError as e: return {"success": False, "error": str(e)} except Exception as e: return {"success": False, "error": str(e)}