-
Notifications
You must be signed in to change notification settings - Fork 65
/
Copy pathbridge_utils.py
78 lines (63 loc) · 3.15 KB
/
bridge_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
from py4j.java_gateway import JavaGateway, GatewayClient
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
# use jvm gateway to create a java class instance by full-qualified class name
def _getjvm_class(gateway, fullClassName):
return gateway.jvm.java.lang.Thread.currentThread().getContextClassLoader().loadClass(fullClassName).newInstance()
def UDFRegister(sqlContext, clazz):
# Registers all SQL user-defined functions. Calls Scala for registration.
sc = sqlContext._sc
gateway = sc._gateway
ureg = gateway.jvm.datafu.spark.UDFRegister
ureg.registerObject(clazz, sqlContext._jsqlContext)
class Context(object):
def __init__(self):
from py4j.java_gateway import java_import
"""When running a Python script from Scala - this function is called
by the script to initialize the connection to the Java Gateway and get the spark context.
code is basically copied from:
https://github.com/apache/zeppelin/blob/master/spark/interpreter/src/main/resources/python/zeppelin_pyspark.py#L30
"""
if os.environ.get("SPARK_EXECUTOR_URI"):
SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])
gateway = JavaGateway(GatewayClient(port=int(os.environ.get("PYSPARK_GATEWAY_PORT"))), auto_convert=True)
java_import(gateway.jvm, "org.apache.spark.SparkEnv")
java_import(gateway.jvm, "org.apache.spark.SparkConf")
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
java_import(gateway.jvm, "org.apache.spark.sql.*")
java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
intp = gateway.entry_point
jSparkSession = intp.pyGetSparkSession()
jsc = intp.pyGetJSparkContext(jSparkSession)
jconf = intp.pyGetSparkConf(jsc)
conf = SparkConf(_jvm = gateway.jvm, _jconf = jconf)
self.sc = SparkContext(jsc=jsc, gateway=gateway, conf=conf)
# Spark 2
self.sparkSession = SparkSession(self.sc, jSparkSession)
self.sqlContext = self.sparkSession._wrapped
ctx = None
def get_contexts():
global ctx
if not ctx:
ctx = Context()
return ctx.sc, ctx.sqlContext, ctx.sparkSession