From e7463fb2c65cc9e892bffa1f0be1ab05f63858c8 Mon Sep 17 00:00:00 2001 From: Emmanuel T Odeke Date: Thu, 10 Apr 2025 22:36:10 +0300 Subject: [PATCH] chore(x-goog-spanner-request-id): plumb for BatchCreateSessions This change plumbs x-goog-spanner-request-id into BatchCreateSessions and asserts that the header is present for that method. Updates #3537 --- .../cloud/spanner/DatabaseClientImpl.java | 73 ++++++++- .../com/google/cloud/spanner/Options.java | 48 +++++- .../google/cloud/spanner/SessionClient.java | 43 +++-- .../com/google/cloud/spanner/SessionImpl.java | 62 +++++++- .../com/google/cloud/spanner/SessionPool.java | 4 + .../cloud/spanner/XGoogSpannerRequestId.java | 88 +++++++++++ .../cloud/spanner/spi/v1/GapicSpannerRpc.java | 25 ++- .../cloud/spanner/spi/v1/SpannerRpc.java | 3 +- .../cloud/spanner/DatabaseClientImplTest.java | 45 +++++- .../com/google/cloud/spanner/OptionsTest.java | 38 +++++ .../cloud/spanner/SessionClientTests.java | 11 +- .../google/cloud/spanner/SessionImplTest.java | 1 + .../spanner/XGoogSpannerRequestIdTest.java | 147 +++++++++++++++++- 13 files changed, 555 insertions(+), 33 deletions(-) diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java index 8e0e07c457..660b74ad9d 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/DatabaseClientImpl.java @@ -24,14 +24,18 @@ import com.google.cloud.spanner.SpannerImpl.ClosedException; import com.google.cloud.spanner.Statement.StatementFactory; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Function; import com.google.common.util.concurrent.ListenableFuture; import com.google.spanner.v1.BatchWriteResponse; import io.opentelemetry.api.common.Attributes; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Objects; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; import javax.annotation.Nullable; class DatabaseClientImpl implements DatabaseClient { @@ -45,6 +49,8 @@ class DatabaseClientImpl implements DatabaseClient { @VisibleForTesting final MultiplexedSessionDatabaseClient multiplexedSessionDatabaseClient; @VisibleForTesting final boolean useMultiplexedSessionPartitionedOps; @VisibleForTesting final boolean useMultiplexedSessionForRW; + private final int dbId; + private final AtomicInteger nthRequest; final boolean useMultiplexedSessionBlindWrite; @@ -91,6 +97,18 @@ class DatabaseClientImpl implements DatabaseClient { this.tracer = tracer; this.useMultiplexedSessionForRW = useMultiplexedSessionForRW; this.commonAttributes = commonAttributes; + + this.dbId = this.dbIdFromClientId(this.clientId); + this.nthRequest = new AtomicInteger(0); + } + + private int dbIdFromClientId(String clientId) { + int i = clientId.indexOf("-"); + String strWithValue = clientId.substring(i + 1); + if (Objects.equals(strWithValue, "")) { + strWithValue = "0"; + } + return Integer.parseInt(strWithValue); } @VisibleForTesting @@ -188,7 +206,11 @@ public CommitResponse writeWithOptions( if (canUseMultiplexedSessionsForRW() && getMultiplexedSessionDatabaseClient() != null) { return getMultiplexedSessionDatabaseClient().writeWithOptions(mutations, options); } - return runWithSessionRetry(session -> session.writeWithOptions(mutations, options)); + + return runWithSessionRetry( + (session, reqId) -> { + return session.writeWithOptions(mutations, withReqId(reqId, options)); + }); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -213,7 +235,8 @@ public CommitResponse writeAtLeastOnceWithOptions( .writeAtLeastOnceWithOptions(mutations, options); } return runWithSessionRetry( - session -> session.writeAtLeastOnceWithOptions(mutations, options)); + (session, reqId) -> + session.writeAtLeastOnceWithOptions(mutations, withReqId(reqId, options))); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -222,6 +245,10 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + private int nextNthRequest() { + return this.nthRequest.incrementAndGet(); + } + @Override public ServerStream batchWriteAtLeastOnce( final Iterable mutationGroups, final TransactionOption... options) @@ -231,7 +258,9 @@ public ServerStream batchWriteAtLeastOnce( if (canUseMultiplexedSessionsForRW() && getMultiplexedSessionDatabaseClient() != null) { return getMultiplexedSessionDatabaseClient().batchWriteAtLeastOnce(mutationGroups, options); } - return runWithSessionRetry(session -> session.batchWriteAtLeastOnce(mutationGroups, options)); + return runWithSessionRetry( + (session, reqId) -> + session.batchWriteAtLeastOnce(mutationGroups, withReqId(reqId, options))); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -383,11 +412,34 @@ private Future getDialectAsync() { return pool.getDialectAsync(); } + private UpdateOption[] withReqId( + final XGoogSpannerRequestId reqId, final UpdateOption... options) { + if (reqId == null) { + return options; + } + ArrayList allOptions = new ArrayList(Arrays.asList(options)); + allOptions.add(new Options.RequestIdOption(reqId)); + return allOptions.toArray(new UpdateOption[0]); + } + + private TransactionOption[] withReqId( + final XGoogSpannerRequestId reqId, final TransactionOption... options) { + if (reqId == null) { + return options; + } + ArrayList allOptions = new ArrayList(Arrays.asList(options)); + allOptions.add(new Options.RequestIdOption(reqId)); + return allOptions.toArray(new TransactionOption[0]); + } + private long executePartitionedUpdateWithPooledSession( final Statement stmt, final UpdateOption... options) { ISpan span = tracer.spanBuilder(PARTITION_DML_TRANSACTION, commonAttributes); try (IScope s = tracer.withSpan(span)) { - return runWithSessionRetry(session -> session.executePartitionedUpdate(stmt, options)); + return runWithSessionRetry( + (session, reqId) -> { + return session.executePartitionedUpdate(stmt, withReqId(reqId, options)); + }); } catch (RuntimeException e) { span.setStatus(e); span.end(); @@ -395,15 +447,22 @@ private long executePartitionedUpdateWithPooledSession( } } - private T runWithSessionRetry(Function callable) { + private T runWithSessionRetry(BiFunction callable) { PooledSessionFuture session = getSession(); + XGoogSpannerRequestId reqId = + XGoogSpannerRequestId.of( + this.dbId, Long.valueOf(session.getChannel()), this.nextNthRequest(), 0); while (true) { try { - return callable.apply(session); + reqId.incrementAttempt(); + return callable.apply(session, reqId); } catch (SessionNotFoundException e) { session = (PooledSessionFuture) pool.getPooledSessionReplacementHandler().replaceSession(e, session); + reqId = + XGoogSpannerRequestId.of( + this.dbId, Long.valueOf(session.getChannel()), this.nextNthRequest(), 0); } } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java index c36f190264..0b9556084e 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/Options.java @@ -177,6 +177,10 @@ public static UpdateTransactionOption excludeTxnFromChangeStreams() { return EXCLUDE_TXN_FROM_CHANGE_STREAMS_OPTION; } + public static RequestIdOption requestId(XGoogSpannerRequestId reqId) { + return new RequestIdOption(reqId); + } + /** * Specifying this will cause the read to yield at most this many rows. This should be greater * than 0. @@ -535,6 +539,7 @@ void appendToOptions(Options options) { private RpcLockHint lockHint; private Boolean lastStatement; private IsolationLevel isolationLevel; + private XGoogSpannerRequestId reqId; // Construction is via factory methods below. private Options() {} @@ -599,6 +604,14 @@ String filter() { return filter; } + boolean hasReqId() { + return reqId != null; + } + + XGoogSpannerRequestId reqId() { + return reqId; + } + boolean hasPriority() { return priority != null; } @@ -756,6 +769,9 @@ public String toString() { if (isolationLevel != null) { b.append("isolationLevel: ").append(isolationLevel).append(' '); } + if (reqId != null) { + b.append("requestId: ").append(reqId.toString()); + } return b.toString(); } @@ -798,7 +814,8 @@ public boolean equals(Object o) { && Objects.equals(orderBy(), that.orderBy()) && Objects.equals(isLastStatement(), that.isLastStatement()) && Objects.equals(lockHint(), that.lockHint()) - && Objects.equals(isolationLevel(), that.isolationLevel()); + && Objects.equals(isolationLevel(), that.isolationLevel()) + && Objects.equals(reqId(), that.reqId()); } @Override @@ -867,6 +884,9 @@ public int hashCode() { if (isolationLevel != null) { result = 31 * result + isolationLevel.hashCode(); } + if (reqId != null) { + result = 31 * result + reqId.hashCode(); + } return result; } @@ -1052,4 +1072,30 @@ public boolean equals(Object o) { return o instanceof LastStatementUpdateOption; } } + + static final class RequestIdOption extends InternalOption + implements ReadOption, TransactionOption, UpdateOption { + private final XGoogSpannerRequestId reqId; + + RequestIdOption(XGoogSpannerRequestId reqId) { + this.reqId = reqId; + } + + @Override + void appendToOptions(Options options) { + options.reqId = this.reqId; + } + + @Override + public int hashCode() { + return RequestIdOption.class.hashCode(); + } + + @Override + public boolean equals(Object o) { + // TODO: Examine why the precedent for LastStatementUpdateOption + // does not check against the actual value. + return o instanceof RequestIdOption; + } + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java index 2edfb66d89..405c5f8681 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java @@ -31,10 +31,11 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; import javax.annotation.concurrent.GuardedBy; /** Client for creating single sessions and batches of sessions. */ -class SessionClient implements AutoCloseable { +class SessionClient implements AutoCloseable, XGoogSpannerRequestId.RequestIdCreator { static class SessionId { private static final PathTemplate NAME_TEMPLATE = PathTemplate.create( @@ -174,6 +175,12 @@ interface SessionConsumer { private final DatabaseId db; private final Attributes commonAttributes; + // SessionClient is created long before a DatabaseClientImpl is created, + // as batch sessions are firstly created then later attached to each Client. + private static AtomicInteger NTH_ID = new AtomicInteger(0); + private final int nthId; + private final AtomicInteger nthRequest; + @GuardedBy("this") private volatile long sessionChannelCounter; @@ -186,6 +193,8 @@ interface SessionConsumer { this.executorFactory = executorFactory; this.executor = executorFactory.get(); this.commonAttributes = spanner.getTracer().createCommonAttributes(db); + this.nthId = SessionClient.NTH_ID.incrementAndGet(); + this.nthRequest = new AtomicInteger(0); } @Override @@ -201,16 +210,24 @@ DatabaseId getDatabaseId() { return db; } + @Override + public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { + return XGoogSpannerRequestId.of(this.nthId, this.nthRequest.incrementAndGet(), channelId, 1); + } + /** Create a single session. */ SessionImpl createSession() { // The sessionChannelCounter could overflow, but that will just flip it to Integer.MIN_VALUE, // which is also a valid channel hint. final Map options; + final long channelId; synchronized (this) { options = optionMap(SessionOption.channelHint(sessionChannelCounter++)); + channelId = sessionChannelCounter; } ISpan span = spanner.getTracer().spanBuilder(SpannerImpl.CREATE_SESSION, this.commonAttributes); try (IScope s = spanner.getTracer().withSpan(span)) { + XGoogSpannerRequestId reqId = this.nextRequestId(channelId, 1); com.google.spanner.v1.Session session = spanner .getRpc() @@ -218,11 +235,13 @@ SessionImpl createSession() { db.getName(), spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - options); + reqId.withOptions(options)); SessionReference sessionReference = new SessionReference( session.getName(), session.getCreateTime(), session.getMultiplexed(), options); - return new SessionImpl(spanner, sessionReference); + SessionImpl sessionImpl = new SessionImpl(spanner, sessionReference); + sessionImpl.setRequestIdCreator(this); + return sessionImpl; } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -273,6 +292,7 @@ SessionImpl createMultiplexedSession() { spanner, new SessionReference( session.getName(), session.getCreateTime(), session.getMultiplexed(), null)); + sessionImpl.setRequestIdCreator(this); span.addAnnotation( String.format("Request for %d multiplexed session returned %d session", 1, 1)); return sessionImpl; @@ -387,6 +407,8 @@ private List internalBatchCreateSessions( .spanBuilderWithExplicitParent(SpannerImpl.BATCH_CREATE_SESSIONS_REQUEST, parent); span.addAnnotation(String.format("Requesting %d sessions", sessionCount)); try (IScope s = spanner.getTracer().withSpan(span)) { + XGoogSpannerRequestId reqId = + XGoogSpannerRequestId.of(this.nthId, this.nthRequest.incrementAndGet(), channelHint, 1); List sessions = spanner .getRpc() @@ -395,21 +417,20 @@ private List internalBatchCreateSessions( sessionCount, spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - options); + reqId.withOptions(options)); span.addAnnotation( String.format( "Request for %d sessions returned %d sessions", sessionCount, sessions.size())); span.end(); List res = new ArrayList<>(sessionCount); for (com.google.spanner.v1.Session session : sessions) { - res.add( + SessionImpl sessionImpl = new SessionImpl( spanner, new SessionReference( - session.getName(), - session.getCreateTime(), - session.getMultiplexed(), - options))); + session.getName(), session.getCreateTime(), session.getMultiplexed(), options)); + sessionImpl.setRequestIdCreator(this); + res.add(sessionImpl); } return res; } catch (RuntimeException e) { @@ -425,6 +446,8 @@ SessionImpl sessionWithId(String name) { synchronized (this) { options = optionMap(SessionOption.channelHint(sessionChannelCounter++)); } - return new SessionImpl(spanner, new SessionReference(name, options)); + SessionImpl sessionImpl = new SessionImpl(spanner, new SessionReference(name, options)); + sessionImpl.setRequestIdCreator(this); + return sessionImpl; } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java index 3ad60cf79b..e6b29f55c1 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionImpl.java @@ -126,18 +126,31 @@ interface SessionTransaction { private final Clock clock; private final Map options; private final ErrorHandler errorHandler; + private XGoogSpannerRequestId.RequestIdCreator requestIdCreator; SessionImpl(SpannerImpl spanner, SessionReference sessionReference) { this(spanner, sessionReference, NO_CHANNEL_HINT); } SessionImpl(SpannerImpl spanner, SessionReference sessionReference, int channelHint) { + this(spanner, sessionReference, channelHint, new XGoogSpannerRequestId.NoopRequestIdCreator()); + } + + SessionImpl( + SpannerImpl spanner, + SessionReference sessionReference, + int channelHint, + XGoogSpannerRequestId.RequestIdCreator requestIdCreator) { this.spanner = spanner; this.tracer = spanner.getTracer(); this.sessionReference = sessionReference; this.clock = spanner.getOptions().getSessionPoolOptions().getPoolMaintainerClock(); this.options = createOptions(sessionReference, channelHint); this.errorHandler = createErrorHandler(spanner.getOptions()); + this.requestIdCreator = requestIdCreator; + if (this.requestIdCreator == null) { + this.requestIdCreator = new XGoogSpannerRequestId.NoopRequestIdCreator(); + } } static Map createOptions( @@ -287,9 +300,16 @@ public CommitResponse writeAtLeastOnceWithOptions( } CommitRequest request = requestBuilder.build(); ISpan span = tracer.spanBuilder(SpannerImpl.COMMIT); + final XGoogSpannerRequestId reqId = reqIdOrFresh(options); + try (IScope s = tracer.withSpan(span)) { return SpannerRetryHelper.runTxWithRetriesOnAborted( - () -> new CommitResponse(spanner.getRpc().commit(request, getOptions()))); + () -> { + // TODO: Detect an abort and then refresh the reqId. + reqId.incrementAttempt(); + return new CommitResponse( + spanner.getRpc().commit(request, reqId.withOptions(getOptions()))); + }); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -298,6 +318,14 @@ public CommitResponse writeAtLeastOnceWithOptions( } } + private XGoogSpannerRequestId reqIdOrFresh(Options options) { + XGoogSpannerRequestId reqId = options.reqId(); + if (reqId == null) { + reqId = this.getRequestIdCreator().nextRequestId(1 /* TODO: channelId */, 0); + } + return reqId; + } + private RequestOptions getRequestOptions(TransactionOption... transactionOptions) { Options requestOptions = Options.fromTransactionOptions(transactionOptions); if (requestOptions.hasPriority() || requestOptions.hasTag()) { @@ -325,16 +353,19 @@ public ServerStream batchWriteAtLeastOnce( .setSession(getName()) .addAllMutationGroups(mutationGroupsProto); RequestOptions batchWriteRequestOptions = getRequestOptions(transactionOptions); + Options allOptions = Options.fromTransactionOptions(transactionOptions); + final XGoogSpannerRequestId reqId = reqIdOrFresh(allOptions); if (batchWriteRequestOptions != null) { requestBuilder.setRequestOptions(batchWriteRequestOptions); } - if (Options.fromTransactionOptions(transactionOptions).withExcludeTxnFromChangeStreams() - == Boolean.TRUE) { + if (allOptions.withExcludeTxnFromChangeStreams() == Boolean.TRUE) { requestBuilder.setExcludeTxnFromChangeStreams(true); } ISpan span = tracer.spanBuilder(SpannerImpl.BATCH_WRITE); try (IScope s = tracer.withSpan(span)) { - return spanner.getRpc().batchWriteAtLeastOnce(requestBuilder.build(), getOptions()); + return spanner + .getRpc() + .batchWriteAtLeastOnce(requestBuilder.build(), reqId.withOptions(getOptions())); } catch (Throwable e) { span.setStatus(e); throw SpannerExceptionFactory.newSpannerException(e); @@ -435,14 +466,18 @@ public AsyncTransactionManagerImpl transactionManagerAsync(TransactionOption... @Override public ApiFuture asyncClose() { - return spanner.getRpc().asyncDeleteSession(getName(), getOptions()); + XGoogSpannerRequestId reqId = + this.getRequestIdCreator().nextRequestId(1 /* TODO: channelId */, 0); + return spanner.getRpc().asyncDeleteSession(getName(), reqId.withOptions(getOptions())); } @Override public void close() { ISpan span = tracer.spanBuilder(SpannerImpl.DELETE_SESSION); try (IScope s = tracer.withSpan(span)) { - spanner.getRpc().deleteSession(getName(), getOptions()); + XGoogSpannerRequestId reqId = + this.getRequestIdCreator().nextRequestId(1 /* TODO: channelId */, 0); + spanner.getRpc().deleteSession(getName(), reqId.withOptions(getOptions())); } catch (RuntimeException e) { span.setStatus(e); throw e; @@ -473,8 +508,13 @@ ApiFuture beginTransactionAsync( } final BeginTransactionRequest request = requestBuilder.build(); final ApiFuture requestFuture; + XGoogSpannerRequestId reqId = + this.getRequestIdCreator().nextRequestId(1 /* TODO: channelId */, 1); try (IScope ignore = tracer.withSpan(span)) { - requestFuture = spanner.getRpc().beginTransactionAsync(request, channelHint, routeToLeader); + requestFuture = + spanner + .getRpc() + .beginTransactionAsync(request, reqId.withOptions(channelHint), routeToLeader); } requestFuture.addListener( () -> { @@ -552,4 +592,12 @@ void onTransactionDone() {} TraceWrapper getTracer() { return tracer; } + + public void setRequestIdCreator(XGoogSpannerRequestId.RequestIdCreator creator) { + this.requestIdCreator = creator; + } + + public XGoogSpannerRequestId.RequestIdCreator getRequestIdCreator() { + return this.requestIdCreator; + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index 37fa2c5d20..3bc70e7f5b 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -1569,6 +1569,10 @@ PooledSession get(final boolean eligibleForLongRunning) { throw SpannerExceptionFactory.propagateInterrupt(e); } } + + public int getChannel() { + return get().getChannel(); + } } interface CachedSession extends Session { diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java index 4f6c011475..325aace2d2 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java @@ -17,10 +17,19 @@ package com.google.cloud.spanner; import com.google.api.core.InternalApi; +import com.google.cloud.spanner.spi.v1.SpannerRpc; import com.google.common.annotations.VisibleForTesting; +import io.grpc.Metadata; import java.math.BigInteger; import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.regex.MatchResult; +import java.util.regex.Matcher; +import java.util.regex.Pattern; @InternalApi public class XGoogSpannerRequestId { @@ -28,6 +37,9 @@ public class XGoogSpannerRequestId { @VisibleForTesting static final String RAND_PROCESS_ID = XGoogSpannerRequestId.generateRandProcessId(); + public static final Metadata.Key REQUEST_HEADER_KEY = + Metadata.Key.of("x-goog-spanner-request-id", Metadata.ASCII_STRING_MARSHALLER); + @VisibleForTesting static final long VERSION = 1; // The version of the specification being implemented. @@ -48,6 +60,26 @@ public static XGoogSpannerRequestId of( return new XGoogSpannerRequestId(nthClientId, nthChannelId, nthRequest, attempt); } + @VisibleForTesting + static final Pattern REGEX = + Pattern.compile("^(\\d)\\.([0-9a-z]{16})\\.(\\d+)\\.(\\d+)\\.(\\d+)\\.(\\d+)$"); + + public static XGoogSpannerRequestId of(String s) { + Matcher m = XGoogSpannerRequestId.REGEX.matcher(s); + if (!m.matches()) { + throw new IllegalStateException( + s + " does not match " + XGoogSpannerRequestId.REGEX.pattern()); + } + + MatchResult mr = m.toMatchResult(); + + return new XGoogSpannerRequestId( + Long.parseLong(mr.group(3)), + Long.parseLong(mr.group(4)), + Long.parseLong(mr.group(5)), + Long.parseLong(mr.group(6))); + } + private static String generateRandProcessId() { // Expecting to use 64-bits of randomness to avoid clashes. BigInteger bigInt = new BigInteger(64, new SecureRandom()); @@ -66,6 +98,13 @@ public String toString() { this.attempt); } + private boolean isGreaterThan(XGoogSpannerRequestId other) { + return this.nthClientId > other.nthClientId + && this.nthChannelId > other.nthChannelId + && this.nthRequest > other.nthRequest + && this.attempt > other.attempt; + } + @Override public boolean equals(Object other) { // instanceof for a null object returns false. @@ -81,8 +120,57 @@ public boolean equals(Object other) { && Objects.equals(this.attempt, otherReqId.attempt); } + public void incrementAttempt() { + this.attempt++; + } + + @SuppressWarnings("unchecked") + public Map withOptions(Map options) { + Map copyOptions = new HashMap<>(); + if (options != null) { + copyOptions.putAll(options); + } + copyOptions.put(SpannerRpc.Option.REQUEST_ID, this); + return copyOptions; + } + @Override public int hashCode() { return Objects.hash(this.nthClientId, this.nthChannelId, this.nthRequest, this.attempt); } + + public interface RequestIdCreator { + XGoogSpannerRequestId nextRequestId(long channelId, int attempt); + } + + public static class NoopRequestIdCreator implements RequestIdCreator { + NoopRequestIdCreator() {} + + @Override + public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { + return XGoogSpannerRequestId.of(1, 1, 1, 0); + } + } + + public static void assertMonotonicityOfIds(String prefix, List reqIds) { + int size = reqIds.size(); + + List violations = new ArrayList<>(); + for (int i = 1; i < size; i++) { + XGoogSpannerRequestId prev = reqIds.get(i - 1); + XGoogSpannerRequestId curr = reqIds.get(i); + if (prev.isGreaterThan(curr)) { + violations.add(String.format("#%d(%s) > #%d(%s)", i - 1, prev, i, curr)); + } + } + + if (violations.size() == 0) { + return; + } + + throw new IllegalStateException( + prefix + + " monotonicity violation:" + + String.join("\n\t", violations.toArray(new String[0]))); + } } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java index 0f51c9544f..56960da3c8 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java @@ -71,6 +71,7 @@ import com.google.cloud.spanner.SpannerOptions; import com.google.cloud.spanner.SpannerOptions.CallContextConfigurator; import com.google.cloud.spanner.SpannerOptions.CallCredentialsProvider; +import com.google.cloud.spanner.XGoogSpannerRequestId; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStub; import com.google.cloud.spanner.admin.database.v1.stub.DatabaseAdminStubSettings; import com.google.cloud.spanner.admin.database.v1.stub.GrpcDatabaseAdminCallableFactory; @@ -88,6 +89,7 @@ import com.google.common.base.Supplier; import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.io.Resources; import com.google.common.util.concurrent.RateLimiter; @@ -193,6 +195,7 @@ import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.Collections; import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -405,6 +408,8 @@ public GapicSpannerRpc(final SpannerOptions options) { final String emulatorHost = System.getenv("SPANNER_EMULATOR_HOST"); try { + // TODO: make our retry settings to inject and increment + // XGoogSpannerRequestId whenever a retry occurs. SpannerStubSettings spannerStubSettings = options .getSpannerStubSettings() @@ -2033,8 +2038,13 @@ GrpcCallContext newCallContext( GcpManagedChannel.AFFINITY_KEY, String.valueOf(boundedChannelHint))); } else { // Set channel affinity in GAX. - context = context.withChannelAffinity(Option.CHANNEL_HINT.getLong(options).intValue()); + Long affinity = Option.CHANNEL_HINT.getLong(options); + if (affinity != null) { + context = context.withChannelAffinity(affinity.intValue()); + } } + String methodName = method.getFullMethodName(); + context = withRequestId(context, options, methodName); } context = context.withExtraHeaders(metadataProvider.newExtraHeaders(resource, projectName)); if (routeToLeader && leaderAwareRoutingEnabled) { @@ -2055,6 +2065,19 @@ GrpcCallContext newCallContext( return (GrpcCallContext) context.merge(apiCallContextFromContext); } + GrpcCallContext withRequestId(GrpcCallContext context, Map options, String methodName) { + XGoogSpannerRequestId reqId = (XGoogSpannerRequestId) options.get(Option.REQUEST_ID); + if (reqId == null) { + return context; + } + + Map> withReqId = + ImmutableMap.of( + XGoogSpannerRequestId.REQUEST_HEADER_KEY.name(), + Collections.singletonList(reqId.toString())); + return context.withExtraHeaders(withReqId); + } + void registerResponseObserver(SpannerResponseObserver responseObserver) { responseObservers.add(responseObserver); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java index 9ad9420474..d029084477 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/SpannerRpc.java @@ -78,7 +78,8 @@ public interface SpannerRpc extends ServiceRpc { /** Options passed in {@link SpannerRpc} methods to control how an RPC is issued. */ enum Option { - CHANNEL_HINT("Channel Hint"); + CHANNEL_HINT("Channel Hint"), + REQUEST_ID("Request Id"); private final String value; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index 70209917f0..56dc38e014 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -105,6 +105,7 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; +import io.grpc.ServerInterceptors; import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.inprocess.InProcessServerBuilder; @@ -119,6 +120,7 @@ import java.util.Arrays; import java.util.Base64; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Random; import java.util.Set; @@ -152,6 +154,7 @@ public class DatabaseClientImplTest { private static final String DATABASE_NAME = String.format( "projects/%s/instances/%s/databases/%s", TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE); + private static XGoogSpannerRequestIdTest.ServerHeaderEnforcer xGoogReqIdInterceptor; private static MockSpannerServiceImpl mockSpanner; private static Server server; private static LocalChannelProvider channelProvider; @@ -220,13 +223,31 @@ public static void startStaticServer() throws IOException { StatementResult.query(SELECT1_FROM_TABLE, MockSpannerTestUtil.SELECT1_RESULTSET)); mockSpanner.setBatchWriteResult(BATCH_WRITE_RESPONSES); + Set checkMethods = + new HashSet( + Arrays.asList( + "google.spanner.v1.Spanner/BatchCreateSessions" + // As functionality is added, uncomment each method. + // "google.spanner.v1.Spanner/BatchWrite", + // "google.spanner.v1.Spanner/BeginTransaction", + // "google.spanner.v1.Spanner/CreateSession", + // "google.spanner.v1.Spanner/DeleteSession", + // "google.spanner.v1.Spanner/ExecuteBatchDml", + // "google.spanner.v1.Spanner/ExecuteSql", + // "google.spanner.v1.Spanner/ExecuteStreamingSql", + // "google.spanner.v1.Spanner/StreamingRead", + // "google.spanner.v1.Spanner/PartitionQuery", + // "google.spanner.v1.Spanner/PartitionRead", + // "google.spanner.v1.Spanner/Commit", + )); + xGoogReqIdInterceptor = new XGoogSpannerRequestIdTest.ServerHeaderEnforcer(checkMethods); executor = Executors.newSingleThreadExecutor(); String uniqueName = InProcessServerBuilder.generateName(); server = InProcessServerBuilder.forName(uniqueName) // We need to use a real executor for timeouts to occur. .scheduledExecutorService(new ScheduledThreadPoolExecutor(1)) - .addService(mockSpanner) + .addService(ServerInterceptors.intercept(mockSpanner, xGoogReqIdInterceptor)) .build() .start(); channelProvider = LocalChannelProvider.create(uniqueName); @@ -266,6 +287,7 @@ public void tearDown() { spanner.close(); spannerWithEmptySessionPool.close(); mockSpanner.reset(); + xGoogReqIdInterceptor.reset(); mockSpanner.removeAllExecutionTimes(); } @@ -1393,6 +1415,7 @@ public void testWriteAtLeastOnceAborted() { List commitRequests = mockSpanner.getRequestsOfType(CommitRequest.class); assertEquals(2, commitRequests.size()); + xGoogReqIdInterceptor.assertIntegrity(); } @Test @@ -5195,6 +5218,26 @@ public void testRetryOnResourceExhausted() { } } + @Test + public void testSelectHasXGoogRequestIdHeader() { + Statement statement = + Statement.newBuilder("select id from test where b=@p1") + .bind("p1") + .toBytesArray( + Arrays.asList(ByteArray.copyFrom("test1"), null, ByteArray.copyFrom("test2"))) + .build(); + mockSpanner.putStatementResult(StatementResult.query(statement, SELECT1_RESULTSET)); + DatabaseClient client = + spanner.getDatabaseClient(DatabaseId.of(TEST_PROJECT, TEST_INSTANCE, TEST_DATABASE)); + try (ResultSet resultSet = client.singleUse().executeQuery(statement)) { + assertTrue(resultSet.next()); + assertEquals(1L, resultSet.getLong(0)); + assertFalse(resultSet.next()); + } finally { + xGoogReqIdInterceptor.assertIntegrity(); + } + } + @Test public void testSessionPoolExhaustedError_containsStackTraces() { assumeFalse( diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java index e2bcc92fed..f8b5304a70 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/OptionsTest.java @@ -83,6 +83,7 @@ public void zeroPrefetchChunksNotAllowed() { @Test public void allOptionsPresent() { + XGoogSpannerRequestId reqId1 = XGoogSpannerRequestId.of(2, 3, 4, 5); Options options = Options.fromReadOptions( Options.limit(10), @@ -90,6 +91,7 @@ public void allOptionsPresent() { Options.dataBoostEnabled(true), Options.directedRead(DIRECTED_READ_OPTIONS), Options.orderBy(RpcOrderBy.NO_ORDER), + Options.requestId(reqId1), Options.lockHint(Options.RpcLockHint.SHARED)); assertThat(options.hasLimit()).isTrue(); assertThat(options.limit()).isEqualTo(10); @@ -101,6 +103,7 @@ public void allOptionsPresent() { assertTrue(options.hasOrderBy()); assertTrue(options.hasLockHint()); assertEquals(DIRECTED_READ_OPTIONS, options.directedReadOptions()); + assertEquals(options.reqId(), reqId1); } @Test @@ -873,4 +876,39 @@ public void testOptions_WithMultipleDifferentIsolationLevels() { Options options = Options.fromTransactionOptions(transactionOptions); assertEquals(options.isolationLevel(), IsolationLevel.SERIALIZABLE); } + + @Test + public void testRequestId() { + XGoogSpannerRequestId reqId1 = XGoogSpannerRequestId.of(1, 2, 3, 4); + XGoogSpannerRequestId reqId2 = XGoogSpannerRequestId.of(2, 3, 4, 5); + Options option1 = Options.fromUpdateOptions(Options.requestId(reqId1)); + Options option1Prime = Options.fromUpdateOptions(Options.requestId(reqId1)); + Options option2 = Options.fromUpdateOptions(Options.requestId(reqId2)); + Options option3 = Options.fromUpdateOptions(); + + assertEquals(option1, option1Prime); + assertNotEquals(option1, option2); + assertEquals(option1.hashCode(), option1Prime.hashCode()); + assertNotEquals(option1, option2); + assertNotEquals(option1, option3); + assertNotEquals(option1.hashCode(), option3.hashCode()); + + assertTrue(option1.hasReqId()); + assertThat(option1.toString()).contains("requestId: " + reqId1.toString()); + + assertFalse(option3.hasReqId()); + assertThat(option3.toString()).doesNotContain("requestId"); + } + + @Test + public void testOptions_WithMultipleDifferentRequestIds() { + XGoogSpannerRequestId reqId1 = XGoogSpannerRequestId.of(1, 1, 1, 1); + XGoogSpannerRequestId reqId2 = XGoogSpannerRequestId.of(1, 1, 1, 2); + TransactionOption[] transactionOptions = { + Options.requestId(reqId1), Options.requestId(reqId2), + }; + Options options = Options.fromTransactionOptions(transactionOptions); + assertNotEquals(options.reqId(), reqId1); + assertEquals(options.reqId(), reqId2); + } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionClientTests.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionClientTests.java index bcba430c52..f04d9678d1 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionClientTests.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionClientTests.java @@ -153,8 +153,17 @@ public void createAndCloseSession() { assertThat(session.getName()).isEqualTo(sessionName); session.close(); + + final ArgumentCaptor> deleteOptionsCaptor = + ArgumentCaptor.forClass(Map.class); + final ArgumentCaptor sessionNameCaptor = ArgumentCaptor.forClass(String.class); + Mockito.verify(rpc).deleteSession(sessionNameCaptor.capture(), deleteOptionsCaptor.capture()); + assertEquals(sessionName, sessionNameCaptor.getValue()); // The same channelHint is passed for deleteSession (contained in "options"). - Mockito.verify(rpc).deleteSession(sessionName, options.getValue()); + assertEquals( + deleteOptionsCaptor.getValue().get(SpannerRpc.Option.CHANNEL_HINT), + options.getValue().get(SpannerRpc.Option.CHANNEL_HINT)); + assertTrue(deleteOptionsCaptor.getValue().containsKey(SpannerRpc.Option.REQUEST_ID)); } } diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index e0403f72d1..a6196df01e 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -144,6 +144,7 @@ public void setUp() { when(rpc.getCommitRetrySettings()) .thenReturn(SpannerStubSettings.newBuilder().commitSettings().getRetrySettings()); session = spanner.getSessionClient(db).createSession(); + ((SessionImpl) session).setRequestIdCreator(new XGoogSpannerRequestId.NoopRequestIdCreator()); Span oTspan = mock(Span.class); ISpan span = new OpenTelemetrySpan(oTspan); when(oTspan.makeCurrent()).thenReturn(mock(Scope.class)); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java index 12c9213c7d..847a4adf7b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/XGoogSpannerRequestIdTest.java @@ -18,18 +18,29 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CopyOnWriteArrayList; import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class XGoogSpannerRequestIdTest { - private static final Pattern REGEX_RAND_PROCESS_ID = - Pattern.compile("1.([0-9a-z]{16})(\\.\\d+){3}\\.(\\d+)$"); @Test public void testEquals() { @@ -48,7 +59,135 @@ public void testEquals() { @Test public void testEnsureHexadecimalFormatForRandProcessID() { String str = XGoogSpannerRequestId.of(1, 2, 3, 4).toString(); - Matcher m = XGoogSpannerRequestIdTest.REGEX_RAND_PROCESS_ID.matcher(str); + Matcher m = XGoogSpannerRequestId.REGEX.matcher(str); assertTrue(m.matches()); } + + public static class ServerHeaderEnforcer implements ServerInterceptor { + private Map> unaryResults; + private Map> streamingResults; + private List gotValues; + private Set checkMethods; + + ServerHeaderEnforcer(Set checkMethods) { + this.gotValues = new CopyOnWriteArrayList(); + this.unaryResults = + new ConcurrentHashMap>(); + this.streamingResults = + new ConcurrentHashMap>(); + this.checkMethods = checkMethods; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, + final Metadata requestHeaders, + ServerCallHandler next) { + boolean isUnary = call.getMethodDescriptor().getType() == MethodType.UNARY; + String methodName = call.getMethodDescriptor().getFullMethodName(); + String gotReqIdStr = requestHeaders.get(XGoogSpannerRequestId.REQUEST_HEADER_KEY); + if (!this.checkMethods.contains(methodName)) { + return next.startCall(call, requestHeaders); + } + + Map> saver = this.streamingResults; + if (isUnary) { + saver = this.unaryResults; + } + + if (Objects.equals(gotReqIdStr, null) || Objects.equals(gotReqIdStr, "")) { + Status status = + Status.fromCode(Status.Code.INVALID_ARGUMENT) + .augmentDescription( + methodName + " lacks " + XGoogSpannerRequestId.REQUEST_HEADER_KEY); + call.close(status, requestHeaders); + return next.startCall(call, requestHeaders); + } + + assertNotNull(gotReqIdStr); + // Firstly assert and validate that at least we've got a requestId. + Matcher m = XGoogSpannerRequestId.REGEX.matcher(gotReqIdStr); + assertTrue(m.matches()); + + XGoogSpannerRequestId reqId = XGoogSpannerRequestId.of(gotReqIdStr); + if (!saver.containsKey(methodName)) { + saver.put(methodName, new CopyOnWriteArrayList()); + } + + saver.get(methodName).add(reqId); + + // Finally proceed with the call. + return next.startCall(call, requestHeaders); + } + + public String[] accumulatedValues() { + return this.gotValues.toArray(new String[0]); + } + + public void assertIntegrity() { + this.unaryResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + XGoogSpannerRequestId.assertMonotonicityOfIds(method, values); + }); + this.streamingResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + XGoogSpannerRequestId.assertMonotonicityOfIds(method, values); + }); + } + + public static class methodAndRequestId { + String method; + String requestId; + + public methodAndRequestId(String method, String requestId) { + this.method = method; + this.requestId = requestId; + } + + public String toString() { + return "{" + this.method + ":" + this.requestId + "}"; + } + } + + public methodAndRequestId[] accumulatedUnaryValues() { + List accumulated = new ArrayList(); + this.unaryResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + for (int i = 0; i < values.size(); i++) { + accumulated.add(new methodAndRequestId(method, values.get(i).toString())); + } + }); + return accumulated.toArray(new methodAndRequestId[0]); + } + + public methodAndRequestId[] accumulatedStreamingValues() { + List accumulated = new ArrayList(); + this.streamingResults.forEach( + (String method, CopyOnWriteArrayList values) -> { + for (int i = 0; i < values.size(); i++) { + accumulated.add(new methodAndRequestId(method, values.get(i).toString())); + } + }); + return accumulated.toArray(new methodAndRequestId[0]); + } + + public void printAccumulatedValues() { + methodAndRequestId[] unary = this.accumulatedUnaryValues(); + System.out.println("accumulatedUnaryvalues"); + for (int i = 0; i < unary.length; i++) { + System.out.println("\t" + unary[i].toString()); + } + methodAndRequestId[] streaming = this.accumulatedStreamingValues(); + System.out.println("accumulatedStreaminvalues"); + for (int i = 0; i < streaming.length; i++) { + System.out.println("\t" + streaming[i].toString()); + } + } + + public void reset() { + this.gotValues.clear(); + this.unaryResults.clear(); + this.streamingResults.clear(); + } + } }