Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a hook to catch invalid messages #1568

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions acceptfunc_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package dns

import (
"encoding/binary"
"net"
"testing"
)

Expand Down Expand Up @@ -33,3 +35,86 @@ func handleNotify(w ResponseWriter, req *Msg) {
m.SetReply(req)
w.WriteMsg(m)
}

func TestInvalidMsg(t *testing.T) {
HandleFunc("example.org.", func(ResponseWriter, *Msg) {
t.Fatal("the handler must not be called in any of these tests")
})
s, addrstr, _, err := RunLocalTCPServer(":0")
if err != nil {
t.Fatalf("unable to run test server: %v", err)
}
defer s.Shutdown()

s.MsgAcceptFunc = func(dh Header) MsgAcceptAction {
switch dh.Id {
case 0x0001:
return MsgAccept
case 0x0002:
return MsgReject
case 0x0003:
return MsgIgnore
case 0x0004:
return MsgRejectNotImplemented
default:
t.Errorf("unexpected ID %x", dh.Id)
return -1
}
}

invalidErrors := make(chan error)
s.MsgInvalidFunc = func(m []byte, err error) {
invalidErrors <- err
}

c, err := net.Dial("tcp", addrstr)
if err != nil {
t.Fatalf("cannot connect to test server: %v", err)
}

write := func(m []byte) {
var length [2]byte
binary.BigEndian.PutUint16(length[:], uint16(len(m)))
_, err := c.Write(length[:])
if err != nil {
t.Fatalf("length write failed: %v", err)
}
_, err = c.Write(m)
if err != nil {
t.Fatalf("content write failed: %v", err)
}
}

/* Message is too short, so there is no header to accept or reject. */

tooShortMessage := make([]byte, 11)
tooShortMessage[1] = 0x3 // ID = 3, would be ignored if it were parsable.

write(tooShortMessage)
// Expect an error to be reported.
<-invalidErrors

/* Message is accepted but is actually invalid. */

badMessage := make([]byte, 13)
badMessage[1] = 0x1 // ID = 1, Accept.
badMessage[5] = 1 // QDCOUNT = 1
badMessage[12] = 99 // Bad question section. Invalid!

write(badMessage)
// Expect an error to be reported.
<-invalidErrors

/* Message is rejected before it can be determined to be invalid. */

close(invalidErrors) // A call to InvalidMsgFunc would panic due to the closed chan.

badMessage[1] = 0x2 // ID = 2, Reject
write(badMessage)

badMessage[1] = 0x3 // ID = 3, Ignore
write(badMessage)

badMessage[1] = 0x4 // ID = 4, RejectNotImplemented
write(badMessage)
}
19 changes: 18 additions & 1 deletion server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,14 @@ type DecorateReader func(Reader) Reader
// Implementations should never return a nil Writer.
type DecorateWriter func(Writer) Writer

// InvalidMsgFunc is a listener hook for observing incoming messages that were discarded
// because they could not be parsed.
// Every message that is read by a Reader will eventually be provided to the Handler,
// rejected (or ignored) by the MsgAcceptFunc, or passed to this function.
type InvalidMsgFunc func(m []byte, err error)

func DefaultMsgInvalidFunc(m []byte, err error) {}

// A Server defines parameters for running an DNS server.
type Server struct {
// Address to listen on, ":dns" if empty.
Expand Down Expand Up @@ -233,6 +241,8 @@ type Server struct {
// AcceptMsgFunc will check the incoming message and will reject it early in the process.
// By default DefaultMsgAcceptFunc will be used.
MsgAcceptFunc MsgAcceptFunc
// MsgInvalidFunc is optional, will be called if a message is received but cannot be parsed.
MsgInvalidFunc InvalidMsgFunc

// Shutdown handling
lock sync.RWMutex
Expand Down Expand Up @@ -277,6 +287,9 @@ func (srv *Server) init() {
if srv.MsgAcceptFunc == nil {
srv.MsgAcceptFunc = DefaultMsgAcceptFunc
}
if srv.MsgInvalidFunc == nil {
srv.MsgInvalidFunc = DefaultMsgInvalidFunc
}
if srv.Handler == nil {
srv.Handler = DefaultServeMux
}
Expand Down Expand Up @@ -531,6 +544,7 @@ func (srv *Server) serveUDP(l net.PacketConn) error {
if cap(m) == srv.UDPSize {
srv.udpPool.Put(m[:srv.UDPSize])
}
srv.MsgInvalidFunc(m, ErrShortRead)
continue
}
wg.Add(1)
Expand Down Expand Up @@ -611,6 +625,7 @@ func (srv *Server) serveUDPPacket(wg *sync.WaitGroup, m []byte, u net.PacketConn
func (srv *Server) serveDNS(m []byte, w *response) {
dh, off, err := unpackMsgHdr(m, 0)
if err != nil {
srv.MsgInvalidFunc(m, err)
// Let client hang, they are sending crap; any reply can be used to amplify.
return
}
Expand All @@ -620,10 +635,12 @@ func (srv *Server) serveDNS(m []byte, w *response) {

switch action := srv.MsgAcceptFunc(dh); action {
case MsgAccept:
if req.unpack(dh, m, off) == nil {
err := req.unpack(dh, m, off)
if err == nil {
break
}

srv.MsgInvalidFunc(m, err)
fallthrough
case MsgReject, MsgRejectNotImplemented:
opcode := req.Opcode
Expand Down
Loading