Skip to content

Commit

Permalink
Fix for duplicate RXSS events and code optimization in http4s blaze s…
Browse files Browse the repository at this point in the history
…erver
  • Loading branch information
IshikaDawda committed Nov 11, 2024
1 parent ef4a6f1 commit 53b3284
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 522 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import com.newrelic.api.agent.security.instrumentation.helpers.{GenericHelper, I
import com.newrelic.api.agent.security.schema._
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException
import com.newrelic.api.agent.security.schema.operation.RXSSOperation
import com.newrelic.api.agent.security.schema.policy.AgentPolicy
import com.newrelic.api.agent.security.utils.logging.LogLevel
import org.http4s.{Headers, Request, Response}

Expand All @@ -29,16 +28,16 @@ object RequestProcessor {
val result = construct((): Unit)
.redeemWith(_ => httpApp(request),
_ => for {
_ <- preprocessHttpRequest(request)
isLockAcquired <- preprocessHttpRequest(request)
resp <- httpApp(request)
_ <- postProcessSecurityHook(resp)
_ <- postProcessSecurityHook(isLockAcquired, resp)
} yield resp
)
result
}

private def preprocessHttpRequest[F[_]: Sync](request: Request[F]): F[Unit] = construct {
val isLockAcquired = BlazeUtils.acquireLockIfPossible()
private def preprocessHttpRequest[F[_]: Sync](request: Request[F]): F[Boolean] = construct {
val isLockAcquired = GenericHelper.acquireLockIfPossible("HTTP4S-BLAZE-REQUEST_LOCK")
try {
if (NewRelicSecurity.isHookProcessingActive && isLockAcquired && !NewRelicSecurity.getAgent.getSecurityMetaData.getRequest.isRequestParsed){

Expand All @@ -48,8 +47,13 @@ object RequestProcessor {

securityRequest.setMethod(request.method.name)
securityRequest.setServerPort(request.serverPort.toInt)
securityRequest.setClientIP(request.remoteAddr.get)
securityRequest.setProtocol(BlazeUtils.getProtocol(request.isSecure.get))
securityRequest.setClientIP(request.remoteAddr.get.toString)

securityRequest.setProtocol("http")
if (request.isSecure.get) {
securityRequest.setProtocol("https")
}

securityRequest.setUrl(request.uri.toString)

if (securityRequest.getClientIP != null && securityRequest.getClientIP.trim.nonEmpty) {
Expand All @@ -58,8 +62,8 @@ object RequestProcessor {
}

processRequestHeaders(request.headers, securityRequest)
securityMetaData.setTracingHeaderValue(BlazeUtils.getTraceHeader(securityRequest.getHeaders))
securityRequest.setContentType(BlazeUtils.getContentType(securityRequest.getHeaders))
securityMetaData.setTracingHeaderValue(getTraceHeader(securityRequest.getHeaders))
securityRequest.setContentType(getContentType(securityRequest.getHeaders))

// TODO extract request body & user class detection

Expand All @@ -70,24 +74,27 @@ object RequestProcessor {

} catch {
case e: Throwable => NewRelicSecurity.getAgent.log(LogLevel.WARNING, String.format(GenericHelper.ERROR_GENERATING_HTTP_REQUEST, HTTP_4S_EMBER_SERVER_2_12_0_23, e.getMessage), e, this.getClass.getName)
} finally {
if (isLockAcquired) {
BlazeUtils.releaseLock()
}
}
isLockAcquired
}

private def getContentType(headers: util.Map[String, String]): String = {
var contentType = StringUtils.EMPTY
if (headers.containsKey("content-type")) contentType = headers.get("content-type")
contentType
}

private def processRequestHeaders(headers: Headers, securityRequest: HttpRequest): Unit = {
headers.foreach(header => {
var takeNextValue = false
var headerKey: String = StringUtils.EMPTY
if (header.name != null && header.name.isEmpty) {
var headerKey = StringUtils.EMPTY
if (header.name != null && !header.name.isEmpty) {
headerKey = header.name.toString
}
val headerValue: String = header.value
val headerValue = header.value

val agentPolicy: AgentPolicy = NewRelicSecurity.getAgent.getCurrentPolicy
val agentMetaData: AgentMetaData = NewRelicSecurity.getAgent.getSecurityMetaData.getMetaData
val agentPolicy = NewRelicSecurity.getAgent.getCurrentPolicy
val agentMetaData = NewRelicSecurity.getAgent.getSecurityMetaData.getMetaData
if (agentPolicy != null
&& agentPolicy.getProtectionMode.getEnabled()
&& agentPolicy.getProtectionMode.getIpBlocking.getEnabled()
Expand Down Expand Up @@ -118,15 +125,16 @@ object RequestProcessor {
})
}

private def postProcessSecurityHook[F[_]: Sync](response: Response[F]): F[Unit] = construct {
private def postProcessSecurityHook[F[_]: Sync](isLockAcquired:Boolean, response: Response[F]): F[Unit] = construct {
try {
if (NewRelicSecurity.isHookProcessingActive) {
if (NewRelicSecurity.isHookProcessingActive && isLockAcquired) {
val securityResponse = NewRelicSecurity.getAgent.getSecurityMetaData.getResponse
securityResponse.setResponseCode(response.status.code)
processResponseHeaders(response.headers, securityResponse)
securityResponse.setResponseContentType(BlazeUtils.getContentType(securityResponse.getHeaders))
securityResponse.setResponseContentType(getContentType(securityResponse.getHeaders))

// TODO extract response body

ServletHelper.executeBeforeExitingTransaction()
if (!ServletHelper.isResponseContentTypeExcluded(NewRelicSecurity.getAgent.getSecurityMetaData.getResponse.getResponseContentType)) {
val rxssOperation = new RXSSOperation(NewRelicSecurity.getAgent.getSecurityMetaData.getRequest, NewRelicSecurity.getAgent.getSecurityMetaData.getResponse, this.getClass.getName, METHOD_WITH_HTTP_APP)
Expand All @@ -152,5 +160,14 @@ object RequestProcessor {
})
}

private def getTraceHeader(headers: util.Map[String, String]): String = {
var data = StringUtils.EMPTY
if (headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER) || headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase)) {
data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER)
if (data == null || data.trim.isEmpty) data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase)
}
data
}

private def construct[F[_]: Sync, T](t: => T): F[T] = Sync[F].delay(t)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import com.newrelic.api.agent.security.instrumentation.helpers.{GenericHelper, I
import com.newrelic.api.agent.security.schema._
import com.newrelic.api.agent.security.schema.exceptions.NewRelicSecurityException
import com.newrelic.api.agent.security.schema.operation.RXSSOperation
import com.newrelic.api.agent.security.schema.policy.AgentPolicy
import com.newrelic.api.agent.security.utils.logging.LogLevel
import org.http4s.{Headers, Request, Response}

Expand All @@ -30,16 +29,16 @@ object RequestProcessor {
val result = construct((): Unit)
.redeemWith(_ => httpApp(request),
_ => for {
_ <- preprocessHttpRequest(request)
isLockAcquired <- preprocessHttpRequest(request)
resp <- httpApp(request)
_ <- postProcessSecurityHook(resp)
_ <- postProcessSecurityHook(isLockAcquired, resp)
} yield resp
)
result
}

private def preprocessHttpRequest[F[_]: Sync](request: Request[F]): F[Unit] = construct {
val isLockAcquired = BlazeUtils.acquireLockIfPossible()
private def preprocessHttpRequest[F[_]: Sync](request: Request[F]): F[Boolean] = construct {
val isLockAcquired = GenericHelper.acquireLockIfPossible("HTTP4S-BLAZE-REQUEST_LOCK")
try {
if (NewRelicSecurity.isHookProcessingActive && isLockAcquired && !NewRelicSecurity.getAgent.getSecurityMetaData.getRequest.isRequestParsed){

Expand All @@ -50,7 +49,12 @@ object RequestProcessor {
securityRequest.setMethod(request.method.name)
securityRequest.setServerPort((request.serverPort).get.asInstanceOf[Port].value)
securityRequest.setClientIP(request.remoteAddr.get.toString)
securityRequest.setProtocol(BlazeUtils.getProtocol(request.isSecure.get))

securityRequest.setProtocol("http")
if (request.isSecure.get) {
securityRequest.setProtocol("https")
}

securityRequest.setUrl(request.uri.toString)

if (securityRequest.getClientIP != null && securityRequest.getClientIP.trim.nonEmpty) {
Expand All @@ -59,8 +63,8 @@ object RequestProcessor {
}

processRequestHeaders(request.headers, securityRequest)
securityMetaData.setTracingHeaderValue(BlazeUtils.getTraceHeader(securityRequest.getHeaders))
securityRequest.setContentType(BlazeUtils.getContentType(securityRequest.getHeaders))
securityMetaData.setTracingHeaderValue(getTraceHeader(securityRequest.getHeaders))
securityRequest.setContentType(getContentType(securityRequest.getHeaders))

// TODO extract request body & user class detection

Expand All @@ -71,24 +75,27 @@ object RequestProcessor {

} catch {
case e: Throwable => NewRelicSecurity.getAgent.log(LogLevel.WARNING, String.format(GenericHelper.ERROR_GENERATING_HTTP_REQUEST, HTTP_4S_EMBER_SERVER_2_12_0_23, e.getMessage), e, this.getClass.getName)
} finally {
if (isLockAcquired) {
BlazeUtils.releaseLock()
}
}
isLockAcquired
}

private def getContentType(headers: util.Map[String, String]): String = {
var contentType = StringUtils.EMPTY
if (headers.containsKey("content-type")) contentType = headers.get("content-type")
contentType
}

private def processRequestHeaders(headers: Headers, securityRequest: HttpRequest): Unit = {
headers.foreach(header => {
var takeNextValue = false
var headerKey: String = StringUtils.EMPTY
var headerKey = StringUtils.EMPTY
if (header.name != null && header.name.nonEmpty) {
headerKey = header.name.toString
}
val headerValue: String = header.value
val headerValue = header.value

val agentPolicy: AgentPolicy = NewRelicSecurity.getAgent.getCurrentPolicy
val agentMetaData: AgentMetaData = NewRelicSecurity.getAgent.getSecurityMetaData.getMetaData
val agentPolicy = NewRelicSecurity.getAgent.getCurrentPolicy
val agentMetaData = NewRelicSecurity.getAgent.getSecurityMetaData.getMetaData
if (agentPolicy != null
&& agentPolicy.getProtectionMode.getEnabled()
&& agentPolicy.getProtectionMode.getIpBlocking.getEnabled()
Expand Down Expand Up @@ -119,13 +126,13 @@ object RequestProcessor {
})
}

private def postProcessSecurityHook[F[_]: Sync](response: Response[F]): F[Unit] = construct {
private def postProcessSecurityHook[F[_]: Sync](isLockAcquired:Boolean, response: Response[F]): F[Unit] = construct {
try {
if (NewRelicSecurity.isHookProcessingActive) {
if (NewRelicSecurity.isHookProcessingActive && isLockAcquired) {
val securityResponse = NewRelicSecurity.getAgent.getSecurityMetaData.getResponse
securityResponse.setResponseCode(response.status.code)
processResponseHeaders(response.headers, securityResponse)
securityResponse.setResponseContentType(BlazeUtils.getContentType(securityResponse.getHeaders))
securityResponse.setResponseContentType(getContentType(securityResponse.getHeaders))

// TODO extract response body

Expand Down Expand Up @@ -154,5 +161,14 @@ object RequestProcessor {
})
}

private def getTraceHeader(headers: util.Map[String, String]): String = {
var data = StringUtils.EMPTY
if (headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER) || headers.containsKey(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase)) {
data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER)
if (data == null || data.trim.isEmpty) data = headers.get(ServletHelper.CSEC_DISTRIBUTED_TRACING_HEADER.toLowerCase)
}
data
}

private def construct[F[_]: Sync, T](t: => T): F[T] = Sync[F].delay(t)
}
Loading

0 comments on commit 53b3284

Please sign in to comment.