diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java index 9c6fba6fb2e4..36ecbc8863ad 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SpannerImpl.java @@ -1319,7 +1319,7 @@ private T runInternal(TransactionCallable callable) { shouldRollback = false; } catch (Exception e) { txnLogger.log(Level.FINE, "User-provided TransactionCallable raised exception", e); - if (txn.isAborted()) { + if (txn.isAborted() || (e instanceof AbortedException)) { span.addAnnotation( "Transaction Attempt Aborted in user operation. Retrying", ImmutableMap.of("Attempt", AttributeValue.longAttributeValue(attempt))); diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java index 35ba83e93203..aff4dace54a5 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/it/ITDMLTest.java @@ -19,6 +19,7 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import com.google.cloud.spanner.AbortedException; import com.google.cloud.spanner.Database; import com.google.cloud.spanner.DatabaseClient; import com.google.cloud.spanner.ErrorCode; @@ -30,6 +31,7 @@ import com.google.cloud.spanner.Mutation; import com.google.cloud.spanner.ResultSet; import com.google.cloud.spanner.SpannerException; +import com.google.cloud.spanner.SpannerExceptionFactory; import com.google.cloud.spanner.Statement; import com.google.cloud.spanner.TimestampBound; import com.google.cloud.spanner.TransactionContext; @@ -59,6 +61,8 @@ public final class ITDMLTest { private static final String DELETE_DML = "DELETE FROM T WHERE T.K like 'boo%';"; private static final long DML_COUNT = 4; + private static boolean throwAbortOnce = false; + @BeforeClass public static void setUpDatabase() { db = @@ -82,6 +86,12 @@ private void executeUpdate(long expectedCount, final String... stmts) { public Long run(TransactionContext transaction) { long rowCount = 0; for (String stmt : stmts) { + if (throwAbortOnce) { + throwAbortOnce = false; + throw SpannerExceptionFactory.newSpannerException( + ErrorCode.ABORTED, "Abort in test"); + } + rowCount += transaction.executeUpdate(Statement.of(stmt)); } return rowCount; @@ -92,6 +102,17 @@ public Long run(TransactionContext transaction) { assertThat(rowCount).isEqualTo(expectedCount); } + @Test + public void abortOnceShouldSucceedAfterRetry() { + try { + throwAbortOnce = true; + executeUpdate(DML_COUNT, INSERT_DML); + assertThat(throwAbortOnce).isFalse(); + } catch (AbortedException e) { + fail("Abort Exception not caught and retried"); + } + } + @Test public void partitionedDML() { executeUpdate(DML_COUNT, INSERT_DML);