1
0
Fichiers
T
Your Name a2505cfe44 Fix: 统一日志术语并清理重复注释
- database -> storage
- 移除重复的Neighbor注释
2026-04-26 01:24:00 +08:00

371 lignes
9.1 KiB
Go

package main
import (
"crypto/rand"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"sync"
"time"
"network-topology-discovery/internal/config"
"network-topology-discovery/internal/device"
"network-topology-discovery/internal/scanner"
"network-topology-discovery/internal/storage"
"network-topology-discovery/internal/topology"
"network-topology-discovery/pkg/models"
)
// App 应用
type App struct {
config *config.Config
builder *topology.Builder
storage *storage.Storage
tasks map[string]*models.ScanTask
mu sync.RWMutex
httpServer *http.Server
}
// NewApp 创建应用
func NewApp(cfg *config.Config) *App {
// 初始化存储(使用JSON文件)
store, err := storage.NewStorage("devices.json")
if err != nil {
log.Printf("Warning: failed to initialize storage: %v", err)
}
app := &App{
config: cfg,
builder: topology.NewBuilder(),
storage: store,
tasks: make(map[string]*models.ScanTask),
}
// 从数据库加载设备到拓扑构建器
if store != nil {
devices, err := store.GetAllDevices()
if err != nil {
log.Printf("Warning: failed to load devices from database: %v", err)
} else {
log.Printf("Loaded %d devices from storage", len(devices))
for _, dev := range devices {
app.builder.AddDevice(dev)
}
}
}
return app
}
// Start 启动应用
func (app *App) Start() error {
// 设置路由
mux := http.NewServeMux()
// 静态文件服务 - 使用文件系统而非embed
webDir := getWebDir()
if _, err := os.Stat(webDir); err == nil {
mux.Handle("/", http.FileServer(http.Dir(webDir)))
} else {
log.Printf("警告: web目录不存在,静态文件服务不可用")
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("<h1>网络拓扑发现系统</h1><p>Web界面文件未找到</p>"))
})
}
// API路由
mux.HandleFunc("/api/scan", app.handleScan)
mux.HandleFunc("/api/scan/{id}", app.handleScanProgress)
mux.HandleFunc("/api/topology", app.handleTopology)
mux.HandleFunc("/api/devices", app.handleGetDevices)
mux.HandleFunc("/api/device", app.handleAddDevice)
mux.HandleFunc("/api/device/{id}", app.handleDeviceDetail)
addr := fmt.Sprintf("%s:%d", app.config.Web.Host, app.config.Web.Port)
app.httpServer = &http.Server{
Addr: addr,
Handler: mux,
}
log.Printf("服务启动在 %s", addr)
return app.httpServer.ListenAndServe()
}
// 生成唯一ID
func generateID() string {
b := make([]byte, 16)
rand.Read(b)
return fmt.Sprintf("%x-%x-%x-%x-%x", b[0:4], b[4:6], b[6:8], b[8:10], b[10:])
}
// getWebDir 获取web目录路径
func getWebDir() string {
// 尝试多个可能的路径
possiblePaths := []string{
"web",
filepath.Join("cmd", "web"),
filepath.Join("..", "web"),
}
for _, path := range possiblePaths {
if _, err := os.Stat(path); err == nil {
absPath, _ := filepath.Abs(path)
return absPath
}
}
// 默认返回web
return "web"
}
// 处理扫描请求
func (app *App) handleScan(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
ScanRange string `json:"scan_range"`
SSHPort int `json:"ssh_port"`
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if req.SSHPort == 0 {
req.SSHPort = 22
}
// 创建扫描任务
taskID := generateID()
task := &models.ScanTask{
ID: taskID,
Status: "running",
StartTime: time.Now(),
Devices: []models.Device{},
}
app.mu.Lock()
app.tasks[taskID] = task
app.mu.Unlock()
// 异步执行扫描
go app.runScan(task, req.ScanRange, req.SSHPort, req.Username, req.Password)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{"task_id": taskID})
}
// 执行扫描
func (app *App) runScan(task *models.ScanTask, cidr string, sshPort int, username, password string) {
defer func() {
task.EndTime = time.Now()
}()
// 创建扫描器
sc := scanner.NewScanner(app.config.Scanner.Concurrency, time.Duration(app.config.Scanner.Timeout)*time.Second)
// 扫描SSH主机
sshHosts, err := sc.ScanAndDiscover(cidr, sshPort)
if err != nil {
task.Status = "failed"
task.ErrorMessage = err.Error()
return
}
task.TotalDevices = len(sshHosts)
// 采集设备信息
var devices []models.Device
for i, ip := range sshHosts {
// 尝试不同设备类型
deviceTypes := []models.DeviceType{
models.DeviceTypeCisco,
models.DeviceTypeHuawei,
models.DeviceTypeH3C,
models.DeviceTypeASA,
models.DeviceTypeLinux,
models.DeviceTypeWindows,
}
var discoveredDevice *models.Device
for _, dtype := range deviceTypes {
dev, err := device.DiscoverDevice(ip, dtype, username, password)
if err == nil && dev.ScanStatus == "success" {
discoveredDevice = dev
break
}
}
if discoveredDevice != nil {
devices = append(devices, *discoveredDevice)
app.builder.AddDevice(*discoveredDevice)
// 保存到数据库
if app.storage != nil {
if err := app.storage.SaveDevice(discoveredDevice); err != nil {
log.Printf("Warning: failed to save device %s to database: %v", ip, err)
}
}
}
// 更新进度
task.ScannedDevices = i + 1
task.Progress = (i + 1) * 100 / len(sshHosts)
task.Devices = devices
}
task.Status = "completed"
task.Progress = 100
task.Devices = devices
}
// 处理扫描进度查询
func (app *App) handleScanProgress(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
app.mu.RLock()
task, exists := app.tasks[id]
app.mu.RUnlock()
if !exists {
http.Error(w, "Task not found", http.StatusNotFound)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(task)
}
// 处理拓扑查询
func (app *App) handleTopology(w http.ResponseWriter, r *http.Request) {
graph := app.builder.Build()
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(graph)
}
// 处理获取所有设备
func (app *App) handleGetDevices(w http.ResponseWriter, r *http.Request) {
var devices []models.Device
// 优先从存储获取
if app.storage != nil {
var err error
devices, err = app.storage.GetAllDevices()
if err != nil {
log.Printf("Error: failed to get devices from storage: %v", err)
// 降级到 builder获取
devices = app.builder.GetDevices()
}
log.Printf("Returning %d devices from storage", len(devices))
} else {
devices = app.builder.GetDevices()
log.Printf("Returning %d devices from builder", len(devices))
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(devices)
}
// 处理添加设备
func (app *App) handleAddDevice(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req struct {
IP string `json:"ip"`
Type string `json:"type"`
Username string `json:"username"`
Password string `json:"password"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
deviceType := models.DeviceType(req.Type)
log.Printf("Adding device: %s (type: %s)", req.IP, req.Type)
dev, err := device.DiscoverDevice(req.IP, deviceType, req.Username, req.Password)
if err != nil {
log.Printf("Failed to discover device %s: %v", req.IP, err)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"message": err.Error()})
return
}
log.Printf("Device discovered: %s, interfaces: %d, neighbors: %d",
dev.IP, len(dev.Interfaces), len(dev.Neighbors))
app.builder.AddDevice(*dev)
// 保存到存储
if app.storage != nil {
if err := app.storage.SaveDevice(dev); err != nil {
log.Printf("Error: failed to save device %s to storage: %v", req.IP, err)
} else {
log.Printf("Device %s saved to storage successfully", req.IP)
}
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(dev)
}
// 处理设备详情查询
func (app *App) handleDeviceDetail(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
devices := app.builder.GetDevices()
for _, dev := range devices {
if dev.ID == id || dev.IP == id {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(dev)
return
}
}
http.Error(w, "Device not found", http.StatusNotFound)
}
func main() {
// 加载配置
configFile := "config.json"
if len(os.Args) > 1 {
configFile = os.Args[1]
}
var cfg *config.Config
if _, err := os.Stat(configFile); err == nil {
cfg, err = config.LoadConfig(configFile)
if err != nil {
log.Printf("加载配置文件失败: %v, 使用默认配置", err)
cfg = config.DefaultConfig()
}
} else {
log.Printf("配置文件不存在, 使用默认配置")
cfg = config.DefaultConfig()
}
// 创建并启动应用
app := NewApp(cfg)
log.Println("网络拓扑发现系统启动...")
if err := app.Start(); err != nil {
log.Fatalf("服务启动失败: %v", err)
}
}