diff --git a/cmd/shadowsocks-local/local.go b/cmd/shadowsocks-local/local.go index d09fe38..bd21ff4 100644 --- a/cmd/shadowsocks-local/local.go +++ b/cmd/shadowsocks-local/local.go @@ -15,11 +15,17 @@ import ( var debug ss.DebugLog var ( - errAddr = errors.New("socks addr type not supported") - errVer = errors.New("socks version not supported") - errMethod = errors.New("socks only support 1 method now") - errAuth = errors.New("socks authentication not required") - errCmd = errors.New("socks command not supported") + errAddrType = errors.New("socks addr type not supported") + errVer = errors.New("socks version not supported") + errMethod = errors.New("socks only support 1 method now") + errAuthExtraData = errors.New("socks authentication get extra data") + errReqExtraData = errors.New("socks request get extra data") + errCmd = errors.New("socks command not supported") +) + +const ( + socksVer5 = 5 + socksCmdConnect = 1 ) func handShake(conn net.Conn) (err error) { @@ -29,26 +35,36 @@ func handShake(conn net.Conn) (err error) { ) // version identification and method selection message in theory can have // at most 256 methods, plus version and nmethod field in total 258 bytes - // the current rfc defines only 3 authentication methods (plus 2 reserved) + // the current rfc defines only 3 authentication methods (plus 2 reserved), + // so it won't be such long in practice - buf := make([]byte, 258-2, 258-2) // reuse the buf to read nmethod field + buf := make([]byte, 258, 258) - if _, err = io.ReadFull(conn, buf[:2]); err != nil { + var n int + // make sure we get the nmethod field + if n, err = io.ReadAtLeast(conn, buf, idNmethod+1); err != nil { return } - if buf[idVer] != 5 { + if buf[idVer] != socksVer5 { return errVer } - nmethod := buf[idNmethod] - if _, err = io.ReadFull(conn, buf[:nmethod]); err != nil { - return + nmethod := int(buf[idNmethod]) + msgLen := nmethod + 2 + if n == msgLen { // handshake done, common case + // do nothing, jump directly to send confirmation + } else if n < msgLen { // has more methods to read, rare case + if _, err = io.ReadFull(conn, buf[n:msgLen]); err != nil { + return + } + } else { // error, should not get extra data + return errAuthExtraData } - // version 5, no authentication required - _, err = conn.Write([]byte{5, 0}) + // send confirmation: version 5, no authentication required + _, err = conn.Write([]byte{socksVer5, 0}) return } -func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err error) { +func getRequest(conn net.Conn) (rawaddr []byte, host string, err error) { const ( idVer = 0 idCmd = 1 @@ -65,71 +81,56 @@ func getRequest(conn net.Conn) (rawaddr []byte, extra []byte, host string, err e ) // refer to getRequest in server.go for why set buffer size to 263 buf := make([]byte, 263, 263) - cur := 0 // current location in buf - reqLen := 0 + var n int + // read till we get possible domain length field + if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil { + return + } + // check version and cmd + if buf[idVer] != socksVer5 { + err = errVer + return + } + if buf[idCmd] != socksCmdConnect { + err = errCmd + return + } - for { - var n int - // usually need to read only once - if n, err = conn.Read(buf[cur:]); err != nil { - // debug.Println("read request error:", err) - return - } - cur += n - if cur < idType+1 { // read till we get addr type - continue - } - // check version and cmd - if buf[idVer] != 5 { - err = errVer - return - } - if buf[idCmd] != 1 { - err = errCmd - return - } - // TODO following code is copied from server.go, fix code duplication? - if buf[idType] == typeIP { - if cur >= lenIP { - // debug.Println("ip request complete, cur:", cur) - reqLen = lenIP - break - } - } else if buf[idType] == typeDm { - if cur < idDmLen+1 { // read until we get address length byte - continue - } - if cur >= lenDmBase+int(buf[idDmLen]) { - // debug.Println("domain request complete, cur:", cur) - reqLen = lenDmBase + int(buf[idDmLen]) - break - } - } else { - err = errAddr + reqLen := lenIP + if buf[idType] == typeDm { + reqLen = int(buf[idDmLen]) + lenDmBase + } else if buf[idType] != typeIP { + err = errAddrType + return + } + + if n == reqLen { + // common case, do nothing + } else if n < reqLen { // rare case + if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil { return } - // debug.Println("request not complete, cur:", cur) + } else { + err = errReqExtraData + return } rawaddr = buf[idType:reqLen] - if cur > reqLen { - extra = buf[reqLen:cur] - // debug.Println("extra:", string(extra)) - } if debug { - if buf[idType] == typeIP { + if buf[idType] == typeDm { + host = string(buf[idDm0 : idDm0+buf[idDmLen]]) + } else if buf[idType] == typeIP { addrIp := make(net.IP, 4) copy(addrIp, buf[idIP0:idIP0+4]) host = addrIp.String() - } else if buf[idType] == typeDm { - host = string(buf[idDm0 : idDm0+buf[idDmLen]]) } var port int16 sb := bytes.NewBuffer(buf[reqLen-2 : reqLen]) binary.Read(sb, binary.BigEndian, &port) host += ":" + strconv.Itoa(int(port)) } + return } @@ -144,7 +145,7 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) { log.Println("socks handshack:", err) return } - rawaddr, extra, addr, err := getRequest(conn) + rawaddr, addr, err := getRequest(conn) if err != nil { log.Println("error getting request:", err) return @@ -163,13 +164,6 @@ func handleConnection(conn net.Conn, server string, encTbl *ss.EncryptTable) { return } defer remote.Close() - if extra != nil { - debug.Println("writing extra content to remote, len", len(extra)) - if _, err = remote.Write(extra); err != nil { - debug.Println("write request extra error:", err) - return - } - } c := make(chan byte, 2) go ss.Pipe(conn, remote, c) @@ -199,7 +193,7 @@ func main() { var configFile string flag.StringVar(&configFile, "c", "config.json", "specify config file") flag.Parse() - + config := ss.ParseConfig(configFile) debug = ss.Debug run(strconv.Itoa(config.LocalPort), config.Password,