Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
Add more command line options, simplify code.
  • Loading branch information
cyfdecyf committed Dec 16, 2012
2 parents a677a85 + 7c1f7da commit 3ee5018
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 128 deletions.
157 changes: 82 additions & 75 deletions cmd/shadowsocks-local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,17 @@ import (
var debug ss.DebugLog

var (
errAddr = errors.New("socks addr type not supported")
errVer = errors.New("socks version not supported")
errMethod = errors.New("socks only support 1 method now")
errAuth = errors.New("socks authentication not required")
errCmd = errors.New("socks command not supported")
errAddrType = errors.New("socks addr type not supported")
errVer = errors.New("socks version not supported")
errMethod = errors.New("socks only support 1 method now")
errAuthExtraData = errors.New("socks authentication get extra data")
errReqExtraData = errors.New("socks request get extra data")
errCmd = errors.New("socks command not supported")
)

const (
socksVer5 = 5
socksCmdConnect = 1
)

func handShake(conn net.Conn) (err error) {
Expand All @@ -29,26 +35,36 @@ func handShake(conn net.Conn) (err error) {
)
// version identification and method selection message in theory can have
// at most 256 methods, plus version and nmethod field in total 258 bytes
// the current rfc defines only 3 authentication methods (plus 2 reserved)
// the current rfc defines only 3 authentication methods (plus 2 reserved),
// so it won't be such long in practice

buf := make([]byte, 258-2, 258-2) // reuse the buf to read nmethod field
buf := make([]byte, 258, 258)

if _, err = io.ReadFull(conn, buf[:2]); err != nil {
var n int
// make sure we get the nmethod field
if n, err = io.ReadAtLeast(conn, buf, idNmethod+1); err != nil {
return
}
if buf[idVer] != 5 {
if buf[idVer] != socksVer5 {
return errVer
}
nmethod := buf[idNmethod]
if _, err = io.ReadFull(conn, buf[:nmethod]); err != nil {
return
nmethod := int(buf[idNmethod])
msgLen := nmethod + 2
if n == msgLen { // handshake done, common case
// do nothing, jump directly to send confirmation
} else if n < msgLen { // has more methods to read, rare case
if _, err = io.ReadFull(conn, buf[n:msgLen]); err != nil {
return
}
} else { // error, should not get extra data
return errAuthExtraData
}
// version 5, no authentication required
_, err = conn.Write([]byte{5, 0})
// send confirmation: version 5, no authentication required
_, err = conn.Write([]byte{socksVer5, 0})
return
}

func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err error) {
func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) {
const (
idVer = 0
idCmd = 1
Expand All @@ -65,71 +81,56 @@ func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err e
)
// refer to getRequest in server.go for why set buffer size to 263
buf := make([]byte, 263, 263)
cur := 0 // current location in buf
reqLen := 0
var n int
// read till we get possible domain length field
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
return
}
// check version and cmd
if buf[idVer] != socksVer5 {
err = errVer
return
}
if buf[idCmd] != socksCmdConnect {
err = errCmd
return
}

for {
var n int
// usually need to read only once
if n, err = conn.Read(buf[cur:]); err != nil {
// debug.Println("read request error:", err)
return
}
cur += n
if cur < idType+1 { // read till we get addr type
continue
}
// check version and cmd
if buf[idVer] != 5 {
err = errVer
return
}
if buf[idCmd] != 1 {
err = errCmd
return
}
// TODO following code is copied from server.go, fix code duplication?
if buf[idType] == typeIP {
if cur >= lenIP {
// debug.Println("ip request complete, cur:", cur)
reqLen = lenIP
break
}
} else if buf[idType] == typeDm {
if cur < idDmLen+1 { // read until we get address length byte
continue
}
if cur >= lenDmBase+int(buf[idDmLen]) {
// debug.Println("domain request complete, cur:", cur)
reqLen = lenDmBase + int(buf[idDmLen])
break
}
} else {
err = errAddr
reqLen := lenIP
if buf[idType] == typeDm {
reqLen = int(buf[idDmLen]) + lenDmBase
} else if buf[idType] != typeIP {
err = errAddrType
return
}

if n == reqLen {
// common case, do nothing
} else if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
// debug.Println("request not complete, cur:", cur)
} else {
err = errReqExtraData
return
}

rawaddr = buf[idType:reqLen]
if cur > reqLen {
extra = buf[reqLen:cur]
// debug.Println("extra:", string(extra))
}

if debug {
if buf[idType] == typeIP {
if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
} else if buf[idType] == typeIP {
addrIp := make(net.IP, 4)
copy(addrIp, buf[idIP0:idIP0+4])
host = addrIp.String()
} else if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
var port int16
sb := bytes.NewBuffer(buf[reqLen-2 : reqLen])
binary.Read(sb, binary.BigEndian, &port)
host += ":" + strconv.Itoa(int(port))
}

return
}

Expand All @@ -144,7 +145,7 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) {
log.Println("socks handshack:", err)
return
}
rawaddr, extra, addr, err := getRequest(conn)
rawaddr, addr, err := getRequest(conn)
if err != nil {
log.Println("error getting request:", err)
return
Expand All @@ -163,13 +164,6 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) {
return
}
defer remote.Close()
if extra != nil {
debug.Println("writing extra content to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}

c := make(chan byte, 2)
go ss.Pipe(conn, remote, c)
Expand All @@ -184,7 +178,7 @@ func run(port, password, server string) {
log.Fatal(err)
}
encTbl := ss.GetTable(password)
log.Printf("starting server at port %v ...\n", port)
log.Printf("starting local socks5 server at port %v, remote shadowsocks server %s...\n", port, server)
for {
conn, err := ln.Accept()
if err != nil {
Expand All @@ -197,11 +191,24 @@ func run(port, password, server string) {

func main() {
var configFile string
var cmdConfig ss.Config

flag.StringVar(&configFile, "c", "config.json", "specify config file")
flag.StringVar(&cmdConfig.Server, "s", "", "server address")
flag.StringVar(&cmdConfig.Password, "k", "", "password")
flag.IntVar(&cmdConfig.ServerPort, "p", 0, "server port")
flag.IntVar(&cmdConfig.LocalPort, "l", 0, "local socks5 proxy port")
flag.BoolVar((*bool)(&debug), "d", false, "print debug message")

flag.Parse()

config := ss.ParseConfig(configFile)
debug = ss.Debug

config, err := ss.ParseConfig(configFile)
if err != nil {
return
}
ss.UpdateConfig(config, &cmdConfig)
ss.SetDebug(debug)

run(strconv.Itoa(config.LocalPort), config.Password,
config.Server+":"+strconv.Itoa(config.ServerPort))
}
82 changes: 41 additions & 41 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"flag"
ss "github.com/shadowsocks/shadowsocks-go/shadowsocks"
"io"
"log"
"net"
"strconv"
Expand All @@ -15,7 +16,7 @@ import (

var debug ss.DebugLog

var errAddr = errors.New("addr type not supported")
var errAddrType = errors.New("addr type not supported")

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
Expand All @@ -35,58 +36,45 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260, 260)
cur := 0 // current location in buf
var n int
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
return
}

// first read the complete request, may read extra bytes
for {
// hopefully, we should only need one read to get the complete request
// this read normally will read just the request, no extra data
reqLen := lenIP
if buf[idType] == typeDm {
reqLen = int(buf[idDmLen]) + lenDmBase
} else if buf[idType] != typeIP {
err = errAddrType
return
}

if n < reqLen { // rare case
ss.SetReadTimeout(conn)
var n int
if n, err = conn.Read(buf[cur:]); err != nil {
// debug.Println("read request error:", err)
return
}
cur += n
if buf[idType] == typeIP {
if cur >= lenIP {
// debug.Println("ip request complete, cur:", cur)
break
}
} else if buf[idType] == typeDm {
if cur < idDmLen+1 { // read until we get address length byte
continue
}
if cur >= lenDmBase+int(buf[idDmLen]) {
// debug.Println("domain request complete, cur:", cur)
break
}
} else {
err = errAddr
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
// debug.Println("request not complete, cur:", cur)
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
}

reqLen := lenIP // default to IP request length
if buf[idType] == typeIP {
// TODO add ipv6 support
if buf[idType] == typeDm {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
} else if buf[idType] == typeIP {
addrIp := make(net.IP, 4)
copy(addrIp, buf[idIP0:idIP0+4])
host = addrIp.String()
} else if buf[idType] == typeDm {
reqLen = lenDmBase + int(buf[idDmLen])
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
var port int16
sb := bytes.NewBuffer(buf[reqLen-2 : reqLen])
binary.Read(sb, binary.BigEndian, &port)

// debug.Println("requesting:", host, "header len", reqLen)
host += ":" + strconv.Itoa(int(port))
if cur > reqLen {
extra = buf[reqLen:cur]
// debug.Println("extra:", string(extra))
}
return
}

Expand All @@ -112,7 +100,7 @@ func handleConnection(conn *ss.Conn) {
defer remote.Close()
// write extra bytes read from
if extra != nil {
debug.Println("writing extra content to remote, len", len(extra))
debug.Println("getRequest read extra data, writing to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
Expand Down Expand Up @@ -164,11 +152,23 @@ func run(port, password string) {

func main() {
var configFile string
var cmdConfig ss.Config

flag.StringVar(&configFile, "c", "config.json", "specify config file")
flag.StringVar(&cmdConfig.Password, "k", "", "password")
flag.IntVar(&cmdConfig.ServerPort, "p", 0, "server port")
flag.IntVar(&cmdConfig.Timeout, "t", 60, "connection timeout (in seconds)")
flag.BoolVar((*bool)(&debug), "d", false, "print debug message")

flag.Parse()

config := ss.ParseConfig(configFile)
debug = ss.Debug
config, err := ss.ParseConfig(configFile)
if err != nil {
return
}
ss.UpdateConfig(config, &cmdConfig)
ss.SetDebug(debug)

if len(config.PortPassword) == 0 {
run(strconv.Itoa(config.ServerPort), config.Password)
} else {
Expand Down
3 changes: 1 addition & 2 deletions config.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@
"8388": "barfoo!",
"8387": "foobar!"
},
"timeout":60,
"debug":true
"timeout":60
}
Loading

0 comments on commit 3ee5018

Please sign in to comment.