package storage import ( "encoding/json" "fmt" "log" "network-topology-discovery/pkg/models" "os" "sync" "time" ) // Storage 存储(使用JSON文件) type Storage struct { mu sync.RWMutex filePath string devices map[string]models.Device } // NewStorage 创建存储实例(兼容旧版,使用默认文件) func NewStorage(filePath string) (*Storage, error) { return NewStorageForTopology(filePath) } // NewStorageForTopology 为特定拓扑创建存储实例 func NewStorageForTopology(filePath string) (*Storage, error) { s := &Storage{ filePath: filePath, devices: make(map[string]models.Device), } // 从文件加载数据 if err := s.load(); err != nil { if !os.IsNotExist(err) { return nil, fmt.Errorf("failed to load storage: %w", err) } // 文件不存在是正常的,创建新文件 log.Printf("Creating new storage file: %s", filePath) } return s, nil } // SetFilePath 切换存储文件路径(用于切换拓扑) func (s *Storage) SetFilePath(filePath string) error { s.mu.Lock() s.filePath = filePath s.devices = make(map[string]models.Device) s.mu.Unlock() // 重新加载数据 if err := s.load(); err != nil { if !os.IsNotExist(err) { return fmt.Errorf("failed to load storage: %w", err) } } log.Printf("Storage switched to: %s", filePath) return nil } // load 从文件加载数据 func (s *Storage) load() error { data, err := os.ReadFile(s.filePath) if err != nil { return err } var devices []models.Device if err := json.Unmarshal(data, &devices); err != nil { return fmt.Errorf("failed to parse storage file: %w", err) } for _, dev := range devices { s.devices[dev.ID] = dev } log.Printf("Loaded %d devices from storage", len(devices)) return nil } // save 保存数据到文件 func (s *Storage) save() error { devices := make([]models.Device, 0, len(s.devices)) for _, dev := range s.devices { devices = append(devices, dev) } data, err := json.MarshalIndent(devices, "", " ") if err != nil { return fmt.Errorf("failed to marshal devices: %w", err) } if err := os.WriteFile(s.filePath, data, 0644); err != nil { return fmt.Errorf("failed to write storage file: %w", err) } return nil } // SaveDevice 保存设备 func (s *Storage) SaveDevice(device *models.Device) error { s.mu.Lock() defer s.mu.Unlock() // 设置ID和扫描时间 if device.ID == "" { device.ID = device.IP } device.LastScanned = time.Now() s.devices[device.ID] = *device // 保存到文件 if err := s.save(); err != nil { return fmt.Errorf("failed to save device: %w", err) } log.Printf("Device saved: %s (%s)", device.IP, device.Hostname) return nil } // GetDevice 获取设备 func (s *Storage) GetDevice(id string) (*models.Device, error) { s.mu.RLock() defer s.mu.RUnlock() device, exists := s.devices[id] if !exists { return nil, fmt.Errorf("device not found: %s", id) } return &device, nil } // GetAllDevices 获取所有设备 func (s *Storage) GetAllDevices() ([]models.Device, error) { s.mu.RLock() defer s.mu.RUnlock() devices := make([]models.Device, 0, len(s.devices)) for _, dev := range s.devices { devices = append(devices, dev) } return devices, nil } // DeleteDevice 删除设备 func (s *Storage) DeleteDevice(id string) error { s.mu.Lock() defer s.mu.Unlock() delete(s.devices, id) // 保存到文件 if err := s.save(); err != nil { return fmt.Errorf("failed to delete device: %w", err) } return nil } // Close 关闭存储(不需要操作,JSON文件不需要关闭) func (s *Storage) Close() error { return nil }