diff --git a/flyteidl/gen/pb-es/flyteidl/plugins/spark_pb.ts b/flyteidl/gen/pb-es/flyteidl/plugins/spark_pb.ts index 20a5463cbb..02e152bc5d 100644 --- a/flyteidl/gen/pb-es/flyteidl/plugins/spark_pb.ts +++ b/flyteidl/gen/pb-es/flyteidl/plugins/spark_pb.ts @@ -5,6 +5,7 @@ import type { BinaryReadOptions, FieldList, JsonReadOptions, JsonValue, PartialMessage, PlainMessage } from "@bufbuild/protobuf"; import { Message, proto3, Struct } from "@bufbuild/protobuf"; +import { K8sPod } from "../core/tasks_pb.js"; /** * @generated from message flyteidl.plugins.SparkApplication @@ -131,6 +132,20 @@ export class SparkJob extends Message { */ databricksInstance = ""; + /** + * Pod Spec for the Spark driver pod + * + * @generated from field: flyteidl.core.K8sPod driverPod = 10; + */ + driverPod?: K8sPod; + + /** + * Pod Spec for the Spark executor pod + * + * @generated from field: flyteidl.core.K8sPod executorPod = 11; + */ + executorPod?: K8sPod; + constructor(data?: PartialMessage) { super(); proto3.util.initPartial(data, this); @@ -148,6 +163,8 @@ export class SparkJob extends Message { { no: 7, name: "databricksConf", kind: "message", T: Struct }, { no: 8, name: "databricksToken", kind: "scalar", T: 9 /* ScalarType.STRING */ }, { no: 9, name: "databricksInstance", kind: "scalar", T: 9 /* ScalarType.STRING */ }, + { no: 10, name: "driverPod", kind: "message", T: K8sPod }, + { no: 11, name: "executorPod", kind: "message", T: K8sPod }, ]); static fromBinary(bytes: Uint8Array, options?: Partial): SparkJob { diff --git a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go index 316d1f1d7b..e610ff5f55 100644 --- a/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go +++ b/flyteidl/gen/pb-go/flyteidl/plugins/spark.pb.go @@ -7,6 +7,7 @@ package plugins import ( + core "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" structpb "google.golang.org/protobuf/types/known/structpb" @@ -132,6 +133,10 @@ type SparkJob struct { // Domain name of your deployment. Use the form .cloud.databricks.com. // This instance name can be set in either flytepropeller or flytekit. DatabricksInstance string `protobuf:"bytes,9,opt,name=databricksInstance,proto3" json:"databricksInstance,omitempty"` + // Pod Spec for the Spark driver pod + DriverPod *core.K8SPod `protobuf:"bytes,10,opt,name=driverPod,proto3" json:"driverPod,omitempty"` + // Pod Spec for the Spark executor pod + ExecutorPod *core.K8SPod `protobuf:"bytes,11,opt,name=executorPod,proto3" json:"executorPod,omitempty"` } func (x *SparkJob) Reset() { @@ -229,6 +234,20 @@ func (x *SparkJob) GetDatabricksInstance() string { return "" } +func (x *SparkJob) GetDriverPod() *core.K8SPod { + if x != nil { + return x.DriverPod + } + return nil +} + +func (x *SparkJob) GetExecutorPod() *core.K8SPod { + if x != nil { + return x.ExecutorPod + } + return nil +} + var File_flyteidl_plugins_spark_proto protoreflect.FileDescriptor var file_flyteidl_plugins_spark_proto_rawDesc = []byte{ @@ -236,64 +255,73 @@ var file_flyteidl_plugins_spark_proto_rawDesc = []byte{ 0x6e, 0x73, 0x2f, 0x73, 0x70, 0x61, 0x72, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x10, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x1a, 0x1c, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, - 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x42, - 0x0a, 0x10, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x22, 0x2e, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x50, 0x59, - 0x54, 0x48, 0x4f, 0x4e, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x4a, 0x41, 0x56, 0x41, 0x10, 0x01, - 0x12, 0x09, 0x0a, 0x05, 0x53, 0x43, 0x41, 0x4c, 0x41, 0x10, 0x02, 0x12, 0x05, 0x0a, 0x01, 0x52, - 0x10, 0x03, 0x22, 0xfe, 0x04, 0x0a, 0x08, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x4a, 0x6f, 0x62, 0x12, - 0x51, 0x0a, 0x0f, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, - 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x27, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, - 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x53, 0x70, 0x61, 0x72, - 0x6b, 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x54, 0x79, 0x70, - 0x65, 0x52, 0x0f, 0x61, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, - 0x70, 0x65, 0x12, 0x30, 0x0a, 0x13, 0x6d, 0x61, 0x69, 0x6e, 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, - 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x66, 0x2f, 0x73, 0x74, 0x72, 0x75, 0x63, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x19, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x63, 0x6f, 0x72, 0x65, 0x2f, 0x74, 0x61, + 0x73, 0x6b, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x42, 0x0a, 0x10, 0x53, 0x70, 0x61, + 0x72, 0x6b, 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x2e, 0x0a, + 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x50, 0x59, 0x54, 0x48, 0x4f, 0x4e, 0x10, + 0x00, 0x12, 0x08, 0x0a, 0x04, 0x4a, 0x41, 0x56, 0x41, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x53, + 0x43, 0x41, 0x4c, 0x41, 0x10, 0x02, 0x12, 0x05, 0x0a, 0x01, 0x52, 0x10, 0x03, 0x22, 0xec, 0x05, + 0x0a, 0x08, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x4a, 0x6f, 0x62, 0x12, 0x51, 0x0a, 0x0f, 0x61, 0x70, + 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x27, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, + 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x41, 0x70, 0x70, 0x6c, + 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x54, 0x79, 0x70, 0x65, 0x52, 0x0f, 0x61, 0x70, + 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x54, 0x79, 0x70, 0x65, 0x12, 0x30, 0x0a, 0x13, 0x6d, 0x61, 0x69, 0x6e, 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x69, 0x6c, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6c, 0x61, 0x73, - 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6c, 0x61, - 0x73, 0x73, 0x12, 0x47, 0x0a, 0x09, 0x73, 0x70, 0x61, 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x18, - 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x29, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x4a, 0x6f, - 0x62, 0x2e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, - 0x52, 0x09, 0x73, 0x70, 0x61, 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x12, 0x4a, 0x0a, 0x0a, 0x68, - 0x61, 0x64, 0x6f, 0x6f, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x2a, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, - 0x6e, 0x73, 0x2e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x4a, 0x6f, 0x62, 0x2e, 0x48, 0x61, 0x64, 0x6f, - 0x6f, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x68, 0x61, 0x64, - 0x6f, 0x6f, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x12, 0x22, 0x0a, 0x0c, 0x65, 0x78, 0x65, 0x63, 0x75, - 0x74, 0x6f, 0x72, 0x50, 0x61, 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x65, - 0x78, 0x65, 0x63, 0x75, 0x74, 0x6f, 0x72, 0x50, 0x61, 0x74, 0x68, 0x12, 0x3f, 0x0a, 0x0e, 0x64, - 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x0e, 0x64, 0x61, - 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x12, 0x28, 0x0a, 0x0f, - 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, - 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x12, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, - 0x69, 0x63, 0x6b, 0x73, 0x49, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x63, 0x65, 0x18, 0x09, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x12, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x49, 0x6e, - 0x73, 0x74, 0x61, 0x6e, 0x63, 0x65, 0x1a, 0x3c, 0x0a, 0x0e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x43, - 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, - 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, - 0x3a, 0x02, 0x38, 0x01, 0x1a, 0x3d, 0x0a, 0x0f, 0x48, 0x61, 0x64, 0x6f, 0x6f, 0x70, 0x43, 0x6f, - 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, - 0x02, 0x38, 0x01, 0x42, 0xc2, 0x01, 0x0a, 0x14, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, - 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x42, 0x0a, 0x53, 0x70, - 0x61, 0x72, 0x6b, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, - 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, - 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, - 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, - 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x58, 0xaa, - 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, - 0x6e, 0x73, 0xca, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, - 0x75, 0x67, 0x69, 0x6e, 0x73, 0xe2, 0x02, 0x1c, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, - 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, - 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x11, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, - 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x13, 0x6d, 0x61, 0x69, 0x6e, + 0x41, 0x70, 0x70, 0x6c, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x69, 0x6c, 0x65, 0x12, + 0x1c, 0x0a, 0x09, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x6d, 0x61, 0x69, 0x6e, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x47, 0x0a, + 0x09, 0x73, 0x70, 0x61, 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x29, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, + 0x69, 0x6e, 0x73, 0x2e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x4a, 0x6f, 0x62, 0x2e, 0x53, 0x70, 0x61, + 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x73, 0x70, 0x61, + 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x12, 0x4a, 0x0a, 0x0a, 0x68, 0x61, 0x64, 0x6f, 0x6f, 0x70, + 0x43, 0x6f, 0x6e, 0x66, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2a, 0x2e, 0x66, 0x6c, 0x79, + 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0x2e, 0x53, 0x70, + 0x61, 0x72, 0x6b, 0x4a, 0x6f, 0x62, 0x2e, 0x48, 0x61, 0x64, 0x6f, 0x6f, 0x70, 0x43, 0x6f, 0x6e, + 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x68, 0x61, 0x64, 0x6f, 0x6f, 0x70, 0x43, 0x6f, + 0x6e, 0x66, 0x12, 0x22, 0x0a, 0x0c, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x6f, 0x72, 0x50, 0x61, + 0x74, 0x68, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, + 0x6f, 0x72, 0x50, 0x61, 0x74, 0x68, 0x12, 0x3f, 0x0a, 0x0e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, + 0x69, 0x63, 0x6b, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x52, 0x0e, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, + 0x63, 0x6b, 0x73, 0x43, 0x6f, 0x6e, 0x66, 0x12, 0x28, 0x0a, 0x0f, 0x64, 0x61, 0x74, 0x61, 0x62, + 0x72, 0x69, 0x63, 0x6b, 0x73, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0f, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x12, 0x2e, 0x0a, 0x12, 0x64, 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x49, + 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x63, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x64, + 0x61, 0x74, 0x61, 0x62, 0x72, 0x69, 0x63, 0x6b, 0x73, 0x49, 0x6e, 0x73, 0x74, 0x61, 0x6e, 0x63, + 0x65, 0x12, 0x33, 0x0a, 0x09, 0x64, 0x72, 0x69, 0x76, 0x65, 0x72, 0x50, 0x6f, 0x64, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, + 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x4b, 0x38, 0x73, 0x50, 0x6f, 0x64, 0x52, 0x09, 0x64, 0x72, 0x69, + 0x76, 0x65, 0x72, 0x50, 0x6f, 0x64, 0x12, 0x37, 0x0a, 0x0b, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, + 0x6f, 0x72, 0x50, 0x6f, 0x64, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x66, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x63, 0x6f, 0x72, 0x65, 0x2e, 0x4b, 0x38, 0x73, 0x50, + 0x6f, 0x64, 0x52, 0x0b, 0x65, 0x78, 0x65, 0x63, 0x75, 0x74, 0x6f, 0x72, 0x50, 0x6f, 0x64, 0x1a, + 0x3c, 0x0a, 0x0e, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x43, 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, + 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x1a, 0x3d, 0x0a, + 0x0f, 0x48, 0x61, 0x64, 0x6f, 0x6f, 0x70, 0x43, 0x6f, 0x6e, 0x66, 0x45, 0x6e, 0x74, 0x72, 0x79, + 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, + 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x42, 0xc2, 0x01, 0x0a, + 0x14, 0x63, 0x6f, 0x6d, 0x2e, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2e, 0x70, 0x6c, + 0x75, 0x67, 0x69, 0x6e, 0x73, 0x42, 0x0a, 0x53, 0x70, 0x61, 0x72, 0x6b, 0x50, 0x72, 0x6f, 0x74, + 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x66, 0x6c, 0x79, 0x74, 0x65, 0x6f, 0x72, 0x67, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x2f, 0x66, + 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x62, 0x2d, 0x67, + 0x6f, 0x2f, 0x66, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x2f, 0x70, 0x6c, 0x75, 0x67, 0x69, + 0x6e, 0x73, 0xa2, 0x02, 0x03, 0x46, 0x50, 0x58, 0xaa, 0x02, 0x10, 0x46, 0x6c, 0x79, 0x74, 0x65, + 0x69, 0x64, 0x6c, 0x2e, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xca, 0x02, 0x10, 0x46, 0x6c, + 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, 0x73, 0xe2, 0x02, + 0x1c, 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x5c, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x73, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x11, + 0x46, 0x6c, 0x79, 0x74, 0x65, 0x69, 0x64, 0x6c, 0x3a, 0x3a, 0x50, 0x6c, 0x75, 0x67, 0x69, 0x6e, + 0x73, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -317,17 +345,20 @@ var file_flyteidl_plugins_spark_proto_goTypes = []interface{}{ nil, // 3: flyteidl.plugins.SparkJob.SparkConfEntry nil, // 4: flyteidl.plugins.SparkJob.HadoopConfEntry (*structpb.Struct)(nil), // 5: google.protobuf.Struct + (*core.K8SPod)(nil), // 6: flyteidl.core.K8sPod } var file_flyteidl_plugins_spark_proto_depIdxs = []int32{ 0, // 0: flyteidl.plugins.SparkJob.applicationType:type_name -> flyteidl.plugins.SparkApplication.Type 3, // 1: flyteidl.plugins.SparkJob.sparkConf:type_name -> flyteidl.plugins.SparkJob.SparkConfEntry 4, // 2: flyteidl.plugins.SparkJob.hadoopConf:type_name -> flyteidl.plugins.SparkJob.HadoopConfEntry 5, // 3: flyteidl.plugins.SparkJob.databricksConf:type_name -> google.protobuf.Struct - 4, // [4:4] is the sub-list for method output_type - 4, // [4:4] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 4, // [4:4] is the sub-list for extension extendee - 0, // [0:4] is the sub-list for field type_name + 6, // 4: flyteidl.plugins.SparkJob.driverPod:type_name -> flyteidl.core.K8sPod + 6, // 5: flyteidl.plugins.SparkJob.executorPod:type_name -> flyteidl.core.K8sPod + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_flyteidl_plugins_spark_proto_init() } diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.py b/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.py index 8ee1759390..6b1d892a42 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.py +++ b/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.py @@ -12,9 +12,10 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +from flyteidl.core import tasks_pb2 as flyteidl_dot_core_dot_tasks__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/plugins/spark.proto\x12\x10\x66lyteidl.plugins\x1a\x1cgoogle/protobuf/struct.proto\"B\n\x10SparkApplication\".\n\x04Type\x12\n\n\x06PYTHON\x10\x00\x12\x08\n\x04JAVA\x10\x01\x12\t\n\x05SCALA\x10\x02\x12\x05\n\x01R\x10\x03\"\xfe\x04\n\x08SparkJob\x12Q\n\x0f\x61pplicationType\x18\x01 \x01(\x0e\x32\'.flyteidl.plugins.SparkApplication.TypeR\x0f\x61pplicationType\x12\x30\n\x13mainApplicationFile\x18\x02 \x01(\tR\x13mainApplicationFile\x12\x1c\n\tmainClass\x18\x03 \x01(\tR\tmainClass\x12G\n\tsparkConf\x18\x04 \x03(\x0b\x32).flyteidl.plugins.SparkJob.SparkConfEntryR\tsparkConf\x12J\n\nhadoopConf\x18\x05 \x03(\x0b\x32*.flyteidl.plugins.SparkJob.HadoopConfEntryR\nhadoopConf\x12\"\n\x0c\x65xecutorPath\x18\x06 \x01(\tR\x0c\x65xecutorPath\x12?\n\x0e\x64\x61tabricksConf\x18\x07 \x01(\x0b\x32\x17.google.protobuf.StructR\x0e\x64\x61tabricksConf\x12(\n\x0f\x64\x61tabricksToken\x18\x08 \x01(\tR\x0f\x64\x61tabricksToken\x12.\n\x12\x64\x61tabricksInstance\x18\t \x01(\tR\x12\x64\x61tabricksInstance\x1a<\n\x0eSparkConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a=\n\x0fHadoopConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc2\x01\n\x14\x63om.flyteidl.pluginsB\nSparkProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1c\x66lyteidl/plugins/spark.proto\x12\x10\x66lyteidl.plugins\x1a\x1cgoogle/protobuf/struct.proto\x1a\x19\x66lyteidl/core/tasks.proto\"B\n\x10SparkApplication\".\n\x04Type\x12\n\n\x06PYTHON\x10\x00\x12\x08\n\x04JAVA\x10\x01\x12\t\n\x05SCALA\x10\x02\x12\x05\n\x01R\x10\x03\"\xec\x05\n\x08SparkJob\x12Q\n\x0f\x61pplicationType\x18\x01 \x01(\x0e\x32\'.flyteidl.plugins.SparkApplication.TypeR\x0f\x61pplicationType\x12\x30\n\x13mainApplicationFile\x18\x02 \x01(\tR\x13mainApplicationFile\x12\x1c\n\tmainClass\x18\x03 \x01(\tR\tmainClass\x12G\n\tsparkConf\x18\x04 \x03(\x0b\x32).flyteidl.plugins.SparkJob.SparkConfEntryR\tsparkConf\x12J\n\nhadoopConf\x18\x05 \x03(\x0b\x32*.flyteidl.plugins.SparkJob.HadoopConfEntryR\nhadoopConf\x12\"\n\x0c\x65xecutorPath\x18\x06 \x01(\tR\x0c\x65xecutorPath\x12?\n\x0e\x64\x61tabricksConf\x18\x07 \x01(\x0b\x32\x17.google.protobuf.StructR\x0e\x64\x61tabricksConf\x12(\n\x0f\x64\x61tabricksToken\x18\x08 \x01(\tR\x0f\x64\x61tabricksToken\x12.\n\x12\x64\x61tabricksInstance\x18\t \x01(\tR\x12\x64\x61tabricksInstance\x12\x33\n\tdriverPod\x18\n \x01(\x0b\x32\x15.flyteidl.core.K8sPodR\tdriverPod\x12\x37\n\x0b\x65xecutorPod\x18\x0b \x01(\x0b\x32\x15.flyteidl.core.K8sPodR\x0b\x65xecutorPod\x1a<\n\x0eSparkConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a=\n\x0fHadoopConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\xc2\x01\n\x14\x63om.flyteidl.pluginsB\nSparkProtoP\x01Z=github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins\xa2\x02\x03\x46PX\xaa\x02\x10\x46lyteidl.Plugins\xca\x02\x10\x46lyteidl\\Plugins\xe2\x02\x1c\x46lyteidl\\Plugins\\GPBMetadata\xea\x02\x11\x46lyteidl::Pluginsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -27,14 +28,14 @@ _SPARKJOB_SPARKCONFENTRY._serialized_options = b'8\001' _SPARKJOB_HADOOPCONFENTRY._options = None _SPARKJOB_HADOOPCONFENTRY._serialized_options = b'8\001' - _globals['_SPARKAPPLICATION']._serialized_start=80 - _globals['_SPARKAPPLICATION']._serialized_end=146 - _globals['_SPARKAPPLICATION_TYPE']._serialized_start=100 - _globals['_SPARKAPPLICATION_TYPE']._serialized_end=146 - _globals['_SPARKJOB']._serialized_start=149 - _globals['_SPARKJOB']._serialized_end=787 - _globals['_SPARKJOB_SPARKCONFENTRY']._serialized_start=664 - _globals['_SPARKJOB_SPARKCONFENTRY']._serialized_end=724 - _globals['_SPARKJOB_HADOOPCONFENTRY']._serialized_start=726 - _globals['_SPARKJOB_HADOOPCONFENTRY']._serialized_end=787 + _globals['_SPARKAPPLICATION']._serialized_start=107 + _globals['_SPARKAPPLICATION']._serialized_end=173 + _globals['_SPARKAPPLICATION_TYPE']._serialized_start=127 + _globals['_SPARKAPPLICATION_TYPE']._serialized_end=173 + _globals['_SPARKJOB']._serialized_start=176 + _globals['_SPARKJOB']._serialized_end=924 + _globals['_SPARKJOB_SPARKCONFENTRY']._serialized_start=801 + _globals['_SPARKJOB_SPARKCONFENTRY']._serialized_end=861 + _globals['_SPARKJOB_HADOOPCONFENTRY']._serialized_start=863 + _globals['_SPARKJOB_HADOOPCONFENTRY']._serialized_end=924 # @@protoc_insertion_point(module_scope) diff --git a/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.pyi b/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.pyi index e6b9e4eb68..559ad64cf6 100644 --- a/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.pyi +++ b/flyteidl/gen/pb_python/flyteidl/plugins/spark_pb2.pyi @@ -1,4 +1,5 @@ from google.protobuf import struct_pb2 as _struct_pb2 +from flyteidl.core import tasks_pb2 as _tasks_pb2 from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor @@ -22,7 +23,7 @@ class SparkApplication(_message.Message): def __init__(self) -> None: ... class SparkJob(_message.Message): - __slots__ = ["applicationType", "mainApplicationFile", "mainClass", "sparkConf", "hadoopConf", "executorPath", "databricksConf", "databricksToken", "databricksInstance"] + __slots__ = ["applicationType", "mainApplicationFile", "mainClass", "sparkConf", "hadoopConf", "executorPath", "databricksConf", "databricksToken", "databricksInstance", "driverPod", "executorPod"] class SparkConfEntry(_message.Message): __slots__ = ["key", "value"] KEY_FIELD_NUMBER: _ClassVar[int] @@ -46,6 +47,8 @@ class SparkJob(_message.Message): DATABRICKSCONF_FIELD_NUMBER: _ClassVar[int] DATABRICKSTOKEN_FIELD_NUMBER: _ClassVar[int] DATABRICKSINSTANCE_FIELD_NUMBER: _ClassVar[int] + DRIVERPOD_FIELD_NUMBER: _ClassVar[int] + EXECUTORPOD_FIELD_NUMBER: _ClassVar[int] applicationType: SparkApplication.Type mainApplicationFile: str mainClass: str @@ -55,4 +58,6 @@ class SparkJob(_message.Message): databricksConf: _struct_pb2.Struct databricksToken: str databricksInstance: str - def __init__(self, applicationType: _Optional[_Union[SparkApplication.Type, str]] = ..., mainApplicationFile: _Optional[str] = ..., mainClass: _Optional[str] = ..., sparkConf: _Optional[_Mapping[str, str]] = ..., hadoopConf: _Optional[_Mapping[str, str]] = ..., executorPath: _Optional[str] = ..., databricksConf: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., databricksToken: _Optional[str] = ..., databricksInstance: _Optional[str] = ...) -> None: ... + driverPod: _tasks_pb2.K8sPod + executorPod: _tasks_pb2.K8sPod + def __init__(self, applicationType: _Optional[_Union[SparkApplication.Type, str]] = ..., mainApplicationFile: _Optional[str] = ..., mainClass: _Optional[str] = ..., sparkConf: _Optional[_Mapping[str, str]] = ..., hadoopConf: _Optional[_Mapping[str, str]] = ..., executorPath: _Optional[str] = ..., databricksConf: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ..., databricksToken: _Optional[str] = ..., databricksInstance: _Optional[str] = ..., driverPod: _Optional[_Union[_tasks_pb2.K8sPod, _Mapping]] = ..., executorPod: _Optional[_Union[_tasks_pb2.K8sPod, _Mapping]] = ...) -> None: ... diff --git a/flyteidl/gen/pb_rust/flyteidl.plugins.rs b/flyteidl/gen/pb_rust/flyteidl.plugins.rs index 65f187c3e0..16589b3e60 100644 --- a/flyteidl/gen/pb_rust/flyteidl.plugins.rs +++ b/flyteidl/gen/pb_rust/flyteidl.plugins.rs @@ -351,6 +351,12 @@ pub struct SparkJob { /// This instance name can be set in either flytepropeller or flytekit. #[prost(string, tag="9")] pub databricks_instance: ::prost::alloc::string::String, + /// Pod Spec for the Spark driver pod + #[prost(message, optional, tag="10")] + pub driver_pod: ::core::option::Option, + /// Pod Spec for the Spark executor pod + #[prost(message, optional, tag="11")] + pub executor_pod: ::core::option::Option, } /// Custom proto for plugin that enables distributed training using #[allow(clippy::derive_partial_eq_without_eq)] diff --git a/flyteidl/protos/flyteidl/plugins/spark.proto b/flyteidl/protos/flyteidl/plugins/spark.proto index 666ea311b2..7968d5be41 100644 --- a/flyteidl/protos/flyteidl/plugins/spark.proto +++ b/flyteidl/protos/flyteidl/plugins/spark.proto @@ -1,10 +1,12 @@ syntax = "proto3"; package flyteidl.plugins; -import "google/protobuf/struct.proto"; option go_package = "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"; +import "google/protobuf/struct.proto"; +import "flyteidl/core/tasks.proto"; + message SparkApplication { enum Type { PYTHON = 0; @@ -31,4 +33,10 @@ message SparkJob { // Domain name of your deployment. Use the form .cloud.databricks.com. // This instance name can be set in either flytepropeller or flytekit. string databricksInstance = 9; + + // Pod Spec for the Spark driver pod + core.K8sPod driverPod = 10; + + // Pod Spec for the Spark executor pod + core.K8sPod executorPod = 11; } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 6beca78f54..4a8c0f50f9 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -280,7 +280,14 @@ func BuildRawPod(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v return nil, nil, "", err } - primaryContainerName = c.Name + // If primaryContainerName is set in taskTemplate config, use it instead + // of c.Name + if val, ok := taskTemplate.GetConfig()[PrimaryContainerKey]; ok { + primaryContainerName = val + c.Name = primaryContainerName + } else { + primaryContainerName = c.Name + } podSpec = &v1.PodSpec{ Containers: []v1.Container{ *c, @@ -563,7 +570,7 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio } // merge podSpec with podTemplate - mergedPodSpec, err := mergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) + mergedPodSpec, err := MergePodSpecs(&podTemplate.Template.Spec, podSpec, primaryContainerName, primaryInitContainerName) if err != nil { return nil, nil, err } @@ -577,10 +584,10 @@ func MergeWithBasePodTemplate(ctx context.Context, tCtx pluginsCore.TaskExecutio return mergedPodSpec, mergedObjectMeta, nil } -// mergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values +// MergePodSpecs merges the two provided PodSpecs. This process uses the first as the base configuration, where values // set by the first PodSpec are overwritten by the second in the return value. Additionally, this function applies // container-level configuration from the basePodSpec. -func mergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { +func MergePodSpecs(basePodSpec *v1.PodSpec, podSpec *v1.PodSpec, primaryContainerName string, primaryInitContainerName string) (*v1.PodSpec, error) { if basePodSpec == nil || podSpec == nil { return nil, errors.New("neither the basePodSpec or the podSpec can be nil") } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 0a70cdd895..139ee583dc 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -2047,13 +2047,13 @@ func TestMergeWithBasePodTemplate(t *testing.T) { func TestMergePodSpecs(t *testing.T) { var priority int32 = 1 - podSpec1, _ := mergePodSpecs(nil, nil, "foo", "foo-init") + podSpec1, _ := MergePodSpecs(nil, nil, "foo", "foo-init") assert.Nil(t, podSpec1) - podSpec2, _ := mergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") + podSpec2, _ := MergePodSpecs(&v1.PodSpec{}, nil, "foo", "foo-init") assert.Nil(t, podSpec2) - podSpec3, _ := mergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") + podSpec3, _ := MergePodSpecs(nil, &v1.PodSpec{}, "foo", "foo-init") assert.Nil(t, podSpec3) podSpec := v1.PodSpec{ @@ -2141,7 +2141,7 @@ func TestMergePodSpecs(t *testing.T) { }, } - mergedPodSpec, err := mergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") + mergedPodSpec, err := MergePodSpecs(&podTemplateSpec, &podSpec, "primary", "primary-init") assert.Nil(t, err) // validate a PodTemplate-only field diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go index 6873fc2257..6225f918c4 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark.go @@ -25,7 +25,8 @@ import ( "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/flytek8s/config" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/k8s" "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/tasklog" - "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + pluginsUtils "github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/flyteorg/flyte/flytestdlib/utils" ) const KindSparkApplication = "SparkApplication" @@ -65,7 +66,7 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } sparkJob := plugins.SparkJob{} - err = utils.UnmarshalStruct(taskTemplate.GetCustom(), &sparkJob) + err = utils.UnmarshalStructToPb(taskTemplate.GetCustom(), &sparkJob) if err != nil { return nil, errors.Wrapf(errors.BadTaskSpecification, err, "invalid TaskSpecification [%v], failed to unmarshal", taskTemplate.GetCustom()) } @@ -75,11 +76,11 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo } sparkConfig := getSparkConfig(taskCtx, &sparkJob) - driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig) + driverSpec, err := createDriverSpec(ctx, taskCtx, sparkConfig, &sparkJob) if err != nil { return nil, err } - executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig) + executorSpec, err := createExecutorSpec(ctx, taskCtx, sparkConfig, &sparkJob) if err != nil { return nil, err } @@ -141,9 +142,27 @@ func serviceAccountName(metadata pluginsCore.TaskExecutionMetadata) string { return name } -func createSparkPodSpec(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, container *v1.Container) *sparkOp.SparkPodSpec { - annotations := utils.UnionMaps(config.GetK8sPluginConfig().DefaultAnnotations, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations())) - labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) +func createSparkPodSpec( + taskCtx pluginsCore.TaskExecutionContext, + podSpec *v1.PodSpec, + container *v1.Container, + k8sPod *core.K8SPod, +) *sparkOp.SparkPodSpec { + + annotations := pluginsUtils.UnionMaps( + config.GetK8sPluginConfig().DefaultAnnotations, + pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetAnnotations()), + ) + labels := pluginsUtils.UnionMaps( + config.GetK8sPluginConfig().DefaultLabels, + pluginsUtils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()), + ) + if k8sPod.GetMetadata().GetAnnotations() != nil { + annotations = pluginsUtils.UnionMaps(annotations, k8sPod.GetMetadata().GetAnnotations()) + } + if k8sPod.GetMetadata().GetLabels() != nil { + labels = pluginsUtils.UnionMaps(labels, k8sPod.GetMetadata().GetLabels()) + } sparkEnv := make([]v1.EnvVar, 0) for _, envVar := range container.Env { @@ -171,18 +190,35 @@ type driverSpec struct { sparkSpec *sparkOp.DriverSpec } -func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*driverSpec, error) { +func createDriverSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*driverSpec, error) { // Spark driver pods should always run as non-interruptible nonInterruptibleTaskCtx := flytek8s.NewPluginTaskExecutionContext(taskCtx, flytek8s.WithInterruptible(false)) podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, nonInterruptibleTaskCtx) if err != nil { return nil, err } + + driverPod := sparkJob.GetDriverPod() + if driverPod != nil { + var customPodSpec *v1.PodSpec + + err = utils.UnmarshalStructToObj(driverPod.GetPodSpec(), &customPodSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal driver pod spec [%v], Err: [%v]", driverPod.GetPodSpec(), err.Error()) + } + + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") + if err != nil { + return nil, err + } + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer) + sparkPodSpec := createSparkPodSpec(nonInterruptibleTaskCtx, podSpec, primaryContainer, driverPod) serviceAccountName := serviceAccountName(nonInterruptibleTaskCtx.TaskExecutionMetadata()) spec := driverSpec{ &sparkOp.DriverSpec{ @@ -203,16 +239,33 @@ type executorSpec struct { serviceAccountName string } -func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string) (*executorSpec, error) { +func createExecutorSpec(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext, sparkConfig map[string]string, sparkJob *plugins.SparkJob) (*executorSpec, error) { podSpec, _, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, err } + + executorPod := sparkJob.GetExecutorPod() + if executorPod != nil { + var customPodSpec *v1.PodSpec + + err = utils.UnmarshalStructToObj(executorPod.GetPodSpec(), &customPodSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal executor pod spec [%v], Err: [%v]", executorPod.GetPodSpec(), err.Error()) + } + + podSpec, err = flytek8s.MergePodSpecs(podSpec, customPodSpec, primaryContainerName, "") + if err != nil { + return nil, err + } + } + primaryContainer, err := flytek8s.GetContainer(podSpec, primaryContainerName) if err != nil { return nil, err } - sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer) + sparkPodSpec := createSparkPodSpec(taskCtx, podSpec, primaryContainer, sparkJob.GetExecutorPod()) serviceAccountName := serviceAccountName(taskCtx.TaskExecutionMetadata()) spec := executorSpec{ primaryContainer, diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 0a6f51d0e2..ed361374e6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -2,6 +2,7 @@ package spark import ( "context" + "encoding/json" "os" "reflect" "strconv" @@ -9,9 +10,9 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -283,6 +284,19 @@ func dummySparkCustomObj(sparkConf map[string]string) *plugins.SparkJob { return &sparkJob } +func dummySparkCustomObjDriverExecutor(sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *plugins.SparkJob { + sparkJob := plugins.SparkJob{} + + sparkJob.MainClass = sparkMainClass + sparkJob.MainApplicationFile = sparkApplicationFile + sparkJob.SparkConf = sparkConf + sparkJob.ApplicationType = plugins.SparkApplication_PYTHON + + sparkJob.DriverPod = driverPod + sparkJob.ExecutorPod = executorPod + return &sparkJob +} + func dummyPodSpec() *corev1.PodSpec { return &corev1.PodSpec{ InitContainers: []corev1.Container{ @@ -337,7 +351,31 @@ func dummySparkTaskTemplateContainer(id string, sparkConf map[string]string) *co } } +func dummySparkTaskTemplateDriverExecutor(id string, sparkConf map[string]string, driverPod *core.K8SPod, executorPod *core.K8SPod) *core.TaskTemplate { + sparkJob := dummySparkCustomObjDriverExecutor(sparkConf, driverPod, executorPod) + + structObj, err := utils.MarshalObjToStruct(sparkJob) + if err != nil { + panic(err) + } + + return &core.TaskTemplate{ + Id: &core.Identifier{Name: id}, + Type: "container", + Target: &core.TaskTemplate_Container{ + Container: &core.Container{ + Image: testImage, + }, + }, + Config: map[string]string{ + flytek8s.PrimaryContainerKey: "primary", + }, + Custom: structObj, + } +} + func dummySparkTaskTemplatePod(id string, sparkConf map[string]string, podSpec *corev1.PodSpec) *core.TaskTemplate { + // add driver/executor pod below sparkJob := dummySparkCustomObj(sparkConf) sparkJobJSON, err := utils.MarshalToString(sparkJob) if err != nil { @@ -930,3 +968,175 @@ func TestGetPropertiesSpark(t *testing.T) { expected := k8s.PluginProperties{} assert.Equal(t, expected, sparkResourceHandler.GetProperties()) } + +func TestBuildResourceCustomK8SPod(t *testing.T) { + + defaultConfig := defaultPluginConfig() + assert.NoError(t, config.SetK8sPluginConfig(defaultConfig)) + + // add extraDriverToleration and extraExecutorToleration + driverExtraToleration := corev1.Toleration{ + Key: "x/flyte-driver", + Value: "extra-driver", + Operator: "Equal", + } + executorExtraToleration := corev1.Toleration{ + Key: "x/flyte-executor", + Value: "extra-executor", + Operator: "Equal", + } + + // pod for driver and executor + driverPodSpec := dummyPodSpec() + executorPodSpec := dummyPodSpec() + driverPodSpec.Tolerations = append(driverPodSpec.Tolerations, driverExtraToleration) + driverPodSpec.NodeSelector = map[string]string{"x/custom": "foo-driver"} + executorPodSpec.Tolerations = append(executorPodSpec.Tolerations, executorExtraToleration) + executorPodSpec.NodeSelector = map[string]string{"x/custom": "foo-executor"} + + driverK8SPod := &core.K8SPod{ + PodSpec: transformStructToStructPB(t, driverPodSpec), + Metadata: &core.K8SObjectMetadata{ + Annotations: map[string]string{"annotation-driver": "val-driver"}, + Labels: map[string]string{"label-driver": "val-driver"}, + }, + } + executorK8SPod := &core.K8SPod{ + PodSpec: transformStructToStructPB(t, executorPodSpec), + Metadata: &core.K8SObjectMetadata{ + Annotations: map[string]string{"annotation-executor": "val-executor"}, + Labels: map[string]string{"label-executor": "val-executor"}, + }, + } + taskTemplate := dummySparkTaskTemplateDriverExecutor("blah-1", dummySparkConf, driverK8SPod, executorK8SPod) + sparkResourceHandler := sparkResourceHandler{} + + taskCtx := dummySparkTaskContext(taskTemplate, true, k8s.PluginState{}) + resource, err := sparkResourceHandler.BuildResource(context.TODO(), taskCtx) + + assert.Nil(t, err) + assert.NotNil(t, resource) + sparkApp, ok := resource.(*sj.SparkApplication) + assert.True(t, ok) + + // Application + assert.Equal(t, v1.TypeMeta{ + Kind: KindSparkApplication, + APIVersion: sparkOp.SchemeGroupVersion.String(), + }, sparkApp.TypeMeta) + + // Application spec + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.ServiceAccount) + assert.Equal(t, sparkOp.PythonApplicationType, sparkApp.Spec.Type) + assert.Equal(t, testImage, *sparkApp.Spec.Image) + assert.Equal(t, testArgs, sparkApp.Spec.Arguments) + assert.Equal(t, sparkOp.RestartPolicy{ + Type: sparkOp.OnFailure, + OnSubmissionFailureRetries: intPtr(int32(14)), + }, sparkApp.Spec.RestartPolicy) + assert.Equal(t, sparkMainClass, *sparkApp.Spec.MainClass) + assert.Equal(t, sparkApplicationFile, *sparkApp.Spec.MainApplicationFile) + + // Driver + assert.Equal(t, utils.UnionMaps( + defaultConfig.DefaultAnnotations, map[string]string{ + "annotation-1": "val1", + "annotation-driver": "val-driver", + }, + ), sparkApp.Spec.Driver.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{ + "label-1": "val1", + "label-driver": "val-driver", + }), sparkApp.Spec.Driver.Labels) + assert.Equal(t, len(findEnvVarByName(sparkApp.Spec.Driver.Env, "FLYTE_MAX_ATTEMPTS").Value), 1) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Driver.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Driver.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Driver.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Driver.Env)) + assert.Equal(t, testImage, *sparkApp.Spec.Driver.Image) + assert.Equal(t, flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()), *sparkApp.Spec.Driver.ServiceAccount) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Driver.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Driver.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Driver.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Driver.SchedulerName) + assert.Equal(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + driverExtraToleration, + }, sparkApp.Spec.Driver.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo-driver", + }, sparkApp.Spec.Driver.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.NonInterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Driver.Affinity.NodeAffinity) + cores, _ := strconv.ParseInt(dummySparkConf["spark.driver.cores"], 10, 32) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Driver.Cores) + assert.Equal(t, dummySparkConf["spark.driver.memory"], *sparkApp.Spec.Driver.Memory) + + // // Executor + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultAnnotations, map[string]string{ + "annotation-1": "val1", + "annotation-executor": "val-executor", + }), sparkApp.Spec.Executor.Annotations) + assert.Equal(t, utils.UnionMaps(defaultConfig.DefaultLabels, map[string]string{ + "label-1": "val1", + "label-executor": "val-executor", + }), sparkApp.Spec.Executor.Labels) + assert.Equal(t, defaultConfig.DefaultEnvVars["foo"], findEnvVarByName(sparkApp.Spec.Executor.Env, "foo").Value) + assert.Equal(t, defaultConfig.DefaultEnvVars["fooEnv"], findEnvVarByName(sparkApp.Spec.Executor.Env, "fooEnv").Value) + assert.Equal(t, findEnvVarByName(dummyEnvVarsWithSecretRef, "SECRET"), findEnvVarByName(sparkApp.Spec.Executor.Env, "SECRET")) + assert.Equal(t, 9, len(sparkApp.Spec.Executor.Env)) + assert.Equal(t, testImage, *sparkApp.Spec.Executor.Image) + assert.Equal(t, defaultConfig.DefaultPodSecurityContext, sparkApp.Spec.Executor.SecurityContenxt) + assert.Equal(t, defaultConfig.DefaultPodDNSConfig, sparkApp.Spec.Executor.DNSConfig) + assert.Equal(t, defaultConfig.EnableHostNetworkingPod, sparkApp.Spec.Executor.HostNetwork) + assert.Equal(t, defaultConfig.SchedulerName, *sparkApp.Spec.Executor.SchedulerName) + assert.ElementsMatch(t, []corev1.Toleration{ + defaultConfig.DefaultTolerations[0], + executorExtraToleration, + defaultConfig.InterruptibleTolerations[0], + }, sparkApp.Spec.Executor.Tolerations) + assert.Equal(t, map[string]string{ + "x/default": "true", + "x/custom": "foo-executor", + "x/interruptible": "true", + }, sparkApp.Spec.Executor.NodeSelector) + assert.Equal(t, &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + defaultConfig.DefaultAffinity.NodeAffinity.RequiredDuringSchedulingIgnoredDuringExecution.NodeSelectorTerms[0].MatchExpressions[0], + *defaultConfig.InterruptibleNodeSelectorRequirement, + }, + }, + }, + }, + }, sparkApp.Spec.Executor.Affinity.NodeAffinity) + cores, _ = strconv.ParseInt(dummySparkConf["spark.executor.cores"], 10, 32) + instances, _ := strconv.ParseInt(dummySparkConf["spark.executor.instances"], 10, 32) + assert.Equal(t, intPtr(int32(instances)), sparkApp.Spec.Executor.Instances) + assert.Equal(t, intPtr(int32(cores)), sparkApp.Spec.Executor.Cores) + assert.Equal(t, dummySparkConf["spark.executor.memory"], *sparkApp.Spec.Executor.Memory) +} + +func transformStructToStructPB(t *testing.T, obj interface{}) *structpb.Struct { + data, err := json.Marshal(obj) + assert.Nil(t, err) + podSpecMap := make(map[string]interface{}) + err = json.Unmarshal(data, &podSpecMap) + assert.Nil(t, err) + s, err := structpb.NewStruct(podSpecMap) + assert.Nil(t, err) + return s +}