Fix: 替换SQLite为JSON文件存储,无需CGO支持

- 移除go-sqlite3依赖,改用纯Go的JSON文件存储
- 解决Windows上CGO_ENABLED=0导致SQLite无法使用的问题
- 添加线程安全的读写锁保护
- 支持数据持久化,重启后数据不丢失
- 简化存储逻辑,提高可靠性
This commit is contained in:
Your Name
2026-04-26 00:46:37 +08:00
parent e5e624d72e
commit 8b7dbf2886
5 changed files with 92 additions and 174 deletions
+78 -167
View File
@@ -1,75 +1,84 @@
package storage
import (
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"sync"
"network-topology-discovery/pkg/models"
"time"
_ "github.com/mattn/go-sqlite3"
)
// Storage 数据库存储
// Storage 存储(使用JSON文件)
type Storage struct {
db *sql.DB
mu sync.RWMutex
filePath string
devices map[string]models.Device
}
// NewStorage 创建存储实例
func NewStorage(dbPath string) (*Storage, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
func NewStorage(filePath string) (*Storage, error) {
s := &Storage{
filePath: filePath,
devices: make(map[string]models.Device),
}
// 创建表
if err := createTables(db); err != nil {
return nil, fmt.Errorf("failed to create tables: %w", err)
// 从文件加载数据
if err := s.load(); err != nil {
if !os.IsNotExist(err) {
return nil, fmt.Errorf("failed to load storage: %w", err)
}
// 文件不存在是正常的,创建新文件
log.Printf("Creating new storage file: %s", filePath)
}
return &Storage{db: db}, nil
return s, nil
}
// createTables 创建数据
func createTables(db *sql.DB) error {
query := `
CREATE TABLE IF NOT EXISTS devices (
id TEXT PRIMARY KEY,
ip TEXT NOT NULL UNIQUE,
type TEXT NOT NULL,
hostname TEXT,
os_version TEXT,
uptime TEXT,
interfaces TEXT,
neighbors TEXT,
last_scanned DATETIME,
scan_status TEXT,
error_message TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_devices_ip ON devices(ip);
CREATE INDEX IF NOT EXISTS idx_devices_type ON devices(type);
`
_, err := db.Exec(query)
return err
// load 从文件加载数据
func (s *Storage) load() error {
data, err := os.ReadFile(s.filePath)
if err != nil {
return err
}
var devices []models.Device
if err := json.Unmarshal(data, &devices); err != nil {
return fmt.Errorf("failed to parse storage file: %w", err)
}
for _, dev := range devices {
s.devices[dev.ID] = dev
}
log.Printf("Loaded %d devices from storage", len(devices))
return nil
}
// save 保存数据到文件
func (s *Storage) save() error {
devices := make([]models.Device, 0, len(s.devices))
for _, dev := range s.devices {
devices = append(devices, dev)
}
data, err := json.MarshalIndent(devices, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal devices: %w", err)
}
if err := os.WriteFile(s.filePath, data, 0644); err != nil {
return fmt.Errorf("failed to write storage file: %w", err)
}
return nil
}
// SaveDevice 保存设备
func (s *Storage) SaveDevice(device *models.Device) error {
// 序列化接口和邻居数据
interfacesJSON, err := json.Marshal(device.Interfaces)
if err != nil {
return fmt.Errorf("failed to marshal interfaces: %w", err)
}
neighborsJSON, err := json.Marshal(device.Neighbors)
if err != nil {
return fmt.Errorf("failed to marshal neighbors: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
// 设置ID和扫描时间
if device.ID == "" {
@@ -77,29 +86,10 @@ func (s *Storage) SaveDevice(device *models.Device) error {
}
device.LastScanned = time.Now()
query := `
INSERT OR REPLACE INTO devices
(id, ip, type, hostname, os_version, uptime, interfaces, neighbors,
last_scanned, scan_status, error_message, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
s.devices[device.ID] = *device
_, err = s.db.Exec(query,
device.ID,
device.IP,
string(device.Type),
device.Hostname,
device.OSVersion,
device.Uptime,
string(interfacesJSON),
string(neighborsJSON),
device.LastScanned,
device.ScanStatus,
device.ErrorMessage,
time.Now(),
)
if err != nil {
// 保存到文件
if err := s.save(); err != nil {
return fmt.Errorf("failed to save device: %w", err)
}
@@ -109,47 +99,12 @@ func (s *Storage) SaveDevice(device *models.Device) error {
// GetDevice 获取设备
func (s *Storage) GetDevice(id string) (*models.Device, error) {
query := `
SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors,
last_scanned, scan_status, error_message
FROM devices WHERE id = ?
`
s.mu.RLock()
defer s.mu.RUnlock()
row := s.db.QueryRow(query, id)
var device models.Device
var typeStr string
var interfacesJSON, neighborsJSON string
var lastScanned sql.NullTime
err := row.Scan(
&device.ID, &device.IP, &typeStr, &device.Hostname,
&device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON,
&lastScanned, &device.ScanStatus, &device.ErrorMessage,
)
if err != nil {
return nil, fmt.Errorf("failed to get device: %w", err)
}
device.Type = models.DeviceType(typeStr)
if lastScanned.Valid {
device.LastScanned = lastScanned.Time
}
// 反序列化接口
if interfacesJSON != "" {
if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil {
log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err)
}
}
// 反序列化邻居
if neighborsJSON != "" {
if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil {
log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err)
}
device, exists := s.devices[id]
if !exists {
return nil, fmt.Errorf("device not found: %s", id)
}
return &device, nil
@@ -157,58 +112,12 @@ func (s *Storage) GetDevice(id string) (*models.Device, error) {
// GetAllDevices 获取所有设备
func (s *Storage) GetAllDevices() ([]models.Device, error) {
query := `
SELECT id, ip, type, hostname, os_version, uptime, interfaces, neighbors,
last_scanned, scan_status, error_message
FROM devices ORDER BY created_at
`
s.mu.RLock()
defer s.mu.RUnlock()
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to query devices: %w", err)
}
defer rows.Close()
var devices []models.Device
for rows.Next() {
var device models.Device
var typeStr string
var interfacesJSON, neighborsJSON string
var lastScanned sql.NullTime
err := rows.Scan(
&device.ID, &device.IP, &typeStr, &device.Hostname,
&device.OSVersion, &device.Uptime, &interfacesJSON, &neighborsJSON,
&lastScanned, &device.ScanStatus, &device.ErrorMessage,
)
if err != nil {
log.Printf("Warning: failed to scan device row: %v", err)
continue
}
device.Type = models.DeviceType(typeStr)
if lastScanned.Valid {
device.LastScanned = lastScanned.Time
}
// 反序列化接口
if interfacesJSON != "" {
if err := json.Unmarshal([]byte(interfacesJSON), &device.Interfaces); err != nil {
log.Printf("Warning: failed to unmarshal interfaces for %s: %v", device.IP, err)
}
}
// 反序列化邻居
if neighborsJSON != "" {
if err := json.Unmarshal([]byte(neighborsJSON), &device.Neighbors); err != nil {
log.Printf("Warning: failed to unmarshal neighbors for %s: %v", device.IP, err)
}
}
devices = append(devices, device)
devices := make([]models.Device, 0, len(s.devices))
for _, dev := range s.devices {
devices = append(devices, dev)
}
return devices, nil
@@ -216,17 +125,19 @@ func (s *Storage) GetAllDevices() ([]models.Device, error) {
// DeleteDevice 删除设备
func (s *Storage) DeleteDevice(id string) error {
_, err := s.db.Exec("DELETE FROM devices WHERE id = ?", id)
if err != nil {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.devices, id)
// 保存到文件
if err := s.save(); err != nil {
return fmt.Errorf("failed to delete device: %w", err)
}
return nil
}
// Close 关闭数据库连接
// Close 关闭存储(不需要操作,JSON文件不需要关闭)
func (s *Storage) Close() error {
if s.db != nil {
return s.db.Close()
}
return nil
}