Skip to content

Commit

Permalink
1.1.0 mlbam proxy
Browse files Browse the repository at this point in the history
- updated to go 1.12 beta
- added version flag
- added web sockets headers
- removed timeout
- tweak proxy requests
  • Loading branch information
jwallet committed Feb 9, 2019
1 parent 5de54f7 commit f30d934
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 109 deletions.
28 changes: 20 additions & 8 deletions mlbamproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@ import (
"net/url"
"os"
"strings"
"time"
)

const version = "1.1.0"

type baseHandle struct{}

var (
_port int
_destination string
_sources []string
_debug bool
_version bool
)

func contains(domains []string, url string) bool {
Expand Down Expand Up @@ -67,7 +69,7 @@ func getScheme(r *http.Request) string {

func getPort(r *http.Request) string {
if r.URL.Scheme == "wss" {
return fmt.Sprintf("%v", 443)
return fmt.Sprintf("%v", _port)
}
return r.URL.Port()
}
Expand Down Expand Up @@ -100,20 +102,26 @@ func initParameters() {
_sources = []string{}

flag.BoolVar(&_debug, "debug", false, "Debug mode")
flag.BoolVar(&_version, "v", false, "Version")
flag.IntVar(&_port, "p", 17070, "Port used by the local proxy")
flag.StringVar(&_destination, "d", "", "Destination domain to forward source domains requests to.")
sources := flag.String("s", "", "Source domains to redirect requests from, separated by commas.")
sources := flag.String("s", "", "Source domains to redirect requests from, separated by commas. (e.g.: --s google.com,facebook.com)")

flag.Parse()

if _version {
fmt.Printf("version %v", version)
os.Exit(1)
}

for _, hostname := range strings.Split(*sources, ",") {
if hostname != "" {
_sources = append(_sources, hostname)
}
}

if !canRedirect() {
printf("Proxy won't redirect, missing flags -s (sources) and/or -d (destination)")
printf("Proxy will act as a default proxy and won't redirect a domain to another, no sources and/or destination were specified")
}
}

Expand All @@ -140,6 +148,8 @@ func copyRequest(u *url.URL, r *http.Request) (*http.Request, error) {
req.Header.Set("X-Forwarded-Proto", "http")
if r.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Proto"), "https")
req.Header.Set(http.CanonicalHeaderKey("X-Forwarded-Port"), fmt.Sprintf("%v", _port))
}

req.Header.Del("Accept-Encoding")
Expand All @@ -148,7 +158,7 @@ func copyRequest(u *url.URL, r *http.Request) (*http.Request, error) {
}

func setupResponse(w *http.ResponseWriter) {
(*w).Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, accessToken, Authorization, Accept, Range")
(*w).Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, accessToken, Authorization, Accept, Range, Upgrade")
}

func dialTLS(network, addr string) (net.Conn, error) {
Expand All @@ -162,7 +172,11 @@ func dialTLS(network, addr string) (net.Conn, error) {
return nil, err
}

cfg := &tls.Config{ServerName: host}
cfg := &tls.Config{
ServerName: host,
InsecureSkipVerify: true,
NextProtos: []string{"h2", "http/1.1"},
}

tlsConn := tls.Client(conn, cfg)
err = tlsConn.Handshake()
Expand Down Expand Up @@ -199,8 +213,6 @@ func (h *baseHandle) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

proxy := NewReverseProxy(proxyURL)
// default timeout to 12 hour
proxy.Timeout = time.Hour * 12

r, _ = copyRequest(proxyURL, r)

Expand Down
115 changes: 14 additions & 101 deletions reverse.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,69 +13,28 @@ import (

var onExitFlushLoop func()

const (
defaultTimeout = time.Minute * 5
)

// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client, support http, also support https tunnel using http.hijacker
// ReverseProxy is an HTTP Handler that takes an incoming request
// and sends it to another server
type ReverseProxy struct {
// Set the timeout of the proxy server, default is 5 minutes
Timeout time.Duration

// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
// Director must not access the provided Request
// after returning.
Director func(*http.Request)

// The transport used to perform proxy requests.
// default is http.DefaultTransport.
Transport http.RoundTripper

// FlushInterval specifies the flush interval
// to flush to the client while copying the
// response body. If zero, no periodic flushing is done.
FlushInterval time.Duration

// ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request.
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger

// ModifyResponse is an optional function that
// modifies the Response from the backend.
// If it returns an error, the proxy returns a StatusBadGateway error.
Director func(*http.Request)
Transport http.RoundTripper
FlushInterval time.Duration
ErrorLog *log.Logger
ModifyResponse func(*http.Response) error
}

type requestCanceler interface {
CancelRequest(req *http.Request)
}

// NewReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir. if the target's query is a=10
// and the incoming request's query is b=100, the target's request's query
// will be a=10&b=100.
// NewReverseProxy does not rewrite the Host header.
// To rewrite Host headers, use ReverseProxy directly with a custom
// Director policy.
// NewReverseProxy returns a new ReverseProxy that routes URLs
func NewReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)

// If Host is empty, the Request.Write method uses
// the value of URL.Host.
// force use URL.Host
req.Host = req.URL.Host
if targetQuery == "" || req.URL.RawQuery == "" {
req.URL.RawQuery = targetQuery + req.URL.RawQuery
Expand Down Expand Up @@ -111,16 +70,14 @@ func copyHeader(dst, src http.Header) {
}
}

// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
var hopHeaders = []string{
"Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
"Te",
"Trailer",
"Transfer-Encoding",
"Upgrade",
}
Expand Down Expand Up @@ -193,7 +150,6 @@ func (p *ReverseProxy) logf(format string, args ...interface{}) {
}

func removeHeaders(header http.Header) {
// Remove hop-by-hop headers listed in the "Connection" header.
if c := header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
Expand All @@ -202,7 +158,6 @@ func removeHeaders(header http.Header) {
}
}

// Remove hop-by-hop headers
for _, h := range hopHeaders {
if header.Get(h) != "" {
header.Del(h)
Expand All @@ -212,18 +167,12 @@ func removeHeaders(header http.Header) {

func addXForwardedForHeader(req *http.Request) {
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}

// Set the originating protocol of the incoming HTTP request. The SSL might
// be terminated on our site and because we doing proxy adding this would
// be helpful for applications on the backend.
req.Header.Set("X-Forwarded-Proto", "http")
if req.TLS != nil {
req.Header.Set("X-Forwarded-Proto", "https")
Expand All @@ -238,13 +187,10 @@ func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) {
}

outreq := new(http.Request)
// Shallow copies of maps, like header
*outreq = *req

if cn, ok := rw.(http.CloseNotifier); ok {
if requestCanceler, ok := transport.(requestCanceler); ok {
// After the Handler has returned, there is no guarantee
// that the channel receives a value, so to make sure
reqDone := make(chan struct{})
defer close(reqDone)
clientGone := cn.CloseNotify()
Expand All @@ -264,14 +210,11 @@ func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) {
p.Director(outreq)
outreq.Close = false

// We may modify the header (shallow copied above), so we only copy it.
outreq.Header = make(http.Header)
copyHeader(outreq.Header, req.Header)

// Remove hop-by-hop headers listed in the "Connection" header, Remove hop-by-hop headers.
removeHeaders(outreq.Header)

// Add X-Forwarded-For Header.
addXForwardedForHeader(outreq)

res, err := transport.RoundTrip(outreq)
Expand All @@ -281,7 +224,6 @@ func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

// Remove hop-by-hop headers listed in the "Connection" header of the response, Remove hop-by-hop headers.
removeHeaders(res.Header)

if p.ModifyResponse != nil {
Expand All @@ -292,10 +234,8 @@ func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) {
}
}

// Copy header from response to client.
copyHeader(rw.Header(), res.Header)

// The "Trailer" header isn't included in the Transport's response, Build it up from Trailer.
if len(res.Trailer) > 0 {
trailerKeys := make([]string, 0, len(res.Trailer))
for k := range res.Trailer {
Expand All @@ -306,16 +246,13 @@ func (p *ReverseProxy) ProxyHTTP(rw http.ResponseWriter, req *http.Request) {

rw.WriteHeader(res.StatusCode)
if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short
// bodies and adding a Content-Length.
if fl, ok := rw.(http.Flusher); ok {
fl.Flush()
}
}

p.copyResponse(rw, res.Body)
// close now, instead of defer, to populate res.Trailer

res.Body.Close()
copyHeader(rw.Header(), res.Trailer)
}
Expand All @@ -337,32 +274,6 @@ func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) {
proxyConn, err := net.Dial("tcp", req.URL.Host)
if err != nil {
p.logf("http: proxy error: %v", err)
clientConn.Close()
return
}

// The returned net.Conn may have read or write deadlines
// already set, depending on the configuration of the
// Server, to set or clear those deadlines as needed
// we set timeout to 5 minutes
deadline := time.Now()
if p.Timeout == 0 {
deadline = deadline.Add(time.Minute * 5)
} else {
deadline = deadline.Add(p.Timeout)
}

err = clientConn.SetDeadline(deadline)
if err != nil {
p.logf("http: proxy error: %v", err)
return
}

err = proxyConn.SetDeadline(deadline)
if err != nil {
p.logf("http: proxy error: %v", err)
clientConn.Close()
proxyConn.Close()
return
}

Expand All @@ -376,15 +287,17 @@ func (p *ReverseProxy) ProxyHTTPS(rw http.ResponseWriter, req *http.Request) {

go func() {
io.Copy(clientConn, proxyConn)
clientConn.Close()
proxyConn.Close()
clientConn.Close()
}()

io.Copy(proxyConn, clientConn)
proxyConn.Close()
clientConn.Close()
}

// ServeHTTP determines if the request is a secured scheme or not
// and reroute to the proper proxy method
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Method == "CONNECT" || req.URL.Scheme == "wss" {
p.ProxyHTTPS(rw, req)
Expand Down

0 comments on commit f30d934

Please sign in to comment.