Skip to content

Commit

Permalink
bug with error handling, rebase
Browse files Browse the repository at this point in the history
Signed-off-by: Hamza El-Saawy <[email protected]>
  • Loading branch information
helsaawy committed May 5, 2022
1 parent 634e02f commit df6977a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
20 changes: 15 additions & 5 deletions hvsock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package winio

import (
"context"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -437,6 +438,10 @@ func canRedial(err error) bool {
}

func (conn *HvsockConn) opErr(op string, err error) error {
// translate from "file closed" to "socket closed"
if errors.Is(err, ErrFileClosed) {
err = sockets.ErrSocketClosed
}
return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
}

Expand All @@ -451,8 +456,8 @@ func (conn *HvsockConn) Read(b []byte) (int, error) {
err = syscall.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.readDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsarecv", err)
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsarecv", eno)
}
return 0, conn.opErr("read", err)
} else if n == 0 {
Expand Down Expand Up @@ -485,8 +490,8 @@ func (conn *HvsockConn) write(b []byte) (int, error) {
err = syscall.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
n, err := conn.sock.asyncIo(c, &conn.sock.writeDeadline, bytes, err)
if err != nil {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError("wsasend", err)
if eno := windows.Errno(0); errors.As(err, &eno) {
err = os.NewSyscallError("wsasend", eno)
}
return 0, conn.opErr("write", err)
}
Expand All @@ -505,11 +510,16 @@ func (conn *HvsockConn) IsClosed() bool {
// shutdown disables sending or receiving on a socket
func (conn *HvsockConn) shutdown(how int) error {
if conn.IsClosed() {
return ErrFileClosed
return sockets.ErrSocketClosed
}

err := syscall.Shutdown(conn.sock.handle, how)
if err != nil {
// If the connection was closed, shutdowns fail with "not connected"
if errors.Is(err, windows.WSAENOTCONN) ||
errors.Is(err, windows.WSAESHUTDOWN) {
err = sockets.ErrSocketClosed
}
return os.NewSyscallError("shutdown", err)
}
return nil
Expand Down
8 changes: 4 additions & 4 deletions pkg/sockets/rawaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ var (
// https://docs.microsoft.com/en-us/windows/win32/winsock/sockaddr-2
type RawSockaddr interface {
// Sockaddr returns a pointer to the RawSockaddr and the length of the struct.
Sockaddr() (ptr unsafe.Pointer, len int32, err error)
Sockaddr() (unsafe.Pointer, int32, error)

// FromBytes populates the RawsockAddr with the data in the byte array.
// Implementers should check the buffer is correctly sized and the address family
Expand All @@ -30,12 +30,12 @@ type RawSockaddr interface {
FromBytes([]byte) error
}

func validateSockAddr(ptr unsafe.Pointer, len int32) error {
func validateSockAddr(ptr unsafe.Pointer, n int32) error {
if ptr == nil {
return fmt.Errorf("pointer is %p: %w", ptr, ErrInvalidPointer)
}
if len < 1 {
return fmt.Errorf("buffer size %d < 1: %w", len, ErrBufferSize)
if n < 1 {
return fmt.Errorf("buffer size %d < 1: %w", n, ErrBufferSize)
}
return nil
}
8 changes: 4 additions & 4 deletions pkg/sockets/sockets.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (

const socketError = uintptr(^uint32(0))

var ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)

// CloseWriter is a connection that can disable writing to itself.
type CloseWriter interface {
net.Conn
Expand Down Expand Up @@ -125,7 +127,6 @@ func (f *runtimeFunc) Load() error {
)
})
return f.err

}

var (
Expand All @@ -141,9 +142,8 @@ var (
)

func ConnectEx(fd windows.Handle, rsa RawSockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *windows.Overlapped) error {
err := connectExFunc.Load()
if err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %e", err)
if err := connectExFunc.Load(); err != nil {
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
}
ptr, n, err := rsa.Sockaddr()
if err != nil {
Expand Down

0 comments on commit df6977a

Please sign in to comment.