package dns import ( "fmt" "net" "dhcp-dns-manager/internal/config" "dhcp-dns-manager/internal/db" "github.com/miekg/dns" "sync" "time" ) type Server struct { config *config.DNSConfig db *db.DB server *dns.Server cache map[string]*CacheEntry cacheMutex sync.RWMutex stopChan chan struct{} } type CacheEntry struct { Records []dns.RR Expires time.Time } func NewServer(cfg *config.DNSConfig, database *db.DB) *Server { return &Server{ config: cfg, db: database, cache: make(map[string]*CacheEntry), stopChan: make(chan struct{}), } } func (s *Server) Start() error { if !s.config.Enabled { return nil } s.server = &dns.Server{ Addr: fmt.Sprintf("%s:%d", s.config.ListenAddr, s.config.ListenPort), Net: "udp", Handler: dns.HandlerFunc(s.handleQuery), } go func() { if err := s.server.ListenAndServe(); err != nil { // Log error } }() // Start cache cleanup go s.cleanupCache() return nil } func (s *Server) Stop() { if s.server != nil { s.server.Shutdown() } close(s.stopChan) } func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) if len(r.Question) == 0 { w.WriteMsg(m) return } q := r.Question[0] // Check cache first if records := s.getFromCache(q.Name, q.Qtype); records != nil { m.Answer = records w.WriteMsg(m) return } // Check local DNS records localRecords := s.getLocalRecords(q.Name, q.Qtype) if len(localRecords) > 0 { m.Answer = localRecords s.addToCache(q.Name, q.Qtype, localRecords) w.WriteMsg(m) return } // Forward to upstream DNS s.forwardQuery(w, r, m, q) } func (s *Server) getLocalRecords(name string, qtype uint16) []dns.RR { records, err := s.db.GetDNSRecords() if err != nil { return nil } var result []dns.RR for _, record := range records { if record.Name != name { continue } var rr dns.RR switch record.Type { case "A": rr = &dns.A{ Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record.TTL)}, A: net.ParseIP(record.Value), } case "CNAME": rr = &dns.CNAME{ Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record.TTL)}, Target: record.Value, } } if rr != nil { result = append(result, rr) } } return result } func (s *Server) forwardQuery(w dns.ResponseWriter, r, m *dns.Msg, q dns.Question) { c := new(dns.Client) for _, upstream := range s.config.Upstream { resp, _, err := c.Exchange(r, upstream+":53") if err == nil && len(resp.Answer) > 0 { m.Answer = resp.Answer s.addToCache(q.Name, q.Qtype, resp.Answer) break } } w.WriteMsg(m) // Log query responseStr := "success" if len(m.Answer) == 0 { responseStr = "empty" } s.db.AddQueryLog( w.RemoteAddr().String(), q.Name, dns.TypeToString[q.Qtype], responseStr, ) } func (s *Server) getFromCache(name string, qtype uint16) []dns.RR { s.cacheMutex.RLock() defer s.cacheMutex.RUnlock() key := cacheKey(name, qtype) entry, exists := s.cache[key] if !exists || time.Now().After(entry.Expires) { return nil } return entry.Records } func (s *Server) addToCache(name string, qtype uint16, records []dns.RR) { s.cacheMutex.Lock() defer s.cacheMutex.Unlock() key := cacheKey(name, qtype) ttl := uint32(300) // Default 5 minutes if len(records) > 0 { ttl = records[0].Header().Ttl } s.cache[key] = &CacheEntry{ Records: records, Expires: time.Now().Add(time.Duration(ttl) * time.Second), } } func (s *Server) cleanupCache() { ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: s.cacheMutex.Lock() now := time.Now() for key, entry := range s.cache { if now.After(entry.Expires) { delete(s.cache, key) } } s.cacheMutex.Unlock() case <-s.stopChan: return } } } func cacheKey(name string, qtype uint16) string { return name + ":" + string(qtype) } func (s *Server) CreateDNSRecord(name, rtype, value string, ttl int) error { record := db.DNSRecord{ Name: name, Type: rtype, Value: value, TTL: ttl, Enabled: true, } return s.db.Create(&record).Error } func (s *Server) DeleteDNSRecord(id uint) error { return s.db.Delete(&db.DNSRecord{}, id).Error } func (s *Server) GetDNSRecords() ([]db.DNSRecord, error) { return s.db.GetDNSRecords() } func (s *Server) GetDNSZones() ([]db.DNSZone, error) { return s.db.GetDNSZones() } func (s *Server) CreateDNSZone(name, zoneType string) error { return s.db.CreateDNSZone(name, zoneType) } func (s *Server) DeleteDNSZone(id uint) error { return s.db.DeleteDNSZone(id) } func (s *Server) GetQueryLogs(limit int) ([]db.DNSQueryLog, error) { var logs []db.DNSQueryLog err := s.db.Order("timestamp DESC").Limit(limit).Find(&logs).Error return logs, err }