server.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. // DNS server implementation.
  2. package dns
  3. import (
  4. "bytes"
  5. "crypto/tls"
  6. "io"
  7. "net"
  8. "sync"
  9. "time"
  10. )
  11. // Maximum number of TCP queries before we close the socket.
  12. const maxTCPQueries = 128
  13. // Handler is implemented by any value that implements ServeDNS.
  14. type Handler interface {
  15. ServeDNS(w ResponseWriter, r *Msg)
  16. }
  17. // A ResponseWriter interface is used by an DNS handler to
  18. // construct an DNS response.
  19. type ResponseWriter interface {
  20. // LocalAddr returns the net.Addr of the server
  21. LocalAddr() net.Addr
  22. // RemoteAddr returns the net.Addr of the client that sent the current request.
  23. RemoteAddr() net.Addr
  24. // WriteMsg writes a reply back to the client.
  25. WriteMsg(*Msg) error
  26. // Write writes a raw buffer back to the client.
  27. Write([]byte) (int, error)
  28. // Close closes the connection.
  29. Close() error
  30. // TsigStatus returns the status of the Tsig.
  31. TsigStatus() error
  32. // TsigTimersOnly sets the tsig timers only boolean.
  33. TsigTimersOnly(bool)
  34. // Hijack lets the caller take over the connection.
  35. // After a call to Hijack(), the DNS package will not do anything with the connection.
  36. Hijack()
  37. }
  38. type response struct {
  39. hijacked bool // connection has been hijacked by handler
  40. tsigStatus error
  41. tsigTimersOnly bool
  42. tsigRequestMAC string
  43. tsigSecret map[string]string // the tsig secrets
  44. udp *net.UDPConn // i/o connection if UDP was used
  45. tcp net.Conn // i/o connection if TCP was used
  46. udpSession *SessionUDP // oob data to get egress interface right
  47. remoteAddr net.Addr // address of the client
  48. writer Writer // writer to output the raw DNS bits
  49. }
  50. // ServeMux is an DNS request multiplexer. It matches the
  51. // zone name of each incoming request against a list of
  52. // registered patterns add calls the handler for the pattern
  53. // that most closely matches the zone name. ServeMux is DNSSEC aware, meaning
  54. // that queries for the DS record are redirected to the parent zone (if that
  55. // is also registered), otherwise the child gets the query.
  56. // ServeMux is also safe for concurrent access from multiple goroutines.
  57. type ServeMux struct {
  58. z map[string]Handler
  59. m *sync.RWMutex
  60. }
  61. // NewServeMux allocates and returns a new ServeMux.
  62. func NewServeMux() *ServeMux { return &ServeMux{z: make(map[string]Handler), m: new(sync.RWMutex)} }
  63. // DefaultServeMux is the default ServeMux used by Serve.
  64. var DefaultServeMux = NewServeMux()
  65. // The HandlerFunc type is an adapter to allow the use of
  66. // ordinary functions as DNS handlers. If f is a function
  67. // with the appropriate signature, HandlerFunc(f) is a
  68. // Handler object that calls f.
  69. type HandlerFunc func(ResponseWriter, *Msg)
  70. // ServeDNS calls f(w, r).
  71. func (f HandlerFunc) ServeDNS(w ResponseWriter, r *Msg) {
  72. f(w, r)
  73. }
  74. // HandleFailed returns a HandlerFunc that returns SERVFAIL for every request it gets.
  75. func HandleFailed(w ResponseWriter, r *Msg) {
  76. m := new(Msg)
  77. m.SetRcode(r, RcodeServerFailure)
  78. // does not matter if this write fails
  79. w.WriteMsg(m)
  80. }
  81. func failedHandler() Handler { return HandlerFunc(HandleFailed) }
  82. // ListenAndServe Starts a server on address and network specified Invoke handler
  83. // for incoming queries.
  84. func ListenAndServe(addr string, network string, handler Handler) error {
  85. server := &Server{Addr: addr, Net: network, Handler: handler}
  86. return server.ListenAndServe()
  87. }
  88. // ListenAndServeTLS acts like http.ListenAndServeTLS, more information in
  89. // http://golang.org/pkg/net/http/#ListenAndServeTLS
  90. func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
  91. cert, err := tls.LoadX509KeyPair(certFile, keyFile)
  92. if err != nil {
  93. return err
  94. }
  95. config := tls.Config{
  96. Certificates: []tls.Certificate{cert},
  97. }
  98. server := &Server{
  99. Addr: addr,
  100. Net: "tcp-tls",
  101. TLSConfig: &config,
  102. Handler: handler,
  103. }
  104. return server.ListenAndServe()
  105. }
  106. // ActivateAndServe activates a server with a listener from systemd,
  107. // l and p should not both be non-nil.
  108. // If both l and p are not nil only p will be used.
  109. // Invoke handler for incoming queries.
  110. func ActivateAndServe(l net.Listener, p net.PacketConn, handler Handler) error {
  111. server := &Server{Listener: l, PacketConn: p, Handler: handler}
  112. return server.ActivateAndServe()
  113. }
  114. func (mux *ServeMux) match(q string, t uint16) Handler {
  115. mux.m.RLock()
  116. defer mux.m.RUnlock()
  117. var handler Handler
  118. b := make([]byte, len(q)) // worst case, one label of length q
  119. off := 0
  120. end := false
  121. for {
  122. l := len(q[off:])
  123. for i := 0; i < l; i++ {
  124. b[i] = q[off+i]
  125. if b[i] >= 'A' && b[i] <= 'Z' {
  126. b[i] |= ('a' - 'A')
  127. }
  128. }
  129. if h, ok := mux.z[string(b[:l])]; ok { // 'causes garbage, might want to change the map key
  130. if t != TypeDS {
  131. return h
  132. }
  133. // Continue for DS to see if we have a parent too, if so delegeate to the parent
  134. handler = h
  135. }
  136. off, end = NextLabel(q, off)
  137. if end {
  138. break
  139. }
  140. }
  141. // Wildcard match, if we have found nothing try the root zone as a last resort.
  142. if h, ok := mux.z["."]; ok {
  143. return h
  144. }
  145. return handler
  146. }
  147. // Handle adds a handler to the ServeMux for pattern.
  148. func (mux *ServeMux) Handle(pattern string, handler Handler) {
  149. if pattern == "" {
  150. panic("dns: invalid pattern " + pattern)
  151. }
  152. mux.m.Lock()
  153. mux.z[Fqdn(pattern)] = handler
  154. mux.m.Unlock()
  155. }
  156. // HandleFunc adds a handler function to the ServeMux for pattern.
  157. func (mux *ServeMux) HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
  158. mux.Handle(pattern, HandlerFunc(handler))
  159. }
  160. // HandleRemove deregistrars the handler specific for pattern from the ServeMux.
  161. func (mux *ServeMux) HandleRemove(pattern string) {
  162. if pattern == "" {
  163. panic("dns: invalid pattern " + pattern)
  164. }
  165. mux.m.Lock()
  166. delete(mux.z, Fqdn(pattern))
  167. mux.m.Unlock()
  168. }
  169. // ServeDNS dispatches the request to the handler whose
  170. // pattern most closely matches the request message. If DefaultServeMux
  171. // is used the correct thing for DS queries is done: a possible parent
  172. // is sought.
  173. // If no handler is found a standard SERVFAIL message is returned
  174. // If the request message does not have exactly one question in the
  175. // question section a SERVFAIL is returned, unlesss Unsafe is true.
  176. func (mux *ServeMux) ServeDNS(w ResponseWriter, request *Msg) {
  177. var h Handler
  178. if len(request.Question) < 1 { // allow more than one question
  179. h = failedHandler()
  180. } else {
  181. if h = mux.match(request.Question[0].Name, request.Question[0].Qtype); h == nil {
  182. h = failedHandler()
  183. }
  184. }
  185. h.ServeDNS(w, request)
  186. }
  187. // Handle registers the handler with the given pattern
  188. // in the DefaultServeMux. The documentation for
  189. // ServeMux explains how patterns are matched.
  190. func Handle(pattern string, handler Handler) { DefaultServeMux.Handle(pattern, handler) }
  191. // HandleRemove deregisters the handle with the given pattern
  192. // in the DefaultServeMux.
  193. func HandleRemove(pattern string) { DefaultServeMux.HandleRemove(pattern) }
  194. // HandleFunc registers the handler function with the given pattern
  195. // in the DefaultServeMux.
  196. func HandleFunc(pattern string, handler func(ResponseWriter, *Msg)) {
  197. DefaultServeMux.HandleFunc(pattern, handler)
  198. }
  199. // Writer writes raw DNS messages; each call to Write should send an entire message.
  200. type Writer interface {
  201. io.Writer
  202. }
  203. // Reader reads raw DNS messages; each call to ReadTCP or ReadUDP should return an entire message.
  204. type Reader interface {
  205. // ReadTCP reads a raw message from a TCP connection. Implementations may alter
  206. // connection properties, for example the read-deadline.
  207. ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error)
  208. // ReadUDP reads a raw message from a UDP connection. Implementations may alter
  209. // connection properties, for example the read-deadline.
  210. ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error)
  211. }
  212. // defaultReader is an adapter for the Server struct that implements the Reader interface
  213. // using the readTCP and readUDP func of the embedded Server.
  214. type defaultReader struct {
  215. *Server
  216. }
  217. func (dr *defaultReader) ReadTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
  218. return dr.readTCP(conn, timeout)
  219. }
  220. func (dr *defaultReader) ReadUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
  221. return dr.readUDP(conn, timeout)
  222. }
  223. // DecorateReader is a decorator hook for extending or supplanting the functionality of a Reader.
  224. // Implementations should never return a nil Reader.
  225. type DecorateReader func(Reader) Reader
  226. // DecorateWriter is a decorator hook for extending or supplanting the functionality of a Writer.
  227. // Implementations should never return a nil Writer.
  228. type DecorateWriter func(Writer) Writer
  229. // A Server defines parameters for running an DNS server.
  230. type Server struct {
  231. // Address to listen on, ":dns" if empty.
  232. Addr string
  233. // if "tcp" or "tcp-tls" (DNS over TLS) it will invoke a TCP listener, otherwise an UDP one
  234. Net string
  235. // TCP Listener to use, this is to aid in systemd's socket activation.
  236. Listener net.Listener
  237. // TLS connection configuration
  238. TLSConfig *tls.Config
  239. // UDP "Listener" to use, this is to aid in systemd's socket activation.
  240. PacketConn net.PacketConn
  241. // Handler to invoke, dns.DefaultServeMux if nil.
  242. Handler Handler
  243. // Default buffer size to use to read incoming UDP messages. If not set
  244. // it defaults to MinMsgSize (512 B).
  245. UDPSize int
  246. // The net.Conn.SetReadTimeout value for new connections, defaults to 2 * time.Second.
  247. ReadTimeout time.Duration
  248. // The net.Conn.SetWriteTimeout value for new connections, defaults to 2 * time.Second.
  249. WriteTimeout time.Duration
  250. // TCP idle timeout for multiple queries, if nil, defaults to 8 * time.Second (RFC 5966).
  251. IdleTimeout func() time.Duration
  252. // Secret(s) for Tsig map[<zonename>]<base64 secret>.
  253. TsigSecret map[string]string
  254. // Unsafe instructs the server to disregard any sanity checks and directly hand the message to
  255. // the handler. It will specifically not check if the query has the QR bit not set.
  256. Unsafe bool
  257. // If NotifyStartedFunc is set it is called once the server has started listening.
  258. NotifyStartedFunc func()
  259. // DecorateReader is optional, allows customization of the process that reads raw DNS messages.
  260. DecorateReader DecorateReader
  261. // DecorateWriter is optional, allows customization of the process that writes raw DNS messages.
  262. DecorateWriter DecorateWriter
  263. // Graceful shutdown handling
  264. inFlight sync.WaitGroup
  265. lock sync.RWMutex
  266. started bool
  267. }
  268. // ListenAndServe starts a nameserver on the configured address in *Server.
  269. func (srv *Server) ListenAndServe() error {
  270. srv.lock.Lock()
  271. defer srv.lock.Unlock()
  272. if srv.started {
  273. return &Error{err: "server already started"}
  274. }
  275. addr := srv.Addr
  276. if addr == "" {
  277. addr = ":domain"
  278. }
  279. if srv.UDPSize == 0 {
  280. srv.UDPSize = MinMsgSize
  281. }
  282. switch srv.Net {
  283. case "tcp", "tcp4", "tcp6":
  284. a, e := net.ResolveTCPAddr(srv.Net, addr)
  285. if e != nil {
  286. return e
  287. }
  288. l, e := net.ListenTCP(srv.Net, a)
  289. if e != nil {
  290. return e
  291. }
  292. srv.Listener = l
  293. srv.started = true
  294. srv.lock.Unlock()
  295. e = srv.serveTCP(l)
  296. srv.lock.Lock() // to satisfy the defer at the top
  297. return e
  298. case "tcp-tls", "tcp4-tls", "tcp6-tls":
  299. network := "tcp"
  300. if srv.Net == "tcp4-tls" {
  301. network = "tcp4"
  302. } else if srv.Net == "tcp6" {
  303. network = "tcp6"
  304. }
  305. l, e := tls.Listen(network, addr, srv.TLSConfig)
  306. if e != nil {
  307. return e
  308. }
  309. srv.Listener = l
  310. srv.started = true
  311. srv.lock.Unlock()
  312. e = srv.serveTCP(l)
  313. srv.lock.Lock() // to satisfy the defer at the top
  314. return e
  315. case "udp", "udp4", "udp6":
  316. a, e := net.ResolveUDPAddr(srv.Net, addr)
  317. if e != nil {
  318. return e
  319. }
  320. l, e := net.ListenUDP(srv.Net, a)
  321. if e != nil {
  322. return e
  323. }
  324. if e := setUDPSocketOptions(l); e != nil {
  325. return e
  326. }
  327. srv.PacketConn = l
  328. srv.started = true
  329. srv.lock.Unlock()
  330. e = srv.serveUDP(l)
  331. srv.lock.Lock() // to satisfy the defer at the top
  332. return e
  333. }
  334. return &Error{err: "bad network"}
  335. }
  336. // ActivateAndServe starts a nameserver with the PacketConn or Listener
  337. // configured in *Server. Its main use is to start a server from systemd.
  338. func (srv *Server) ActivateAndServe() error {
  339. srv.lock.Lock()
  340. defer srv.lock.Unlock()
  341. if srv.started {
  342. return &Error{err: "server already started"}
  343. }
  344. pConn := srv.PacketConn
  345. l := srv.Listener
  346. if pConn != nil {
  347. if srv.UDPSize == 0 {
  348. srv.UDPSize = MinMsgSize
  349. }
  350. if t, ok := pConn.(*net.UDPConn); ok {
  351. if e := setUDPSocketOptions(t); e != nil {
  352. return e
  353. }
  354. srv.started = true
  355. srv.lock.Unlock()
  356. e := srv.serveUDP(t)
  357. srv.lock.Lock() // to satisfy the defer at the top
  358. return e
  359. }
  360. }
  361. if l != nil {
  362. srv.started = true
  363. srv.lock.Unlock()
  364. e := srv.serveTCP(l)
  365. srv.lock.Lock() // to satisfy the defer at the top
  366. return e
  367. }
  368. return &Error{err: "bad listeners"}
  369. }
  370. // Shutdown gracefully shuts down a server. After a call to Shutdown, ListenAndServe and
  371. // ActivateAndServe will return. All in progress queries are completed before the server
  372. // is taken down. If the Shutdown is taking longer than the reading timeout an error
  373. // is returned.
  374. func (srv *Server) Shutdown() error {
  375. srv.lock.Lock()
  376. if !srv.started {
  377. srv.lock.Unlock()
  378. return &Error{err: "server not started"}
  379. }
  380. srv.started = false
  381. srv.lock.Unlock()
  382. if srv.PacketConn != nil {
  383. srv.PacketConn.Close()
  384. }
  385. if srv.Listener != nil {
  386. srv.Listener.Close()
  387. }
  388. fin := make(chan bool)
  389. go func() {
  390. srv.inFlight.Wait()
  391. fin <- true
  392. }()
  393. select {
  394. case <-time.After(srv.getReadTimeout()):
  395. return &Error{err: "server shutdown is pending"}
  396. case <-fin:
  397. return nil
  398. }
  399. }
  400. // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
  401. func (srv *Server) getReadTimeout() time.Duration {
  402. rtimeout := dnsTimeout
  403. if srv.ReadTimeout != 0 {
  404. rtimeout = srv.ReadTimeout
  405. }
  406. return rtimeout
  407. }
  408. // serveTCP starts a TCP listener for the server.
  409. // Each request is handled in a separate goroutine.
  410. func (srv *Server) serveTCP(l net.Listener) error {
  411. defer l.Close()
  412. if srv.NotifyStartedFunc != nil {
  413. srv.NotifyStartedFunc()
  414. }
  415. reader := Reader(&defaultReader{srv})
  416. if srv.DecorateReader != nil {
  417. reader = srv.DecorateReader(reader)
  418. }
  419. handler := srv.Handler
  420. if handler == nil {
  421. handler = DefaultServeMux
  422. }
  423. rtimeout := srv.getReadTimeout()
  424. // deadline is not used here
  425. for {
  426. rw, e := l.Accept()
  427. if e != nil {
  428. if neterr, ok := e.(net.Error); ok && neterr.Temporary() {
  429. continue
  430. }
  431. return e
  432. }
  433. m, e := reader.ReadTCP(rw, rtimeout)
  434. srv.lock.RLock()
  435. if !srv.started {
  436. srv.lock.RUnlock()
  437. return nil
  438. }
  439. srv.lock.RUnlock()
  440. if e != nil {
  441. continue
  442. }
  443. srv.inFlight.Add(1)
  444. go srv.serve(rw.RemoteAddr(), handler, m, nil, nil, rw)
  445. }
  446. }
  447. // serveUDP starts a UDP listener for the server.
  448. // Each request is handled in a separate goroutine.
  449. func (srv *Server) serveUDP(l *net.UDPConn) error {
  450. defer l.Close()
  451. if srv.NotifyStartedFunc != nil {
  452. srv.NotifyStartedFunc()
  453. }
  454. reader := Reader(&defaultReader{srv})
  455. if srv.DecorateReader != nil {
  456. reader = srv.DecorateReader(reader)
  457. }
  458. handler := srv.Handler
  459. if handler == nil {
  460. handler = DefaultServeMux
  461. }
  462. rtimeout := srv.getReadTimeout()
  463. // deadline is not used here
  464. for {
  465. m, s, e := reader.ReadUDP(l, rtimeout)
  466. srv.lock.RLock()
  467. if !srv.started {
  468. srv.lock.RUnlock()
  469. return nil
  470. }
  471. srv.lock.RUnlock()
  472. if e != nil {
  473. continue
  474. }
  475. srv.inFlight.Add(1)
  476. go srv.serve(s.RemoteAddr(), handler, m, l, s, nil)
  477. }
  478. }
  479. // Serve a new connection.
  480. func (srv *Server) serve(a net.Addr, h Handler, m []byte, u *net.UDPConn, s *SessionUDP, t net.Conn) {
  481. defer srv.inFlight.Done()
  482. w := &response{tsigSecret: srv.TsigSecret, udp: u, tcp: t, remoteAddr: a, udpSession: s}
  483. if srv.DecorateWriter != nil {
  484. w.writer = srv.DecorateWriter(w)
  485. } else {
  486. w.writer = w
  487. }
  488. q := 0 // counter for the amount of TCP queries we get
  489. reader := Reader(&defaultReader{srv})
  490. if srv.DecorateReader != nil {
  491. reader = srv.DecorateReader(reader)
  492. }
  493. Redo:
  494. req := new(Msg)
  495. err := req.Unpack(m)
  496. if err != nil { // Send a FormatError back
  497. x := new(Msg)
  498. x.SetRcodeFormatError(req)
  499. w.WriteMsg(x)
  500. goto Exit
  501. }
  502. if !srv.Unsafe && req.Response {
  503. goto Exit
  504. }
  505. w.tsigStatus = nil
  506. if w.tsigSecret != nil {
  507. if t := req.IsTsig(); t != nil {
  508. secret := t.Hdr.Name
  509. if _, ok := w.tsigSecret[secret]; !ok {
  510. w.tsigStatus = ErrKeyAlg
  511. }
  512. w.tsigStatus = TsigVerify(m, w.tsigSecret[secret], "", false)
  513. w.tsigTimersOnly = false
  514. w.tsigRequestMAC = req.Extra[len(req.Extra)-1].(*TSIG).MAC
  515. }
  516. }
  517. h.ServeDNS(w, req) // Writes back to the client
  518. Exit:
  519. if w.tcp == nil {
  520. return
  521. }
  522. // TODO(miek): make this number configurable?
  523. if q > maxTCPQueries { // close socket after this many queries
  524. w.Close()
  525. return
  526. }
  527. if w.hijacked {
  528. return // client calls Close()
  529. }
  530. if u != nil { // UDP, "close" and return
  531. w.Close()
  532. return
  533. }
  534. idleTimeout := tcpIdleTimeout
  535. if srv.IdleTimeout != nil {
  536. idleTimeout = srv.IdleTimeout()
  537. }
  538. m, e := reader.ReadTCP(w.tcp, idleTimeout)
  539. if e == nil {
  540. q++
  541. goto Redo
  542. }
  543. w.Close()
  544. return
  545. }
  546. func (srv *Server) readTCP(conn net.Conn, timeout time.Duration) ([]byte, error) {
  547. conn.SetReadDeadline(time.Now().Add(timeout))
  548. l := make([]byte, 2)
  549. n, err := conn.Read(l)
  550. if err != nil || n != 2 {
  551. if err != nil {
  552. return nil, err
  553. }
  554. return nil, ErrShortRead
  555. }
  556. length, _ := unpackUint16(l, 0)
  557. if length == 0 {
  558. return nil, ErrShortRead
  559. }
  560. m := make([]byte, int(length))
  561. n, err = conn.Read(m[:int(length)])
  562. if err != nil || n == 0 {
  563. if err != nil {
  564. return nil, err
  565. }
  566. return nil, ErrShortRead
  567. }
  568. i := n
  569. for i < int(length) {
  570. j, err := conn.Read(m[i:int(length)])
  571. if err != nil {
  572. return nil, err
  573. }
  574. i += j
  575. }
  576. n = i
  577. m = m[:n]
  578. return m, nil
  579. }
  580. func (srv *Server) readUDP(conn *net.UDPConn, timeout time.Duration) ([]byte, *SessionUDP, error) {
  581. conn.SetReadDeadline(time.Now().Add(timeout))
  582. m := make([]byte, srv.UDPSize)
  583. n, s, e := ReadFromSessionUDP(conn, m)
  584. if e != nil || n == 0 {
  585. if e != nil {
  586. return nil, nil, e
  587. }
  588. return nil, nil, ErrShortRead
  589. }
  590. m = m[:n]
  591. return m, s, nil
  592. }
  593. // WriteMsg implements the ResponseWriter.WriteMsg method.
  594. func (w *response) WriteMsg(m *Msg) (err error) {
  595. var data []byte
  596. if w.tsigSecret != nil { // if no secrets, dont check for the tsig (which is a longer check)
  597. if t := m.IsTsig(); t != nil {
  598. data, w.tsigRequestMAC, err = TsigGenerate(m, w.tsigSecret[t.Hdr.Name], w.tsigRequestMAC, w.tsigTimersOnly)
  599. if err != nil {
  600. return err
  601. }
  602. _, err = w.writer.Write(data)
  603. return err
  604. }
  605. }
  606. data, err = m.Pack()
  607. if err != nil {
  608. return err
  609. }
  610. _, err = w.writer.Write(data)
  611. return err
  612. }
  613. // Write implements the ResponseWriter.Write method.
  614. func (w *response) Write(m []byte) (int, error) {
  615. switch {
  616. case w.udp != nil:
  617. n, err := WriteToSessionUDP(w.udp, m, w.udpSession)
  618. return n, err
  619. case w.tcp != nil:
  620. lm := len(m)
  621. if lm < 2 {
  622. return 0, io.ErrShortBuffer
  623. }
  624. if lm > MaxMsgSize {
  625. return 0, &Error{err: "message too large"}
  626. }
  627. l := make([]byte, 2, 2+lm)
  628. l[0], l[1] = packUint16(uint16(lm))
  629. m = append(l, m...)
  630. n, err := io.Copy(w.tcp, bytes.NewReader(m))
  631. return int(n), err
  632. }
  633. panic("not reached")
  634. }
  635. // LocalAddr implements the ResponseWriter.LocalAddr method.
  636. func (w *response) LocalAddr() net.Addr {
  637. if w.tcp != nil {
  638. return w.tcp.LocalAddr()
  639. }
  640. return w.udp.LocalAddr()
  641. }
  642. // RemoteAddr implements the ResponseWriter.RemoteAddr method.
  643. func (w *response) RemoteAddr() net.Addr { return w.remoteAddr }
  644. // TsigStatus implements the ResponseWriter.TsigStatus method.
  645. func (w *response) TsigStatus() error { return w.tsigStatus }
  646. // TsigTimersOnly implements the ResponseWriter.TsigTimersOnly method.
  647. func (w *response) TsigTimersOnly(b bool) { w.tsigTimersOnly = b }
  648. // Hijack implements the ResponseWriter.Hijack method.
  649. func (w *response) Hijack() { w.hijacked = true }
  650. // Close implements the ResponseWriter.Close method
  651. func (w *response) Close() error {
  652. // Can't close the udp conn, as that is actually the listener.
  653. if w.tcp != nil {
  654. e := w.tcp.Close()
  655. w.tcp = nil
  656. return e
  657. }
  658. return nil
  659. }