Skip to content

Commit

Permalink
Merge pull request #182 from newrelic/feature/NR-223852
Browse files Browse the repository at this point in the history
NR-223852 : Retry request with different endpoint
  • Loading branch information
lovesh-ap authored Feb 14, 2024
2 parents 128a92d + 3bf3e69 commit 0711914
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ private void setApplicationConfig(Connector[] connectors) {
String protocol = JettyUtils.getProtocol(connector.getProtocols());
if(protocol != null) {
NewRelicSecurity.getAgent().setApplicationConnectionConfig(((NetworkConnector) connector).getPort(), protocol);
System.out.println("setting server config as : "+((NetworkConnector) connector).getPort() + ":"+protocol);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ private void setApplicationConfig(Connector[] connectors) {
String protocol = JettyUtils.getProtocol(connector.getProtocols());
if(protocol != null) {
NewRelicSecurity.getAgent().setApplicationConnectionConfig(((NetworkConnector) connector).getPort(), protocol);
System.out.println("setting server config as : "+((NetworkConnector) connector).getPort() + ":"+protocol);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ private void setApplicationConfig(Connector[] connectors) {
String protocol = JettyUtils.getProtocol(connector.getProtocols());
if(protocol != null) {
NewRelicSecurity.getAgent().setApplicationConnectionConfig(((NetworkConnector) connector).getPort(), protocol);
System.out.println("setting server config as : "+((NetworkConnector) connector).getPort() + ":"+protocol);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import okhttp3.internal.http.HttpMethod;
import org.apache.commons.lang3.StringUtils;

import java.net.MalformedURLException;
import java.net.URL;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

Expand All @@ -20,13 +19,11 @@ public class RequestUtils {
private static final FileLoggerThreadPool logger = FileLoggerThreadPool.getInstance();
public static final String ERROR_IN_FUZZ_REQUEST_GENERATION = "Error in fuzz request generation {}";

public static Request generateK2Request(FuzzRequestBean httpRequest) {
public static Request generateK2Request(FuzzRequestBean httpRequest, String endpoint) {
try {
String scheme = NewRelicSecurity.getAgent().getApplicationConnectionConfig(httpRequest.getServerPort());
logger.log(LogLevel.FINER, String.format("Firing request : %s", JsonConverter.toJSON(httpRequest)), RequestUtils.class.getName());
StringBuilder url = new StringBuilder(String.format("%s://localhost", scheme!=null?scheme:httpRequest.getProtocol()));
url.append(":");
url.append(httpRequest.getServerPort());
StringBuilder url = new StringBuilder(endpoint);
url.append(httpRequest.getUrl());
RequestBody requestBody = null;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.newrelic.agent.security.AgentInfo;
import com.newrelic.agent.security.intcodeagent.filelogging.FileLoggerThreadPool;
import com.newrelic.agent.security.intcodeagent.models.FuzzRequestBean;
import com.newrelic.api.agent.security.utils.logging.LogLevel;
import com.newrelic.agent.security.intcodeagent.models.javaagent.FuzzFailEvent;
import com.newrelic.agent.security.intcodeagent.websocket.EventSendPool;
Expand All @@ -15,6 +16,7 @@
import java.io.InterruptedIOException;
import java.security.cert.CertificateException;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.TimeUnit;

public class RestClient {
Expand Down Expand Up @@ -118,7 +120,24 @@ public OkHttpClient getClient() {
return clientThreadLocal.get();
}

public void fireRequest(Request request, int repeatCount, String fuzzRequestId) {
public void fireRequest(FuzzRequestBean httpRequest, List<String> endpoints, int repeatCount, String fuzzRequestId){

int responseCode = 999;
for (String endpoint : endpoints) {
try {
Request request = RequestUtils.generateK2Request(httpRequest, endpoint);
if (request != null) {
responseCode = RestClient.getInstance().fireRequest(request, repeatCount + endpoints.size() -1, fuzzRequestId);
}
if(responseCode == 301){continue;}
break;
} catch (SSLException e){}
}


}

public int fireRequest(Request request, int repeatCount, String fuzzRequestId) throws SSLException {
OkHttpClient client = clientThreadLocal.get();

logger.log(LogLevel.FINER, String.format(FIRING_REQUEST_METHOD_S, request.method()), RestClient.class.getName());
Expand All @@ -143,9 +162,13 @@ public void fireRequest(Request request, int repeatCount, String fuzzRequestId)
if (client.connectionPool() != null) {
client.connectionPool().evictAll();
}
return response.code();
} catch (SSLException e){
logger.log(LogLevel.FINE, String.format("Request failed due to SSL Exception %s ", request, e), RestClient.class.getName());
throw e;
} catch (InterruptedIOException e){
if(repeatCount >= 0){
fireRequest(request, --repeatCount, fuzzRequestId);
return fireRequest(request, --repeatCount, fuzzRequestId);
}
} catch (IOException e) {
logger.log(LogLevel.FINER, String.format(CALL_FAILED_REQUEST_S_REASON, request), e, RestClient.class.getName());
Expand All @@ -159,6 +182,7 @@ public void fireRequest(Request request, int repeatCount, String fuzzRequestId)
EventSendPool.getInstance().sendEvent(fuzzFailEvent);
}

return 999;
}

public boolean isConnected() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.newrelic.agent.security.instrumentator.os.OsVariablesInstance;
import com.newrelic.agent.security.instrumentator.utils.CallbackUtils;
import com.newrelic.agent.security.intcodeagent.filelogging.FileLoggerThreadPool;
import com.newrelic.api.agent.security.NewRelicSecurity;
import com.newrelic.api.agent.security.utils.logging.LogLevel;
import com.newrelic.agent.security.intcodeagent.models.FuzzRequestBean;
import com.newrelic.agent.security.intcodeagent.models.javaagent.IntCodeControlCommand;
Expand All @@ -16,9 +17,11 @@
import okhttp3.Request;
import org.apache.commons.lang3.StringUtils;

import javax.net.ssl.SSLException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

/**
Expand Down Expand Up @@ -100,10 +103,8 @@ public Boolean call() throws InterruptedException {
MonitorGrpcFuzzFailRequestQueueThread.submitNewTask();
GrpcClientRequestReplayHelper.getInstance().addToRequestQueue(new ControlCommandDto(controlCommand.getId(), httpRequest, payloadList));
} else {
Request request = RequestUtils.generateK2Request(httpRequest);
if(request != null) {
RestClient.getInstance().fireRequest(request, repeatCount, controlCommand.getId());
}
List<String> endpoints = prepareAllEndpoints(NewRelicSecurity.getAgent().getApplicationConnectionConfig());
RestClient.getInstance().fireRequest(httpRequest, endpoints, repeatCount + endpoints.size() -1, controlCommand.getId());
}
return true;
} catch (JsonProcessingException e){
Expand All @@ -126,6 +127,19 @@ public Boolean call() throws InterruptedException {
return true;
}

private List<String> prepareAllEndpoints(Map<Integer, String> applicationConnectionConfig) {
List<String> endpoitns = new ArrayList<>();
for (Map.Entry<Integer, String> connectionConfig : applicationConnectionConfig.entrySet()) {
endpoitns.add(String.format("%s://localhost:%s", connectionConfig.getValue(), connectionConfig.getKey()));
endpoitns.add(String.format("%s://localhost:%s", toggleProtocol(connectionConfig.getValue()), connectionConfig.getKey()));
}
return endpoitns;
}

private String toggleProtocol(String value) {
return StringUtils.equalsAnyIgnoreCase(value, "https")? "http": "https";
}

public static void processControlCommand(IntCodeControlCommand command) {
RestRequestThreadPool.getInstance().executor
.submit(new RestRequestProcessor(command, MAX_REPETITION));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ public String getApplicationConnectionConfig(int port) {
return appServerInfo.getConnectionConfiguration().get(port);
}

@Override
public Map<Integer, String> getApplicationConnectionConfig() {
return AppServerInfoHelper.getAppServerInfo().getConnectionConfiguration();
}

@Override
public void setServerInfo(String key, String value) {
AppServerInfo appServerInfo = AppServerInfoHelper.getAppServerInfo();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ public String getApplicationConnectionConfig(int port) {
return null;
}

@Override
public Map<Integer, String> getApplicationConnectionConfig() {
//TODO Ishika please fill this as per your needs
return null;
}

@Override
public void log(LogLevel logLevel, String event, Throwable throwableEvent, String logSourceClassName) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import java.lang.instrument.Instrumentation;
import java.net.URL;
import java.util.Collections;
import java.util.Map;

/**
* Provides NoOps for API objects to avoid returning <code>null</code>. Do not call these objects directly.
Expand Down Expand Up @@ -96,6 +98,11 @@ public String getApplicationConnectionConfig(int port) {
return null;
}

@Override
public Map<Integer, String> getApplicationConnectionConfig() {
return Collections.emptyMap();
}

@Override
public void log(LogLevel logLevel, String event, Throwable throwableEvent, String logSourceClassName) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.lang.instrument.Instrumentation;
import java.net.URL;
import java.util.Map;

/**
* The New Relic Security Java Agent's API.
Expand Down Expand Up @@ -55,6 +56,8 @@ public interface SecurityAgent {

String getApplicationConnectionConfig(int port);

Map<Integer, String> getApplicationConnectionConfig();

void log(LogLevel logLevel, String event, Throwable throwableEvent, String logSourceClassName);

void log(LogLevel logLevel, String event, String logSourceClassName);
Expand Down

0 comments on commit 0711914

Please sign in to comment.