Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handles MessageSystemAttributeNames #1005

Merged
merged 3 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ val scalatest = "org.scalatest" %% "scalatest" % "3.2.18"
val awaitility = "org.awaitility" % "awaitility-scala" % "4.2.1"

val amazonJavaSdkSqs = "com.amazonaws" % "aws-java-sdk-sqs" % "1.12.699" exclude ("commons-logging", "commons-logging")
val amazonJavaV2SdkSqs = "software.amazon.awssdk" % "sqs" % "2.25.30"
val amazonJavaV2SdkSqs = "software.amazon.awssdk" % "sqs" % "2.25.60"

val pekkoVersion = "1.0.2"
val pekkoHttpVersion = "1.0.1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,29 @@ abstract class AmazonJavaSdkNewTestSuite
result should contain theSameElementsAs Set(queue1Url, queue2Url, queue4Url)
}

test("should receive system all attributes") {
// given
val queueUrl = testClient.createQueue(
"testQueue2.fifo",
Map(FifoQueueAttributeName -> "true", ContentBasedDeduplicationAttributeName -> "false")
)
testClient.sendMessage(
queueUrl,
"test123",
messageGroupId = Option("gp1"),
messageDeduplicationId = Option("dup1")
)

// when
val messages = testClient.receiveMessage(queueUrl, systemAttributes = List("All"))

// then
messages should have size 1
val messageAttributes = messages.head.attributes
messageAttributes(MessageDeduplicationId) shouldBe "dup1"
messageAttributes(MessageGroupId) shouldBe "gp1"
}

private def doTestSendAndReceiveMessageWithAttributes(
content: String,
messageAttributes: Map[String, MessageAttribute] = Map.empty,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
package org.elasticmq.rest.sqs.client

import com.amazonaws.services.sqs.AmazonSQS
import com.amazonaws.services.sqs.model.{BatchResultErrorEntry, CancelMessageMoveTaskRequest, ChangeMessageVisibilityBatchRequest, ChangeMessageVisibilityBatchRequestEntry, CreateQueueRequest, DeleteMessageBatchRequest, DeleteMessageBatchRequestEntry, GetQueueAttributesRequest, GetQueueUrlRequest, ListDeadLetterSourceQueuesRequest, ListMessageMoveTasksRequest, MessageAttributeValue, MessageSystemAttributeValue, PurgeQueueRequest, QueueDoesNotExistException, ReceiveMessageRequest, ResourceNotFoundException, SendMessageBatchRequest, SendMessageBatchRequestEntry, SendMessageRequest, StartMessageMoveTaskRequest, UnsupportedOperationException}
import com.amazonaws.services.sqs.model.{
BatchResultErrorEntry,
CancelMessageMoveTaskRequest,
ChangeMessageVisibilityBatchRequest,
ChangeMessageVisibilityBatchRequestEntry,
CreateQueueRequest,
DeleteMessageBatchRequest,
DeleteMessageBatchRequestEntry,
GetQueueAttributesRequest,
GetQueueUrlRequest,
ListDeadLetterSourceQueuesRequest,
ListMessageMoveTasksRequest,
MessageAttributeValue,
MessageSystemAttributeValue,
PurgeQueueRequest,
QueueDoesNotExistException,
ReceiveMessageRequest,
ResourceNotFoundException,
SendMessageBatchRequest,
SendMessageBatchRequestEntry,
SendMessageRequest,
StartMessageMoveTaskRequest,
UnsupportedOperationException
}
import org.elasticmq._

import java.nio.ByteBuffer
Expand Down Expand Up @@ -49,7 +72,9 @@ class AwsSdkV1SqsClient(client: AmazonSQS) extends SqsClient {
queueUrl: QueueUrl,
messageBody: String,
messageAttributes: Map[String, MessageAttribute] = Map.empty,
awsTraceHeader: Option[String] = None
awsTraceHeader: Option[String] = None,
messageGroupId: Option[String] = None,
messageDeduplicationId: Option[String] = None
): Either[SqsClientError, Unit] = interceptErrors {
client.sendMessage(
new SendMessageRequest()
Expand All @@ -59,6 +84,8 @@ class AwsSdkV1SqsClient(client: AmazonSQS) extends SqsClient {
mapAwsTraceHeader(awsTraceHeader)
)
.withMessageAttributes(mapMessageAttributes(messageAttributes))
.withMessageGroupId(messageGroupId.orNull)
.withMessageDeduplicationId(messageDeduplicationId.orNull)
)
}

Expand Down Expand Up @@ -190,7 +217,9 @@ class AwsSdkV1SqsClient(client: AmazonSQS) extends SqsClient {

private def mapAwsTraceHeader(awsTraceHeader: Option[MessageMoveTaskStatus]) = {
awsTraceHeader
.map(header => Map("AWSTraceHeader" -> new MessageSystemAttributeValue().withStringValue(header).withDataType("String")).asJava)
.map(header =>
Map("AWSTraceHeader" -> new MessageSystemAttributeValue().withStringValue(header).withDataType("String")).asJava
)
.orNull
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class AwsSdkV2SqsClient(client: software.amazon.awssdk.services.sqs.SqsClient) e
queueUrl: QueueUrl,
messageBody: String,
messageAttributes: Map[String, MessageAttribute] = Map.empty,
awsTraceHeader: Option[String] = None
awsTraceHeader: Option[String] = None,
messageGroupId: Option[String] = None,
messageDeduplicationId: Option[String] = None
): Either[SqsClientError, Unit] = interceptErrors {
client.sendMessage(
SendMessageRequest
Expand All @@ -50,6 +52,8 @@ class AwsSdkV2SqsClient(client: software.amazon.awssdk.services.sqs.SqsClient) e
.messageBody(messageBody)
.messageSystemAttributes(mapAwsTraceHeader(awsTraceHeader))
.messageAttributes(mapMessageAttributes(messageAttributes))
.messageGroupId(messageGroupId.orNull)
.messageDeduplicationId(messageDeduplicationId.orNull)
.build()
)
}
Expand Down Expand Up @@ -181,7 +185,7 @@ class AwsSdkV2SqsClient(client: software.amazon.awssdk.services.sqs.SqsClient) e
ReceiveMessageRequest
.builder()
.queueUrl(queueUrl)
.attributeNamesWithStrings(systemAttributes.asJava)
.messageSystemAttributeNames(systemAttributes.map(SdkMessageSystemAttributeName.fromValue).asJava)
.messageAttributeNames(messageAttributes.asJava)
.maxNumberOfMessages(maxNumberOfMessages.map(Int.box).orNull)
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ trait SqsClient {
queueUrl: QueueUrl,
messageBody: String,
messageAttributes: Map[String, MessageAttribute] = Map.empty,
awsTraceHeader: Option[String] = None
awsTraceHeader: Option[String] = None,
messageGroupId: Option[String] = None,
messageDeduplicationId: Option[String] = None
): Either[SqsClientError, Unit]

def receiveMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ trait ReceiveMessageDirectives {
)

val attributeNames = requestParameters.AttributeNames.getOrElse(List.empty)
val systemAttributeNames = requestParameters.MessageSystemAttributeNames.getOrElse(List.empty)
def calculateAttributeValues(msg: MessageData): List[(String, String)] = {
import AttributeValuesCalculator.Rule

possiblyEmptyAttributeValuesCalculator.calculate[String](
attributeNames,
attributeNames ++ systemAttributeNames,
Rule(SenderIdAttribute, () => Some("127.0.0.1")),
Rule(SentTimestampAttribute, () => Some(msg.created.toInstant.toEpochMilli.toString)),
Rule(ApproximateReceiveCountAttribute, () => Some(msg.statistics.approximateReceiveCount.toString)),
Expand Down Expand Up @@ -158,6 +159,7 @@ trait ReceiveMessageDirectives {

case class ReceiveMessageActionRequest(
AttributeNames: Option[List[String]],
MessageSystemAttributeNames: Option[List[String]],
MaxNumberOfMessages: Option[Int],
MessageAttributeNames: Option[List[String]],
QueueUrl: String,
Expand All @@ -169,6 +171,7 @@ trait ReceiveMessageDirectives {
object ReceiveMessageActionRequest {
def apply(
AttributeNames: Option[List[String]],
MessageSystemAttributeNames: Option[List[String]],
MaxNumberOfMessages: Option[Int],
MessageAttributeNames: Option[List[String]],
QueueUrl: String,
Expand All @@ -179,6 +182,9 @@ trait ReceiveMessageDirectives {
new ReceiveMessageActionRequest(
AttributeNames =
AttributeNames.map(atr => if (atr.contains("All")) MessageReadeableAttributeNames.AllAttributeNames else atr),
MessageSystemAttributeNames = MessageSystemAttributeNames.map(atr =>
if (atr.contains("All")) MessageReadeableAttributeNames.AllAttributeNames else atr
),
MaxNumberOfMessages = MaxNumberOfMessages,
MessageAttributeNames = MessageAttributeNames,
QueueUrl = QueueUrl,
Expand All @@ -188,14 +194,15 @@ trait ReceiveMessageDirectives {
)
}

implicit val requestJsonFormat: RootJsonFormat[ReceiveMessageActionRequest] = jsonFormat7(
implicit val requestJsonFormat: RootJsonFormat[ReceiveMessageActionRequest] = jsonFormat8(
ReceiveMessageActionRequest.apply
)

implicit val requestParamReader: FlatParamsReader[ReceiveMessageActionRequest] =
new FlatParamsReader[ReceiveMessageActionRequest] {
override def read(params: Map[String, String]): ReceiveMessageActionRequest = {
val attributeNames = attributeNamesReader.read(params, MessageReadeableAttributeNames.AllAttributeNames)
val systemAttributeNames = attributeNamesReader.read(params, MessageReadeableAttributeNames.AllAttributeNames)
val maxNumberOfMessages = params.get(MessageReadeableAttributeNames.MaxNumberOfMessagesAttribute).map(_.toInt)
val messageAttributeNames = getMessageAttributeNames(params).toList
val queueUrl = requiredParameter(params)(QueueUrlParameter)
Expand All @@ -204,6 +211,7 @@ trait ReceiveMessageDirectives {
val waitTimeSeconds = params.get(MessageReadeableAttributeNames.WaitTimeSecondsAttribute).map(_.toLong)
ReceiveMessageActionRequest(
Some(attributeNames),
Some(systemAttributeNames),
maxNumberOfMessages,
Some(messageAttributeNames),
queueUrl,
Expand Down
Loading