Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oauth config override #63

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ The general philosophy is for the configuration to drive Polyglot's behavior and

Polyglot uses statically linked [boringssl](https://boringssl.googlesource.com/boringssl/) libraries under the hood and doesn't require the host machine to have any specific libraries. Whether or not the client uses TLS to talk to the server can be controlled using the `--use_tls` flag or the corresponding configuration entry.

Polyglot can also do client certificate authentication with the `--tls_client_cert_path` and `--tls_client_key_path` flags. If the hostname on the server does not match the endpoint (e.g. connecting
to `localhost`, but the server thinks it's `foo.example.com`), `--tls_client_override_authority=foo.example.com` can be used.

### Authenticating requests using OAuth

Polyglot has built-in support for authentication of requests using OAuth tokens in two ways:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public static void callEndpoint(
dynamicClient = DynamicGrpcClient.create(methodDescriptor, hostAndPort, callConfig);
}

logger.info("Reading input from stdin");

ImmutableList<DynamicMessage> requestMessages =
MessageReader.forStdin(methodDescriptor.getInputType()).read();
StreamObserver<DynamicMessage> streamObserver =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@

import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Optional;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

/** Provides easy access to the arguments passed on the command line. */
public class CommandLineArgs {
@Option(name = "--full_method", metaVar = "<some.package.Service/doSomething>")
Expand Down Expand Up @@ -47,6 +49,30 @@ public class CommandLineArgs {
@Option(name = "--tls_ca_cert_path", metaVar = "<path>")
private String tlsCaCertPath;

@Option(name = "--tls_client_cert_path", metaVar = "<path>")
private String tlsClientCertPath;

@Option(name = "--tls_client_key_path", metaVar = "<path>")
private String tlsClientKeyPath;

@Option(name = "--tls_client_override_authority", metaVar = "<host>")
private String tlsClientOverrideAuthority;

@Option(name = "--oauth_refresh_token_endpoint_url", metaVar = "<url>")
private String oauthRefreshTokenEndpointUrl;

@Option(name = "--oauth_client_id", metaVar = "<client-id>")
private String oauthClientId;

@Option(name = "--oauth_client_secret", metaVar = "<client-secret>")
private String oauthClientSecret;

@Option(name = "--oauth_refresh_token_path", metaVar = "<path>")
private String oauthRefreshTokenPath;

@Option(name = "--oauth_access_token_path", metaVar = "<path>")
private String oauthAccessTokenPath;

@Option(name = "--help")
private Boolean help;

Expand All @@ -65,23 +91,23 @@ public class CommandLineArgs {

// TODO: Move to a "list_services"-specific flag container
@Option(
name = "--service_filter",
metaVar = "service_name",
usage="Filters service names containing this string e.g. --service_filter TestService")
name = "--service_filter",
metaVar = "service_name",
usage="Filters service names containing this string e.g. --service_filter TestService")
private String serviceFilterArg;

// TODO: Move to a "list_services"-specific flag container
@Option(
name = "--method_filter",
metaVar = "method_name",
usage="Filters service methods to those containing this string e.g. --method_name List")
name = "--method_filter",
metaVar = "method_name",
usage="Filters service methods to those containing this string e.g. --method_name List")
private String methodFilterArg;

//TODO: Move to a "list_services"-specific flag container
@Option(
name = "--with_message",
metaVar = "true|false",
usage="If true, then the message specification for the method is rendered")
name = "--with_message",
metaVar = "true|false",
usage="If true, then the message specification for the method is rendered")
private String withMessageArg;

// *************************************************************************
Expand Down Expand Up @@ -153,6 +179,38 @@ public Optional<Path> tlsCaCertPath() {
return maybePath(tlsCaCertPath);
}

public Optional<Path> tlsClientCertPath() {
return maybePath(tlsClientCertPath);
}

public Optional<Path> tlsClientKeyPath() {
return maybePath(tlsClientKeyPath);
}

public Optional<String> tlsClientOverrideAuthority() {
return Optional.ofNullable(tlsClientOverrideAuthority);
}

public Optional<URL> oauthRefreshTokenEndpointUrl() {
return maybeUrl(oauthRefreshTokenEndpointUrl);
}

public Optional<String> oauthClientId() {
return Optional.ofNullable(oauthClientId);
}

public Optional<String> oauthClientSecret() {
return Optional.ofNullable(oauthClientSecret);
}

public Optional<Path> oauthRefreshTokenPath() {
return maybePath(oauthRefreshTokenPath);
}

public Optional<Path> oauthAccessTokenPath() {
return maybePath(oauthAccessTokenPath);
}

/**
* First stage of a migration towards a "command"-based instantiation of polyglot.
* Supported commands:
Expand Down Expand Up @@ -214,4 +272,17 @@ private static Optional<Path> maybePath(String rawPath) {
Preconditions.checkArgument(Files.exists(path), "File " + rawPath + " does not exist");
return Optional.of(Paths.get(rawPath));
}

private static Optional<URL> maybeUrl(String rawUrl) {
if (rawUrl == null) {
return Optional.empty();
}
try {
URL url = new URL(rawUrl);
return Optional.of(url);
} catch (MalformedURLException e) {
throw new IllegalArgumentException("URL " + rawUrl + " is invalid", e);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ private Configuration getDefaultConfigurationInternal() {
private Configuration getNamedConfigurationInternal(String name) {
Preconditions.checkState(!isEmptyConfig(), "Cannot load named config with a config set");
return configSet.get().getConfigurationsList().stream()
.filter(config -> config.getName().equals(name))
.findAny()
.orElseThrow(() -> new IllegalArgumentException("Could not find named config: " + name));
.filter(config -> config.getName().equals(name))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation still off

.findAny()
.orElseThrow(() -> new IllegalArgumentException("Could not find named config: " + name));
}

/** Returns the {@link Configuration} with overrides, if any, applied to it. */
Expand All @@ -111,7 +111,7 @@ private Configuration applyOverrides(Configuration configuration) {
if (overrides.get().outputFilePath().isPresent()) {
resultBuilder.getOutputConfigBuilder().setDestination(Destination.FILE);
resultBuilder.getOutputConfigBuilder().setFilePath(
overrides.get().outputFilePath().get().toString());
overrides.get().outputFilePath().get().toString());
}
if (!overrides.get().additionalProtocIncludes().isEmpty()) {
List<String> additionalIncludes = new ArrayList<>();
Expand All @@ -122,14 +122,49 @@ private Configuration applyOverrides(Configuration configuration) {
}
if (overrides.get().protoDiscoveryRoot().isPresent()) {
resultBuilder.getProtoConfigBuilder().setProtoDiscoveryRoot(
overrides.get().protoDiscoveryRoot().get().toString());
overrides.get().protoDiscoveryRoot().get().toString());
}
if (overrides.get().getRpcDeadlineMs().isPresent()) {
resultBuilder.getCallConfigBuilder().setDeadlineMs(overrides.get().getRpcDeadlineMs().get());
}
if (overrides.get().tlsCaCertPath().isPresent()) {
resultBuilder.getCallConfigBuilder().setTlsCaCertPath(
overrides.get().tlsCaCertPath().get().toString());
overrides.get().tlsCaCertPath().get().toString());
}
if (overrides.get().tlsClientCertPath().isPresent()) {
resultBuilder.getCallConfigBuilder().setTlsClientCertPath(
overrides.get().tlsClientCertPath().get().toString());
}
if (overrides.get().tlsClientKeyPath().isPresent()) {
resultBuilder.getCallConfigBuilder().setTlsClientKeyPath(
overrides.get().tlsClientKeyPath().get().toString());
}
if (overrides.get().tlsClientOverrideAuthority().isPresent()) {
resultBuilder.getCallConfigBuilder().setTlsClientOverrideAuthority(
overrides.get().tlsClientOverrideAuthority().get());
}
if (overrides.get().oauthRefreshTokenEndpointUrl().isPresent()) {
resultBuilder.getCallConfigBuilder().getOauthConfigBuilder().getRefreshTokenCredentialsBuilder()
.setTokenEndpointUrl(overrides.get().oauthRefreshTokenEndpointUrl().get().toString());
}
if (overrides.get().oauthClientId().isPresent()) {
resultBuilder.getCallConfigBuilder().getOauthConfigBuilder().getRefreshTokenCredentialsBuilder()
.getClientBuilder().setId(overrides.get().oauthClientId().get());
}
if (overrides.get().oauthClientSecret().isPresent()) {
resultBuilder.getCallConfigBuilder().getOauthConfigBuilder().getRefreshTokenCredentialsBuilder()
.getClientBuilder().setSecret(overrides.get().oauthClientSecret().get());
}
if (overrides.get().oauthRefreshTokenPath().isPresent()) {
resultBuilder.getCallConfigBuilder().getOauthConfigBuilder().getRefreshTokenCredentialsBuilder()
.setRefreshTokenPath(overrides.get().oauthRefreshTokenPath().get().toString());
}
// Note the ordering of setting these fields is important. Oauth configuration has a oneof field, corresponding
// to access or refresh tokens. We want access tokens to take precedence, setting this field last will ensure this
// occurs. See https://developers.google.com/protocol-buffers/docs/proto#oneof
if (overrides.get().oauthAccessTokenPath().isPresent()) {
resultBuilder.getCallConfigBuilder().getOauthConfigBuilder().getAccessTokenCredentialsBuilder()
.setAccessTokenPath(overrides.get().oauthAccessTokenPath().get().toString());
}
return resultBuilder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;

import javax.net.ssl.SSLException;

import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
Expand All @@ -21,6 +19,10 @@
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.DynamicMessage;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
Expand All @@ -34,9 +36,8 @@
import io.grpc.stub.StreamObserver;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import javax.net.ssl.SSLException;
import me.dinowernli.grpc.polyglot.protobuf.DynamicMessageMarshaller;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import polyglot.ConfigProto.CallConfiguration;

/** A grpc client which operates on dynamic messages. */
Expand Down Expand Up @@ -232,17 +233,28 @@ private static Channel createChannel(HostAndPort endpoint, CallConfiguration cal
if (!callConfiguration.getUseTls()) {
return createPlaintextChannel(endpoint);
}
return NettyChannelBuilder.forAddress(endpoint.getHostText(), endpoint.getPort())
.sslContext(createSslContext(callConfiguration))
.negotiationType(NegotiationType.TLS)
.build();
NettyChannelBuilder nettyChannelBuilder =
NettyChannelBuilder.forAddress(endpoint.getHostText(), endpoint.getPort())
.sslContext(createSslContext(callConfiguration))
.negotiationType(NegotiationType.TLS);

if (!callConfiguration.getTlsClientOverrideAuthority().isEmpty()) {
nettyChannelBuilder.overrideAuthority(callConfiguration.getTlsClientOverrideAuthority());
}

return nettyChannelBuilder.build();
}

private static SslContext createSslContext(CallConfiguration callConfiguration) {
SslContextBuilder resultBuilder = GrpcSslContexts.forClient();
if (!callConfiguration.getTlsCaCertPath().isEmpty()) {
resultBuilder.trustManager(loadFile(callConfiguration.getTlsCaCertPath()));
}
if (!callConfiguration.getTlsClientCertPath().isEmpty()) {
resultBuilder.keyManager(
loadFile(callConfiguration.getTlsClientCertPath()),
loadFile(callConfiguration.getTlsClientKeyPath()));
}
try {
return resultBuilder.build();
} catch (SSLException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ public AccessToken refreshAccessToken() throws IOException {
logger.info("Refresh successful, got access token");
return new AccessToken(
refreshResponse.getAccessToken(),
computeExpirtyDate(refreshResponse.getExpiresInSeconds()));
computeExpiryDate(refreshResponse.getExpiresInSeconds()));
}

private Date computeExpirtyDate(long expiresInSeconds) {
private Date computeExpiryDate(long expiresInSeconds) {
long expiresInSecondsWithMargin = (long) (expiresInSeconds * ACCESS_TOKEN_EXPIRY_MARGIN);
return Date.from(clock.instant().plusSeconds(expiresInSecondsWithMargin));
}
Expand Down
20 changes: 16 additions & 4 deletions src/main/java/me/dinowernli/grpc/polyglot/testing/TestServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import java.util.Random;

import com.google.common.base.Throwables;

import io.grpc.Server;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyServerBuilder;
import io.netty.handler.ssl.ClientAuth;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import polyglot.test.TestProto.TestResponse;
import polyglot.test.TestServiceGrpc.TestServiceImplBase;
Expand Down Expand Up @@ -101,13 +104,22 @@ public void blockingShutdown() {

/** An {@link SslContext} for use in unit test servers. Loads our testing certificates. */
public static SslContext serverSslContextForTesting() throws IOException {
return GrpcSslContexts
.forServer(TestUtils.loadServerChainCert(), TestUtils.loadServerKey())
.trustManager(TestUtils.loadRootCaCert())
.sslProvider(SslProvider.OPENSSL)
return getSslContextBuilder().build();
}

/** An {@link SslContext} for use in unit test servers with client certs. Loads our testing certificates. */
public static SslContext serverSslContextWithClientCertsForTesting() throws IOException {
return getSslContextBuilder()
.clientAuth(ClientAuth.REQUIRE)
.build();
}

private static SslContextBuilder getSslContextBuilder() {
return GrpcSslContexts.forServer(TestUtils.loadServerChainCert(), TestUtils.loadServerKey())
.trustManager(TestUtils.loadRootCaCert())
.sslProvider(SslProvider.OPENSSL);
}

/** Starts a grpc server on the given port, throws {@link IOException} on failure. */
private static Server tryStartServer(
int port,
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/me/dinowernli/grpc/polyglot/testing/TestUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,16 @@ public static File loadRootCaCert() {
return Paths.get(TESTING_CERTS_DIR.toString(), "ca.pem").toFile();
}

/** Returns a file containing a client certificate for use in tests. */
public static File loadClientCert() {
return Paths.get(TESTING_CERTS_DIR.toString(), "client.pem").toFile();
}

/** Returns a file containing a client key for use in tests. */
public static File loadClientKey() {
return Paths.get(TESTING_CERTS_DIR.toString(), "client.key").toFile();
}

/** Returns a file containing a certificate chain from our testing root CA to our server. */
public static File loadServerChainCert() {
return Paths.get(TESTING_CERTS_DIR.toString(), "server.pem").toFile();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
-----BEGIN CERTIFICATE-----
MIICZjCCAc+gAwIBAgIJAOsqHrpa5cF9MA0GCSqGSIb3DQEBCwUAMFkxCzAJBgNV
MIICZjCCAc+gAwIBAgIJAJqx8JYIArI0MA0GCSqGSIb3DQEBCwUAMFkxCzAJBgNV
BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX
aWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNjA1MTUwOTIz
MDZaFw0yNjA1MTMwOTIzMDZaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21l
aWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xNzA4MzEyMzM3
MDNaFw0yNzA4MjkyMzM3MDNaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21l
LVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNV
BAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA9T39HWNC
0r9NjD7wmFF6luJJ+NuWG0tuuZoGpbNldQpsZXlS0J/OwNAk+55p6it2Yr89jxM9
Ea83oYTnjLuGQ/tJmUmPNau2Z4Q/M41000lD6Hd0Sxw7St2nLlgTOMRyEJEAaBBC
yKtHiq6cvu3UmNzY+jok5hmRjGlWHnNsWisCAwEAAaM2MDQwCQYDVR0TBAIwADAL
BAMMCWxvY2FsaG9zdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAzroPZgt6
jX0qwmuo5Y04J8eATK2Lq/ohFfI+LoaqXJ7nOrHViPLpQfqErtHCrcuOD++BcTJ4
0iw6d0uwxTT0L73cfKI+1zJKcI6jAJ1a86kpYJVqNY5mIDqVXSP2/Ig6Q7I3qDOZ
RM/sNIypd56bQy8GmQYcL1ng1zqy1kBn4jsCAwEAAaM2MDQwCQYDVR0TBAIwADAL
BgNVHQ8EBAMCBeAwGgYDVR0RBBMwEYIJbG9jYWxob3N0hwR/AAABMA0GCSqGSIb3
DQEBCwUAA4GBAJg2NDTZZB9Kl9mFgjIsL/M4dz/wspsGhwuglpOSwarFkKvSYkxD
61Ls4rp4qT5vEt0EJjksTsxNVdzR9DD0k+LENuEzM+VlzPaKoKrrZRZeiLYnfY28
etxVuVVW78jd03rx+FpVOql+lKT1hnWn40IVLjLdT60shHfVt34Z6t98
DQEBCwUAA4GBAAPiapcOh3JBfC6f7dkUAP9KDpNHK8JA1My4+CxkRyShC0rKf+K6
3wRLLb6f9qyvs3FkSF5uTcD12Irj88SzlMPiu/civVv4ldY/5w1XKmh3BoWwe+cH
jaqPi0MX/uarPCgbgkt219INsBi/Sc8V8Yp1qjZp+pvJ0A80/XdD56/x
-----END CERTIFICATE-----
Loading