Skip to content

Commit

Permalink
Allow custom dialer to be specified via opts
Browse files Browse the repository at this point in the history
  • Loading branch information
AlCutter authored and beevik committed May 30, 2023
1 parent 86be42a commit 036a5fe
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
3 changes: 2 additions & 1 deletion CONTRIBUTORS
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ Anton Tolchanov (knyar)
Christopher Batey (chbatey)
Meng Zhuo (mengzhuo)
Leonid Evdokimov (darkk)
Ask Bjørn Hansen (abh)
Ask Bjørn Hansen (abh)
Al Cutter (AlCutter)
51 changes: 31 additions & 20 deletions ntp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"time"

"golang.org/x/net/ipv4"
Expand Down Expand Up @@ -167,6 +168,9 @@ type QueryOptions struct {
LocalAddress string // IP address to use for the client address
Port int // Server port, defaults to 123
TTL int // IP TTL to use, defaults to system default

// Dial allows the user to override the default UDP dialer behavior when contacting the remote NTP server.
Dial func(localAddress string, localPort int, remoteAddress string, remotePort int) (net.Conn, error)
}

// A Response contains time data, some of which is returned by the NTP server
Expand Down Expand Up @@ -324,29 +328,14 @@ func getTime(host string, opt QueryOptions) (*msg, ntpTime, error) {
if opt.Version < 2 || opt.Version > 4 {
return nil, 0, errors.New("invalid protocol version requested")
}

// Resolve the remote NTP server address.
raddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(host, "123"))
if err != nil {
return nil, 0, err
}

// Resolve the local address if specified as an option.
var laddr *net.UDPAddr
if opt.LocalAddress != "" {
laddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(opt.LocalAddress, "0"))
if err != nil {
return nil, 0, err
}
if opt.Port == 0 {
opt.Port = 123
}

// Override the port if requested.
if opt.Port != 0 {
raddr.Port = opt.Port
if opt.Dial == nil {
opt.Dial = defaultDial
}

// Prepare a "connection" to the remote server.
con, err := net.DialUDP("udp", laddr, raddr)
con, err := opt.Dial(opt.LocalAddress, 0, host, opt.Port)
if err != nil {
return nil, 0, err
}
Expand Down Expand Up @@ -436,6 +425,28 @@ func getTime(host string, opt QueryOptions) (*msg, ntpTime, error) {
return recvMsg, recvTime, nil
}

// defaultDial provides a UDP dialer based on Go's built-in net stack.
func defaultDial(localAddress string, localPort int, remoteAddress string, remotePort int) (net.Conn, error) {
raddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(remoteAddress, strconv.Itoa(remotePort)))
if err != nil {
return nil, err
}

var laddr *net.UDPAddr
if localAddress != "" {
laddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(localAddress, strconv.Itoa(localPort)))
if err != nil {
return nil, err
}
}

con, err := net.DialUDP("udp", laddr, raddr)
if err != nil {
return nil, err
}
return con, err
}

// parseTime parses the NTP packet along with the packet receive time to
// generate a Response record.
func parseTime(m *msg, recvTime ntpTime) *Response {
Expand Down
25 changes: 25 additions & 0 deletions ntp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
package ntp

import (
"errors"
"net"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -326,3 +328,26 @@ func TestOfflineKissCode(t *testing.T) {
assert.Equal(t, kissCode(c.id), c.str)
}
}

func TestOfflineCustomDialer(t *testing.T) {
ntpHost := "remote"
localHost := "local"
dialerCalled := false

qo := QueryOptions{
LocalAddress: localHost,
Dial: func(la string, lp int, ra string, rp int) (net.Conn, error) {
assert.Equal(t, la, localHost)
assert.Equal(t, ra, ntpHost)
assert.Equal(t, rp, 123)
// Only expect to be called once:
assert.False(t, dialerCalled)

dialerCalled = true
return nil, errors.New("not dialing")
},
}
_, _ = QueryWithOptions(ntpHost, qo)

assert.True(t, dialerCalled)
}

0 comments on commit 036a5fe

Please sign in to comment.