Files
dhcp-dns-manager/internal/dns/server.go
T
CNBUGS AI 8ad4c3576d Fix DHCP client unable to get IP and config not persisting
- Fixed verifyAssignment being too strict for new clients
- Fixed parseRequestedIP string conversion bug
- Fixed response sent to 0.0.0.0 instead of broadcast address
- Added SO_BROADCAST support for UDP socket
- Fixed session persistence after page refresh (localStorage)
- Added in-memory session store for auth middleware
- Added config reloader so DHCP server picks up web UI changes dynamically
2026-04-24 16:03:54 +08:00

248 wiersze
4.8 KiB
Go

package dns
import (
"fmt"
"net"
"dhcp-dns-manager/internal/config"
"dhcp-dns-manager/internal/db"
"github.com/miekg/dns"
"sync"
"time"
)
type Server struct {
config *config.DNSConfig
db *db.DB
server *dns.Server
cache map[string]*CacheEntry
cacheMutex sync.RWMutex
stopChan chan struct{}
}
type CacheEntry struct {
Records []dns.RR
Expires time.Time
}
func NewServer(cfg *config.DNSConfig, database *db.DB) *Server {
return &Server{
config: cfg,
db: database,
cache: make(map[string]*CacheEntry),
stopChan: make(chan struct{}),
}
}
func (s *Server) Start() error {
if !s.config.Enabled {
return nil
}
s.server = &dns.Server{
Addr: fmt.Sprintf("%s:%d", s.config.ListenAddr, s.config.ListenPort),
Net: "udp",
Handler: dns.HandlerFunc(s.handleQuery),
}
go func() {
if err := s.server.ListenAndServe(); err != nil {
// Log error
}
}()
// Start cache cleanup
go s.cleanupCache()
return nil
}
func (s *Server) Stop() {
if s.server != nil {
s.server.Shutdown()
}
close(s.stopChan)
}
func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
if len(r.Question) == 0 {
w.WriteMsg(m)
return
}
q := r.Question[0]
// Check cache first
if records := s.getFromCache(q.Name, q.Qtype); records != nil {
m.Answer = records
w.WriteMsg(m)
return
}
// Check local DNS records
localRecords := s.getLocalRecords(q.Name, q.Qtype)
if len(localRecords) > 0 {
m.Answer = localRecords
s.addToCache(q.Name, q.Qtype, localRecords)
w.WriteMsg(m)
return
}
// Forward to upstream DNS
s.forwardQuery(w, r, m, q)
}
func (s *Server) getLocalRecords(name string, qtype uint16) []dns.RR {
records, err := s.db.GetDNSRecords()
if err != nil {
return nil
}
var result []dns.RR
for _, record := range records {
if record.Name != name {
continue
}
var rr dns.RR
switch record.Type {
case "A":
rr = &dns.A{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
A: net.ParseIP(record.Value),
}
case "CNAME":
rr = &dns.CNAME{
Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
Target: record.Value,
}
}
if rr != nil {
result = append(result, rr)
}
}
return result
}
func (s *Server) forwardQuery(w dns.ResponseWriter, r, m *dns.Msg, q dns.Question) {
c := new(dns.Client)
for _, upstream := range s.config.Upstream {
resp, _, err := c.Exchange(r, upstream+":53")
if err == nil && len(resp.Answer) > 0 {
m.Answer = resp.Answer
s.addToCache(q.Name, q.Qtype, resp.Answer)
break
}
}
w.WriteMsg(m)
// Log query
responseStr := "success"
if len(m.Answer) == 0 {
responseStr = "empty"
}
s.db.AddQueryLog(
w.RemoteAddr().String(),
q.Name,
dns.TypeToString[q.Qtype],
responseStr,
)
}
func (s *Server) getFromCache(name string, qtype uint16) []dns.RR {
s.cacheMutex.RLock()
defer s.cacheMutex.RUnlock()
key := cacheKey(name, qtype)
entry, exists := s.cache[key]
if !exists || time.Now().After(entry.Expires) {
return nil
}
return entry.Records
}
func (s *Server) addToCache(name string, qtype uint16, records []dns.RR) {
s.cacheMutex.Lock()
defer s.cacheMutex.Unlock()
key := cacheKey(name, qtype)
ttl := uint32(300) // Default 5 minutes
if len(records) > 0 {
ttl = records[0].Header().Ttl
}
s.cache[key] = &CacheEntry{
Records: records,
Expires: time.Now().Add(time.Duration(ttl) * time.Second),
}
}
func (s *Server) cleanupCache() {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
s.cacheMutex.Lock()
now := time.Now()
for key, entry := range s.cache {
if now.After(entry.Expires) {
delete(s.cache, key)
}
}
s.cacheMutex.Unlock()
case <-s.stopChan:
return
}
}
}
func cacheKey(name string, qtype uint16) string {
return name + ":" + string(qtype)
}
func (s *Server) CreateDNSRecord(name, rtype, value string, ttl int) error {
record := db.DNSRecord{
Name: name,
Type: rtype,
Value: value,
TTL: ttl,
Enabled: true,
}
return s.db.Create(&record).Error
}
func (s *Server) DeleteDNSRecord(id uint) error {
return s.db.Delete(&db.DNSRecord{}, id).Error
}
func (s *Server) GetDNSRecords() ([]db.DNSRecord, error) {
return s.db.GetDNSRecords()
}
func (s *Server) GetDNSZones() ([]db.DNSZone, error) {
return s.db.GetDNSZones()
}
func (s *Server) CreateDNSZone(name, zoneType string) error {
return s.db.CreateDNSZone(name, zoneType)
}
func (s *Server) DeleteDNSZone(id uint) error {
return s.db.DeleteDNSZone(id)
}
func (s *Server) GetQueryLogs(limit int) ([]db.DNSQueryLog, error) {
var logs []db.DNSQueryLog
err := s.db.Order("timestamp DESC").Limit(limit).Find(&logs).Error
return logs, err
}