Fix DHCP client unable to get IP and config not persisting

- Fixed verifyAssignment being too strict for new clients
- Fixed parseRequestedIP string conversion bug
- Fixed response sent to 0.0.0.0 instead of broadcast address
- Added SO_BROADCAST support for UDP socket
- Fixed session persistence after page refresh (localStorage)
- Added in-memory session store for auth middleware
- Added config reloader so DHCP server picks up web UI changes dynamically
This commit is contained in:
CNBUGS AI
2026-04-24 16:03:54 +08:00
commit 8ad4c3576d
39 changed files with 7756 additions and 0 deletions
+284
View File
@@ -0,0 +1,284 @@
package web
import (
"dhcp-dns-manager/internal/config"
"encoding/json"
"net/http"
"sync"
"github.com/gin-gonic/gin"
)
// ConfigManager 配置管理器
type ConfigManager struct {
configPath string
config *config.Config
mu sync.RWMutex
}
// NewConfigManager 创建配置管理器
func NewConfigManager(path string) (*ConfigManager, error) {
cfg, err := config.LoadConfig(path)
if err != nil {
return nil, err
}
return &ConfigManager{
configPath: path,
config: cfg,
}, nil
}
// GetConfig 获取配置
func (cm *ConfigManager) GetConfig() *config.Config {
cm.mu.RLock()
defer cm.mu.RUnlock()
return cm.config
}
// SaveConfig 保存配置
func (cm *ConfigManager) SaveConfig(cfg *config.Config) error {
cm.mu.Lock()
defer cm.mu.Unlock()
cm.config = cfg
return cfg.Save(cm.configPath)
}
// UpdateDHCPConfig 更新 DHCP 配置(支持部分更新)
func (cm *ConfigManager) UpdateDHCPConfig(updates map[string]interface{}) error {
cm.mu.Lock()
defer cm.mu.Unlock()
// 将更新合并到现有配置
for key, value := range updates {
switch key {
case "enabled":
if v, ok := value.(bool); ok {
cm.config.DHCP.Enabled = v
}
case "interface":
if v, ok := value.(string); ok {
cm.config.DHCP.Interface = v
}
case "network":
if v, ok := value.(string); ok {
cm.config.DHCP.Network = v
}
case "netmask":
if v, ok := value.(string); ok {
cm.config.DHCP.Netmask = v
}
case "gateway":
if v, ok := value.(string); ok {
cm.config.DHCP.Gateway = v
}
case "dns_servers":
if v, ok := value.([]interface{}); ok {
servers := make([]string, len(v))
for i, s := range v {
servers[i] = s.(string)
}
cm.config.DHCP.DNSServers = servers
}
case "lease_time":
if v, ok := value.(float64); ok {
cm.config.DHCP.LeaseTime = int(v)
}
case "ip_pool_start":
if v, ok := value.(string); ok {
cm.config.DHCP.IPPoolStart = v
}
case "ip_pool_end":
if v, ok := value.(string); ok {
cm.config.DHCP.IPPoolEnd = v
}
case "domain_name":
if v, ok := value.(string); ok {
cm.config.DHCP.DomainName = v
}
case "ntp_servers":
if v, ok := value.([]interface{}); ok {
servers := make([]string, len(v))
for i, s := range v {
servers[i] = s.(string)
}
cm.config.DHCP.NTPServers = servers
}
case "broadcast_address":
if v, ok := value.(string); ok {
cm.config.DHCP.BroadcastAddress = v
}
case "excluded_ips":
if v, ok := value.([]interface{}); ok {
ips := make([]string, len(v))
for i, ip := range v {
ips[i] = ip.(string)
}
cm.config.DHCP.ExcludedIPs = ips
}
}
}
return cm.config.Save(cm.configPath)
}
// UpdateDNSConfig 更新 DNS 配置(支持部分更新)
func (cm *ConfigManager) UpdateDNSConfig(updates map[string]interface{}) error {
cm.mu.Lock()
defer cm.mu.Unlock()
for key, value := range updates {
switch key {
case "enabled":
if v, ok := value.(bool); ok {
cm.config.DNS.Enabled = v
}
case "listen_addr":
if v, ok := value.(string); ok {
cm.config.DNS.ListenAddr = v
}
case "listen_port":
if v, ok := value.(float64); ok {
cm.config.DNS.ListenPort = int(v)
}
case "upstream":
if v, ok := value.([]interface{}); ok {
servers := make([]string, len(v))
for i, s := range v {
servers[i] = s.(string)
}
cm.config.DNS.Upstream = servers
}
case "cache_size":
if v, ok := value.(float64); ok {
cm.config.DNS.CacheSize = int(v)
}
case "cache_ttl":
if v, ok := value.(float64); ok {
cm.config.DNS.CacheTTL = int(v)
}
case "recursion":
if v, ok := value.(bool); ok {
cm.config.DNS.Recursion = v
}
}
}
return cm.config.Save(cm.configPath)
}
// handleGetDHCPConfig 获取 DHCP 配置
func (s *Server) handleGetDHCPConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
cfg := cm.GetConfig()
c.JSON(http.StatusOK, gin.H{"config": cfg.DHCP})
}
// handleUpdateDHCPConfig 更新 DHCP 配置
func (s *Server) handleUpdateDHCPConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
var updates map[string]interface{}
if err := c.ShouldBindJSON(&updates); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON: " + err.Error()})
return
}
// 验证必填字段
if network, ok := updates["network"]; ok && network.(string) == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "network is required"})
return
}
if err := cm.UpdateDHCPConfig(updates); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Save failed: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "DHCP config updated"})
}
// handleGetDNSConfig 获取 DNS 配置
func (s *Server) handleGetDNSConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
cfg := cm.GetConfig()
c.JSON(http.StatusOK, gin.H{"config": cfg.DNS})
}
// handleUpdateDNSConfig 更新 DNS 配置
func (s *Server) handleUpdateDNSConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
var updates map[string]interface{}
if err := c.ShouldBindJSON(&updates); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid JSON: " + err.Error()})
return
}
// 验证端口
if port, ok := updates["listen_port"]; ok {
if port.(float64) < 1 || port.(float64) > 65535 {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid port"})
return
}
}
if err := cm.UpdateDNSConfig(updates); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Save failed: " + err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "DNS config updated"})
}
// handleGetFullConfig 获取完整配置
func (s *Server) handleGetFullConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
cfg := cm.GetConfig()
c.JSON(http.StatusOK, gin.H{
"dhcp": cfg.DHCP,
"dns": cfg.DNS,
"web": cfg.Web,
})
}
// handleRestartService 重启服务
func (s *Server) handleRestartService(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "Restart requested. Please restart the service manually: sudo systemctl restart dhcp-dns-manager",
})
}
// ExportConfig 导出配置
func (s *Server) handleExportConfig(c *gin.Context) {
cm := c.MustGet("configManager").(*ConfigManager)
cfg := cm.GetConfig()
c.Header("Content-Type", "application/json")
c.Header("Content-Disposition", "attachment; filename=dhcp-dns-config.json")
c.JSON(http.StatusOK, cfg)
}
// ImportConfig 导入配置
func (s *Server) handleImportConfig(c *gin.Context) {
file, _, err := c.Request.FormFile("config")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to upload config file"})
return
}
defer file.Close()
var cfg config.Config
if err := json.NewDecoder(file).Decode(&cfg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid config file: " + err.Error()})
return
}
cm := c.MustGet("configManager").(*ConfigManager)
if err := cm.SaveConfig(&cfg); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Config imported successfully"})
}
+321
View File
@@ -0,0 +1,321 @@
package web
import (
"fmt"
"dhcp-dns-manager/internal/config"
"dhcp-dns-manager/internal/db"
"dhcp-dns-manager/internal/dhcp"
"dhcp-dns-manager/internal/dns"
"github.com/gin-gonic/gin"
"net/http"
"sync"
"time"
)
type Server struct {
config *config.WebConfig
db *db.DB
dhcpServer *dhcp.Server
dnsServer *dns.Server
router *gin.Engine
configManager *ConfigManager
}
type User struct {
ID uint `gorm:"primaryKey"`
Username string `gorm:"uniqueIndex"`
Password string
IsAdmin bool
}
func NewServer(cfg *config.WebConfig, database *db.DB, d *dhcp.Server, n *dns.Server, cm *ConfigManager) *Server {
gin.SetMode(gin.ReleaseMode)
s := &Server{
config: cfg,
db: database,
dhcpServer: d,
dnsServer: n,
router: gin.New(),
configManager: cm,
}
// Wire up config reloader so DHCP server picks up web UI config changes
d.SetConfigReloader(func() *config.DHCPConfig {
cfg := cm.GetConfig()
dhcpCfg := new(config.DHCPConfig)
*dhcpCfg = cfg.DHCP // copy the value
return dhcpCfg
})
s.setupRoutes()
return s
}
func (s *Server) setupRoutes() {
// Custom recovery middleware that returns JSON
s.router.Use(gin.CustomRecovery(func(c *gin.Context, err any) {
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("Internal server error: %v", err),
})
c.Abort()
}))
s.router.Use(gin.Logger())
// Inject ConfigManager into context
s.router.Use(func(c *gin.Context) {
c.Set("configManager", s.configManager)
c.Next()
})
// Static files
s.router.Static("/static", "./web/static")
// Public routes
s.router.GET("/", s.handleIndex)
s.router.POST("/api/login", s.handleLogin)
s.router.GET("/api/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "message": "Server is running"})
})
// Protected routes
protected := s.router.Group("/api")
protected.Use(s.authMiddleware())
{
// Dashboard
protected.GET("/dashboard", s.handleDashboard)
// DHCP
protected.GET("/dhcp/config", s.handleGetDHCPConfig)
protected.PUT("/dhcp/config", s.handleUpdateDHCPConfig)
protected.GET("/dhcp/leases", s.handleGetLeases)
protected.GET("/dhcp/bindings", s.handleGetBindings)
protected.POST("/dhcp/bindings", s.handleCreateBinding)
protected.DELETE("/dhcp/bindings/:id", s.handleDeleteBinding)
// DNS
protected.GET("/dns/config", s.handleGetDNSConfig)
protected.PUT("/dns/config", s.handleUpdateDNSConfig)
protected.GET("/dns/records", s.handleGetRecords)
protected.POST("/dns/records", s.handleCreateRecord)
protected.DELETE("/dns/records/:id", s.handleDeleteRecord)
protected.GET("/dns/logs", s.handleGetLogs)
protected.GET("/dns/zones", s.handleGetZones)
protected.POST("/dns/zones", s.handleCreateZone)
protected.DELETE("/dns/zones/:id", s.handleDeleteZone)
// Config
protected.GET("/config", s.handleGetFullConfig)
protected.PUT("/config", s.handleUpdateConfig)
protected.GET("/config/export", s.handleExportConfig)
protected.POST("/config/import", s.handleImportConfig)
// Service
protected.POST("/service/restart", s.handleRestartService)
}
}
// sessionStore in-memory session store
var sessionStore = sync.Map{}
func (s *Server) authMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
sessionID := c.GetHeader("X-Session-ID")
if sessionID == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
// Validate session exists in store
if _, ok := sessionStore.Load(sessionID); !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Session expired"})
c.Abort()
return
}
c.Next()
}
}
func (s *Server) Start() error {
addr := fmt.Sprintf("%s:%d", s.config.Host, s.config.Port)
return s.router.Run(addr)
}
// Handlers
func (s *Server) handleIndex(c *gin.Context) {
c.File("./web/templates/index.html")
}
func (s *Server) handleLogin(c *gin.Context) {
var req struct {
Username string `json:"username"`
Password string `json:"password"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// TODO: Validate against database
// For now, simple demo auth
if req.Username == "admin" && req.Password == "admin" {
sessionID := fmt.Sprintf("session-%d-%s", time.Now().UnixNano(), req.Username)
sessionStore.Store(sessionID, true)
c.JSON(http.StatusOK, gin.H{
"session_id": sessionID,
"is_admin": true,
})
return
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid credentials"})
}
func (s *Server) handleDashboard(c *gin.Context) {
leases := s.dhcpServer.GetLeases()
bindings, _ := s.dhcpServer.GetStaticBindings()
records, _ := s.dnsServer.GetDNSRecords()
c.JSON(http.StatusOK, gin.H{
"active_leases": len(leases),
"static_bindings": len(bindings),
"dns_records": len(records),
"leases": leases,
"bindings": bindings,
"records": records,
})
}
func (s *Server) handleGetLeases(c *gin.Context) {
leases := s.dhcpServer.GetLeases()
c.JSON(http.StatusOK, gin.H{"leases": leases})
}
func (s *Server) handleGetBindings(c *gin.Context) {
bindings, err := s.dhcpServer.GetStaticBindings()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"bindings": bindings})
}
func (s *Server) handleCreateBinding(c *gin.Context) {
var req struct {
MAC string `json:"mac"`
IP string `json:"ip"`
Hostname string `json:"hostname"`
Description string `json:"description"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := s.dhcpServer.CreateStaticBinding(req.MAC, req.IP, req.Hostname, req.Description); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Binding created"})
}
func (s *Server) handleDeleteBinding(c *gin.Context) {
_ = c.Param("id")
// TODO: Convert to uint and delete
c.JSON(http.StatusOK, gin.H{"message": "Binding deleted"})
}
func (s *Server) handleGetRecords(c *gin.Context) {
records, err := s.dnsServer.GetDNSRecords()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"records": records})
}
func (s *Server) handleCreateRecord(c *gin.Context) {
var req struct {
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
TTL int `json:"ttl"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := s.dnsServer.CreateDNSRecord(req.Name, req.Type, req.Value, req.TTL); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Record created"})
}
func (s *Server) handleDeleteRecord(c *gin.Context) {
_ = c.Param("id")
// TODO: Convert to uint and delete
c.JSON(http.StatusOK, gin.H{"message": "Record deleted"})
}
func (s *Server) handleGetLogs(c *gin.Context) {
logs, err := s.dnsServer.GetQueryLogs(100)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"logs": logs})
}
func (s *Server) handleGetZones(c *gin.Context) {
zones, err := s.dnsServer.GetDNSZones()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"zones": zones})
}
func (s *Server) handleCreateZone(c *gin.Context) {
var req struct {
Name string `json:"name"`
Type string `json:"type"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := s.dnsServer.CreateDNSZone(req.Name, req.Type); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Zone created"})
}
func (s *Server) handleDeleteZone(c *gin.Context) {
_ = c.Param("id")
// TODO: Convert to uint and delete
c.JSON(http.StatusOK, gin.H{"message": "Zone deleted"})
}
func (s *Server) handleGetConfig(c *gin.Context) {
// Return current config (without sensitive data)
c.JSON(http.StatusOK, gin.H{"config": "placeholder"})
}
func (s *Server) handleUpdateConfig(c *gin.Context) {
// Update config
c.JSON(http.StatusOK, gin.H{"message": "Config updated"})
}