Skip to content

Commit

Permalink
Extract request parsing in server.go
Browse files Browse the repository at this point in the history
  • Loading branch information
cyfdecyf committed Dec 13, 2012
1 parent 5686c46 commit 35e1e06
Showing 1 changed file with 82 additions and 47 deletions.
129 changes: 82 additions & 47 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,68 @@ var debug ss.DebugLog

var errAddr = errors.New("addr type not supported")

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 4096, 4096) // use 4096 to read more if possible
cur := 0 // current location in buf

// first read the complete request, may read extra bytes
var n int
for {
// hopefully, we should only need one read to get the complete request
if n, err = conn.Read(buf[cur:]); err != nil {
// debug.Println("read request error:", err)
return
}
cur += n
// buf[0] is address type
if buf[0] == 1 { // ip address
// request need: 1(addrType) + 4(IP) + 2(port)
if n >= (1 + 4 + 2) {
// debug.Println("ip request complete, cur:", cur)
break
}
} else if buf[0] == 3 { // domain name
if cur == 1 { // need at least the addrLen byte
continue
}
// request need: 2(addrType & addrLen) + addrLen + 2(port)
// buf[1] is addrLen
if n >= (2 + int(buf[1]) + 2) {
// debug.Println("domain request complete, cur:", cur)
break
}
} else {
err = errAddr
return
}
// debug.Println("request not complete, cur:", cur)
}

reqLen := 1 + 4 + 2 // default to IP addr length
if buf[0] == 1 {
addrIp := make(net.IP, 4)
copy(addrIp, buf[1:5])
host = addrIp.String()
} else if buf[0] == 3 {
reqLen = 2 + int(buf[1]) + 2
host = string(buf[2 : 2+buf[1]])
}
var port int16
sb := bytes.NewBuffer(buf[reqLen-2 : reqLen])
binary.Read(sb, binary.BigEndian, &port)

// debug.Println("requesting:", host, "header len", reqLen)
host += ":" + strconv.Itoa(int(port))
if cur > reqLen {
extra = buf[reqLen:cur]
// debug.Println("extra:", string(extra))
}
return
}

func handleConnection(conn *ss.Conn) {
if debug {
// function arguments are always evaluated, so surround debug
Expand All @@ -24,60 +86,33 @@ func handleConnection(conn *ss.Conn) {
}
defer conn.Close()

var addr string
var port int16
var addrType byte
var remote net.Conn
var c chan byte
var err error

buf := make([]byte, 1)
if _, err = conn.Read(buf); err != nil {
goto onError
}
addrType = buf[0]
if addrType == 1 {
buf = make([]byte, 6)
if _, err = conn.Read(buf); err != nil {
goto onError
}
sb := bytes.NewBuffer(buf[4:6])
binary.Read(sb, binary.BigEndian, &port)
addrIp := make(net.IP, 4)
copy(addrIp, buf[0:4])
addr = addrIp.String()
} else if addrType == 3 {
if _, err = conn.Read(buf); err != nil {
goto onError
}
addrLen := buf[0]
buf = make([]byte, addrLen+2)
if _, err = conn.Read(buf); err != nil {
goto onError
}
sb := bytes.NewBuffer(buf[addrLen : addrLen+2])
binary.Read(sb, binary.BigEndian, &port)
addr = string(buf[0:addrLen])
} else {
log.Println("unsurpported addr type")
err = errAddr
goto onError
host, extra, err := getRequest(conn)
if err != nil {
debug.Println("error getting request:", err)
return
}
debug.Println("connecting", addr)
if remote, err = net.Dial("tcp", addr+":"+strconv.Itoa(int(port))); err != nil {
goto onError
debug.Println("connecting", host)
remote, err := net.Dial("tcp", host)
if err != nil {
debug.Println("error connecting to:", host, err)
return
}
defer remote.Close()
debug.Println("piping", addr)
c = make(chan byte, 2)
// write extra bytes read from
if extra != nil {
debug.Println("writing extra content to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}
debug.Println("piping", host)
c := make(chan byte, 2)
go ss.Pipe(conn, remote, c)
go ss.Pipe(remote, conn, c)
<-c // close the other connection whenever one connection is closed
debug.Println("closing", addr)
debug.Println("closing", host)
return

onError:
debug.Println("error", addr, err)
}

// Add a encrypt table cache to save memory and startup time in case of many
Expand Down

0 comments on commit 35e1e06

Please sign in to comment.