Skip to content

Commit

Permalink
feat: Add DialOption for DialWithXXX method more universal and add us…
Browse files Browse the repository at this point in the history
…age for using SOCKS5 proxy to dial.
  • Loading branch information
Aaron-cdx committed Apr 26, 2024
1 parent 334a88f commit 362485a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,25 @@ if err != nil {
defer client.Close()
```

+ Dial by SOCKS5 proxy

```go
client, err := DialWithPasswd(addr, user, passwd, WithDialFuncOption(func(network string, address string) (net.Conn, error) {
// get proxy address from env or config
proxyAddress := os.Getenv("socks5_proxy")
dial, err := proxy.SOCKS5(network, proxyAddress, nil, nil)
if err != nil {
t.Fatal(err)
}
c, err := dial.Dial(network, address)
return c, err
}))
if err != nil {
handleErr(err)
}
defer client.Close()
```

## execute commmand

+ Don't care about output, calling Run
Expand Down
53 changes: 42 additions & 11 deletions sshclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ import (
type remoteScriptType byte
type remoteShellType byte

// DialOption represents an option for Dial connection.
type DialOption func(opt *dialOptions)

// dialOptions contains all the options set by WithXxxOpt func as flow.
type dialOptions struct {
dialFunc func(network string, address string) (net.Conn, error)
}

// WithDialFuncOption append dialFunc field to dialOptions
func WithDialFuncOption(dialFunc func(network string, address string) (net.Conn, error)) DialOption {
return func(opt *dialOptions) {
opt.dialFunc = dialFunc
}
}

const (
cmdLine remoteScriptType = iota
rawScript
Expand All @@ -36,7 +51,11 @@ type Client struct {
}

// DialWithPasswd starts a client connection to the given SSH server with passwd authmethod.
func DialWithPasswd(addr, user, passwd string) (*Client, error) {
func DialWithPasswd(addr, user, passwd string, options ...DialOption) (*Client, error) {
opts := &dialOptions{}
for _, option := range options {
option(opts)
}
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
Expand All @@ -45,11 +64,15 @@ func DialWithPasswd(addr, user, passwd string) (*Client, error) {
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}

return Dial("tcp", addr, config)
return Dial("tcp", addr, config, opts)
}

// DialWithKey starts a client connection to the given SSH server with key authmethod.
func DialWithKey(addr, user, keyfile string) (*Client, error) {
func DialWithKey(addr, user, keyfile string, options ...DialOption) (*Client, error) {
opts := &dialOptions{}
for _, option := range options {
option(opts)
}
key, err := ioutil.ReadFile(keyfile)
if err != nil {
return nil, err
Expand All @@ -67,12 +90,15 @@ func DialWithKey(addr, user, keyfile string) (*Client, error) {
},
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}

return Dial("tcp", addr, config)
return Dial("tcp", addr, config, opts)
}

// DialWithKeyWithPassphrase same as DialWithKey but with a passphrase to decrypt the private key
func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string) (*Client, error) {
func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string, options ...DialOption) (*Client, error) {
opts := &dialOptions{}
for _, option := range options {
option(opts)
}
key, err := ioutil.ReadFile(keyfile)
if err != nil {
return nil, err
Expand All @@ -91,12 +117,19 @@ func DialWithKeyWithPassphrase(addr, user, keyfile string, passphrase string) (*
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}

return Dial("tcp", addr, config)
return Dial("tcp", addr, config, opts)
}

// Dial starts a client connection to the given SSH server.
// This wraps ssh.Dial.
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
func Dial(network, addr string, config *ssh.ClientConfig, opts *dialOptions) (*Client, error) {
if opts != nil && opts.dialFunc != nil {
conn, err := opts.dialFunc(network, addr)
if err != nil {
return nil, err
}
return DialWithConnection(conn, network, config)
}
sshClient, err := ssh.Dial(network, addr, config)
if err != nil {
return nil, err
Expand All @@ -114,9 +147,7 @@ func DialWithConnection(conn net.Conn, addr string, config *ssh.ClientConfig) (*

client := ssh.NewClient(ncc, chans, reqs)

return &Client{
client: client,
}, nil
return &Client{sshClient: client}, nil
}

// Close closes the underlying client network connection.
Expand Down

0 comments on commit 362485a

Please sign in to comment.