Skip to content

Commit

Permalink
feat: Support for CAS server certificate authority. (#2060)
Browse files Browse the repository at this point in the history
For Cloud SQL instances with CAS enabled, the connector will use the certificate chain of trust reported by the
SQL Admin API to validate the instance connection.
  • Loading branch information
hessjcg authored Sep 18, 2024
1 parent 5dc7e80 commit 4332ffc
Show file tree
Hide file tree
Showing 11 changed files with 236 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Socket connect(ConnectionConfig config, long timeoutMs) throws IOException {
SSLSocket socket = (SSLSocket) metadata.getSslContext().getSocketFactory().createSocket();
socket.setKeepAlive(true);
socket.setTcpNoDelay(true);

socket.connect(new InetSocketAddress(instanceIp, serverProxyPort));

try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -84,10 +84,15 @@ private void checkDatabaseCompatibility(
}

// Creates a Certificate object from a provided string.
private Certificate createCertificate(String cert) throws CertificateException {
private List<Certificate> parseCertificateChain(String cert) throws CertificateException {
byte[] certBytes = cert.getBytes(StandardCharsets.UTF_8);
ByteArrayInputStream certStream = new ByteArrayInputStream(certBytes);
return CertificateFactory.getInstance("X.509").generateCertificate(certStream);
List<Certificate> certificates = new ArrayList<>();
while (certStream.available() > 0) {
Certificate c = CertificateFactory.getInstance("X.509").generateCertificate(certStream);
certificates.add(c);
}
return certificates;
}

private String generatePublicKeyCert(KeyPair keyPair) {
Expand Down Expand Up @@ -296,18 +301,17 @@ private InstanceMetadata fetchMetadata(CloudSqlInstanceName instanceName, AuthTy
+ "IP address.",
instanceName.getConnectionName()));
}

// Update the Server CA certificate used to create the SSL connection with the instance.
try {
Certificate instanceCaCertificate =
createCertificate(instanceMetadata.getServerCaCert().getCert());
List<Certificate> instanceCaCertificates =
parseCertificateChain(instanceMetadata.getServerCaCert().getCert());

logger.debug(String.format("[%s] METADATA DONE", instanceName));

return new InstanceMetadata(
instanceName,
ipAddrs,
Collections.singletonList(instanceCaCertificate),
instanceCaCertificates,
"GOOGLE_MANAGED_CAS_CA".equals(instanceMetadata.getServerCaMode()),
instanceMetadata.getDnsName(),
pscEnabled);
Expand Down Expand Up @@ -371,7 +375,9 @@ private Certificate fetchEphemeralCertificate(
// Parse the certificate from the response.
Certificate ephemeralCertificate;
try {
ephemeralCertificate = createCertificate(response.getEphemeralCert().getCert());
// The response contains a single certificate. This uses the parseCertificateChain method
// to parse the response, and then uses the first, and only, certificate.
ephemeralCertificate = parseCertificateChain(response.getEphemeralCert().getCert()).get(0);
} catch (CertificateException ex) {
throw new RuntimeException(
String.format(
Expand Down Expand Up @@ -407,8 +413,7 @@ private SslData createSslData(
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
kmf.init(authKeyStore, new char[0]);

TrustManagerFactory tmf =
InstanceCheckingTrustManagerFactory.newInstance(instanceName, instanceMetadata);
TrustManagerFactory tmf = InstanceCheckingTrustManagerFactory.newInstance(instanceMetadata);

SSLContext sslContext;

Expand All @@ -428,7 +433,6 @@ private SslData createSslData(
sslContext = SSLContext.getInstance("TLSv1.2");
}
}

sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), new SecureRandom());

logger.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import javax.net.ssl.TrustManagerFactory;

Expand All @@ -39,32 +40,44 @@
* <p>class ConscryptWorkaroundTrustManager - the workaround for the Conscrypt bug.
*
* <p>class InstanceCheckingTrustManager - delegates TLS checks to the default provider and then
* checks that the Subject CN field contains the Cloud SQL instance ID.
* does custom hostname checking in accordance with these rules:
*
* <p>If the instance supports CAS certificates (instanceMetadata.casEnabled == true), or the
* connection is being made to a PSC endpoint (instanceMetadata.pscEnabled == true) the connector
* should validate that the server certificate subjectAlterantiveNames contains an entry that
* matches instanceMetadata.dnsName.
*
* <p>Otherwise, the connector should check that the Subject CN field contains the Cloud SQL
* instance ID in the form: "project-name:instance-name"
*/
class InstanceCheckingTrustManagerFactory extends TrustManagerFactory {

static InstanceCheckingTrustManagerFactory newInstance(
CloudSqlInstanceName instanceName, InstanceMetadata instanceMetadata)
static TrustManagerFactory newInstance(InstanceMetadata instanceMetadata)
throws NoSuchAlgorithmException, KeyStoreException, CertificateException, IOException {

TrustManagerFactory delegate = TrustManagerFactory.getInstance("X.509");
KeyStore trustedKeyStore = KeyStore.getInstance(KeyStore.getDefaultType());
trustedKeyStore.load(null, null);
trustedKeyStore.setCertificateEntry(
"instance", instanceMetadata.getInstanceCaCertificates().get(0));

InstanceCheckingTrustManagerFactory tmf =
new InstanceCheckingTrustManagerFactory(instanceName, delegate);
// Add all the certificates in the chain of trust to the trust keystore.
for (Certificate cert : instanceMetadata.getInstanceCaCertificates()) {
trustedKeyStore.setCertificateEntry("ca" + cert.hashCode(), cert);
}

// Use a custom trust manager factory that checks the CN against the instance name
// The delegate TrustManagerFactory will check the certificate chain, but will not do
// hostname checking.
InstanceCheckingTrustManagerFactory tmf =
new InstanceCheckingTrustManagerFactory(instanceMetadata, delegate);
tmf.init(trustedKeyStore);

return tmf;
}

private InstanceCheckingTrustManagerFactory(
CloudSqlInstanceName instanceName, TrustManagerFactory delegate) {
InstanceMetadata instanceMetadata, TrustManagerFactory delegate) {
super(
new InstanceCheckingTrustManagerFactorySpi(instanceName, delegate),
new InstanceCheckingTrustManagerFactorySpi(instanceMetadata, delegate),
delegate.getProvider(),
delegate.getAlgorithm());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
*/
class InstanceCheckingTrustManagerFactorySpi extends TrustManagerFactorySpi {
private final TrustManagerFactory delegate;
private final CloudSqlInstanceName instanceName;
private final InstanceMetadata instanceMetadata;

InstanceCheckingTrustManagerFactorySpi(
CloudSqlInstanceName instanceName, TrustManagerFactory delegate) {
this.instanceName = instanceName;
InstanceMetadata instanceMetadata, TrustManagerFactory delegate) {
this.instanceMetadata = instanceMetadata;
this.delegate = delegate;
}

Expand Down Expand Up @@ -65,7 +65,7 @@ protected TrustManager[] engineGetTrustManagers() {
tm = new ConscryptWorkaroundDelegatingTrustManger(tm);
}

delegates[i] = new InstanceCheckingTrustManger(instanceName, tm);
delegates[i] = new InstanceCheckingTrustManger(instanceMetadata, tm);
} else {
delegates[i] = tms[i];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
import java.net.Socket;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import javax.naming.InvalidNameException;
import javax.naming.ldap.LdapName;
import javax.naming.ldap.Rdn;
Expand All @@ -37,11 +40,11 @@
*/
class InstanceCheckingTrustManger extends X509ExtendedTrustManager {
private final X509ExtendedTrustManager tm;
private final CloudSqlInstanceName instanceName;
private final InstanceMetadata instanceMetadata;

public InstanceCheckingTrustManger(
CloudSqlInstanceName instanceName, X509ExtendedTrustManager tm) {
this.instanceName = instanceName;
InstanceMetadata instanceMetadata, X509ExtendedTrustManager tm) {
this.instanceMetadata = instanceMetadata;
this.tm = tm;
}

Expand Down Expand Up @@ -92,6 +95,66 @@ private void checkCertificateChain(X509Certificate[] chain) throws CertificateEx
throw new CertificateException("Subject is missing");
}

if (instanceMetadata.isCasManagedCertificate() || instanceMetadata.isPscEnabled()) {
checkSan(chain);
} else {
checkCn(chain);
}
}

private void checkSan(X509Certificate[] chain) throws CertificateException {
List<String> sans = getSans(chain[0]);
String dns = instanceMetadata.getDnsName();
if (dns == null || dns.isEmpty()) {
throw new CertificateException(
"Instance metadata for " + instanceMetadata.getInstanceName() + " has an empty dnsName");
}
for (String san : sans) {
if (san.equalsIgnoreCase(dns)) {
return;
}
}
throw new CertificateException(
"Server certificate does not contain expected name '"
+ instanceMetadata.getDnsName()
+ "' for Cloud SQL instance "
+ instanceMetadata.getInstanceName());
}

private List<String> getSans(X509Certificate cert) throws CertificateException {
ArrayList<String> names = new ArrayList<>();

Collection<List<?>> sanAsn1Field = cert.getSubjectAlternativeNames();
if (sanAsn1Field == null) {
return names;
}

for (List item : sanAsn1Field) {
Integer type = (Integer) item.get(0);
// RFC 5280 section 4.2.1.6. "Subject Alternative Name"
// describes the structure of subjectAlternativeName record.
// type == 0 means this contains an "otherName"
// type == 2 means this contains a "dNSName"
if (type == 0 || type == 2) {
Object value = item.get(1);
if (value instanceof byte[]) {
// This would only happen if the customer provided a non-standard JSSE encryption
// provider. The standard JSSE providers all return a list of Strings for the SAN.
// To handle this case, the project would need to add the BouncyCastle crypto library
// as a dependency, and follow the example to decode an ASN1 SAN data structure:
// https://stackoverflow.com/questions/30993879/retrieve-subject-alternative-names-of-x-509-certificate-in-java
throw new UnsupportedOperationException(
"Server certificate SAN field cannot be decoded.");
} else if (value instanceof String) {
names.add((String) value);
}
}
}
return names;
}

private void checkCn(X509Certificate[] chain) throws CertificateException {

String cn = null;

try {
Expand All @@ -111,7 +174,10 @@ private void checkCertificateChain(X509Certificate[] chain) throws CertificateEx
}

// parse CN from subject. CN always comes last in the list.
String instName = this.instanceName.getProjectId() + ":" + this.instanceName.getInstanceId();
String instName =
this.instanceMetadata.getInstanceName().getProjectId()
+ ":"
+ this.instanceMetadata.getInstanceName().getInstanceId();
if (!instName.equals(cn)) {
throw new CertificateException(
"Server certificate CN does not match instance name. Server certificate CN="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,30 @@ public void setup() throws GeneralSecurityException {
}

MockHttpTransport fakeSuccessHttpTransport(Duration certDuration) {
return fakeSuccessHttpTransport(TestKeys.getServerCertPem(), certDuration, null);
return fakeSuccessHttpTransport(TestKeys.getServerCertPem(), certDuration, null, false, false);
}

MockHttpTransport fakeSuccessHttpTransport(Duration certDuration, String baseUrl) {
return fakeSuccessHttpTransport(TestKeys.getServerCertPem(), certDuration, baseUrl);
return fakeSuccessHttpTransport(
TestKeys.getServerCertPem(), certDuration, baseUrl, false, false);
}

MockHttpTransport fakeSuccessHttpTransport(String serverCert, Duration certDuration) {
return fakeSuccessHttpTransport(serverCert, certDuration, null);
return fakeSuccessHttpTransport(serverCert, certDuration, null, false, false);
}

MockHttpTransport fakeSuccessHttpCasTransport(Duration certDuration) {
return fakeSuccessHttpTransport(
TestKeys.getCasServerCertChainPem(), certDuration, null, true, false);
}

MockHttpTransport fakeSuccessHttpPscCasTransport(Duration certDuration) {
return fakeSuccessHttpTransport(
TestKeys.getCasServerCertChainPem(), certDuration, null, true, true);
}

MockHttpTransport fakeSuccessHttpTransport(
String serverCert, Duration certDuration, String baseUrl) {
String serverCert, Duration certDuration, String baseUrl, boolean cas, boolean psc) {
final JsonFactory jsonFactory = new GsonFactory();
return new MockHttpTransport() {
@Override
Expand All @@ -167,7 +178,10 @@ public LowLevelHttpResponse execute() throws IOException {
new IpMapping().setIpAddress(PRIVATE_IP).setType("PRIVATE")))
.setServerCaCert(new SslCert().setCert(serverCert))
.setDatabaseVersion("POSTGRES14")
.setRegion("myRegion");
.setRegion("myRegion")
.setPscEnabled(psc ? Boolean.TRUE : null)
.setDnsName(cas || psc ? "db.example.com" : null)
.setServerCaMode(cas ? "GOOGLE_MANAGED_CAS_CA" : null);
settings.setFactory(jsonFactory);
response
.setContent(settings.toPrettyString())
Expand Down
35 changes: 35 additions & 0 deletions core/src/test/java/com/google/cloud/sql/core/ConnectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import java.net.Socket;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
Expand Down Expand Up @@ -147,6 +149,39 @@ public void create_successfulPublicConnection() throws IOException, InterruptedE
assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

@Test
public void create_successfulPublicCasConnection() throws IOException, InterruptedException {
PrivateKey privateKey = TestKeys.getServerKeyPair().getPrivate();
X509Certificate[] cert = TestKeys.getCasServerCertChain();

FakeSslServer sslServer = new FakeSslServer(privateKey, cert);
ConnectionConfig config =
new ConnectionConfig.Builder()
.withCloudSqlInstance("myProject:myRegion:myInstance")
.withIpTypes("PRIMARY")
.build();

int port = sslServer.start(PUBLIC_IP);

ConnectionInfoRepositoryFactory factory =
new StubConnectionInfoRepositoryFactory(fakeSuccessHttpCasTransport(Duration.ZERO));

Connector connector =
new Connector(
config.getConnectorConfig(),
factory,
stubCredentialFactoryProvider.getInstanceCredentialFactory(config.getConnectorConfig()),
defaultExecutor,
clientKeyPair,
10,
TEST_MAX_REFRESH_MS,
port);

Socket socket = connector.connect(config, TEST_MAX_REFRESH_MS);

assertThat(readLine(socket)).isEqualTo(SERVER_MESSAGE);
}

private boolean isWindows() {
String os = System.getProperty("os.name").toLowerCase();
return os.contains("win");
Expand Down
Loading

0 comments on commit 4332ffc

Please sign in to comment.