Skip to content
This repository has been archived by the owner on Apr 9, 2020. It is now read-only.

Commit

Permalink
support one time auth in client & server
Browse files Browse the repository at this point in the history
append "-ota" suffix in method name to enable one time auth
  • Loading branch information
ayanamist committed Dec 5, 2015
1 parent 07f4a06 commit 88021d8
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 64 deletions.
6 changes: 5 additions & 1 deletion cmd/shadowsocks-local/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,12 @@ func parseServerConfig(config *ss.Config) {
}

if len(config.ServerPassword) == 0 {
method := config.Method
if config.Auth {
method += "-ota"
}
// only one encryption table
cipher, err := ss.NewCipher(config.Method, config.Password)
cipher, err := ss.NewCipher(method, config.Password)
if err != nil {
log.Fatal("Failed generating ciphers:", err)
}
Expand Down
119 changes: 67 additions & 52 deletions cmd/shadowsocks-server/server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"bytes"
"encoding/binary"
"errors"
"flag"
Expand All @@ -17,61 +18,61 @@ import (
"syscall"
)

var debug ss.DebugLog
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address

func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
const (
idType = 0 // address type index
idIP0 = 1 // ip addres start index
idDmLen = 1 // domain address length index
idDm0 = 2 // domain address start index
lenIPv4 = net.IPv4len + 2 // ipv4 + 2port
lenIPv6 = net.IPv6len + 2 // ipv6 + 2port
lenDmBase = 2 // 1addrLen + 2port, plus addrLen
lenHmacSha1 = 10
)

typeIPv4 = 1 // type is ipv4 address
typeDm = 3 // type is domain address
typeIPv6 = 4 // type is ipv6 address
var debug ss.DebugLog

lenIPv4 = 1 + net.IPv4len + 2 // 1addrType + ipv4 + 2port
lenIPv6 = 1 + net.IPv6len + 2 // 1addrType + ipv6 + 2port
lenDmBase = 1 + 1 + 2 // 1addrType + 1addrLen + 2port, plus addrLen
)
func getRequest(conn *ss.Conn, auth bool) (host string, ota bool, err error) {
ss.SetReadTimeout(conn)

// buf size should at least have the same size with the largest possible
// request size (when addrType is 3, domain name has at most 256 bytes)
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port)
buf := make([]byte, 260)
var n int
// 1(addrType) + 1(lenByte) + 256(max length address) + 2(port) + 10(hmac-sha1)
buf := make([]byte, 270)
// read till we get possible domain length field
ss.SetReadTimeout(conn)
if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil {
if _, err = io.ReadFull(conn, buf[:idType+1]); err != nil {
return
}

reqLen := -1
switch buf[idType] {
var reqStart, reqEnd int
addrType := buf[idType]
switch addrType & ss.AddrMask {
case typeIPv4:
reqLen = lenIPv4
reqStart, reqEnd = idIP0, idIP0+lenIPv4
case typeIPv6:
reqLen = lenIPv6
reqStart, reqEnd = idIP0, idIP0+lenIPv6
case typeDm:
reqLen = int(buf[idDmLen]) + lenDmBase
if _, err = io.ReadFull(conn, buf[idType+1:idDmLen+1]); err != nil {
return
}
reqStart, reqEnd = idDm0, int(idDm0+buf[idDmLen]+lenDmBase)
default:
err = fmt.Errorf("addr type %d not supported", buf[idType])
err = fmt.Errorf("addr type %d not supported", addrType&ss.AddrMask)
return
}

if n < reqLen { // rare case
if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil {
return
}
} else if n > reqLen {
// it's possible to read more than just the request head
extra = buf[reqLen:n]
if _, err = io.ReadFull(conn, buf[reqStart:reqEnd]); err != nil {
return
}

// Return string for typeIP is not most efficient, but browsers (Chrome,
// Safari, Firefox) all seems using typeDm exclusively. So this is not a
// big problem.
switch buf[idType] {
switch addrType & ss.AddrMask {
case typeIPv4:
host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String()
case typeIPv6:
Expand All @@ -80,8 +81,22 @@ func getRequest(conn *ss.Conn) (host string, extra []byte, err error) {
host = string(buf[idDm0 : idDm0+buf[idDmLen]])
}
// parse port
port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen])
port := binary.BigEndian.Uint16(buf[reqEnd-2 : reqEnd])
host = net.JoinHostPort(host, strconv.Itoa(int(port)))
// if specified one time auth enabled, we should verify this
if auth || addrType&ss.OneTimeAuthMask > 0 {
ota = true
if _, err = io.ReadFull(conn, buf[reqEnd:reqEnd+lenHmacSha1]); err != nil {
return
}
iv := conn.GetIv()
key := conn.GetKey()
actualHmacSha1Buf := ss.HmacSha1(append(iv, key...), buf[:reqEnd])
if !bytes.Equal(buf[reqEnd:reqEnd+lenHmacSha1], actualHmacSha1Buf) {
err = fmt.Errorf("verify one time auth failed, iv=%v key=%v data=%v", iv, key, buf[:reqEnd])
return
}
}
return
}

Expand All @@ -90,7 +105,11 @@ const logCntDelta = 100
var connCnt int
var nextLogConnCnt int = logCntDelta

func handleConnection(conn *ss.Conn) {
type isClosed struct {
isClosed bool
}

func handleConnection(conn *ss.Conn, auth bool) {
var host string

connCnt++ // this maybe not accurate, but should be enough
Expand Down Expand Up @@ -118,7 +137,7 @@ func handleConnection(conn *ss.Conn) {
}
}()

host, extra, err := getRequest(conn)
host, ota, err := getRequest(conn, auth)
if err != nil {
log.Println("error getting request", conn.RemoteAddr(), conn.LocalAddr(), err)
return
Expand All @@ -140,18 +159,14 @@ func handleConnection(conn *ss.Conn) {
remote.Close()
}
}()
// write extra bytes read from
if extra != nil {
// debug.Println("getRequest read extra data, writing to remote, len", len(extra))
if _, err = remote.Write(extra); err != nil {
debug.Println("write request extra error:", err)
return
}
}
if debug {
debug.Printf("piping %s<->%s", conn.RemoteAddr(), host)
debug.Printf("piping %s<->%s ota=%v connOta=%v", conn.RemoteAddr(), host, ota, conn.IsOta())
}
if ota {
go ss.PipeThenCloseOta(conn, remote)
} else {
go ss.PipeThenClose(conn, remote)
}
go ss.PipeThenClose(conn, remote)
ss.PipeThenClose(remote, conn)
closed = true
return
Expand Down Expand Up @@ -195,7 +210,7 @@ func (pm *PasswdManager) del(port string) {
// port. A different approach would be directly change the password used by
// that port, but that requires **sharing** password between the port listener
// and password manager.
func (pm *PasswdManager) updatePortPasswd(port, password string) {
func (pm *PasswdManager) updatePortPasswd(port, password string, auth bool) {
pl, ok := pm.get(port)
if !ok {
log.Printf("new port %s added\n", port)
Expand All @@ -208,7 +223,7 @@ func (pm *PasswdManager) updatePortPasswd(port, password string) {
}
// run will add the new port listener to passwdManager.
// So there maybe concurrent access to passwdManager and we need lock to protect it.
go run(port, password)
go run(port, password, auth)
}

var passwdManager = PasswdManager{portListener: map[string]*PortListener{}}
Expand All @@ -227,7 +242,7 @@ func updatePasswd() {
return
}
for port, passwd := range config.PortPassword {
passwdManager.updatePortPasswd(port, passwd)
passwdManager.updatePortPasswd(port, passwd, config.Auth)
if oldconfig.PortPassword != nil {
delete(oldconfig.PortPassword, port)
}
Expand All @@ -254,7 +269,7 @@ func waitSignal() {
}
}

func run(port, password string) {
func run(port, password string, auth bool) {
ln, err := net.Listen("tcp", ":"+port)
if err != nil {
log.Printf("error listening port %v: %v\n", port, err)
Expand All @@ -280,7 +295,7 @@ func run(port, password string) {
continue
}
}
go handleConnection(ss.NewConn(conn, cipher.Copy()))
go handleConnection(ss.NewConn(conn, cipher.Copy()), auth)
}
}

Expand Down Expand Up @@ -357,7 +372,7 @@ func main() {
runtime.GOMAXPROCS(core)
}
for port, password := range config.PortPassword {
go run(port, password)
go run(port, password, config.Auth)
}

waitSignal()
Expand Down
6 changes: 6 additions & 0 deletions shadowsocks/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"os"
"reflect"
"time"
"strings"
)

type Config struct {
Expand All @@ -23,6 +24,7 @@ type Config struct {
LocalPort int `json:"local_port"`
Password string `json:"password"`
Method string `json:"method"` // encryption method
Auth bool `json:"auth"` // one time auth

// following options are only used by server
PortPassword map[string]string `json:"port_password"`
Expand Down Expand Up @@ -85,6 +87,10 @@ func ParseConfig(path string) (config *Config, err error) {
return nil, err
}
readTimeout = time.Duration(config.Timeout) * time.Second
if strings.HasSuffix(strings.ToLower(config.Method), "-ota") {
config.Method = config.Method[:len(config.Method) - 4]
config.Auth = true
}
return
}

Expand Down
54 changes: 51 additions & 3 deletions shadowsocks/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ import (
"strconv"
)

const (
OneTimeAuthMask byte = 0x10
AddrMask byte = 0xf
)

type Conn struct {
net.Conn
*Cipher
readBuf []byte
writeBuf []byte
readBuf []byte
writeBuf []byte
chunkId uint32
}

func NewConn(c net.Conn, cipher *Cipher) *Conn {
Expand Down Expand Up @@ -58,7 +64,16 @@ func DialWithRawAddr(rawaddr []byte, server string, cipher *Cipher) (c *Conn, er
return
}
c = NewConn(conn, cipher)
if _, err = c.Write(rawaddr); err != nil {
if cipher.ota {
if c.enc == nil {
if _, err = c.initEncrypt(); err != nil {
return
}
}
rawaddr[0] |= OneTimeAuthMask
rawaddr = otaConnectAuth(cipher.iv, cipher.key, rawaddr)
}
if _, err = c.write(rawaddr); err != nil {
c.Close()
return nil, err
}
Expand All @@ -74,6 +89,28 @@ func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) {
return DialWithRawAddr(ra, server, cipher)
}

func (c *Conn) GetIv() (iv []byte) {
iv = make([]byte, len(c.iv))
copy(iv, c.iv)
return
}

func (c *Conn) GetKey() (key []byte) {
key = make([]byte, len(c.key))
copy(key, c.key)
return
}

func (c *Conn) IsOta() bool {
return c.ota
}

func (c *Conn) GetAndIncrChunkId() (chunkId uint32) {
chunkId = c.chunkId
c.chunkId += 1
return
}

func (c *Conn) Read(b []byte) (n int, err error) {
if c.dec == nil {
iv := make([]byte, c.info.ivLen)
Expand All @@ -83,6 +120,9 @@ func (c *Conn) Read(b []byte) (n int, err error) {
if err = c.initDecrypt(iv); err != nil {
return
}
if len(c.iv) == 0 {
c.iv = iv
}
}

cipherData := c.readBuf
Expand All @@ -100,6 +140,14 @@ func (c *Conn) Read(b []byte) (n int, err error) {
}

func (c *Conn) Write(b []byte) (n int, err error) {
if c.ota {
chunkId := c.GetAndIncrChunkId()
b = otaReqChunkAuth(c.iv, chunkId, b)
}
return c.write(b)
}

func (c *Conn) write(b []byte) (n int, err error) {
var iv []byte
if c.enc == nil {
iv, err = c.initEncrypt()
Expand Down
Loading

0 comments on commit 88021d8

Please sign in to comment.