Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for setting headers in async relay requests #47

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder /app/main .

# Set default value for port exposed
ENV HTTP_SERVER_PORT 8080
ENV HTTP_SERVER_PORT=8080

EXPOSE $HTTP_SERVER_PORT

Expand Down
13 changes: 10 additions & 3 deletions cmd/gateway_server/internal/controllers/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,18 @@ func (c *RelayController) HandleRelay(ctx *fasthttp.RequestCtx) {
return
}

contentType := string(ctx.Request.Header.Peek("content-type"))
if contentType == "" {
contentType = "application/json"
}

relay, err := c.relayer.SendRelay(&models.SendRelayRequest{
Payload: &models.Payload{
Data: string(ctx.PostBody()),
Method: string(ctx.Method()),
Path: path,
// TODO: the best here will been able to get the chain configuration to use the configure headers.
Headers: map[string]string{"content-type": contentType},
Data: string(ctx.PostBody()),
Method: string(ctx.Method()),
Path: path,
},
Chain: chainID,
})
Expand Down
1 change: 1 addition & 0 deletions internal/db_query/queries.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions internal/node_selector_service/checks/async_relay_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type nodeRelayResponse struct {
Error error
}

func SendRelaysAsync(relayer pokt_v0.PocketRelayer, nodes []*models.QosNode, payload string, method string, path string) chan *nodeRelayResponse {
func SendRelaysAsync(relayer pokt_v0.PocketRelayer, nodes []*models.QosNode, payload string, method string, path string, headers map[string]string) chan *nodeRelayResponse {
// Define a channel to receive relay responses
relayResponses := make(chan *nodeRelayResponse, len(nodes))
var wg sync.WaitGroup
Expand All @@ -22,8 +22,13 @@ func SendRelaysAsync(relayer pokt_v0.PocketRelayer, nodes []*models.QosNode, pay
sendRelayAsync := func(node *models.QosNode) {
defer wg.Done()
relay, err := relayer.SendRelay(&relayer_models.SendRelayRequest{
Signer: node.GetAppStakeSigner(),
Payload: &relayer_models.Payload{Data: payload, Method: method, Path: path},
Signer: node.GetAppStakeSigner(),
Payload: &relayer_models.Payload{
Data: payload,
Method: method,
Path: path,
Headers: headers,
},
Chain: node.GetChain(),
SelectedNodePubKey: node.GetPublicKey(),
Session: node.MorseSession,
Expand Down
30 changes: 29 additions & 1 deletion internal/node_selector_service/checks/chain_config_handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package checks

import "github.com/pokt-network/gateway-server/internal/chain_configurations_registry"
import (
"github.com/pokt-network/gateway-server/internal/chain_configurations_registry"
"github.com/pokt-network/gateway-server/pkg/common"
)

// GetBlockHeightTolerance - helper function to retrieve block height tolerance across checks
func GetBlockHeightTolerance(chainConfiguration chain_configurations_registry.ChainConfigurationsService, chainId string, defaultValue int) int {
Expand All @@ -19,3 +22,28 @@ func GetDataIntegrityHeightLookback(chainConfiguration chain_configurations_regi
}
return int(*chainConfig.DataIntegrityCheckLookbackHeight)
}

// GetFixedHeaders returns the fixed headers for a specific chain configuration.
// It takes a ChainConfigurationsService to retrieve the chain configuration,
// the chainId string to identify the specific chain,
// and a defaultValue map[string]string to return in case the chain configuration is not found.
// The function first retrieves the chain configuration using the chainConfiguration.GetChainConfiguration method.
// If the chain configuration is not found, it returns the defaultValue.
// If the chain configuration is found, it retrieves the fixed headers as a map[string]string from the chain configuration.
// If the fixed headers cannot be cast into a map[string]string, it returns the defaultValue.
// Otherwise, it returns the retrieved fixed headers.
func GetFixedHeaders(chainConfiguration chain_configurations_registry.ChainConfigurationsService, chainId string, defaultValue map[string]string) map[string]string {
chainConfig, ok := chainConfiguration.GetChainConfiguration(chainId)
value := defaultValue

if ok && chainConfig.FixedHeaders != nil {
if headers, castOk := chainConfig.FixedHeaders.Get().(map[string]string); castOk {
// apply the specific headers override coming from chain configuration over the defaults one.
// in this way, the chain configuration on db only needs to hold the overrides or additions that are may
// not add to base code.
value = common.MergeStringMaps(value, headers)
}
}

return value
}
12 changes: 9 additions & 3 deletions internal/node_selector_service/checks/data_integrity_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ const (
dataIntegrityHeightLookbackDefault = 25
)

var (
dataIntegrityHeadersDefault = map[string]string{"content-type": "application/json"}
)

type nodeHashRspPair struct {
node *models.QosNode
blockIdentifier string
Expand Down Expand Up @@ -47,10 +51,12 @@ func PerformDataIntegrityCheck(check *Check, calculatePayload GetBlockByNumberPa

var nodeResponsePairs []*nodeHashRspPair

// find a random block to search that nodes should have access too
blockNumberToSearch := sourceOfTruth.GetLastKnownHeight() - uint64(GetDataIntegrityHeightLookback(check.ChainConfiguration, sourceOfTruth.GetChain(), dataIntegrityHeightLookbackDefault))
chainId := sourceOfTruth.GetChain()
checkHeaders := GetFixedHeaders(check.ChainConfiguration, chainId, dataIntegrityHeadersDefault)

attestationResponses := SendRelaysAsync(check.PocketRelayer, getEligibleDataIntegrityCheckNodes(check.NodeList), calculatePayload(blockNumberToSearch), "POST", path)
// find a random block to search that nodes should have access too
blockNumberToSearch := sourceOfTruth.GetLastKnownHeight() - uint64(GetDataIntegrityHeightLookback(check.ChainConfiguration, chainId, dataIntegrityHeightLookbackDefault))
attestationResponses := SendRelaysAsync(check.PocketRelayer, getEligibleDataIntegrityCheckNodes(check.NodeList), calculatePayload(blockNumberToSearch), "POST", path, checkHeaders)
for rsp := range attestationResponses {

if rsp.Error != nil {
Expand Down
11 changes: 8 additions & 3 deletions internal/node_selector_service/checks/height_check_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ const (
defaultCheckPenalty = time.Minute * 5
)

var (
defaultHeaders = map[string]string{"content-type": "application/json"}
)

type HeightJsonParser func(response string) (uint64, error)

// PerformDefaultHeightCheck is the default implementation of a height check by:
Expand All @@ -26,12 +30,13 @@ type HeightJsonParser func(response string) (uint64, error)
// 3. Filtering out nodes that are returning a height out of the zScore threshold
// 4. Punishing the nodes with defaultCheckPenalty that exceed the height tolerance.
func PerformDefaultHeightCheck(check *Check, payload string, path string, parseHeight HeightJsonParser, logger *zap.Logger) {

logger.Sugar().Infow("running default height check", "chain", check.NodeList[0].GetChain())
chainId := check.NodeList[0].GetChain()
logger.Sugar().Infow("running default height check", "chain", chainId)
checkHeaders := GetFixedHeaders(check.ChainConfiguration, chainId, defaultHeaders)

var nodesResponded []*models.QosNode
// Send request to all nodes
relayResponses := SendRelaysAsync(check.PocketRelayer, getEligibleHeightCheckNodes(check.NodeList), payload, "POST", path)
relayResponses := SendRelaysAsync(check.PocketRelayer, getEligibleHeightCheckNodes(check.NodeList), payload, "POST", path, checkHeaders)

// Process relay responses
for resp := range relayResponses {
Expand Down
5 changes: 5 additions & 0 deletions internal/relayer/relayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ func (r *Relayer) altruistRelay(req *models.SendRelayRequest) (*models.SendRelay
fasthttp.ReleaseResponse(response)
}()

checkHeaders := checks.GetFixedHeaders(r.chainConfigurationRegistry, req.Chain, map[string]string{"content-type": "application/json"})
for key, value := range checkHeaders {
request.Header.Set(key, value)
}

requestTimeout := r.getAltruistRequestTimeout(req.Chain)
request.Header.SetUserAgent(r.userAgent)
request.SetRequestURI(chainConfig.AltruistUrl.String)
Expand Down
18 changes: 18 additions & 0 deletions pkg/common/maps.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package common

func MergeStringMaps(map1, map2 map[string]string) map[string]string {
// Create a new map to store the merged result
mergedMap := make(map[string]string)

// Add all entries from map1 to mergedMap
for k, v := range map1 {
mergedMap[k] = v
}

// Add all entries from map2 to mergedMap
for k, v := range map2 {
mergedMap[k] = v
}

return mergedMap
}
1 change: 1 addition & 0 deletions pkg/pokt/pokt_v0/basic_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ func (r BasicClient) makeRequest(endpoint string, method string, requestData any
}()

request.Header.SetUserAgent(r.userAgent)
request.Header.SetContentType("application/json")

if hostOverride != nil {
request.SetRequestURI(*hostOverride + endpoint)
Expand Down
Loading