diff --git a/server/apiv2/handlers/tso_keyspace_group.go b/server/apiv2/handlers/tso_keyspace_group.go index 5de8fd6a9cc..7030c332406 100644 --- a/server/apiv2/handlers/tso_keyspace_group.go +++ b/server/apiv2/handlers/tso_keyspace_group.go @@ -40,9 +40,9 @@ func RegisterTSOKeyspaceGroup(r *gin.RouterGroup) { router.GET("", GetKeyspaceGroups) router.GET("/:id", GetKeyspaceGroupByID) router.DELETE("/:id", DeleteKeyspaceGroupByID) + router.PATCH("/:id", SetNodesForKeyspaceGroup) // only to support set nodes + router.PATCH("/:id/*node", SetPriorityForKeyspaceGroup) // only to support set priority router.POST("/:id/alloc", AllocNodesForKeyspaceGroup) - router.POST("/:id/nodes", SetNodesForKeyspaceGroup) - router.POST("/:id/priority", SetPriorityForKeyspaceGroup) router.POST("/:id/split", SplitKeyspaceGroupByID) router.DELETE("/:id/split", FinishSplitKeyspaceByID) router.POST("/:id/merge", MergeKeyspaceGroups) @@ -436,8 +436,7 @@ func SetNodesForKeyspaceGroup(c *gin.Context) { // SetPriorityForKeyspaceGroupParams defines the params for setting priority of tso node for the keyspace group. type SetPriorityForKeyspaceGroupParams struct { - Node string `json:"node"` - Priority int `json:"priority"` + Priority int `json:"priority"` } // SetPriorityForKeyspaceGroup sets priority of tso node for the keyspace group. @@ -447,6 +446,11 @@ func SetPriorityForKeyspaceGroup(c *gin.Context) { c.AbortWithStatusJSON(http.StatusBadRequest, "invalid keyspace group id") return } + node, err := parseNodeAddress(c) + if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, "invalid node address") + return + } svr := c.MustGet(middlewares.ServerContextKey).(*server.Server) manager := svr.GetKeyspaceGroupManager() if manager == nil { @@ -468,12 +472,12 @@ func SetPriorityForKeyspaceGroup(c *gin.Context) { // check if node exists members := kg.Members if slice.NoneOf(members, func(i int) bool { - return members[i].Address == setParams.Node + return members[i].Address == node }) { c.AbortWithStatusJSON(http.StatusBadRequest, "tso node does not exist in the keyspace group") } // set priority - err = manager.SetPriorityForKeyspaceGroup(id, setParams.Node, setParams.Priority) + err = manager.SetPriorityForKeyspaceGroup(id, node, setParams.Priority) if err != nil { c.AbortWithStatusJSON(http.StatusInternalServerError, err.Error()) return @@ -492,6 +496,19 @@ func validateKeyspaceGroupID(c *gin.Context) (uint32, error) { return uint32(id), nil } +func parseNodeAddress(c *gin.Context) (string, error) { + node := c.Param("node") + if node == "" { + return "", errors.New("invalid node address") + } + // In pd-ctl, we use url.PathEscape to escape the node address and replace the % to \%. + // But in the gin framework, it will unescape the node address automatically. + // So we need to replace the \/ to /. + node = strings.ReplaceAll(node, "\\/", "/") + node = strings.TrimPrefix(node, "/") + return node, nil +} + func isValid(id uint32) bool { return id >= utils.DefaultKeyspaceGroupID && id <= utils.MaxKeyspaceGroupCountInUse } diff --git a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go index 41bcba0e90b..dc33016eafb 100644 --- a/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go +++ b/tests/integrations/mcs/keyspace/tso_keyspace_group_test.go @@ -354,7 +354,7 @@ func (suite *keyspaceGroupTestSuite) tryGetKeyspaceGroup(id uint32) (*endpoint.K func (suite *keyspaceGroupTestSuite) trySetNodesForKeyspaceGroup(id int, request *handlers.SetNodesForKeyspaceGroupParams) (*endpoint.KeyspaceGroup, int) { data, err := json.Marshal(request) suite.NoError(err) - httpReq, err := http.NewRequest(http.MethodPost, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d/nodes", id), bytes.NewBuffer(data)) + httpReq, err := http.NewRequest(http.MethodPatch, suite.server.GetAddr()+keyspaceGroupsPrefix+fmt.Sprintf("/%d", id), bytes.NewBuffer(data)) suite.NoError(err) resp, err := suite.dialClient.Do(httpReq) suite.NoError(err) diff --git a/tools/pd-ctl/pdctl/command/global.go b/tools/pd-ctl/pdctl/command/global.go index 8d888b60b1f..85fe63ac8be 100644 --- a/tools/pd-ctl/pdctl/command/global.go +++ b/tools/pd-ctl/pdctl/command/global.go @@ -165,7 +165,7 @@ func getEndpoints(cmd *cobra.Command) []string { return strings.Split(addrs, ",") } -func postJSON(cmd *cobra.Command, prefix string, input map[string]interface{}) { +func requestJSON(cmd *cobra.Command, method, prefix string, input map[string]interface{}) { data, err := json.Marshal(input) if err != nil { cmd.Println(err) @@ -175,19 +175,31 @@ func postJSON(cmd *cobra.Command, prefix string, input map[string]interface{}) { endpoints := getEndpoints(cmd) err = tryURLs(cmd, endpoints, func(endpoint string) error { var msg []byte - var r *http.Response + var req *http.Request + var resp *http.Response url := endpoint + "/" + prefix - r, err = dialClient.Post(url, "application/json", bytes.NewBuffer(data)) + switch method { + case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodGet: + req, err = http.NewRequest(method, url, bytes.NewBuffer(data)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + resp, err = dialClient.Do(req) + default: + err := errors.Errorf("method %s not supported", method) + return err + } if err != nil { return err } - defer r.Body.Close() - if r.StatusCode != http.StatusOK { - msg, err = io.ReadAll(r.Body) + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + msg, err = io.ReadAll(resp.Body) if err != nil { return err } - return errors.Errorf("[%d] %s", r.StatusCode, msg) + return errors.Errorf("[%d] %s", resp.StatusCode, msg) } return nil }) @@ -198,6 +210,14 @@ func postJSON(cmd *cobra.Command, prefix string, input map[string]interface{}) { cmd.Println("Success!") } +func postJSON(cmd *cobra.Command, prefix string, input map[string]interface{}) { + requestJSON(cmd, http.MethodPost, prefix, input) +} + +func patchJSON(cmd *cobra.Command, prefix string, input map[string]interface{}) { + requestJSON(cmd, http.MethodPatch, prefix, input) +} + // do send a request to server. Default is Get. func do(endpoint, prefix, method string, resp *string, customHeader http.Header, b *bodyOption) error { var err error diff --git a/tools/pd-ctl/pdctl/command/keyspace_group_command.go b/tools/pd-ctl/pdctl/command/keyspace_group_command.go index a4be612a301..b5acf0fa7e8 100644 --- a/tools/pd-ctl/pdctl/command/keyspace_group_command.go +++ b/tools/pd-ctl/pdctl/command/keyspace_group_command.go @@ -288,17 +288,17 @@ func setNodesKeyspaceGroupCommandFunc(cmd *cobra.Command, args []string) { cmd.Printf("Failed to parse the keyspace group ID: %s\n", err) return } - addresses := make([]string, 0, len(args)-1) + nodes := make([]string, 0, len(args)-1) for _, arg := range args[1:] { u, err := url.ParseRequestURI(arg) if u == nil || err != nil { cmd.Printf("Failed to parse the tso node address: %s\n", err) return } - addresses = append(addresses, arg) + nodes = append(nodes, arg) } - postJSON(cmd, fmt.Sprintf("%s/%s/nodes", keyspaceGroupsPrefix, args[0]), map[string]interface{}{ - "Nodes": addresses, + patchJSON(cmd, fmt.Sprintf("%s/%s", keyspaceGroupsPrefix, args[0]), map[string]interface{}{ + "Nodes": nodes, }) } @@ -313,21 +313,26 @@ func setPriorityKeyspaceGroupCommandFunc(cmd *cobra.Command, args []string) { return } - address := args[1] - u, err := url.ParseRequestURI(address) + node := args[1] + u, err := url.ParseRequestURI(node) if u == nil || err != nil { cmd.Printf("Failed to parse the tso node address: %s\n", err) return } + // Escape the node address to avoid the error of parsing the url + // But the url.PathEscape will escape the '/' to '%2F', which % will cause the error of parsing the url + // So we need to replace the % to \% + node = url.PathEscape(node) + node = strings.ReplaceAll(node, "%", "\\%") + priority, err := strconv.ParseInt(args[2], 10, 32) if err != nil { cmd.Printf("Failed to parse the priority: %s\n", err) return } - postJSON(cmd, fmt.Sprintf("%s/%s/priority", keyspaceGroupsPrefix, args[0]), map[string]interface{}{ - "Node": address, + patchJSON(cmd, fmt.Sprintf("%s/%s/%s", keyspaceGroupsPrefix, args[0], node), map[string]interface{}{ "Priority": priority, }) }