-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rewrite this entire package and bring it up to v2
- Loading branch information
Showing
24 changed files
with
2,250 additions
and
1,979 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package beanstalk | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/url" | ||
"strings" | ||
"time" | ||
) | ||
|
||
// ParseURI returns the socket of the specified URI and if the connection is | ||
// supposed to be a TLS or plaintext connection. Valid URI schemes are: | ||
// | ||
// beanstalk://host:port | ||
// beanstalks://host:port | ||
// tls://host:port | ||
// | ||
// Where both the beanstalks and tls scheme mean the same thing. Alternatively, | ||
// it is also possibly to just specify the host:port combo which is assumed to | ||
// be a plaintext connection. | ||
func ParseURI(uri string) (string, bool, error) { | ||
var host string | ||
var isTLS bool | ||
|
||
if strings.Contains(uri, "://") { | ||
url, err := url.Parse(uri) | ||
if err != nil { | ||
return "", false, err | ||
} | ||
|
||
// Determine the protocol scheme of the URI. | ||
switch strings.ToLower(url.Scheme) { | ||
case "beanstalk": | ||
case "beanstalks", "tls": | ||
isTLS = true | ||
default: | ||
return "", false, fmt.Errorf("%s: unknown beanstalk URI scheme", url.Scheme) | ||
} | ||
|
||
host = url.Host | ||
} else { | ||
host = uri | ||
} | ||
|
||
// Validate the resulting host:port combo. | ||
_, _, err := net.SplitHostPort(host) | ||
switch { | ||
case err != nil && strings.Contains(err.Error(), "missing port in address"): | ||
if isTLS { | ||
host += ":11400" | ||
} else { | ||
host += ":11300" | ||
} | ||
case err != nil: | ||
return "", false, err | ||
} | ||
|
||
return host, isTLS, nil | ||
} | ||
|
||
func includes(a []string, s string) bool { | ||
for _, e := range a { | ||
if e == s { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} | ||
|
||
func contextTimeoutFunc(d time.Duration, fn func(ctx context.Context) error) error { | ||
ctx, cancel := context.WithTimeout(context.Background(), d) | ||
defer cancel() | ||
|
||
return fn(ctx) | ||
} | ||
|
||
type ioHandler interface { | ||
setupConnection(conn *Conn, config Config) error | ||
handleIO(conn *Conn, config Config) error | ||
} | ||
|
||
// keepConnected is responsible for keeping a connection to a URI up. | ||
func keepConnected(handler ioHandler, conn *Conn, config Config, close chan struct{}) { | ||
URI := conn.URI | ||
|
||
go func() { | ||
var err error | ||
for { | ||
// Reconnect to the beanstalk server if no connection is active. | ||
for conn == nil { | ||
if conn, err = Dial(URI, config); err != nil { | ||
config.ErrorLog.Printf("Unable to connect to beanstalk server %s: %s", URI, err) | ||
|
||
select { | ||
// Wait a bit and try again. | ||
case <-time.After(config.ReconnectTimeout): | ||
continue | ||
case <-close: | ||
return | ||
} | ||
} | ||
} | ||
|
||
config.InfoLog.Printf("Connected to beanstalk server %s", conn) | ||
|
||
// Set up the connection. If not successful, close the connection, wait | ||
// a bit and reconnect. | ||
err := handler.setupConnection(conn, config) | ||
if err != nil { | ||
config.InfoLog.Printf("Unable to set up the beanstalk connection: %s", err) | ||
_ = conn.Close() | ||
conn = nil | ||
|
||
select { | ||
case <-time.After(config.ReconnectTimeout): | ||
case <-close: | ||
return | ||
} | ||
|
||
continue | ||
} | ||
|
||
// call the IO handler for as long as it wants it, or the connection is up. | ||
if err = handler.handleIO(conn, config); err != nil && err != ErrDisconnected { | ||
config.ErrorLog.Printf("Disconnected from beanstalk server %s: %s", conn, err) | ||
} else { | ||
config.InfoLog.Printf("Disconnected from beanstalk server %s", conn) | ||
} | ||
|
||
_ = conn.Close() | ||
conn = nil | ||
|
||
select { | ||
case <-close: | ||
return | ||
default: | ||
} | ||
} | ||
}() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package beanstalk | ||
|
||
import "testing" | ||
|
||
func TestParseURI(t *testing.T) { | ||
t.Run("WithValidSchemes", func(t *testing.T) { | ||
for _, scheme := range []string{"beanstalk", "beanstalks", "tls"} { | ||
uri := scheme + "://localhost:12345" | ||
|
||
host, useTLS, err := ParseURI(uri) | ||
switch { | ||
case err != nil: | ||
t.Errorf("Unable to parse URI: %s", uri) | ||
case host != "localhost:12345": | ||
t.Errorf("Unexpected host: %s", host) | ||
} | ||
|
||
switch scheme { | ||
case "beanstalk": | ||
if useTLS { | ||
t.Errorf("%s: scheme shouldn't support TLS", scheme) | ||
} | ||
case "beanstalks", "tls": | ||
if !useTLS { | ||
t.Errorf("%s: scheme should support TLS", scheme) | ||
} | ||
default: | ||
t.Fatalf("%s: unknown scheme", scheme) | ||
} | ||
} | ||
}) | ||
|
||
t.Run("WithMissingScheme", func(t *testing.T) { | ||
host, useTLS, err := ParseURI("localhost:11300") | ||
switch { | ||
case err != nil: | ||
t.Fatalf("Error parsing URI without scheme: %s", err) | ||
case host != "localhost:11300": | ||
t.Errorf("Unexpected host: %s", host) | ||
case useTLS: | ||
t.Error("Unexpected TLS to be set") | ||
} | ||
}) | ||
|
||
t.Run("WithMissingPort", func(t *testing.T) { | ||
host, _, err := ParseURI("beanstalk://localhost") | ||
switch { | ||
case err != nil: | ||
t.Fatalf("Error parsing URI without port") | ||
case host != "localhost:11300": | ||
t.Errorf("%s: Expected port 11300 to be added to the socket", host) | ||
} | ||
}) | ||
|
||
t.Run("WithMissingTLSPort", func(t *testing.T) { | ||
host, _, err := ParseURI("beanstalks://localhost") | ||
switch { | ||
case err != nil: | ||
t.Fatalf("Error parsing URI without port") | ||
case host != "localhost:11400": | ||
t.Errorf("%s: Expected port 11400 to be added to the socket", host) | ||
} | ||
}) | ||
|
||
t.Run("WithInvalidScheme", func(t *testing.T) { | ||
if _, _, err := ParseURI("foo://localhost:12345"); err == nil { | ||
t.Fatal("Expected an error, but got nothing") | ||
} | ||
}) | ||
} |
Oops, something went wrong.