Skip to content
This repository has been archived by the owner on Apr 9, 2020. It is now read-only.

support one time auth in client & server #113

Merged
merged 3 commits into from
Dec 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
language: go
go:
- 1.4.2
- 1.4.3
install:
- go get golang.org/x/crypto/blowfish
- go get golang.org/x/crypto/cast5
Expand All @@ -10,3 +10,4 @@ install:
- go install ./cmd/shadowsocks-server
script:
- PATH=$PATH:$HOME/gopath/bin bash -x ./script/test.sh
sudo: false
6 changes: 5 additions & 1 deletion cmd/shadowsocks-local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,12 @@ func parseServerConfig(config *ss.Config) {
}

if len(config.ServerPassword) == 0 {
method := config.Method
if config.Auth {
method += "-ota"
}
// only one encryption table
cipher, err := ss.NewCipher(config.Method, config.Password)
cipher, err := ss.NewCipher(method, config.Password)
if err != nil {
log.Fatal("Failed generating ciphers:", err)
}
Expand Down
119 changes: 67 additions & 52 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"encoding/binary"
"errors"
"flag"
Expand All @@ -17,61 +18,61 @@ import (
"syscall"
)

var debug ss.DebugLog
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
lenHmacSha1 = 10
)

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address
var debug ss.DebugLog

lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port
lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port
lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
)
func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) {
ss.SetReadTimeout(conn)

// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260)
var n int
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
buf := make([]byte, 270)
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
return
}

reqLen := -1
switch buf[idType] {
var reqStart, reqEnd int
addrType := buf[idType]
switch addrType & ss.AddrMask {
case typeIPv4:
reqLen = lenIPv4
reqStart, reqEnd = idIP0, idIP0+lenIPv4
case typeIPv6:
reqLen = lenIPv6
reqStart, reqEnd = idIP0, idIP0+lenIPv6
case typeDm:
reqLen = int(buf[idDmLen]) + lenDmBase
if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
return
}
reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase)
default:
err = fmt.Errorf("addr type %d not supported", buf[idType])
err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask)
return
}

if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
return
}

// Return string for typeIP is not most efficient, but browsers (Chrome,
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
// big problem.
switch buf[idType] {
switch addrType & ss.AddrMask {
case typeIPv4:
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
case typeIPv6:
Expand All @@ -80,8 +81,22 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen])
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
// if specified one time auth enabled, we should verify this
if auth || addrType&ss.OneTimeAuthMask > 0 {
ota = true
if _, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil {
return
}
iv := conn.GetIv()
key := conn.GetKey()
actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd])
if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) {
err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd])
return
}
}
return
}

Expand All @@ -90,7 +105,11 @@ const logCntDelta = 100
var connCnt int
var nextLogConnCnt int = logCntDelta

func handleConnection(conn *ss.Conn) {
type isClosed struct {
isClosed bool
}

func handleConnection(conn *ss.Conn, auth bool) {
var host string

connCnt++ // this maybe not accurate, but should be enough
Expand Down Expand Up @@ -118,7 +137,7 @@ func handleConnection(conn *ss.Conn) {
}
}()

host, extra, err := getRequest(conn)
host, ota, err := getRequest(conn, auth)
if err != nil {
log.Println("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err)
return
Expand All @@ -140,18 +159,14 @@ func handleConnection(conn *ss.Conn) {
remote.Close()
}
}()
// write extra bytes read from
if extra != nil {
// debug.Println("getRequest read extra data, writing to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}
if debug {
debug.Printf("piping %s<->%s", conn.RemoteAddr(), host)
debug.Printf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta())
}
if ota {
go ss.PipeThenCloseOta(conn, remote)
} else {
go ss.PipeThenClose(conn, remote)
}
go ss.PipeThenClose(conn, remote)
ss.PipeThenClose(remote, conn)
closed = true
return
Expand Down Expand Up @@ -195,7 +210,7 @@ func (pm *PasswdManager) del(port string) {
// port. A different approach would be directly change the password used by
// that port, but that requires **sharing** password between the port listener
// and password manager.
func (pm *PasswdManager) updatePortPasswd(port, password string) {
func (pm *PasswdManager) updatePortPasswd(port, password string, auth bool) {
pl, ok := pm.get(port)
if !ok {
log.Printf("new port %s added\n", port)
Expand All @@ -208,7 +223,7 @@ func (pm *PasswdManager) updatePortPasswd(port, password string) {
}
// run will add the new port listener to passwdManager.
// So there maybe concurrent access to passwdManager and we need lock to protect it.
go run(port, password)
go run(port, password, auth)
}

var passwdManager = PasswdManager{portListener: map[string]*PortListener{}}
Expand All @@ -227,7 +242,7 @@ func updatePasswd() {
return
}
for port, passwd := range config.PortPassword {
passwdManager.updatePortPasswd(port, passwd)
passwdManager.updatePortPasswd(port, passwd, config.Auth)
if oldconfig.PortPassword != nil {
delete(oldconfig.PortPassword, port)
}
Expand All @@ -254,7 +269,7 @@ func waitSignal() {
}
}

func run(port, password string) {
func run(port, password string, auth bool) {
ln, err := net.Listen("tcp", ":"+port)
if err != nil {
log.Printf("error listening port %v: %v\n", port, err)
Expand All @@ -280,7 +295,7 @@ func run(port, password string) {
continue
}
}
go handleConnection(ss.NewConn(conn, cipher.Copy()))
go handleConnection(ss.NewConn(conn, cipher.Copy()), auth)
}
}

Expand Down Expand Up @@ -357,7 +372,7 @@ func main() {
runtime.GOMAXPROCS(core)
}
for port, password := range config.PortPassword {
go run(port, password)
go run(port, password, config.Auth)
}

waitSignal()
Expand Down
4 changes: 1 addition & 3 deletions script/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ test_get() {
# -s silent to disable progress meter, but enable --show-error
# -i to include http header
# -L to follow redirect so we should always get HTTP 200
cont=`curl --socks5 $SOCKS -s --show-error -i -L $url 2>&1`
cont=`curl -m 5 --socks5 $SOCKS -s --show-error -i -L $url 2>&1`
ok=`echo $cont | grep -E -o "HTTP/1\.1 +$code"`
html=`echo $cont | grep -E -o -i "$target"`
if [[ -z $ok || -z $html ]] ; then
Expand Down Expand Up @@ -101,8 +101,6 @@ test_server_local_pair() {

local url
url=http://127.0.0.1:$HTTP_PORT/README.md
test_shadowsocks $url table
test_shadowsocks $url rc4
test_shadowsocks $url rc4-md5
test_shadowsocks $url aes-128-cfb
test_shadowsocks $url aes-192-cfb
Expand Down
9 changes: 6 additions & 3 deletions shadowsocks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"reflect"
"time"
"strings"
)

type Config struct {
Expand All @@ -23,6 +24,7 @@ type Config struct {
LocalPort int `json:"local_port"`
Password string `json:"password"`
Method string `json:"method"` // encryption method
Auth bool `json:"auth"` // one time auth

// following options are only used by server
PortPassword map[string]string `json:"port_password"`
Expand Down Expand Up @@ -85,6 +87,10 @@ func ParseConfig(path string) (config *Config, err error) {
return nil, err
}
readTimeout = time.Duration(config.Timeout) * time.Second
if strings.HasSuffix(strings.ToLower(config.Method), "-ota") {
config.Method = config.Method[:len(config.Method) - 4]
config.Auth = true
}
return
}

Expand Down Expand Up @@ -124,9 +130,6 @@ func UpdateConfig(old, new *Config) {
}
}
}
if old.Method == "table" {
old.Method = ""
}

old.Timeout = new.Timeout
readTimeout = time.Duration(old.Timeout) * time.Second
Expand Down
Loading