Skip to content

Commit

Permalink
Refactor code to isolate packet structure from UDP connection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
sabhiram committed Apr 16, 2018
1 parent b83cc3a commit 4fd002b
Show file tree
Hide file tree
Showing 5 changed files with 159 additions and 122 deletions.
77 changes: 76 additions & 1 deletion cmd/wol/wol.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package main
import (
"errors"
"fmt"
"net"
"os"
"os/user"
"path"
Expand Down Expand Up @@ -34,6 +35,38 @@ var (

////////////////////////////////////////////////////////////////////////////////

// ipFromInterface returns a `*net.UDPAddr` from a network interface name.
func ipFromInterface(iface string) (*net.UDPAddr, error) {
ief, err := net.InterfaceByName(iface)
if err != nil {
return nil, err
}

addrs, err := ief.Addrs()
if err == nil && len(addrs) <= 0 {
err = fmt.Errorf("no address associated with interface %s", iface)
}
if err != nil {
return nil, err
}

// Validate that one of the addr's is a valid network IP address.
for _, addr := range addrs {
switch ip := addr.(type) {
case *net.IPNet:
// Verify that the DefaultMask for the address we want to use exists.
if ip.IP.DefaultMask() != nil {
return &net.UDPAddr{
IP: ip.IP,
}, nil
}
}
}
return nil, fmt.Errorf("no address associated with interface %s", iface)
}

////////////////////////////////////////////////////////////////////////////////

// Run the alias command.
func aliasCmd(args []string, aliases *Aliases) error {
if len(args) >= 2 {
Expand Down Expand Up @@ -99,7 +132,49 @@ func wakeCmd(args []string, aliases *Aliases) error {
bcastInterface = cliFlags.BroadcastInterface
}

err = wol.SendMagicPacket(macAddr, cliFlags.BroadcastIP+":"+cliFlags.UDPPort, bcastInterface)
// Populate the local address in the event that the broadcast interface has
// been set.
var localAddr *net.UDPAddr
if bcastInterface != "" {
localAddr, err = ipFromInterface(bcastInterface)
if err != nil {
return err
}
}

// The address to broadcast to is usually the default `255.255.255.255` but
// can be overloaded by specifying an override in the CLI arguments.
bcastAddr := fmt.Sprintf("%s:%s", cliFlags.BroadcastIP, cliFlags.UDPPort)
udpAddr, err := net.ResolveUDPAddr("udp", bcastAddr)
if err != nil {
return err
}

// Build the magic packet.
mp, err := wol.New(macAddr)
if err != nil {
return err
}

// Grab a stream of bytes to send.
bs, err := mp.Marshal()
if err != nil {
return err
}

// Grab a UDP connection to send our packet of bytes.
conn, err := net.DialUDP("udp", localAddr, udpAddr)
if err != nil {
return err
}
defer conn.Close()

fmt.Printf("Attempting to send a magic packet to MAC %s\n", macAddr)
fmt.Printf("... Broadcasting to: %s\n", bcastAddr)
n, err := conn.Write(bs)
if err == nil && n != 102 {
err = fmt.Errorf("magic packet sent was %d bytes (expected 102 bytes sent)", n)
}
if err != nil {
return err
}
Expand Down
44 changes: 44 additions & 0 deletions cmd/wol/wol_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

////////////////////////////////////////////////////////////////////////////////

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

////////////////////////////////////////////////////////////////////////////////

func TestIPFromInterface(t *testing.T) {
interfaces, err := net.Interfaces()
assert.Nil(t, err)

// We can't actually enforce that we get a valid IP, but either the error
// or the pointer should be nil.
for _, i := range interfaces {
addr, err := ipFromInterface(i.Name)
if err == nil {
assert.NotNil(t, addr)
} else {
assert.Nil(t, addr)
}
}
}

func TestIPFromInterfaceNegative(t *testing.T) {
// Test some fake interfaces.
var NegativeTestCases = []struct {
iface string
}{
{"fake-interface-0"},
{"fake-interface-1"},
}

for _, tc := range NegativeTestCases {
addr, err := ipFromInterface(tc.iface)
assert.Nil(t, addr)
assert.NotNil(t, err)
}
}
93 changes: 11 additions & 82 deletions magic_packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,17 @@ func New(mac string) (*MagicPacket, error) {
var packet MagicPacket
var macAddr MACAddress

// We only support 6 byte MAC addresses since it is much harder to use the
// binary.Write(...) interface when the size of the MagicPacket is dynamic.
if !reMAC.MatchString(mac) {
return nil, fmt.Errorf("invalid mac-address %s", mac)
}

hwAddr, err := net.ParseMAC(mac)
if err != nil {
return nil, err
}

// We only support 6 byte MAC addresses since it is much harder to use the
// binary.Write(...) interface when the size of the MagicPacket is dynamic.
if !reMAC.MatchString(mac) {
return nil, fmt.Errorf("%s is not a IEEE 802 MAC-48 address", mac)
}

// Copy bytes from the returned HardwareAddr -> a fixed size MACAddress.
for idx := range macAddr {
macAddr[idx] = hwAddr[idx]
Expand All @@ -63,83 +63,12 @@ func New(mac string) (*MagicPacket, error) {
return &packet, nil
}

////////////////////////////////////////////////////////////////////////////////

// GetIPFromInterface returns a `*net.UDPAddr` from a network interface name.
func GetIPFromInterface(iface string) (*net.UDPAddr, error) {
ief, err := net.InterfaceByName(iface)
if err != nil {
return nil, err
}

addrs, err := ief.Addrs()
if err != nil {
return nil, err
} else if len(addrs) <= 0 {
return nil, fmt.Errorf("no address associated with interface %s", iface)
}

// Validate that one of the addr's is a valid network IP address.
for _, addr := range addrs {
switch ip := addr.(type) {
case *net.IPNet:
// Verify that the DefaultMask for the address we want to use exists.
if ip.IP.DefaultMask() != nil {
return &net.UDPAddr{
IP: ip.IP,
}, nil
}
}
}
return nil, fmt.Errorf("no address associated with interface %s", iface)
}

// SendMagicPacket sends a magic packet UDP broadcast to the specified `macAddr`.
// The broadcast is sent to `bcastAddr` via the `iface`. An empty `iface`
// implies a nil local address to dial.
func SendMagicPacket(macAddr, bcastAddr, iface string) error {
// Construct a MagicPacket for the given MAC Address.
magicPacket, err := New(macAddr)
if err != nil {
return err
}

// Fill our byte buffer with the bytes in our MagicPacket.
// Marshal serializes the magic packet structure into a 102 byte slice.
func (mp *MagicPacket) Marshal() ([]byte, error) {
var buf bytes.Buffer
binary.Write(&buf, binary.BigEndian, magicPacket)
fmt.Printf("Attempting to send a magic packet to MAC %s\n", macAddr)
fmt.Printf("... Broadcasting to: %s\n", bcastAddr)

// Get a UDPAddr to send the broadcast to.
udpAddr, err := net.ResolveUDPAddr("udp", bcastAddr)
if err != nil {
return err
}

// If an interface was specified, get the address associated with it.
var localAddr *net.UDPAddr
if iface != "" {
var err error
localAddr, err = GetIPFromInterface(iface)
if err != nil {
return err
}
}

// Open a UDP connection, and defer it's cleanup.
connection, err := net.DialUDP("udp", localAddr, udpAddr)
if err != nil {
return err
}
defer connection.Close()

// Write the bytes of the MagicPacket to the connection.
n, err := connection.Write(buf.Bytes())
if err != nil {
return err
} else if n != 102 {
fmt.Printf("Warning: %d bytes written, %d expected!\n", n, 102)
if err := binary.Write(&buf, binary.BigEndian, mp); err != nil {
return nil, err
}

return nil
return buf.Bytes(), nil
}
61 changes: 25 additions & 36 deletions magic_packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package wol
////////////////////////////////////////////////////////////////////////////////

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -12,16 +11,14 @@ import (
////////////////////////////////////////////////////////////////////////////////

func TestNewMagicPacket(t *testing.T) {
var PositiveTestCases = []struct {
for _, tc := range []struct {
mac string
expected MACAddress
}{
{"00:00:00:00:00:00", MACAddress{0, 0, 0, 0, 0, 0}},
{"00:ff:01:03:00:00", MACAddress{0, 255, 1, 3, 0, 0}},
{"00-ff-01-03-00-00", MACAddress{0, 255, 1, 3, 0, 0}},
}

for _, tc := range PositiveTestCases {
} {
pkt, err := New(tc.mac)
for _, v := range pkt.header {
assert.Equal(t, int(v), 255)
Expand All @@ -34,47 +31,39 @@ func TestNewMagicPacket(t *testing.T) {
}

func TestNewMagicPacketNegative(t *testing.T) {
var NegativeTestCases = []struct {
for _, tc := range []struct {
mac string
}{
{"00x00:00:00:00:00"},
{"00:00:Z0:00:00:00"},
}

for _, tc := range NegativeTestCases {
{"01:23:45:67:89:ab:cd:ef"},
{"01:23:45:67:89:ab:cd:ef:00:00:01:23:45:67:89:ab:cd:ef:00:00"},
{"01-23-45-67-89-ab-cd-ef"},
{"01-23-45-67-89-ab-cd-ef-00-00-01-23-45-67-89-ab-cd-ef-00-00"},
{"0123.4567.89ab"},
{"0123.4567.89ab.cdef"},
{"0123.4567.89ab.cdef.0000.0123.4567.89ab.cdef.0000"},
} {
_, err := New(tc.mac)
assert.NotNil(t, err)
}
}

func TestGetIPFromInterface(t *testing.T) {
interfaces, err := net.Interfaces()
assert.Nil(t, err)

// We can't actually enforce that we get a valid IP, but either the error
// or the pointer should be nil.
for _, i := range interfaces {
addr, err := GetIPFromInterface(i.Name)
if err == nil {
assert.NotNil(t, addr)
} else {
assert.Nil(t, addr)
}
}
}

func TestGetIPFromInterfaceNegative(t *testing.T) {
// Test some fake interfaces.
var NegativeTestCases = []struct {
iface string
func TestMagicPacketMarshal(t *testing.T) {
for _, tc := range []struct {
mac string
count int
}{
{"fake-interface-0"},
{"fake-interface-1"},
}
{"00:00:00:00:00:00", 102},
{"00:ff:01:03:00:00", 102},
{"00-ff-01-03-00-00", 102},
} {
pkt, err := New(tc.mac)
assert.Equal(t, err, nil)

for _, tc := range NegativeTestCases {
addr, err := GetIPFromInterface(tc.iface)
assert.Nil(t, addr)
assert.NotNil(t, err)
bs, err := pkt.Marshal()
assert.Equal(t, err, nil)

assert.Equal(t, len(bs), tc.count)
}
}
6 changes: 3 additions & 3 deletions version_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package wol

// WARNING: Auto generated version file. Do not edit this file by hand.
// WARNING: go get github.com/sabhiram/gover to manage this file.
// Version: 1.1.1
// Version: 1.1.2
const (
Major = 1
Minor = 1
Patch = 1
Patch = 2

Version = "1.1.1"
Version = "1.1.2"
)

0 comments on commit 4fd002b

Please sign in to comment.