diff --git a/dgraph/cmd/alpha/run.go b/dgraph/cmd/alpha/run.go index 0bdd2eee154..881f3fff6cb 100644 --- a/dgraph/cmd/alpha/run.go +++ b/dgraph/cmd/alpha/run.go @@ -241,6 +241,8 @@ they form a Raft group and provide synchronous replication. "The SASL username for Kafka."). Flag("sasl-password", "The SASL password for Kafka."). + Flag("sasl-mechanism", + "The SASL mechanism for Kafka (PLAIN, SCRAM-SHA-256 or SCRAM-SHA-512)"). Flag("ca-cert", "The path to CA cert file for TLS encryption."). Flag("client-cert", diff --git a/go.mod b/go.mod index 4f0fc6e2c10..2f3d0b462ea 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,7 @@ require ( github.com/spf13/viper v1.7.1 github.com/stretchr/testify v1.6.1 github.com/twpayne/go-geom v1.0.5 + github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c go.etcd.io/etcd v0.0.0-20190228193606-a943ad0ee4c9 go.opencensus.io v0.22.5 go.uber.org/zap v1.16.0 diff --git a/go.sum b/go.sum index 3883123e7b2..e043e8f4ba5 100644 --- a/go.sum +++ b/go.sum @@ -595,7 +595,9 @@ github.com/vektah/dataloaden v0.2.1-0.20190515034641-a19b9a6e7c9e/go.mod h1:/HUd github.com/vektah/gqlparser/v2 v2.1.0/go.mod h1:SyUiHgLATUR8BiYURfTirrTcGpcE+4XkV2se04Px1Ms= github.com/willf/bitset v1.1.10 h1:NotGKqX0KwQ72NUzqrjZq5ipPNDQex9lo3WpaS8L2sc= github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= +github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c h1:u40Z8hqBAAQyv+vATcGgV0YCnDjqSL7/q/JyPhhJSPk= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= +github.com/xdg/stringprep v1.0.0 h1:d9X0esnoa3dFsV0FG35rAT0RIhYFlPq7MiP+DW89La0= github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:GwrjFmJcFw6At/Gs6z4yjiIwzuJ1/+UwLxMQDVQXShQ= diff --git a/worker/server_state.go b/worker/server_state.go index 8dcd07af41a..9191b1db019 100644 --- a/worker/server_state.go +++ b/worker/server_state.go @@ -45,7 +45,7 @@ const ( SecurityDefaults = `token=; whitelist=;` LudicrousDefaults = `enabled=false; concurrency=2000;` CDCDefaults = `file=; kafka=; sasl_user=; sasl_password=; ca_cert=; client_cert=; ` + - `client_key=;` + `client_key=; sasl-mechanism=PLAIN;` LimitDefaults = `mutations=allow; query-edge=1000000; normalize-node=10000; ` + `mutations-nquad=1000000; disallow-drop=false; query-timeout=0ms; txn-abort-after=5m; ` + ` max-retries=-1;max-pending-queries=10000` diff --git a/worker/sink_handler.go b/worker/sink_handler.go index 592d2bf65d1..b43b261b7fa 100644 --- a/worker/sink_handler.go +++ b/worker/sink_handler.go @@ -17,6 +17,8 @@ package worker import ( + "crypto/sha256" + "crypto/sha512" "crypto/tls" "crypto/x509" "encoding/binary" @@ -27,6 +29,7 @@ import ( "strings" "github.com/pkg/errors" + "github.com/xdg/scram" "github.com/Shopify/sarama" @@ -116,6 +119,27 @@ func newKafkaSink(config *z.SuperFlag) (Sink, error) { saramaConf.Net.SASL.User = config.GetString("sasl-user") saramaConf.Net.SASL.Password = config.GetString("sasl-password") } + mechanism := config.GetString("sasl-mechanism") + if mechanism != "" { + switch mechanism { + case sarama.SASLTypeSCRAMSHA256: + saramaConf.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA256 + saramaConf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &scramClient{HashGeneratorFcn: sha256.New} + } + case sarama.SASLTypeSCRAMSHA512: + saramaConf.Net.SASL.Mechanism = sarama.SASLTypeSCRAMSHA512 + saramaConf.Net.SASL.SCRAMClientGeneratorFunc = func() sarama.SCRAMClient { + return &scramClient{HashGeneratorFcn: sha512.New} + } + case sarama.SASLTypePlaintext: + saramaConf.Net.SASL.Mechanism = sarama.SASLTypePlaintext + default: + return nil, errors.Errorf("Invalid SASL mechanism. Valid mechanisms are: %s, %s and %s", + sarama.SASLTypePlaintext, sarama.SASLTypeSCRAMSHA256, sarama.SASLTypeSCRAMSHA512) + } + } + brokers := strings.Split(config.GetString("kafka"), ",") client, err := sarama.NewClient(brokers, saramaConf) if err != nil { @@ -195,3 +219,27 @@ func newFileSink(path *z.SuperFlag) (Sink, error) { fileWriter: w, }, nil } + +type scramClient struct { + *scram.Client + *scram.ClientConversation + scram.HashGeneratorFcn +} + +func (sc *scramClient) Begin(userName, password, authzID string) (err error) { + sc.Client, err = sc.HashGeneratorFcn.NewClient(userName, password, authzID) + if err != nil { + return err + } + sc.ClientConversation = sc.Client.NewConversation() + return nil +} + +func (sc *scramClient) Step(challenge string) (response string, err error) { + response, err = sc.ClientConversation.Step(challenge) + return +} + +func (sc *scramClient) Done() bool { + return sc.ClientConversation.Done() +}