Skip to content

Commit

Permalink
Merge pull request #24 from data-integrations/classloader-fix
Browse files Browse the repository at this point in the history
[PLUGIN-1617] Fix python native mode
  • Loading branch information
saimukkamala authored Jul 11, 2023
2 parents 5ee4453 + 0cc2d04 commit 42f697c
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 9 deletions.
6 changes: 3 additions & 3 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
<surefire.redirectTestOutputToFile>true</surefire.redirectTestOutputToFile>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<!-- version properties -->
<cdap.version>6.1.0-SNAPSHOT</cdap.version>
<hydrator.version>2.3.0-SNAPSHOT</hydrator.version>
<cdap.version>6.10.0-SNAPSHOT</cdap.version>
<hydrator.version>2.12.0-SNAPSHOT</hydrator.version>
<junit.version>4.11</junit.version>
<jython.version>2.5.2</jython.version>
<py4j.version>0.10.8.1</py4j.version>
Expand Down Expand Up @@ -109,7 +109,7 @@
</dependency>
<dependency>
<groupId>io.cdap.cdap</groupId>
<artifactId>cdap-data-pipeline</artifactId>
<artifactId>cdap-data-pipeline3_2.12</artifactId>
<version>${cdap.version}</version>
<scope>test</scope>
</dependency>
Expand Down
13 changes: 12 additions & 1 deletion src/main/java/io/cdap/plugin/python/transform/KeyStores.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
package io.cdap.plugin.python.transform;

import org.apache.commons.lang.time.DateUtils;

import sun.security.x509.AlgorithmId;
import sun.security.x509.CertificateAlgorithmId;
import sun.security.x509.CertificateExtensions;
import sun.security.x509.CertificateIssuerName;
import sun.security.x509.CertificateSerialNumber;
import sun.security.x509.CertificateSubjectName;
import sun.security.x509.CertificateValidity;
import sun.security.x509.CertificateVersion;
import sun.security.x509.CertificateX509Key;
import sun.security.x509.GeneralName;
import sun.security.x509.GeneralNames;
import sun.security.x509.IPAddressName;
import sun.security.x509.SubjectAlternativeNameExtension;
import sun.security.x509.X500Name;
import sun.security.x509.X509CertImpl;
import sun.security.x509.X509CertInfo;
Expand Down Expand Up @@ -134,6 +140,10 @@ public static KeyStore generatedCertKeyStore(int validityDays, String password,
private static X509Certificate getCertificate(String dn, KeyPair pair, int days, String algorithm) throws IOException,
CertificateException, NoSuchProviderException, NoSuchAlgorithmException, InvalidKeyException, SignatureException {
// Calculate the validity interval of the certificate
GeneralNames generalNames = new GeneralNames();
generalNames.add(new GeneralName(new IPAddressName("127.0.0.1")));
CertificateExtensions ext = new CertificateExtensions();
ext.set(SubjectAlternativeNameExtension.NAME, new SubjectAlternativeNameExtension(generalNames));
Date from = new Date();
Date to = DateUtils.addDays(from, days);
CertificateValidity interval = new CertificateValidity(from, to);
Expand All @@ -143,8 +153,10 @@ private static X509Certificate getCertificate(String dn, KeyPair pair, int days,
X500Name owner = new X500Name(dn);
// Create an info objects with the provided information, which will be used to create the certificate
X509CertInfo info = new X509CertInfo();
info.set(X509CertInfo.VERSION, new CertificateVersion(CertificateVersion.V3));
info.set(X509CertInfo.VALIDITY, interval);
info.set(X509CertInfo.SERIAL_NUMBER, new CertificateSerialNumber(sn));
info.set(X509CertInfo.EXTENSIONS, ext);
// In java 7, subject is of type CertificateSubjectName and issuer is of type CertificateIssuerName.
// These were changed to X500Name in Java8. So looking at the field type before setting them.
// This certificate will be self signed, hence the subject and the issuer are same.
Expand All @@ -165,7 +177,6 @@ private static X509Certificate getCertificate(String dn, KeyPair pair, int days,
info.set(X509CertInfo.ISSUER, owner);
}
info.set(X509CertInfo.KEY, new CertificateX509Key(pair.getPublic()));
info.set(X509CertInfo.VERSION, new CertificateVersion(CertificateVersion.V3));
AlgorithmId algo = new AlgorithmId(AlgorithmId.sha1WithRSAEncryption_oid);
info.set(X509CertInfo.ALGORITHM_ID, new CertificateAlgorithmId(algo));
// Create the certificate and sign it with the private key
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.cdap.cdap.etl.api.Emitter;
import io.cdap.plugin.common.script.ScriptContext;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.awaitility.Awaitility;
import org.awaitility.core.ConditionTimeoutException;
import org.slf4j.Logger;
Expand All @@ -30,6 +31,7 @@

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.net.URL;
Expand Down Expand Up @@ -106,8 +108,10 @@ private KeyStore generatePemFileAndKeyStore(String transformTempDirString) throw

private File prepareTempFiles() throws IOException, UnrecoverableKeyException,
CertificateEncodingException, NoSuchAlgorithmException, KeyStoreException {
URL url = getClass().getResource("/pythonEvaluator.py");
String scriptText = new String(Files.readAllBytes(Paths.get(url.getPath())), StandardCharsets.UTF_8);
String scriptText;
try (InputStream url = getClass().getResourceAsStream("/pythonEvaluator.py")) {
scriptText = IOUtils.toString(url, StandardCharsets.UTF_8);
}
scriptText = scriptText.replaceAll(USER_CODE_PLACEHOLDER, config.getScript());

Path transformTempDirPath = Files.createTempDirectory("transform");
Expand Down Expand Up @@ -191,8 +195,18 @@ public void initialize(ScriptContext scriptContext) throws IOException,


Class[] entryClasses = new Class[]{Py4jTransport.class};
py4jTransport = (Py4jTransport) gatewayServer.getPythonServerEntryPoint(entryClasses);

// gatewayServer.getPythonServerEntryPoint function uses the current thread classloader (Executor classloader)
// to load classes instead of using Plugin classloader which causes classloading issues.
// To avoid this we are setting the current thread classloader to Plugin classloader before calling
// gatewayServer.getPythonServerEntryPoint function and revert it back to Executor classloader.
ClassLoader exectorClassLoader = Thread.currentThread().getContextClassLoader();
ClassLoader pluginClassloader = Py4jTransport.class.getClassLoader();
Thread.currentThread().setContextClassLoader(pluginClassloader);
try {
py4jTransport = (Py4jTransport) gatewayServer.getPythonServerEntryPoint(entryClasses);
} finally {
Thread.currentThread().setContextClassLoader(exectorClassLoader);
}
LOGGER.debug("Waiting for py4j gateway to start...");

try {
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/pythonEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Java:

# address must match cert, because we're checking hostnames
gateway_parameters = GatewayParameters(
address='localhost',
address='127.0.0.1',
ssl_context=client_ssl_context)

transform_transport = PythonTransformTransportImpl()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,27 @@
package io.cdap.plugin.python.transform;

import com.google.common.collect.ImmutableMap;
import com.google.common.io.Files;
import io.cdap.cdap.api.data.format.StructuredRecord;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -62,4 +75,30 @@ public void testImportThirdPartyLibrary() throws Exception {

Assert.assertEquals(new HashSet<>(outputRecords), new HashSet<>(INPUT_DEFAULT));
}

@Test
public void testSSLCertificateGeneration() throws UnrecoverableKeyException, CertificateException,
NoSuchAlgorithmException, KeyStoreException, IOException {
KeyStore ks = KeyStores.generatedCertKeyStore(10, "password");
File pemFile = temporaryFolder.newFile("selfsigned.pem");
KeyStores.generatePemFileFromKeyStore(ks, "password", pemFile);

List<String> certFile = Files.readLines(pemFile, StandardCharsets.UTF_8);
certFile.removeIf(String::isEmpty);
// should contain 6 lines
// -----BEGIN RSA PRIVATE KEY-----\n <encoded private key>\n -----END RSA PRIVATE KEY-----\n
// -----BEGIN CERTIFICATE-----\n <encoded public key>\n -----END CERTIFICATE-----
Assert.assertEquals(6, certFile.size());
byte [] decodedPublicKey = Base64.getDecoder().decode(certFile.get(4));
X509Certificate cert = (X509Certificate) CertificateFactory.getInstance("X.509")
.generateCertificate(new ByteArrayInputStream(decodedPublicKey));

Collection<List<?>> sans = cert.getSubjectAlternativeNames();
// Should contain only 1 san (ip: 127.0.0.1)
Assert.assertEquals(1, sans.size());
Integer sanType = (Integer) ((List<?>) sans.toArray()[0]).get(0);
Assert.assertEquals((Integer) 7, sanType); // Enum for IPAddress
String sanValue = (String) ((List<?>) sans.toArray()[0]).get(1);
Assert.assertEquals("127.0.0.1", sanValue);
}
}

0 comments on commit 42f697c

Please sign in to comment.