| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- package dns
- import (
- "fmt"
- "net"
- "dhcp-dns-manager/internal/config"
- "dhcp-dns-manager/internal/db"
- "github.com/miekg/dns"
- "sync"
- "time"
- )
- type Server struct {
- config *config.DNSConfig
- db *db.DB
- server *dns.Server
- cache map[string]*CacheEntry
- cacheMutex sync.RWMutex
- stopChan chan struct{}
- }
- type CacheEntry struct {
- Records []dns.RR
- Expires time.Time
- }
- func NewServer(cfg *config.DNSConfig, database *db.DB) *Server {
- return &Server{
- config: cfg,
- db: database,
- cache: make(map[string]*CacheEntry),
- stopChan: make(chan struct{}),
- }
- }
- func (s *Server) Start() error {
- if !s.config.Enabled {
- return nil
- }
- s.server = &dns.Server{
- Addr: fmt.Sprintf("%s:%d", s.config.ListenAddr, s.config.ListenPort),
- Net: "udp",
- Handler: dns.HandlerFunc(s.handleQuery),
- }
- go func() {
- if err := s.server.ListenAndServe(); err != nil {
- // Log error
- }
- }()
- // Start cache cleanup
- go s.cleanupCache()
- return nil
- }
- func (s *Server) Stop() {
- if s.server != nil {
- s.server.Shutdown()
- }
- close(s.stopChan)
- }
- func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
- m := new(dns.Msg)
- m.SetReply(r)
-
- if len(r.Question) == 0 {
- w.WriteMsg(m)
- return
- }
- q := r.Question[0]
-
- // Check cache first
- if records := s.getFromCache(q.Name, q.Qtype); records != nil {
- m.Answer = records
- w.WriteMsg(m)
- return
- }
- // Check local DNS records
- localRecords := s.getLocalRecords(q.Name, q.Qtype)
- if len(localRecords) > 0 {
- m.Answer = localRecords
- s.addToCache(q.Name, q.Qtype, localRecords)
- w.WriteMsg(m)
- return
- }
- // Forward to upstream DNS
- s.forwardQuery(w, r, m, q)
- }
- func (s *Server) getLocalRecords(name string, qtype uint16) []dns.RR {
- records, err := s.db.GetDNSRecords()
- if err != nil {
- return nil
- }
- var result []dns.RR
- for _, record := range records {
- if record.Name != name {
- continue
- }
-
- var rr dns.RR
- switch record.Type {
- case "A":
- rr = &dns.A{
- Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
- A: net.ParseIP(record.Value),
- }
- case "CNAME":
- rr = &dns.CNAME{
- Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
- Target: record.Value,
- }
- }
-
- if rr != nil {
- result = append(result, rr)
- }
- }
-
- return result
- }
- func (s *Server) forwardQuery(w dns.ResponseWriter, r, m *dns.Msg, q dns.Question) {
- c := new(dns.Client)
-
- for _, upstream := range s.config.Upstream {
- resp, _, err := c.Exchange(r, upstream+":53")
- if err == nil && len(resp.Answer) > 0 {
- m.Answer = resp.Answer
- s.addToCache(q.Name, q.Qtype, resp.Answer)
- break
- }
- }
-
- w.WriteMsg(m)
-
- // Log query
- responseStr := "success"
- if len(m.Answer) == 0 {
- responseStr = "empty"
- }
- s.db.AddQueryLog(
- w.RemoteAddr().String(),
- q.Name,
- dns.TypeToString[q.Qtype],
- responseStr,
- )
- }
- func (s *Server) getFromCache(name string, qtype uint16) []dns.RR {
- s.cacheMutex.RLock()
- defer s.cacheMutex.RUnlock()
-
- key := cacheKey(name, qtype)
- entry, exists := s.cache[key]
- if !exists || time.Now().After(entry.Expires) {
- return nil
- }
-
- return entry.Records
- }
- func (s *Server) addToCache(name string, qtype uint16, records []dns.RR) {
- s.cacheMutex.Lock()
- defer s.cacheMutex.Unlock()
-
- key := cacheKey(name, qtype)
- ttl := uint32(300) // Default 5 minutes
- if len(records) > 0 {
- ttl = records[0].Header().Ttl
- }
-
- s.cache[key] = &CacheEntry{
- Records: records,
- Expires: time.Now().Add(time.Duration(ttl) * time.Second),
- }
- }
- func (s *Server) cleanupCache() {
- ticker := time.NewTicker(1 * time.Minute)
- defer ticker.Stop()
-
- for {
- select {
- case <-ticker.C:
- s.cacheMutex.Lock()
- now := time.Now()
- for key, entry := range s.cache {
- if now.After(entry.Expires) {
- delete(s.cache, key)
- }
- }
- s.cacheMutex.Unlock()
- case <-s.stopChan:
- return
- }
- }
- }
- func cacheKey(name string, qtype uint16) string {
- return name + ":" + string(qtype)
- }
- func (s *Server) CreateDNSRecord(name, rtype, value string, ttl int) error {
- record := db.DNSRecord{
- Name: name,
- Type: rtype,
- Value: value,
- TTL: ttl,
- Enabled: true,
- }
- return s.db.Create(&record).Error
- }
- func (s *Server) DeleteDNSRecord(id uint) error {
- return s.db.Delete(&db.DNSRecord{}, id).Error
- }
- func (s *Server) GetDNSRecords() ([]db.DNSRecord, error) {
- return s.db.GetDNSRecords()
- }
- func (s *Server) GetDNSZones() ([]db.DNSZone, error) {
- return s.db.GetDNSZones()
- }
- func (s *Server) CreateDNSZone(name, zoneType string) error {
- return s.db.CreateDNSZone(name, zoneType)
- }
- func (s *Server) DeleteDNSZone(id uint) error {
- return s.db.DeleteDNSZone(id)
- }
- func (s *Server) GetQueryLogs(limit int) ([]db.DNSQueryLog, error) {
- var logs []db.DNSQueryLog
- err := s.db.Order("timestamp DESC").Limit(limit).Find(&logs).Error
- return logs, err
- }
|