Skip to content

Commit

Permalink
Rewrite this entire package and bring it up to v2
Browse files Browse the repository at this point in the history
  • Loading branch information
prep authored Aug 5, 2019
1 parent e9ac7ca commit 1fddb00
Show file tree
Hide file tree
Showing 24 changed files with 2,250 additions and 1,979 deletions.
586 changes: 501 additions & 85 deletions README.md

Large diffs are not rendered by default.

142 changes: 142 additions & 0 deletions beanstalk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
package beanstalk

import (
"context"
"fmt"
"net"
"net/url"
"strings"
"time"
)

// ParseURI returns the socket of the specified URI and if the connection is
// supposed to be a TLS or plaintext connection. Valid URI schemes are:
//
// beanstalk://host:port
// beanstalks://host:port
// tls://host:port
//
// Where both the beanstalks and tls scheme mean the same thing. Alternatively,
// it is also possibly to just specify the host:port combo which is assumed to
// be a plaintext connection.
func ParseURI(uri string) (string, bool, error) {
var host string
var isTLS bool

if strings.Contains(uri, "://") {
url, err := url.Parse(uri)
if err != nil {
return "", false, err
}

// Determine the protocol scheme of the URI.
switch strings.ToLower(url.Scheme) {
case "beanstalk":
case "beanstalks", "tls":
isTLS = true
default:
return "", false, fmt.Errorf("%s: unknown beanstalk URI scheme", url.Scheme)
}

host = url.Host
} else {
host = uri
}

// Validate the resulting host:port combo.
_, _, err := net.SplitHostPort(host)
switch {
case err != nil && strings.Contains(err.Error(), "missing port in address"):
if isTLS {
host += ":11400"
} else {
host += ":11300"
}
case err != nil:
return "", false, err
}

return host, isTLS, nil
}

func includes(a []string, s string) bool {
for _, e := range a {
if e == s {
return true
}
}

return false
}

func contextTimeoutFunc(d time.Duration, fn func(ctx context.Context) error) error {
ctx, cancel := context.WithTimeout(context.Background(), d)
defer cancel()

return fn(ctx)
}

type ioHandler interface {
setupConnection(conn *Conn, config Config) error
handleIO(conn *Conn, config Config) error
}

// keepConnected is responsible for keeping a connection to a URI up.
func keepConnected(handler ioHandler, conn *Conn, config Config, close chan struct{}) {
URI := conn.URI

go func() {
var err error
for {
// Reconnect to the beanstalk server if no connection is active.
for conn == nil {
if conn, err = Dial(URI, config); err != nil {
config.ErrorLog.Printf("Unable to connect to beanstalk server %s: %s", URI, err)

select {
// Wait a bit and try again.
case <-time.After(config.ReconnectTimeout):
continue
case <-close:
return
}
}
}

config.InfoLog.Printf("Connected to beanstalk server %s", conn)

// Set up the connection. If not successful, close the connection, wait
// a bit and reconnect.
err := handler.setupConnection(conn, config)
if err != nil {
config.InfoLog.Printf("Unable to set up the beanstalk connection: %s", err)
_ = conn.Close()
conn = nil

select {
case <-time.After(config.ReconnectTimeout):
case <-close:
return
}

continue
}

// call the IO handler for as long as it wants it, or the connection is up.
if err = handler.handleIO(conn, config); err != nil && err != ErrDisconnected {
config.ErrorLog.Printf("Disconnected from beanstalk server %s: %s", conn, err)
} else {
config.InfoLog.Printf("Disconnected from beanstalk server %s", conn)
}

_ = conn.Close()
conn = nil

select {
case <-close:
return
default:
}
}
}()
}
70 changes: 70 additions & 0 deletions beanstalk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package beanstalk

import "testing"

func TestParseURI(t *testing.T) {
t.Run("WithValidSchemes", func(t *testing.T) {
for _, scheme := range []string{"beanstalk", "beanstalks", "tls"} {
uri := scheme + "://localhost:12345"

host, useTLS, err := ParseURI(uri)
switch {
case err != nil:
t.Errorf("Unable to parse URI: %s", uri)
case host != "localhost:12345":
t.Errorf("Unexpected host: %s", host)
}

switch scheme {
case "beanstalk":
if useTLS {
t.Errorf("%s: scheme shouldn't support TLS", scheme)
}
case "beanstalks", "tls":
if !useTLS {
t.Errorf("%s: scheme should support TLS", scheme)
}
default:
t.Fatalf("%s: unknown scheme", scheme)
}
}
})

t.Run("WithMissingScheme", func(t *testing.T) {
host, useTLS, err := ParseURI("localhost:11300")
switch {
case err != nil:
t.Fatalf("Error parsing URI without scheme: %s", err)
case host != "localhost:11300":
t.Errorf("Unexpected host: %s", host)
case useTLS:
t.Error("Unexpected TLS to be set")
}
})

t.Run("WithMissingPort", func(t *testing.T) {
host, _, err := ParseURI("beanstalk://localhost")
switch {
case err != nil:
t.Fatalf("Error parsing URI without port")
case host != "localhost:11300":
t.Errorf("%s: Expected port 11300 to be added to the socket", host)
}
})

t.Run("WithMissingTLSPort", func(t *testing.T) {
host, _, err := ParseURI("beanstalks://localhost")
switch {
case err != nil:
t.Fatalf("Error parsing URI without port")
case host != "localhost:11400":
t.Errorf("%s: Expected port 11400 to be added to the socket", host)
}
})

t.Run("WithInvalidScheme", func(t *testing.T) {
if _, _, err := ParseURI("foo://localhost:12345"); err == nil {
t.Fatal("Expected an error, but got nothing")
}
})
}
Loading

0 comments on commit 1fddb00

Please sign in to comment.