Skip to content

Commit

Permalink
Merge branch 'release-5.3' into cherry-pick-5273-to-release-5.3
Browse files Browse the repository at this point in the history
  • Loading branch information
lance6716 authored May 25, 2022
2 parents 6bdb346 + 1feef71 commit c866462
Show file tree
Hide file tree
Showing 14 changed files with 310 additions and 15 deletions.
36 changes: 31 additions & 5 deletions dm/dm/master/openapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@ package master

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/http"
"net/http/httputil"

"github.com/pingcap/failpoint"

"github.com/deepmap/oapi-codegen/pkg/middleware"
"github.com/labstack/echo/v4"
Expand All @@ -45,7 +48,7 @@ const (

// redirectRequestToLeaderMW a middleware auto redirect request to leader.
// because the leader has some data in memory, only the leader can process the request.
func (s *Server) redirectRequestToLeaderMW() echo.MiddlewareFunc {
func (s *Server) reverseRequestToLeaderMW(tlsCfg *tls.Config) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(ctx echo.Context) error {
ctx2 := ctx.Request().Context()
Expand All @@ -58,13 +61,36 @@ func (s *Server) redirectRequestToLeaderMW() echo.MiddlewareFunc {
if err != nil {
return err
}
return ctx.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("http://%s%s", leaderOpenAPIAddr, ctx.Request().RequestURI))

failpoint.Inject("MockNotSetTls", func() {
tlsCfg = nil
})
// simpleProxy just reverses to leader host
simpleProxy := httputil.ReverseProxy{
Director: func(req *http.Request) {
if tlsCfg != nil {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = leaderOpenAPIAddr
req.Host = leaderOpenAPIAddr
},
}
if tlsCfg != nil {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.TLSClientConfig = tlsCfg
simpleProxy.Transport = transport
}
log.L().Info("reverse request to leader", zap.String("Request URL", ctx.Request().URL.String()), zap.String("leader", leaderOpenAPIAddr), zap.Bool("hasTLS", tlsCfg != nil))
simpleProxy.ServeHTTP(ctx.Response(), ctx.Request())
return nil
}
}
}

// InitOpenAPIHandles init openapi handlers.
func (s *Server) InitOpenAPIHandles() error {
func (s *Server) InitOpenAPIHandles(tlsCfg *tls.Config) error {
swagger, err := openapi.GetSwagger()
if err != nil {
return err
Expand All @@ -77,7 +103,7 @@ func (s *Server) InitOpenAPIHandles() error {
// set logger
e.Use(openapi.ZapLogger(logger))
e.Use(echomiddleware.Recover())
e.Use(s.redirectRequestToLeaderMW())
e.Use(s.reverseRequestToLeaderMW(tlsCfg))
// disables swagger server name validation. it seems to work poorly
swagger.Servers = nil
// use our validation middleware to check all requests against the OpenAPI schema.
Expand Down
103 changes: 100 additions & 3 deletions dm/dm/master/openapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@ import (
"context"
"fmt"
"net/http"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"

"github.com/DATA-DOG/go-sqlmock"
"github.com/deepmap/oapi-codegen/pkg/testutil"
"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -75,7 +78,7 @@ func (t *openAPISuite) SetUpTest(c *check.C) {
c.Assert(ha.ClearTestInfoOperation(t.etcdTestCli), check.IsNil)
}

func (t *openAPISuite) TestRedirectRequestToLeader(c *check.C) {
func (t *openAPISuite) TestReverseRequestToLeader(c *check.C) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

Expand Down Expand Up @@ -134,9 +137,103 @@ func (t *openAPISuite) TestRedirectRequestToLeader(c *check.C) {
c.Assert(resultListSource.Data, check.HasLen, 0)
c.Assert(resultListSource.Total, check.Equals, 0)

// list source not from leader will get a redirect
// list source from non-leader will get result too
result2 := testutil.NewRequest().Get(baseURL).Go(t.testT, s2.echo)
c.Assert(result2.Code(), check.Equals, http.StatusTemporaryRedirect)
c.Assert(result2.Code(), check.Equals, http.StatusOK)
var resultListSource2 openapi.GetSourceListResponse
err = result2.UnmarshalBodyToObject(&resultListSource2)
c.Assert(err, check.IsNil)
c.Assert(resultListSource2.Data, check.HasLen, 0)
c.Assert(resultListSource2.Total, check.Equals, 0)
}

func (t *openAPISuite) TestReverseRequestToHttpsLeader(c *check.C) {
pwd, err := os.Getwd()
require.NoError(t.testT, err)
caPath := pwd + "/tls_for_test/ca.pem"
certPath := pwd + "/tls_for_test/dm.pem"
keyPath := pwd + "/tls_for_test/dm.key"

// master1
masterAddr1 := tempurl.Alloc()[len("http://"):]
peerAddr1 := tempurl.Alloc()[len("http://"):]
cfg1 := NewConfig()
require.NoError(t.testT, cfg1.Parse([]string{
"--name=dm-master-tls-1",
fmt.Sprintf("--data-dir=%s", t.testT.TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr1),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr1),
fmt.Sprintf("--peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr1),
fmt.Sprintf("--initial-cluster=dm-master-tls-1=https://%s", peerAddr1),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg1.ExperimentalFeatures.OpenAPI = true
s1 := NewServer(cfg1)
ctx1, cancel1 := context.WithCancel(context.Background())
require.NoError(t.testT, s1.Start(ctx1))
defer func() {
cancel1()
s1.Close()
}()
// wait the first one become the leader
require.True(t.testT, utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s1.election.IsLeader() && s1.scheduler.Started()
}))

// master2
masterAddr2 := tempurl.Alloc()[len("http://"):]
peerAddr2 := tempurl.Alloc()[len("http://"):]
cfg2 := NewConfig()
require.NoError(t.testT, cfg2.Parse([]string{
"--name=dm-master-tls-2",
fmt.Sprintf("--data-dir=%s", t.testT.TempDir()),
fmt.Sprintf("--master-addr=https://%s", masterAddr2),
fmt.Sprintf("--advertise-addr=https://%s", masterAddr2),
fmt.Sprintf("--peer-urls=https://%s", peerAddr2),
fmt.Sprintf("--advertise-peer-urls=https://%s", peerAddr2),
"--ssl-ca=" + caPath,
"--ssl-cert=" + certPath,
"--ssl-key=" + keyPath,
}))
cfg2.ExperimentalFeatures.OpenAPI = true
cfg2.Join = s1.cfg.MasterAddr // join to an existing cluster
s2 := NewServer(cfg2)
ctx2, cancel2 := context.WithCancel(context.Background())
require.NoError(t.testT, s2.Start(ctx2))
defer func() {
cancel2()
s2.Close()
}()
// wait the second master ready
require.False(t.testT, utils.WaitSomething(30, 100*time.Millisecond, func() bool {
return s2.election.IsLeader()
}))

baseURL := "/api/v1/sources"
// list source from leader
result := testutil.NewRequest().Get(baseURL).Go(t.testT, s1.echo)
require.Equal(t.testT, http.StatusOK, result.Code())
var resultListSource openapi.GetSourceListResponse
require.NoError(t.testT, result.UnmarshalBodyToObject(&resultListSource))
require.Len(t.testT, resultListSource.Data, 0)
require.Equal(t.testT, 0, resultListSource.Total)

// with tls, list source not from leader will get result too
result = testutil.NewRequest().Get(baseURL).Go(t.testT, s2.echo)
require.Equal(t.testT, http.StatusOK, result.Code())
var resultListSource2 openapi.GetSourceListResponse
require.NoError(t.testT, result.UnmarshalBodyToObject(&resultListSource2))
require.Len(t.testT, resultListSource2.Data, 0)
require.Equal(t.testT, 0, resultListSource2.Total)

// without tls, list source not from leader will be 502
require.NoError(t.testT, failpoint.Enable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls", `return()`))
result = testutil.NewRequest().Get(baseURL).Go(t.testT, s2.echo)
require.Equal(t.testT, http.StatusBadGateway, result.Code())
require.NoError(t.testT, failpoint.Disable("github.com/pingcap/tiflow/dm/dm/master/MockNotSetTls"))
}

func (t *openAPISuite) TestOpenAPIWillNotStartInDefaultConfig(c *check.C) {
Expand Down
8 changes: 7 additions & 1 deletion dm/dm/master/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,14 @@ func (s *Server) Start(ctx context.Context) (err error) {
"/status": getStatusHandle(),
"/debug/": getDebugHandler(),
}

if s.cfg.ExperimentalFeatures.OpenAPI {
if initOpenAPIErr := s.InitOpenAPIHandles(); initOpenAPIErr != nil {
// tls3 is used to openapi reverse proxy
tls3, err1 := toolutils.NewTLS(s.cfg.SSLCA, s.cfg.SSLCert, s.cfg.SSLKey, s.cfg.AdvertiseAddr, s.cfg.CertAllowedCN)
if err1 != nil {
return terror.ErrMasterTLSConfigNotValid.Delegate(err1)
}
if initOpenAPIErr := s.InitOpenAPIHandles(tls3.TLSConfig()); initOpenAPIErr != nil {
return terror.ErrOpenAPICommonError.Delegate(initOpenAPIErr)
}
userHandles["/api/v1/"] = s.echo
Expand Down
4 changes: 3 additions & 1 deletion dm/syncer/dml_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ func (w *DMLWorker) executeBatchJobs(queueID int, jobs []*job) {
time.Sleep(time.Duration(t) * time.Second)
})
// use background context to execute sqls as much as possible
ctx, cancel := w.tctx.WithTimeout(maxDMLExecutionDuration)
// set timeout to maxDMLConnectionDuration to make sure dmls can be replicated to downstream event if the latency is high
// if users need to quit this asap, we can support pause-task/stop-task --force in the future
ctx, cancel := w.tctx.WithTimeout(maxDMLConnectionDuration)
defer cancel()
affect, err = db.ExecuteSQL(ctx, queries, args...)
failpoint.Inject("SafeModeExit", func(val failpoint.Value) {
Expand Down
1 change: 0 additions & 1 deletion dm/syncer/syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ var (
maxDDLConnectionTimeout = fmt.Sprintf("%dm", MaxDDLConnectionTimeoutMinute)

maxDMLConnectionDuration, _ = time.ParseDuration(maxDMLConnectionTimeout)
maxDMLExecutionDuration = 30 * time.Second

maxPauseOrStopWaitTime = 10 * time.Second

Expand Down
39 changes: 36 additions & 3 deletions dm/tests/openapi/client/openapi_source_check
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import sys
import requests
import ssl

SOURCE1_NAME = "mysql-01"
SOURCE2_NAME = "mysql-02"
Expand All @@ -11,6 +12,10 @@ WORKER2_NAME = "worker2"
API_ENDPOINT = "http://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER = "http://127.0.0.1:8361/api/v1/sources"

API_ENDPOINT_HTTPS = "https://127.0.0.1:8261/api/v1/sources"
API_ENDPOINT_NOT_LEADER_HTTPS = "https://127.0.0.1:8361/api/v1/sources"



def create_source_failed():
resp = requests.post(url=API_ENDPOINT)
Expand Down Expand Up @@ -47,6 +52,19 @@ def create_source2_success():
print("create_source1_success resp=", resp.json())
assert resp.status_code == 201

def create_source_success_https(ssl_ca, ssl_cert, ssl_key):
req = {
"case_sensitive": False,
"enable_gtid": False,
"host": "127.0.0.1",
"password": "123456",
"port": 3306,
"source_name": SOURCE1_NAME,
"user": "root",
}
resp = requests.post(url=API_ENDPOINT_HTTPS, json=req, verify=ssl_ca, cert=(ssl_cert, ssl_key))
print("create_source_success_https resp=", resp.json())
assert resp.status_code == 201

def list_source_success(source_count):
resp = requests.get(url=API_ENDPOINT)
Expand All @@ -55,6 +73,12 @@ def list_source_success(source_count):
print("list_source_by_openapi_success resp=", data)
assert data["total"] == int(source_count)

def list_source_success_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_success_https resp=", data)
assert data["total"] == int(source_count)

def list_source_with_status_success(source_count, status_count):
resp = requests.get(url=API_ENDPOINT + "?with_status=true")
Expand All @@ -66,13 +90,19 @@ def list_source_with_status_success(source_count, status_count):
assert len(data["data"][i]["status_list"]) == int(status_count)


def list_source_with_redirect(source_count):
def list_source_with_reverse(source_count):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER)
assert resp.status_code == 200
data = resp.json()
print("list_source_by_openapi_redirect resp=", data)
print("list_source_with_reverse resp=", data)
assert data["total"] == int(source_count)

def list_source_with_reverse_https(source_count, ssl_ca, ssl_cert, ssl_key):
resp = requests.get(url=API_ENDPOINT_NOT_LEADER_HTTPS, verify=ssl_ca, cert=(ssl_cert, ssl_key))
assert resp.status_code == 200
data = resp.json()
print("list_source_with_reverse_https resp=", data)
assert data["total"] == int(source_count)

def delete_source_success(source_name):
resp = requests.delete(url=API_ENDPOINT + "/" + source_name)
Expand Down Expand Up @@ -215,8 +245,11 @@ if __name__ == "__main__":
"create_source_failed": create_source_failed,
"create_source1_success": create_source1_success,
"create_source2_success": create_source2_success,
"create_source_success_https": create_source_success_https,
"list_source_success": list_source_success,
"list_source_with_redirect": list_source_with_redirect,
"list_source_success_https": list_source_success_https,
"list_source_with_reverse_https": list_source_with_reverse_https,
"list_source_with_reverse": list_source_with_reverse,
"list_source_with_status_success": list_source_with_status_success,
"delete_source_failed": delete_source_failed,
"delete_source_success": delete_source_success,
Expand Down
Loading

0 comments on commit c866462

Please sign in to comment.