Skip to content

Commit

Permalink
Add gocql wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
wxing1292 committed Mar 2, 2021
1 parent e810ca5 commit aa3fd70
Show file tree
Hide file tree
Showing 10 changed files with 1,488 additions and 0 deletions.
89 changes: 89 additions & 0 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package gocql

import (
"context"
"fmt"

"github.com/gocql/gocql"
)

var _ Batch = (*batch)(nil)

type (
batch struct {
session *session

batch *gocql.Batch
}
)

// Definition of all BatchTypes
const (
LoggedBatch BatchType = iota
UnloggedBatch
CounterBatch
)

func newBatch(
session *session,
gocqlBatch *gocql.Batch,
) *batch {
return &batch{
session: session,
batch: gocqlBatch,
}
}

func (b *batch) Query(stmt string, args ...interface{}) {
b.batch.Query(stmt, args...)
}

func (b *batch) WithContext(ctx context.Context) Batch {
b2 := b.batch.WithContext(ctx)
if b2 == nil {
return nil
}
return newBatch(b.session, b2)
}

func (b *batch) WithTimestamp(timestamp int64) Batch {
b.batch.WithTimestamp(timestamp)
return newBatch(b.session, b.batch)
}

func mustConvertBatchType(batchType BatchType) gocql.BatchType {
switch batchType {
case LoggedBatch:
return gocql.LoggedBatch
case UnloggedBatch:
return gocql.UnloggedBatch
case CounterBatch:
return gocql.CounterBatch
default:
panic(fmt.Sprintf("Unknown gocql BatchType: %v", batchType))
}
}
175 changes: 175 additions & 0 deletions common/persistence/nosql/nosqlplugin/cassandra/gocql/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package gocql

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"io/ioutil"
"strings"
"time"

"github.com/gocql/gocql"

"go.temporal.io/server/common/auth"
"go.temporal.io/server/common/resolver"
"go.temporal.io/server/common/service/config"
)

func NewCassandraCluster(
cfg config.Cassandra,
resolver resolver.ServiceResolver,
) (*gocql.ClusterConfig, error) {
var resolvedHosts []string
for _, host := range parseHosts(cfg.Hosts) {
resolvedHosts = append(resolvedHosts, resolver.Resolve(host)...)
}
cluster := gocql.NewCluster(resolvedHosts...)
cluster.ProtoVersion = 4
if cfg.Port > 0 {
cluster.Port = cfg.Port
}
if cfg.User != "" && cfg.Password != "" {
cluster.Authenticator = gocql.PasswordAuthenticator{
Username: cfg.User,
Password: cfg.Password,
}
}
if cfg.Keyspace != "" {
cluster.Keyspace = cfg.Keyspace
}
if cfg.Datacenter != "" {
cluster.HostFilter = gocql.DataCentreHostFilter(cfg.Datacenter)
}
if cfg.TLS != nil && cfg.TLS.Enabled {
if cfg.TLS.CertData != "" && cfg.TLS.CertFile != "" {
return nil, errors.New("Cannot specify both certData and certFile properties")
}

if cfg.TLS.KeyData != "" && cfg.TLS.KeyFile != "" {
return nil, errors.New("Cannot specify both keyData and keyFile properties")
}

if cfg.TLS.CaData != "" && cfg.TLS.CaFile != "" {
return nil, errors.New("Cannot specify both caData and caFile properties")
}

cluster.SslOpts = &gocql.SslOptions{
CaPath: cfg.TLS.CaFile,
EnableHostVerification: cfg.TLS.EnableHostVerification,
Config: auth.NewTLSConfigForServer(cfg.TLS.ServerName, cfg.TLS.EnableHostVerification),
}

var certBytes []byte
var keyBytes []byte
var err error

if cfg.TLS.CertFile != "" {
certBytes, err = ioutil.ReadFile(cfg.TLS.CertFile)
if err != nil {
return nil, fmt.Errorf("error reading client certificate file: %w", err)
}
} else if cfg.TLS.CertData != "" {
certBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.CertData)
if err != nil {
return nil, fmt.Errorf("client certificate could not be decoded: %w", err)
}
}

if cfg.TLS.KeyFile != "" {
keyBytes, err = ioutil.ReadFile(cfg.TLS.KeyFile)
if err != nil {
return nil, fmt.Errorf("error reading client certificate private key file: %w", err)
}
} else if cfg.TLS.KeyData != "" {
keyBytes, err = base64.StdEncoding.DecodeString(cfg.TLS.KeyData)
if err != nil {
return nil, fmt.Errorf("client certificate private key could not be decoded: %w", err)
}
}

if len(certBytes) > 0 {
clientCert, err := tls.X509KeyPair(certBytes, keyBytes)
if err != nil {
return nil, fmt.Errorf("unable to generate x509 key pair: %w", err)
}

cluster.SslOpts.Certificates = []tls.Certificate{clientCert}
}

if cfg.TLS.CaData != "" {
cluster.SslOpts.RootCAs = x509.NewCertPool()
pem, err := base64.StdEncoding.DecodeString(cfg.TLS.CaData)
if err != nil {
return nil, fmt.Errorf("caData could not be decoded: %w", err)
}
if !cluster.SslOpts.RootCAs.AppendCertsFromPEM(pem) {
return nil, errors.New("failed to load decoded CA Cert as PEM")
}
}
}

if cfg.MaxConns > 0 {
cluster.NumConns = cfg.MaxConns
}

if cfg.ConnectTimeout > 0 {
cluster.ConnectTimeout = cfg.ConnectTimeout
}

cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
MaxRetries: 30,
InitialInterval: 1 * time.Second,
MaxInterval: 8 * time.Second,
}

cluster.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(gocql.RoundRobinHostPolicy())
return cluster, nil
}

// regionHostFilter returns a gocql host filter for the given region name
func regionHostFilter(region string) gocql.HostFilter {
return gocql.HostFilterFunc(func(host *gocql.HostInfo) bool {
applicationRegion := region
if len(host.DataCenter()) < 3 {
return false
}
return host.DataCenter()[:3] == applicationRegion
})
}

// parseHosts returns parses a list of hosts separated by comma
func parseHosts(input string) []string {
var hosts []string
for _, h := range strings.Split(input, ",") {
if host := strings.TrimSpace(h); len(host) > 0 {
hosts = append(hosts, host)
}
}
return hosts
}
Loading

0 comments on commit aa3fd70

Please sign in to comment.