Skip to content

Commit

Permalink
feat: enable DirectPath bound token in InstantiatingGrpcChannelProvid…
Browse files Browse the repository at this point in the history
…er (#3572)

Prepares a ComputeEngineCredentials that fetches DirectPath bound tokens
for the gRPC ChannelCredentials if applicable.
  • Loading branch information
rockspore authored Feb 10, 2025
1 parent 316c425 commit 5080495
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final Boolean keepAliveWithoutCalls;
private final ChannelPoolSettings channelPoolSettings;
@Nullable private final Credentials credentials;
@Nullable private final CallCredentials altsCallCredentials;
@Nullable private final CallCredentials mtlsS2ACallCredentials;
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
Expand Down Expand Up @@ -191,6 +192,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.channelPoolSettings = builder.channelPoolSettings;
this.channelConfigurator = builder.channelConfigurator;
this.credentials = builder.credentials;
this.altsCallCredentials = builder.altsCallCredentials;
this.mtlsS2ACallCredentials = builder.mtlsS2ACallCredentials;
this.channelPrimer = builder.channelPrimer;
this.attemptDirectPath = builder.attemptDirectPath;
Expand Down Expand Up @@ -616,8 +618,14 @@ private ManagedChannel createSingleChannel() throws IOException {
boolean useDirectPathXds = false;
if (canUseDirectPath()) {
CallCredentials callCreds = MoreCallCredentials.from(credentials);
// altsCallCredentials may be null and GoogleDefaultChannelCredentials
// will solely use callCreds. Otherwise it uses altsCallCredentials
// for DirectPath connections and callCreds for CloudPath fallbacks.
ChannelCredentials channelCreds =
GoogleDefaultChannelCredentials.newBuilder().callCredentials(callCreds).build();
GoogleDefaultChannelCredentials.newBuilder()
.callCredentials(callCreds)
.altsCallCredentials(altsCallCredentials)
.build();
useDirectPathXds = isDirectPathXdsEnabled();
if (useDirectPathXds) {
// google-c2p: CloudToProd(C2P) Directpath. This scheme is defined in
Expand Down Expand Up @@ -822,6 +830,7 @@ public static final class Builder {
@Nullable private Boolean keepAliveWithoutCalls;
@Nullable private ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
@Nullable private Credentials credentials;
@Nullable private CallCredentials altsCallCredentials;
@Nullable private CallCredentials mtlsS2ACallCredentials;
@Nullable private ChannelPrimer channelPrimer;
private ChannelPoolSettings channelPoolSettings;
Expand Down Expand Up @@ -853,6 +862,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.keepAliveWithoutCalls = provider.keepAliveWithoutCalls;
this.channelConfigurator = provider.channelConfigurator;
this.credentials = provider.credentials;
this.altsCallCredentials = provider.altsCallCredentials;
this.mtlsS2ACallCredentials = provider.mtlsS2ACallCredentials;
this.channelPrimer = provider.channelPrimer;
this.channelPoolSettings = provider.channelPoolSettings;
Expand Down Expand Up @@ -919,6 +929,7 @@ Builder setUseS2A(boolean useS2A) {
this.useS2A = useS2A;
return this;
}

/*
* Sets the allowed hard bound token types for this TransportChannelProvider.
*
Expand Down Expand Up @@ -996,6 +1007,7 @@ public Integer getMaxInboundMetadataSize() {
public Builder setKeepAliveTime(org.threeten.bp.Duration duration) {
return setKeepAliveTimeDuration(toJavaTimeDuration(duration));
}

/** The time without read activity before sending a keepalive ping. */
public Builder setKeepAliveTimeDuration(java.time.Duration duration) {
this.keepAliveTime = duration;
Expand Down Expand Up @@ -1172,6 +1184,18 @@ boolean isMtlsS2AHardBoundTokensEnabled() {
.anyMatch(val -> val.equals(HardBoundTokenTypes.MTLS_S2A));
}

boolean isDirectPathBoundTokenEnabled() {
// If the list of allowed hard bound token types is empty or doesn't contain
// {@code HardBoundTokenTypes.ALTS}, the {@code credentials} are null or not of type
// {@code ComputeEngineCredentials} then DirectPath hard bound tokens should not be used.
// DirectPath hard bound tokens should only be used on ALTS channels.
if (allowedHardBoundTokenTypes.isEmpty()
|| this.credentials == null
|| !(credentials instanceof ComputeEngineCredentials)) return false;
return allowedHardBoundTokenTypes.stream()
.anyMatch(val -> val.equals(HardBoundTokenTypes.ALTS));
}

CallCredentials createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport googleAuthTransport,
ComputeEngineCredentials.BindingEnforcement bindingEnforcement) {
Expand All @@ -1194,6 +1218,11 @@ public InstantiatingGrpcChannelProvider build() {
ComputeEngineCredentials.GoogleAuthTransport.MTLS,
ComputeEngineCredentials.BindingEnforcement.ON);
}
if (isDirectPathBoundTokenEnabled()) {
this.altsCallCredentials =
createHardBoundTokensCallCredentials(
ComputeEngineCredentials.GoogleAuthTransport.ALTS, null);
}
InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
new InstantiatingGrpcChannelProvider(this);
instantiatingGrpcChannelProvider.removeApiKeyCredentialDuplicateHeaders();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import com.google.api.core.ApiFunction;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.HardBoundTokenTypes;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
Expand Down Expand Up @@ -735,6 +736,59 @@ public void canUseDirectPath_happyPath() throws IOException {
.setEndpoint(DEFAULT_ENDPOINT)
.setEnvProvider(envProvider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class));
Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isFalse();
InstantiatingGrpcChannelProvider provider =
new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016);
Truth.assertThat(provider.canUseDirectPath()).isTrue();

// verify this info is passed correctly to transport channel
TransportChannel transportChannel = provider.getTransportChannel();
Truth.assertThat(((GrpcTransportChannel) transportChannel).isDirectPath()).isTrue();
transportChannel.shutdownNow();
}

@Test
public void canUseDirectPath_boundTokenNotEnabledWithNonComputeCredentials() {
System.setProperty("os.name", "Linux");
Credentials credentials = Mockito.mock(Credentials.class);
EnvironmentProvider envProvider = Mockito.mock(EnvironmentProvider.class);
Mockito.when(
envProvider.getenv(
InstantiatingGrpcChannelProvider.DIRECT_PATH_ENV_DISABLE_DIRECT_PATH))
.thenReturn("false");
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setAttemptDirectPath(true)
.setAllowHardBoundTokenTypes(Collections.singletonList(HardBoundTokenTypes.ALTS))
.setCredentials(credentials)
.setEndpoint(DEFAULT_ENDPOINT)
.setEnvProvider(envProvider);
Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isFalse();
InstantiatingGrpcChannelProvider provider =
new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016);
Truth.assertThat(provider.canUseDirectPath()).isFalse();
}

@Test
public void canUseDirectPath_happyPathWithBoundToken() throws IOException {
System.setProperty("os.name", "Linux");
EnvironmentProvider envProvider = Mockito.mock(EnvironmentProvider.class);
Mockito.when(
envProvider.getenv(
InstantiatingGrpcChannelProvider.DIRECT_PATH_ENV_DISABLE_DIRECT_PATH))
.thenReturn("false");
// verify the credentials gets called and returns a non-null builder.
Mockito.when(computeEngineCredentials.toBuilder())
.thenReturn(ComputeEngineCredentials.newBuilder());
InstantiatingGrpcChannelProvider.Builder builder =
InstantiatingGrpcChannelProvider.newBuilder()
.setAttemptDirectPath(true)
.setCredentials(computeEngineCredentials)
.setAllowHardBoundTokenTypes(Collections.singletonList(HardBoundTokenTypes.ALTS))
.setEndpoint(DEFAULT_ENDPOINT)
.setEnvProvider(envProvider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class));
Truth.assertThat(builder.isDirectPathBoundTokenEnabled()).isTrue();
InstantiatingGrpcChannelProvider provider =
new InstantiatingGrpcChannelProvider(builder, GCE_PRODUCTION_NAME_AFTER_2016);
Truth.assertThat(provider.canUseDirectPath()).isTrue();
Expand Down

0 comments on commit 5080495

Please sign in to comment.