diff --git a/cmd/main.go b/cmd/main.go index 4d4b923..c82a8d2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -14,6 +14,7 @@ import ( "network-topology-discovery/internal/config" "network-topology-discovery/internal/device" "network-topology-discovery/internal/scanner" + "network-topology-discovery/internal/storage" "network-topology-discovery/internal/topology" "network-topology-discovery/pkg/models" ) @@ -22,6 +23,7 @@ import ( type App struct { config *config.Config builder *topology.Builder + storage *storage.Storage tasks map[string]*models.ScanTask mu sync.RWMutex httpServer *http.Server @@ -29,11 +31,33 @@ type App struct { // NewApp 创建应用 func NewApp(cfg *config.Config) *App { - return &App{ + // 初始化数据库 + store, err := storage.NewStorage("network-topology.db") + if err != nil { + log.Printf("Warning: failed to initialize database: %v", err) + } + + app := &App{ config: cfg, builder: topology.NewBuilder(), + storage: store, tasks: make(map[string]*models.ScanTask), } + + // 从数据库加载设备到拓扑构建器 + if store != nil { + devices, err := store.GetAllDevices() + if err != nil { + log.Printf("Warning: failed to load devices from database: %v", err) + } else { + log.Printf("Loaded %d devices from database", len(devices)) + for _, dev := range devices { + app.builder.AddDevice(dev) + } + } + } + + return app } // Start 启动应用 @@ -56,6 +80,7 @@ func (app *App) Start() error { mux.HandleFunc("/api/scan", app.handleScan) mux.HandleFunc("/api/scan/{id}", app.handleScanProgress) mux.HandleFunc("/api/topology", app.handleTopology) + mux.HandleFunc("/api/devices", app.handleGetDevices) mux.HandleFunc("/api/device", app.handleAddDevice) mux.HandleFunc("/api/device/{id}", app.handleDeviceDetail) @@ -183,6 +208,13 @@ func (app *App) runScan(task *models.ScanTask, cidr string, sshPort int, usernam if discoveredDevice != nil { devices = append(devices, *discoveredDevice) app.builder.AddDevice(*discoveredDevice) + + // 保存到数据库 + if app.storage != nil { + if err := app.storage.SaveDevice(discoveredDevice); err != nil { + log.Printf("Warning: failed to save device %s to database: %v", ip, err) + } + } } // 更新进度 @@ -221,6 +253,27 @@ func (app *App) handleTopology(w http.ResponseWriter, r *http.Request) { json.NewEncoder(w).Encode(graph) } +// 处理获取所有设备 +func (app *App) handleGetDevices(w http.ResponseWriter, r *http.Request) { + var devices []models.Device + + // 优先从数据库获取 + if app.storage != nil { + var err error + devices, err = app.storage.GetAllDevices() + if err != nil { + log.Printf("Warning: failed to get devices from database: %v", err) + // 降级到从builder获取 + devices = app.builder.GetDevices() + } + } else { + devices = app.builder.GetDevices() + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(devices) +} + // 处理添加设备 func (app *App) handleAddDevice(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -251,6 +304,13 @@ func (app *App) handleAddDevice(w http.ResponseWriter, r *http.Request) { app.builder.AddDevice(*dev) + // 保存到数据库 + if app.storage != nil { + if err := app.storage.SaveDevice(dev); err != nil { + log.Printf("Warning: failed to save device to database: %v", err) + } + } + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(dev) } diff --git a/go.mod b/go.mod index 6ef6dbe..8b60996 100644 --- a/go.mod +++ b/go.mod @@ -4,4 +4,7 @@ go 1.26.2 require golang.org/x/crypto v0.50.0 -require golang.org/x/sys v0.43.0 // indirect +require ( + github.com/mattn/go-sqlite3 v1.14.42 // indirect + golang.org/x/sys v0.43.0 // indirect +) diff --git a/go.sum b/go.sum index 3e36771..d79008f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/mattn/go-sqlite3 v1.14.42 h1:MigqEP4ZmHw3aIdIT7T+9TLa90Z6smwcthx+Azv4Cgo= +github.com/mattn/go-sqlite3 v1.14.42/go.mod h1:pjEuOr8IwzLJP2MfGeTb0A35jauH+C2kbHKBr7yXKVQ= golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= diff --git a/internal/storage/storage.go b/internal/storage/storage.go new file mode 100644 index 0000000..502988a --- /dev/null +++ b/internal/storage/storage.go @@ -0,0 +1,232 @@ +package storage + +import ( + "database/sql" + "encoding/json" + "fmt" + "log" + "network-topology-discovery/pkg/models" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +// Storage 数据库存储 +type Storage struct { + db *sql.DB +} + +// NewStorage 创建存储实例 +func NewStorage(dbPath string) (*Storage, error) { + db, err := sql.Open("sqlite3", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + // 创建表 + if err := createTables(db); err != nil { + return nil, fmt.Errorf("failed to create tables: %w", err) + } + + return &Storage{db: db}, nil +} + +// createTables 创建数据表 +func createTables(db *sql.DB) error { + query := ` + CREATE TABLE IF NOT EXISTS devices ( + id TEXT PRIMARY KEY, + ip TEXT NOT NULL UNIQUE, + type TEXT NOT NULL, + hostname TEXT, + os_version TEXT, + uptime TEXT, + interfaces TEXT, + neighbors TEXT, + last_scanned DATETIME, + scan_status TEXT, + error_message TEXT, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_devices_ip ON devices(ip); + CREATE INDEX IF NOT EXISTS idx_devices_type ON devices(type); + ` + + _, err := db.Exec(query) + return err +} + +// SaveDevice 保存设备 +func (s *Storage) SaveDevice(device *models.Device) error { + // 序列化接口和邻居数据 + interfacesJSON, err := json.Marshal(device.Interfaces) + if err != nil { + return fmt.Errorf("failed to marshal interfaces: %w", err) + } + + neighborsJSON, err := json.Marshal(device.Neighbors) + if err != nil { + return fmt.Errorf("failed to marshal neighbors: %w", err) + } + + // 设置ID和扫描时间 + if device.ID == "" { + device.ID = device.IP + } + device.LastScanned = time.Now() + + query := ` + INSERT OR REPLACE INTO devices + (id, ip, type, hostname, os_version, uptime, interfaces, neighbors, + last_scanned, scan_status, error_message, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + + _, err = s.db.Exec(query, + device.ID, + device.IP, + string(device.Type), + device.Hostname, + device.OSVersion, + device.Uptime, + string(interfacesJSON), + string(neighborsJSON), + device.LastScanned, + device.ScanStatus, + device.ErrorMessage, + time.Now(), + ) + + if err != nil { + return fmt.Errorf("failed to save device: %w", err) + } + + log.Printf("Device saved: %s (%s)", device.IP, device.Hostname) + return nil +} + +// GetDevice 获取设备 +func (s *Storage) GetDevice(id string) (*models.Device, error) { + query := ` + SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors, + last_scanned, scan_status, error_message + FROM devices WHERE id = ? + ` + + row := s.db.QueryRow(query, id) + + var device models.Device + var typeStr string + var interfacesJSON, neighborsJSON string + var lastScanned sql.NullTime + + err := row.Scan( + &device.ID, &device.IP, &typeStr, &device.Hostname, + &device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON, + &lastScanned, &device.ScanStatus, &device.ErrorMessage, + ) + + if err != nil { + return nil, fmt.Errorf("failed to get device: %w", err) + } + + device.Type = models.DeviceType(typeStr) + + if lastScanned.Valid { + device.LastScanned = lastScanned.Time + } + + // 反序列化接口 + if interfacesJSON != "" { + if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil { + log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err) + } + } + + // 反序列化邻居 + if neighborsJSON != "" { + if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil { + log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err) + } + } + + return &device, nil +} + +// GetAllDevices 获取所有设备 +func (s *Storage) GetAllDevices() ([]models.Device, error) { + query := ` + SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors, + last_scanned, scan_status, error_message + FROM devices ORDER BY created_at + ` + + rows, err := s.db.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to query devices: %w", err) + } + defer rows.Close() + + var devices []models.Device + + for rows.Next() { + var device models.Device + var typeStr string + var interfacesJSON, neighborsJSON string + var lastScanned sql.NullTime + + err := rows.Scan( + &device.ID, &device.IP, &typeStr, &device.Hostname, + &device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON, + &lastScanned, &device.ScanStatus, &device.ErrorMessage, + ) + + if err != nil { + log.Printf("Warning: failed to scan device row: %v", err) + continue + } + + device.Type = models.DeviceType(typeStr) + + if lastScanned.Valid { + device.LastScanned = lastScanned.Time + } + + // 反序列化接口 + if interfacesJSON != "" { + if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil { + log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err) + } + } + + // 反序列化邻居 + if neighborsJSON != "" { + if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil { + log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err) + } + } + + devices = append(devices, device) + } + + return devices, nil +} + +// DeleteDevice 删除设备 +func (s *Storage) DeleteDevice(id string) error { + _, err := s.db.Exec("DELETE FROM devices WHERE id = ?", id) + if err != nil { + return fmt.Errorf("failed to delete device: %w", err) + } + return nil +} + +// Close 关闭数据库连接 +func (s *Storage) Close() error { + if s.db != nil { + return s.db.Close() + } + return nil +} diff --git a/web/js/app.js b/web/js/app.js index 2afd31f..dceb771 100644 --- a/web/js/app.js +++ b/web/js/app.js @@ -7,6 +7,7 @@ document.addEventListener('DOMContentLoaded', function() { initCytoscape(); initEventListeners(); loadTopology(); + loadDeviceList(); // 加载设备列表 }); // 初始化Cytoscape @@ -160,6 +161,7 @@ async function pollProgress() { // 如果完成,更新拓扑 if (task.status === 'completed' || task.status === 'failed') { loadTopology(); + loadDeviceList(); // 刷新设备列表 currentTaskId = null; return; } @@ -248,6 +250,43 @@ async function loadTopology() { } } +// 加载设备列表 +async function loadDeviceList() { + try { + const response = await fetch('/api/devices'); + const devices = await response.json(); + + const listContainer = document.getElementById('device-list'); + listContainer.innerHTML = ''; + + if (devices.length === 0) { + listContainer.innerHTML = '

暂无设备

'; + return; + } + + devices.forEach(device => { + const item = document.createElement('div'); + item.className = 'device-item'; + + const interfaceCount = device.interfaces ? device.interfaces.length : 0; + const neighborCount = device.neighbors ? device.neighbors.length : 0; + + item.innerHTML = ` +
${device.ip}
+
${device.type} - ${device.hostname || 'Unknown'}
+
+ 接口: ${interfaceCount} | 邻居: ${neighborCount} +
+
${device.scan_status || 'pending'}
+ `; + item.addEventListener('click', () => showDeviceDetail(device.id || device.ip)); + listContainer.appendChild(item); + }); + } catch (error) { + console.error('加载设备列表失败:', error); + } +} + // 显示设备详情 async function showDeviceDetail(deviceId) { try { @@ -328,6 +367,7 @@ async function addDevice(event) { document.getElementById('modal').classList.remove('active'); document.getElementById('add-device-form').reset(); loadTopology(); + loadDeviceList(); // 刷新设备列表 alert('设备添加成功'); } else { const error = await response.json();