package storage import ( "database/sql" "encoding/json" "fmt" "log" "network-topology-discovery/pkg/models" "time" _ "github.com/mattn/go-sqlite3" ) // Storage 数据库存储 type Storage struct { db *sql.DB } // NewStorage 创建存储实例 func NewStorage(dbPath string) (*Storage, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // 创建表 if err := createTables(db); err != nil { return nil, fmt.Errorf("failed to create tables: %w", err) } return &Storage{db: db}, nil } // createTables 创建数据表 func createTables(db *sql.DB) error { query := ` CREATE TABLE IF NOT EXISTS devices ( id TEXT PRIMARY KEY, ip TEXT NOT NULL UNIQUE, type TEXT NOT NULL, hostname TEXT, os_version TEXT, uptime TEXT, interfaces TEXT, neighbors TEXT, last_scanned DATETIME, scan_status TEXT, error_message TEXT, created_at DATETIME DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_devices_ip ON devices(ip); CREATE INDEX IF NOT EXISTS idx_devices_type ON devices(type); ` _, err := db.Exec(query) return err } // SaveDevice 保存设备 func (s *Storage) SaveDevice(device *models.Device) error { // 序列化接口和邻居数据 interfacesJSON, err := json.Marshal(device.Interfaces) if err != nil { return fmt.Errorf("failed to marshal interfaces: %w", err) } neighborsJSON, err := json.Marshal(device.Neighbors) if err != nil { return fmt.Errorf("failed to marshal neighbors: %w", err) } // 设置ID和扫描时间 if device.ID == "" { device.ID = device.IP } device.LastScanned = time.Now() query := ` INSERT OR REPLACE INTO devices (id, ip, type, hostname, os_version, uptime, interfaces, neighbors, last_scanned, scan_status, error_message, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` _, err = s.db.Exec(query, device.ID, device.IP, string(device.Type), device.Hostname, device.OSVersion, device.Uptime, string(interfacesJSON), string(neighborsJSON), device.LastScanned, device.ScanStatus, device.ErrorMessage, time.Now(), ) if err != nil { return fmt.Errorf("failed to save device: %w", err) } log.Printf("Device saved: %s (%s)", device.IP, device.Hostname) return nil } // GetDevice 获取设备 func (s *Storage) GetDevice(id string) (*models.Device, error) { query := ` SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors, last_scanned, scan_status, error_message FROM devices WHERE id = ? ` row := s.db.QueryRow(query, id) var device models.Device var typeStr string var interfacesJSON, neighborsJSON string var lastScanned sql.NullTime err := row.Scan( &device.ID, &device.IP, &typeStr, &device.Hostname, &device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON, &lastScanned, &device.ScanStatus, &device.ErrorMessage, ) if err != nil { return nil, fmt.Errorf("failed to get device: %w", err) } device.Type = models.DeviceType(typeStr) if lastScanned.Valid { device.LastScanned = lastScanned.Time } // 反序列化接口 if interfacesJSON != "" { if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil { log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err) } } // 反序列化邻居 if neighborsJSON != "" { if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil { log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err) } } return &device, nil } // GetAllDevices 获取所有设备 func (s *Storage) GetAllDevices() ([]models.Device, error) { query := ` SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors, last_scanned, scan_status, error_message FROM devices ORDER BY created_at ` rows, err := s.db.Query(query) if err != nil { return nil, fmt.Errorf("failed to query devices: %w", err) } defer rows.Close() var devices []models.Device for rows.Next() { var device models.Device var typeStr string var interfacesJSON, neighborsJSON string var lastScanned sql.NullTime err := rows.Scan( &device.ID, &device.IP, &typeStr, &device.Hostname, &device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON, &lastScanned, &device.ScanStatus, &device.ErrorMessage, ) if err != nil { log.Printf("Warning: failed to scan device row: %v", err) continue } device.Type = models.DeviceType(typeStr) if lastScanned.Valid { device.LastScanned = lastScanned.Time } // 反序列化接口 if interfacesJSON != "" { if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil { log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err) } } // 反序列化邻居 if neighborsJSON != "" { if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil { log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err) } } devices = append(devices, device) } return devices, nil } // DeleteDevice 删除设备 func (s *Storage) DeleteDevice(id string) error { _, err := s.db.Exec("DELETE FROM devices WHERE id = ?", id) if err != nil { return fmt.Errorf("failed to delete device: %w", err) } return nil } // Close 关闭数据库连接 func (s *Storage) Close() error { if s.db != nil { return s.db.Close() } return nil }