server.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. package dns
  2. import (
  3. "fmt"
  4. "net"
  5. "dhcp-dns-manager/internal/config"
  6. "dhcp-dns-manager/internal/db"
  7. "github.com/miekg/dns"
  8. "sync"
  9. "time"
  10. )
  11. type Server struct {
  12. config *config.DNSConfig
  13. db *db.DB
  14. server *dns.Server
  15. cache map[string]*CacheEntry
  16. cacheMutex sync.RWMutex
  17. stopChan chan struct{}
  18. }
  19. type CacheEntry struct {
  20. Records []dns.RR
  21. Expires time.Time
  22. }
  23. func NewServer(cfg *config.DNSConfig, database *db.DB) *Server {
  24. return &Server{
  25. config: cfg,
  26. db: database,
  27. cache: make(map[string]*CacheEntry),
  28. stopChan: make(chan struct{}),
  29. }
  30. }
  31. func (s *Server) Start() error {
  32. if !s.config.Enabled {
  33. return nil
  34. }
  35. s.server = &dns.Server{
  36. Addr: fmt.Sprintf("%s:%d", s.config.ListenAddr, s.config.ListenPort),
  37. Net: "udp",
  38. Handler: dns.HandlerFunc(s.handleQuery),
  39. }
  40. go func() {
  41. if err := s.server.ListenAndServe(); err != nil {
  42. // Log error
  43. }
  44. }()
  45. // Start cache cleanup
  46. go s.cleanupCache()
  47. return nil
  48. }
  49. func (s *Server) Stop() {
  50. if s.server != nil {
  51. s.server.Shutdown()
  52. }
  53. close(s.stopChan)
  54. }
  55. func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
  56. m := new(dns.Msg)
  57. m.SetReply(r)
  58. if len(r.Question) == 0 {
  59. w.WriteMsg(m)
  60. return
  61. }
  62. q := r.Question[0]
  63. // Check cache first
  64. if records := s.getFromCache(q.Name, q.Qtype); records != nil {
  65. m.Answer = records
  66. w.WriteMsg(m)
  67. return
  68. }
  69. // Check local DNS records
  70. localRecords := s.getLocalRecords(q.Name, q.Qtype)
  71. if len(localRecords) > 0 {
  72. m.Answer = localRecords
  73. s.addToCache(q.Name, q.Qtype, localRecords)
  74. w.WriteMsg(m)
  75. return
  76. }
  77. // Forward to upstream DNS
  78. s.forwardQuery(w, r, m, q)
  79. }
  80. func (s *Server) getLocalRecords(name string, qtype uint16) []dns.RR {
  81. records, err := s.db.GetDNSRecords()
  82. if err != nil {
  83. return nil
  84. }
  85. var result []dns.RR
  86. for _, record := range records {
  87. if record.Name != name {
  88. continue
  89. }
  90. var rr dns.RR
  91. switch record.Type {
  92. case "A":
  93. rr = &dns.A{
  94. Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
  95. A: net.ParseIP(record.Value),
  96. }
  97. case "CNAME":
  98. rr = &dns.CNAME{
  99. Hdr: dns.RR_Header{Name: name, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(record.TTL)},
  100. Target: record.Value,
  101. }
  102. }
  103. if rr != nil {
  104. result = append(result, rr)
  105. }
  106. }
  107. return result
  108. }
  109. func (s *Server) forwardQuery(w dns.ResponseWriter, r, m *dns.Msg, q dns.Question) {
  110. c := new(dns.Client)
  111. for _, upstream := range s.config.Upstream {
  112. resp, _, err := c.Exchange(r, upstream+":53")
  113. if err == nil && len(resp.Answer) > 0 {
  114. m.Answer = resp.Answer
  115. s.addToCache(q.Name, q.Qtype, resp.Answer)
  116. break
  117. }
  118. }
  119. w.WriteMsg(m)
  120. // Log query
  121. responseStr := "success"
  122. if len(m.Answer) == 0 {
  123. responseStr = "empty"
  124. }
  125. s.db.AddQueryLog(
  126. w.RemoteAddr().String(),
  127. q.Name,
  128. dns.TypeToString[q.Qtype],
  129. responseStr,
  130. )
  131. }
  132. func (s *Server) getFromCache(name string, qtype uint16) []dns.RR {
  133. s.cacheMutex.RLock()
  134. defer s.cacheMutex.RUnlock()
  135. key := cacheKey(name, qtype)
  136. entry, exists := s.cache[key]
  137. if !exists || time.Now().After(entry.Expires) {
  138. return nil
  139. }
  140. return entry.Records
  141. }
  142. func (s *Server) addToCache(name string, qtype uint16, records []dns.RR) {
  143. s.cacheMutex.Lock()
  144. defer s.cacheMutex.Unlock()
  145. key := cacheKey(name, qtype)
  146. ttl := uint32(300) // Default 5 minutes
  147. if len(records) > 0 {
  148. ttl = records[0].Header().Ttl
  149. }
  150. s.cache[key] = &CacheEntry{
  151. Records: records,
  152. Expires: time.Now().Add(time.Duration(ttl) * time.Second),
  153. }
  154. }
  155. func (s *Server) cleanupCache() {
  156. ticker := time.NewTicker(1 * time.Minute)
  157. defer ticker.Stop()
  158. for {
  159. select {
  160. case <-ticker.C:
  161. s.cacheMutex.Lock()
  162. now := time.Now()
  163. for key, entry := range s.cache {
  164. if now.After(entry.Expires) {
  165. delete(s.cache, key)
  166. }
  167. }
  168. s.cacheMutex.Unlock()
  169. case <-s.stopChan:
  170. return
  171. }
  172. }
  173. }
  174. func cacheKey(name string, qtype uint16) string {
  175. return name + ":" + string(qtype)
  176. }
  177. func (s *Server) CreateDNSRecord(name, rtype, value string, ttl int) error {
  178. record := db.DNSRecord{
  179. Name: name,
  180. Type: rtype,
  181. Value: value,
  182. TTL: ttl,
  183. Enabled: true,
  184. }
  185. return s.db.Create(&record).Error
  186. }
  187. func (s *Server) DeleteDNSRecord(id uint) error {
  188. return s.db.Delete(&db.DNSRecord{}, id).Error
  189. }
  190. func (s *Server) GetDNSRecords() ([]db.DNSRecord, error) {
  191. return s.db.GetDNSRecords()
  192. }
  193. func (s *Server) GetDNSZones() ([]db.DNSZone, error) {
  194. return s.db.GetDNSZones()
  195. }
  196. func (s *Server) CreateDNSZone(name, zoneType string) error {
  197. return s.db.CreateDNSZone(name, zoneType)
  198. }
  199. func (s *Server) DeleteDNSZone(id uint) error {
  200. return s.db.DeleteDNSZone(id)
  201. }
  202. func (s *Server) GetQueryLogs(limit int) ([]db.DNSQueryLog, error) {
  203. var logs []db.DNSQueryLog
  204. err := s.db.Order("timestamp DESC").Limit(limit).Find(&logs).Error
  205. return logs, err
  206. }