diff --git a/examples/models/python/config/file_io_matplotlib.yaml b/examples/models/python/config/file_io_matplotlib.yaml new file mode 100644 index 000000000..6b45f1e89 --- /dev/null +++ b/examples/models/python/config/file_io_matplotlib.yaml @@ -0,0 +1,9 @@ + +job: + runModel: + + inputs: + quarterly_sales: "outputs/using_data/profit_by_region.csv" + + outputs: + sales_report: "outputs/file_io/sales_report.svg" diff --git a/examples/models/python/config/file_io_powerpoint.yaml b/examples/models/python/config/file_io_powerpoint.yaml new file mode 100644 index 000000000..ef12028d5 --- /dev/null +++ b/examples/models/python/config/file_io_powerpoint.yaml @@ -0,0 +1,10 @@ + +job: + runModel: + + inputs: + quarterly_sales: "outputs/using_data/profit_by_region.csv" + report_template: "inputs/sales_report_template.pptx" + + outputs: + sales_report: "outputs/file_io/sales_report.pptx" diff --git a/examples/models/python/data/inputs/sales_report_template.pptx b/examples/models/python/data/inputs/sales_report_template.pptx new file mode 100644 index 000000000..ce21df877 Binary files /dev/null and b/examples/models/python/data/inputs/sales_report_template.pptx differ diff --git a/examples/models/python/src/tutorial/file_io_matplotlib.py b/examples/models/python/src/tutorial/file_io_matplotlib.py new file mode 100644 index 000000000..5c0997125 --- /dev/null +++ b/examples/models/python/src/tutorial/file_io_matplotlib.py @@ -0,0 +1,69 @@ +# Licensed to the Fintech Open Source Foundation (FINOS) under one or +# more contributor license agreements. See the NOTICE file distributed +# with this work for additional information regarding copyright ownership. +# FINOS licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with the +# License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as _tp + +import tracdap.rt.api as trac + +import matplotlib +import matplotlib.pyplot as plt + +import tutorial.schemas as schemas + + +class MatplotlibReport(trac.TracModel): + + def define_parameters(self) -> _tp.Dict[str, trac.ModelParameter]: + + return trac.define_parameters() + + def define_inputs(self) -> _tp.Dict[str, trac.ModelInputSchema]: + + quarterly_sales_schema = trac.load_schema(schemas, "profit_by_region.csv") + quarterly_sales = trac.define_input(quarterly_sales_schema, label="Quarterly sales data") + + return { "quarterly_sales": quarterly_sales } + + def define_outputs(self) -> _tp.Dict[str, trac.ModelOutputSchema]: + + sales_report = trac.define_output(trac.CommonFileTypes.SVG, label="Quarterly sales report") + + return { "sales_report": sales_report } + + def run_model(self, ctx: trac.TracContext): + + matplotlib.use("agg") + + quarterly_sales = ctx.get_pandas_table("quarterly_sales") + + regions = quarterly_sales["region"] + values = quarterly_sales["gross_profit"] + + fig = plt.figure() + + plt.bar(regions, values) + plt.title("Profit by region report") + plt.xlabel("Region") + plt.ylabel("Gross profit") + + with ctx.put_file_stream("sales_report") as sales_report: + plt.savefig(sales_report, format='svg') + + plt.close(fig) + + +if __name__ == "__main__": + import tracdap.rt.launch as launch + launch.launch_model(MatplotlibReport, "config/file_io_matplotlib.yaml", "config/sys_config.yaml") diff --git a/examples/models/python/src/tutorial/file_io_powerpoint.py b/examples/models/python/src/tutorial/file_io_powerpoint.py new file mode 100644 index 000000000..cb995ea6a --- /dev/null +++ b/examples/models/python/src/tutorial/file_io_powerpoint.py @@ -0,0 +1,77 @@ +# Licensed to the Fintech Open Source Foundation (FINOS) under one or +# more contributor license agreements. See the NOTICE file distributed +# with this work for additional information regarding copyright ownership. +# FINOS licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with the +# License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as _tp + +import tracdap.rt.api as trac + +import pptx +from pptx.chart.data import CategoryChartData +from pptx.enum.chart import XL_CHART_TYPE +from pptx.util import Inches + +import tutorial.schemas as schemas + + +class PowerpointReport(trac.TracModel): + + def define_parameters(self) -> _tp.Dict[str, trac.ModelParameter]: + + return trac.define_parameters() + + def define_inputs(self) -> _tp.Dict[str, trac.ModelInputSchema]: + + quarterly_sales_schema = trac.load_schema(schemas, "profit_by_region.csv") + quarterly_sales = trac.define_input(quarterly_sales_schema, label="Quarterly sales data") + report_template = trac.define_input(trac.CommonFileTypes.POWERPOINT, label="Quarterly sales report template") + + return { + "quarterly_sales": quarterly_sales, + "report_template": report_template } + + def define_outputs(self) -> _tp.Dict[str, trac.ModelOutputSchema]: + + sales_report = trac.define_output(trac.CommonFileTypes.POWERPOINT, label="Quarterly sales report") + + return { "sales_report": sales_report } + + def run_model(self, ctx: trac.TracContext): + + quarterly_sales = ctx.get_pandas_table("quarterly_sales") + + with ctx.get_file_stream("report_template") as report_template: + presentation = pptx.Presentation(report_template) + + slide = presentation.slides.add_slide(presentation.slide_layouts[5]) + + title_frame = slide.shapes[0].text_frame + title_frame.text = 'Profit by Region Report' + + # define chart data --------------------- + chart_data = CategoryChartData() + chart_data.categories = quarterly_sales["region"] + chart_data.add_series('Gross Profit', quarterly_sales["gross_profit"]) + + # add chart to slide -------------------- + x, y, cx, cy = Inches(2), Inches(2), Inches(6), Inches(4.5) + slide.shapes.add_chart(XL_CHART_TYPE.COLUMN_CLUSTERED, x, y, cx, cy, chart_data) + + with ctx.put_file_stream("sales_report") as sales_report: + presentation.save(sales_report) + + +if __name__ == "__main__": + import tracdap.rt.launch as launch + launch.launch_model(PowerpointReport, "config/file_io_powerpoint.yaml", "config/sys_config.yaml") diff --git a/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/file.proto b/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/file.proto index 37ec0f849..8a9b60fab 100644 --- a/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/file.proto +++ b/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/file.proto @@ -37,3 +37,10 @@ message FileDefinition { TagSelector storageId = 5; string dataItem = 6; } + + +message FileType { + + string extension = 1; + string mimeType = 2; +} diff --git a/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/model.proto b/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/model.proto index db70d1ea7..8148290b5 100644 --- a/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/model.proto +++ b/tracdap-api/tracdap-metadata/src/main/proto/tracdap/metadata/model.proto @@ -22,7 +22,9 @@ option java_package = "org.finos.tracdap.metadata"; option java_multiple_files = true; import "tracdap/metadata/type.proto"; +import "tracdap/metadata/object_id.proto"; import "tracdap/metadata/data.proto"; +import "tracdap/metadata/file.proto"; /** @@ -69,7 +71,12 @@ message ModelParameter { */ message ModelInputSchema { - SchemaDefinition schema = 1; + ObjectType objectType = 6; + + oneof requirement { + SchemaDefinition schema = 1; + FileType fileType = 7; + } optional string label = 2; @@ -93,7 +100,12 @@ message ModelInputSchema { */ message ModelOutputSchema { - SchemaDefinition schema = 1; + ObjectType objectType = 6; + + oneof requirement { + SchemaDefinition schema = 1; + FileType fileType = 7; + } optional string label = 2; diff --git a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/config/ServiceProperties.java b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/config/ServiceProperties.java index 6998df06f..f533c10d3 100644 --- a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/config/ServiceProperties.java +++ b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/config/ServiceProperties.java @@ -1,9 +1,10 @@ /* - * Copyright 2024 Accenture Global Solutions Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed to the Fintech Open Source Foundation (FINOS) under one or + * more contributor license agreements. See the NOTICE file distributed + * with this work for additional information regarding copyright ownership. + * FINOS licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * @@ -16,8 +17,6 @@ package org.finos.tracdap.common.config; -import java.util.Map; - public class ServiceProperties { public static final String GATEWAY_HTTP_PREFIX = "gateway.http.prefix"; diff --git a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/DelayedExecutionInterceptor.java b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/DelayedExecutionInterceptor.java index c1c575d2c..84539f036 100644 --- a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/DelayedExecutionInterceptor.java +++ b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/DelayedExecutionInterceptor.java @@ -1,9 +1,10 @@ /* - * Copyright 2024 Accenture Global Solutions Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at + * Licensed to the Fintech Open Source Foundation (FINOS) under one or + * more contributor license agreements. See the NOTICE file distributed + * with this work for additional information regarding copyright ownership. + * FINOS licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * diff --git a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/GrpcErrorMapping.java b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/GrpcErrorMapping.java index 5497f1f26..edb25203b 100644 --- a/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/GrpcErrorMapping.java +++ b/tracdap-libs/tracdap-lib-common/src/main/java/org/finos/tracdap/common/grpc/GrpcErrorMapping.java @@ -100,7 +100,7 @@ public static StatusRuntimeException processError(Throwable error) { log.warn("No gRPC error code mapping is available for the error {}", error.getClass().getSimpleName()); - var trailers = basicErrorTrailers(Status.Code.INTERNAL, Status.INTERNAL.getDescription()); + var trailers = basicErrorTrailers(Status.Code.INTERNAL, INTERNAL_ERROR_MESSAGE); return Status.fromCode(Status.Code.INTERNAL) .withDescription(Status.INTERNAL.getDescription()) diff --git a/tracdap-libs/tracdap-lib-test/src/main/java/org/finos/tracdap/test/meta/TestData.java b/tracdap-libs/tracdap-lib-test/src/main/java/org/finos/tracdap/test/meta/TestData.java index 551993370..4a9ffc941 100644 --- a/tracdap-libs/tracdap-lib-test/src/main/java/org/finos/tracdap/test/meta/TestData.java +++ b/tracdap-libs/tracdap-lib-test/src/main/java/org/finos/tracdap/test/meta/TestData.java @@ -355,6 +355,7 @@ public static ObjectDefinition dummyModelDef() { .putParameters("param1", ModelParameter.newBuilder().setParamType(TypeSystem.descriptor(BasicType.STRING)).build()) .putParameters("param2", ModelParameter.newBuilder().setParamType(TypeSystem.descriptor(BasicType.INTEGER)).build()) .putInputs("input1", ModelInputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(SchemaDefinition.newBuilder() .setSchemaType(SchemaType.TABLE) .setTable(TableSchema.newBuilder() @@ -373,6 +374,7 @@ public static ObjectDefinition dummyModelDef() { .setFormatCode("GBP")))) .build()) .putOutputs("output1", ModelOutputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(SchemaDefinition.newBuilder() .setSchemaType(SchemaType.TABLE) .setTable(TableSchema.newBuilder() diff --git a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/consistency/JobConsistencyValidator.java b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/consistency/JobConsistencyValidator.java index 328bb8144..2ba24ffbf 100644 --- a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/consistency/JobConsistencyValidator.java +++ b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/consistency/JobConsistencyValidator.java @@ -343,19 +343,19 @@ private static ValidationContext paramMatchesSchema(String paramName, Value para } // Param comes from the job definition - private static ValidationContext paramMatchesSchema(String paramName, Value paramValue, ModelParameter requiredParam, ValidationContext ctx) { + private static ValidationContext paramMatchesSchema(String paramName, Value paramValue, ModelParameter paramDef, ValidationContext ctx) { var paramType = TypeSystem.descriptor(paramValue); - var requiredType = requiredParam.getParamType(); + var requiredType = paramDef.getParamType(); return paramMatchesType(paramName, paramType, requiredType, ctx); } // Param comes from upstream node - private static ValidationContext paramMatchesSchema(String paramName, ModelParameter suppliedParam, ModelParameter requiredParam, ValidationContext ctx) { + private static ValidationContext paramMatchesSchema(String paramName, ModelParameter suppliedParam, ModelParameter paramDef, ValidationContext ctx) { var paramType = suppliedParam.getParamType(); - var requiredType = requiredParam.getParamType(); + var requiredType = paramDef.getParamType(); return paramMatchesType(paramName, paramType, requiredType, ctx); } @@ -389,14 +389,14 @@ private static ValidationContext inputMatchesSchema(String inputName, TagSelecto } // Input comes from the job definition - private static ValidationContext inputMatchesSchema(String inputName, TagSelector inputSelector, ModelInputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext inputMatchesSchema(String inputName, TagSelector inputSelector, ModelInputSchema inputDef, ValidationContext ctx) { var inputObject = ctx.getMetadataBundle().getResource(inputSelector); if (inputObject == null) { // It is fine if an optional input is not supplied - if (requiredSchema.getOptional()) + if (inputDef.getOptional()) return ctx; return ctx.error(String.format( @@ -404,44 +404,62 @@ private static ValidationContext inputMatchesSchema(String inputName, TagSelecto inputName, MetadataUtil.objectKey(inputSelector))); } - if (inputObject.getObjectType() != ObjectType.DATA) { + if (inputObject.getObjectType() != inputDef.getObjectType()) { return ctx.error(String.format( - "Input is not a dataset (expected %s, got %s)", - ObjectType.DATA, inputObject.getObjectType())); + "Input [%s] is not the right object type (expected %s, got %s)", + inputName, inputDef.getObjectType(), inputObject.getObjectType())); } - // In case inference failed, requiredSchema == null so stop here + // In case inference failed, inputDef == null so stop here if (ctx.failed()) return ctx; - if (requiredSchema.getDynamic()) - return checkDynamicDataSchema(inputObject.getData(), requiredSchema.getSchema(), ctx); + if (inputDef.getObjectType() == ObjectType.FILE) + return checkFileType(inputObject.getFile(), inputDef.getFileType(), ctx); + else if (inputDef.getDynamic()) + return checkDynamicDataSchema(inputObject.getData(), inputDef.getSchema(), ctx); else - return checkDataSchema(inputObject.getData(), requiredSchema.getSchema(), ctx); + return checkDataSchema(inputObject.getData(), inputDef.getSchema(), ctx); } // Input comes from upstream node - private static ValidationContext inputMatchesSchema(String inputName, ModelInputSchema inputSchema, ModelInputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext inputMatchesSchema(String inputName, ModelInputSchema inputSchema, ModelInputSchema inputDef, ValidationContext ctx) { - if (inputSchema.getOptional() && !requiredSchema.getOptional()) + if (inputSchema.getObjectType() != inputDef.getObjectType()) { + return ctx.error(String.format( + "Input [%s] is not the right object type (expected %s, got %s)", + inputName, inputDef.getObjectType(), inputSchema.getObjectType())); + } + + if (inputSchema.getOptional() && !inputDef.getOptional()) ctx.error("Required model input [" + inputName + "] is connected to an optional input"); - if (requiredSchema.getDynamic() || inputSchema.getDynamic()) - return checkDynamicDataSchema(inputSchema.getSchema(), requiredSchema.getSchema(), ctx); + if (inputDef.getObjectType() == ObjectType.FILE) + return checkFileType(inputSchema.getFileType(), inputDef.getFileType(), ctx); + else if (inputDef.getDynamic() || inputSchema.getDynamic()) + return checkDynamicDataSchema(inputSchema.getSchema(), inputDef.getSchema(), ctx); else - return checkDataSchema(inputSchema.getSchema(), requiredSchema.getSchema(), ctx); + return checkDataSchema(inputSchema.getSchema(), inputDef.getSchema(), ctx); } // Input comes from upstream node - private static ValidationContext inputMatchesSchema(String inputName, ModelOutputSchema outputSchema, ModelInputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext inputMatchesSchema(String inputName, ModelOutputSchema outputSchema, ModelInputSchema inputDef, ValidationContext ctx) { + + if (outputSchema.getObjectType() != inputDef.getObjectType()) { + return ctx.error(String.format( + "Input [%s] is not the right object type (expected %s, got %s)", + inputName, inputDef.getObjectType(), outputSchema.getObjectType())); + } - if (outputSchema.getOptional() && !requiredSchema.getOptional()) + if (outputSchema.getOptional() && !inputDef.getOptional()) ctx.error("Required model input [" + inputName + "] is connected to an optional model output"); - if (requiredSchema.getDynamic() || outputSchema.getDynamic()) - return checkDynamicDataSchema(outputSchema.getSchema(), requiredSchema.getSchema(), ctx); + if (inputDef.getObjectType() == ObjectType.FILE) + return checkFileType(outputSchema.getFileType(), inputDef.getFileType(), ctx); + else if (inputDef.getDynamic() || outputSchema.getDynamic()) + return checkDynamicDataSchema(outputSchema.getSchema(), inputDef.getSchema(), ctx); else - return checkDataSchema(outputSchema.getSchema(), requiredSchema.getSchema(), ctx); + return checkDataSchema(outputSchema.getSchema(), inputDef.getSchema(), ctx); } // Output comes from the job definition @@ -454,7 +472,7 @@ private static ValidationContext outputMatchesSchema(String outputName, TagSelec } // Output comes from the job definition - private static ValidationContext outputMatchesSchema(String outputName, TagSelector outputSelector, ModelOutputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext outputMatchesSchema(String outputName, TagSelector outputSelector, ModelOutputSchema outputDef, ValidationContext ctx) { var outputObject = ctx.getMetadataBundle().getResource(outputSelector); @@ -469,44 +487,62 @@ private static ValidationContext outputMatchesSchema(String outputName, TagSelec outputName, MetadataUtil.objectKey(outputSelector))); } - if (outputObject.getObjectType() != ObjectType.DATA) { + if (outputObject.getObjectType() != outputDef.getObjectType()) { return ctx.error(String.format( - "Output is not a dataset (expected %s, got %s)", - ObjectType.DATA, outputObject.getObjectType())); + "Output [%s] is not the right object type (expected %s, got %s)", + outputName, outputDef.getObjectType(), outputObject.getObjectType())); } - // In case inference failed, requiredSchema == null so stop here + // In case inference failed, outputDef == null so stop here if (ctx.failed()) return ctx; - if (requiredSchema.getDynamic()) - return checkDynamicDataSchema(outputObject.getData(), requiredSchema.getSchema(), ctx); + if (outputDef.getObjectType() == ObjectType.FILE) + return checkFileType(outputObject.getFile(), outputDef.getFileType(), ctx); + else if (outputDef.getDynamic()) + return checkDynamicDataSchema(outputObject.getData(), outputDef.getSchema(), ctx); else - return checkDataSchema(outputObject.getData(), requiredSchema.getSchema(), ctx); + return checkDataSchema(outputObject.getData(), outputDef.getSchema(), ctx); } // Output comes from upstream node - private static ValidationContext outputMatchesSchema(String outputName, ModelInputSchema inputSchema, ModelOutputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext outputMatchesSchema(String outputName, ModelInputSchema inputSchema, ModelOutputSchema outputDef, ValidationContext ctx) { + + if (inputSchema.getObjectType() != outputDef.getObjectType()) { + return ctx.error(String.format( + "Input [%s] is not the right object type (expected %s, got %s)", + outputName, outputDef.getObjectType(), inputSchema.getObjectType())); + } - if (inputSchema.getOptional() && !requiredSchema.getOptional()) + if (inputSchema.getOptional() && !outputDef.getOptional()) ctx.error("Required output [" + outputName + "] is connected to an optional input"); - if (requiredSchema.getDynamic() || inputSchema.getDynamic()) - return checkDynamicDataSchema(inputSchema.getSchema(), requiredSchema.getSchema(), ctx); + if (outputDef.getObjectType() == ObjectType.FILE) + return checkFileType(inputSchema.getFileType(), outputDef.getFileType(), ctx); + else if (outputDef.getDynamic() || inputSchema.getDynamic()) + return checkDynamicDataSchema(inputSchema.getSchema(), outputDef.getSchema(), ctx); else - return checkDataSchema(inputSchema.getSchema(), requiredSchema.getSchema(), ctx); + return checkDataSchema(inputSchema.getSchema(), outputDef.getSchema(), ctx); } // Output comes from upstream node - private static ValidationContext outputMatchesSchema(String outputName, ModelOutputSchema outputSchema, ModelOutputSchema requiredSchema, ValidationContext ctx) { + private static ValidationContext outputMatchesSchema(String outputName, ModelOutputSchema outputSchema, ModelOutputSchema outputDef, ValidationContext ctx) { + + if (outputSchema.getObjectType() != outputDef.getObjectType()) { + return ctx.error(String.format( + "Input [%s] is not the right object type (expected %s, got %s)", + outputName, outputDef.getObjectType(), outputSchema.getObjectType())); + } - if (outputSchema.getOptional() && !requiredSchema.getOptional()) + if (outputSchema.getOptional() && !outputDef.getOptional()) ctx.error("Required output [" + outputName + "] is connected to an optional model output"); - if (requiredSchema.getDynamic() || outputSchema.getDynamic()) - return checkDynamicDataSchema(outputSchema.getSchema(), requiredSchema.getSchema(), ctx); + if (outputDef.getObjectType() == ObjectType.FILE) + return checkFileType(outputSchema.getFileType(), outputDef.getFileType(), ctx); + else if (outputDef.getDynamic() || outputSchema.getDynamic()) + return checkDynamicDataSchema(outputSchema.getSchema(), outputDef.getSchema(), ctx); else - return checkDataSchema(outputSchema.getSchema(), requiredSchema.getSchema(), ctx); + return checkDataSchema(outputSchema.getSchema(), outputDef.getSchema(), ctx); } private static ValidationContext checkDataSchema(DataDefinition suppliedData, SchemaDefinition requiredSchema, ValidationContext ctx) { @@ -630,6 +666,44 @@ private static SchemaDefinition findSchema(DataDefinition dataset, MetadataBundl throw new EUnexpected(); } + private static ValidationContext checkFileType(FileDefinition fileDef, FileType fileType, ValidationContext ctx) { + + if (!fileType.getExtension().equals(fileDef.getExtension())) { + + ctx.error(String.format( + "File extension does not match (expected [%s], got [%s])", + fileType.getExtension(), fileDef.getExtension())); + } + + if (!fileType.getMimeType().equals(fileDef.getMimeType())) { + + ctx.error(String.format( + "Mime type does not match (expected [%s], got [%s])", + fileType.getExtension(), fileDef.getExtension())); + } + + return ctx; + } + + private static ValidationContext checkFileType(FileType suppliedFileType, FileType requiredFileType, ValidationContext ctx) { + + if (!requiredFileType.getExtension().equals(suppliedFileType.getExtension())) { + + ctx.error(String.format( + "File extension does not match (expected [%s], got [%s])", + requiredFileType.getExtension(), suppliedFileType.getExtension())); + } + + if (!requiredFileType.getMimeType().equals(suppliedFileType.getMimeType())) { + + ctx.error(String.format( + "Mime type does not match (expected [%s], got [%s])", + requiredFileType.getExtension(), suppliedFileType.getExtension())); + } + + return ctx; + } + // ----------------------------------------------------------------------------------------------------------------- // Graph consistency checks (model and output nodes, tracing upstream) diff --git a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/FileValidator.java b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/FileValidator.java index e249326ef..4c0eba879 100644 --- a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/FileValidator.java +++ b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/FileValidator.java @@ -22,6 +22,7 @@ import org.finos.tracdap.common.validation.core.Validator; import org.finos.tracdap.common.validation.core.ValidatorUtils; import org.finos.tracdap.metadata.FileDefinition; +import org.finos.tracdap.metadata.FileType; import org.finos.tracdap.metadata.ObjectType; import org.finos.tracdap.metadata.TagSelector; import com.google.protobuf.Descriptors; @@ -32,7 +33,8 @@ @Validator(type = ValidationType.STATIC) public class FileValidator { - private static final Pattern EXT_PATTERN = Pattern.compile(".*\\.([^./\\\\]+)"); + private static final Pattern NAME_EXT_PATTERN = Pattern.compile(".*\\.([^./\\\\]+)\\Z"); + private static final Pattern EXT_PATTERN = Pattern.compile("\\A([^./\\\\]+)\\Z"); private static final Descriptors.Descriptor FILE_DEF; private static final Descriptors.FieldDescriptor FD_NAME; @@ -42,6 +44,10 @@ public class FileValidator { private static final Descriptors.FieldDescriptor FD_DATA_ITEM; private static final Descriptors.FieldDescriptor FD_STORAGE_ID; + private static final Descriptors.Descriptor FILE_TYPE; + private static final Descriptors.FieldDescriptor FT_EXTENSION; + private static final Descriptors.FieldDescriptor FT_MIME_TYPE; + static { FILE_DEF = FileDefinition.getDescriptor(); FD_NAME = ValidatorUtils.field(FILE_DEF, FileDefinition.NAME_FIELD_NUMBER); @@ -50,6 +56,10 @@ public class FileValidator { FD_SIZE = ValidatorUtils.field(FILE_DEF, FileDefinition.SIZE_FIELD_NUMBER); FD_DATA_ITEM = ValidatorUtils.field(FILE_DEF, FileDefinition.DATAITEM_FIELD_NUMBER); FD_STORAGE_ID = ValidatorUtils.field(FILE_DEF, FileDefinition.STORAGEID_FIELD_NUMBER); + + FILE_TYPE = FileType.getDescriptor(); + FT_EXTENSION = ValidatorUtils.field(FILE_TYPE, FileType.EXTENSION_FIELD_NUMBER); + FT_MIME_TYPE = ValidatorUtils.field(FILE_TYPE, FileType.MIMETYPE_FIELD_NUMBER); } @Validator @@ -89,9 +99,25 @@ public static ValidationContext file(FileDefinition msg, ValidationContext ctx) return ctx; } + @Validator + public static ValidationContext fileType(FileType msg, ValidationContext ctx) { + + ctx = ctx.push(FT_EXTENSION) + .apply(CommonValidators::required) + .apply(FileValidator::extensionIsValid) + .pop(); + + ctx = ctx.push(FT_MIME_TYPE) + .apply(CommonValidators::required) + .apply(CommonValidators::mimeType) + .pop(); + + return ctx; + } + private static ValidationContext extensionMatchesName(String extension, String fileName, ValidationContext ctx) { - var nameExtMatch = EXT_PATTERN.matcher(fileName); + var nameExtMatch = NAME_EXT_PATTERN.matcher(fileName); if (nameExtMatch.matches()) { @@ -109,4 +135,14 @@ private static ValidationContext extensionMatchesName(String extension, String f return ctx; } + + private static ValidationContext extensionIsValid(String extension, ValidationContext ctx) { + + var extMatch = EXT_PATTERN.matcher(extension); + + if (!extMatch.matches()) + ctx.error(String.format("File extension [%s] is not valid", extension)); + + return ctx; + } } diff --git a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/JobValidator.java b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/JobValidator.java index 9e8e0a2b1..3dd175fee 100644 --- a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/JobValidator.java +++ b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/JobValidator.java @@ -24,6 +24,7 @@ import com.google.protobuf.Descriptors; +import java.util.List; import java.util.Map; import static org.finos.tracdap.common.validation.core.ValidatorUtils.field; @@ -37,6 +38,8 @@ public class JobValidator { Map.entry(JobDefinition.JobDetailsCase.RUNFLOW, JobType.RUN_FLOW), Map.entry(JobDefinition.JobDetailsCase.IMPORTMODEL, JobType.IMPORT_MODEL)); + private static final List ALLOWED_IO_TYPES = List.of(ObjectType.DATA, ObjectType.FILE); + private static final Descriptors.Descriptor JOB_DEFINITION; private static final Descriptors.FieldDescriptor JD_JOB_TYPE; private static final Descriptors.OneofDescriptor JD_JOB_DETAILS; @@ -165,7 +168,7 @@ public static ValidationContext runModelOrFlow( .applyMapKeys(CommonValidators::identifier) .applyMapKeys(CommonValidators::notTracReserved) .applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class) - .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.DATA) + .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ALLOWED_IO_TYPES) .applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class) .pop(); @@ -173,7 +176,7 @@ public static ValidationContext runModelOrFlow( .applyMapKeys(CommonValidators::identifier) .applyMapKeys(CommonValidators::notTracReserved) .applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class) - .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.DATA) + .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ALLOWED_IO_TYPES) .applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class) .pop(); @@ -181,7 +184,7 @@ public static ValidationContext runModelOrFlow( .applyMapKeys(CommonValidators::identifier) .applyMapKeys(CommonValidators::notTracReserved) .applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class) - .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.DATA) + .applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ALLOWED_IO_TYPES) .applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class) .pop(); diff --git a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ModelValidator.java b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ModelValidator.java index 4770a811b..97447d14e 100644 --- a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ModelValidator.java +++ b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ModelValidator.java @@ -51,14 +51,18 @@ public class ModelValidator { private static final Descriptors.FieldDescriptor MP_PARAM_PROPS; private static final Descriptors.Descriptor MODEL_INPUT_SCHEMA; + private static final Descriptors.FieldDescriptor MIS_OBJECT_TYPE; private static final Descriptors.FieldDescriptor MIS_SCHEMA; + private static final Descriptors.FieldDescriptor MIS_FILE_TYPE; private static final Descriptors.FieldDescriptor MIS_LABEL; private static final Descriptors.FieldDescriptor MIS_OPTIONAL; private static final Descriptors.FieldDescriptor MIS_DYNAMIC; private static final Descriptors.FieldDescriptor MIS_INPUT_PROPS; private static final Descriptors.Descriptor MODEL_OUTPUT_SCHEMA; + private static final Descriptors.FieldDescriptor MOS_OBJECT_TYPE; private static final Descriptors.FieldDescriptor MOS_SCHEMA; + private static final Descriptors.FieldDescriptor MOS_FILE_TYPE; private static final Descriptors.FieldDescriptor MOS_LABEL; private static final Descriptors.FieldDescriptor MOS_OPTIONAL; private static final Descriptors.FieldDescriptor MOS_DYNAMIC; @@ -83,14 +87,18 @@ public class ModelValidator { MP_PARAM_PROPS = field(MODEL_PARAMETER, ModelParameter.PARAMPROPS_FIELD_NUMBER); MODEL_INPUT_SCHEMA = ModelInputSchema.getDescriptor(); + MIS_OBJECT_TYPE = field(MODEL_INPUT_SCHEMA, ModelOutputSchema.OBJECTTYPE_FIELD_NUMBER); MIS_SCHEMA = field(MODEL_INPUT_SCHEMA, ModelInputSchema.SCHEMA_FIELD_NUMBER); + MIS_FILE_TYPE = field(MODEL_INPUT_SCHEMA, ModelInputSchema.FILETYPE_FIELD_NUMBER); MIS_LABEL = field(MODEL_INPUT_SCHEMA, ModelInputSchema.LABEL_FIELD_NUMBER); MIS_OPTIONAL = field(MODEL_INPUT_SCHEMA, ModelInputSchema.OPTIONAL_FIELD_NUMBER); MIS_DYNAMIC = field(MODEL_INPUT_SCHEMA, ModelInputSchema.DYNAMIC_FIELD_NUMBER); MIS_INPUT_PROPS = field(MODEL_INPUT_SCHEMA, ModelInputSchema.INPUTPROPS_FIELD_NUMBER); MODEL_OUTPUT_SCHEMA = ModelOutputSchema.getDescriptor(); + MOS_OBJECT_TYPE = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.OBJECTTYPE_FIELD_NUMBER); MOS_SCHEMA = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.SCHEMA_FIELD_NUMBER); + MOS_FILE_TYPE = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.FILETYPE_FIELD_NUMBER); MOS_LABEL = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.LABEL_FIELD_NUMBER); MOS_OPTIONAL = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.OPTIONAL_FIELD_NUMBER); MOS_DYNAMIC = field(MODEL_OUTPUT_SCHEMA, ModelOutputSchema.DYNAMIC_FIELD_NUMBER); @@ -207,14 +215,35 @@ public static ValidationContext modelParameter(ModelParameter msg, ValidationCon @Validator public static ValidationContext modelInputSchema(ModelInputSchema msg, ValidationContext ctx) { - // Dynamic schemas require different validation logic - - ctx = ctx.push(MIS_SCHEMA) + ctx = ctx.push(MIS_OBJECT_TYPE) .apply(CommonValidators::required) - .applyIf(!msg.getDynamic(), SchemaValidator::schema, SchemaDefinition.class) - .applyIf(msg.getDynamic(), SchemaValidator::dynamicSchema, SchemaDefinition.class) + .apply(CommonValidators::nonZeroEnum, ObjectType.class) .pop(); + if (msg.getObjectType() == ObjectType.DATA) { + + // Dynamic schemas require different validation logic + + ctx = ctx.push(MIS_SCHEMA) + .apply(CommonValidators::required) + .applyIf(!msg.getDynamic(), SchemaValidator::schema, SchemaDefinition.class) + .applyIf(msg.getDynamic(), SchemaValidator::dynamicSchema, SchemaDefinition.class) + .pop(); + } + else if (msg.getObjectType() == ObjectType.FILE) { + + ctx = ctx.push(MIS_FILE_TYPE) + .apply(CommonValidators::required) + .apply(FileValidator::fileType, FileType.class) + .pop(); + } + else { + + ctx = ctx.push(MIS_OBJECT_TYPE) + .error(String.format("Object type [%s] is not supported", msg.getObjectType())) + .pop(); + } + ctx = ctx.push(MIS_LABEL) .apply(CommonValidators::optional) .apply(CommonValidators::labelLengthLimit) @@ -231,14 +260,35 @@ public static ValidationContext modelInputSchema(ModelInputSchema msg, Validatio @Validator public static ValidationContext modelOutputSchema(ModelOutputSchema msg, ValidationContext ctx) { - // Dynamic schemas require different validation logic - - ctx = ctx.push(MOS_SCHEMA) + ctx = ctx.push(MOS_OBJECT_TYPE) .apply(CommonValidators::required) - .applyIf(!msg.getDynamic(), SchemaValidator::schema, SchemaDefinition.class) - .applyIf(msg.getDynamic(), SchemaValidator::dynamicSchema, SchemaDefinition.class) + .apply(CommonValidators::nonZeroEnum, ObjectType.class) .pop(); + if (msg.getObjectType() == ObjectType.DATA) { + + // Dynamic schemas require different validation logic + + ctx = ctx.push(MOS_SCHEMA) + .apply(CommonValidators::required) + .applyIf(!msg.getDynamic(), SchemaValidator::schema, SchemaDefinition.class) + .applyIf(msg.getDynamic(), SchemaValidator::dynamicSchema, SchemaDefinition.class) + .pop(); + } + else if (msg.getObjectType() == ObjectType.FILE) { + + ctx = ctx.push(MOS_FILE_TYPE) + .apply(CommonValidators::required) + .apply(FileValidator::fileType, FileType.class) + .pop(); + } + else { + + ctx = ctx.push(MOS_OBJECT_TYPE) + .error(String.format("Object type [%s] is not supported", msg.getObjectType())) + .pop(); + } + ctx = ctx.push(MOS_LABEL) .apply(CommonValidators::optional) .apply(CommonValidators::labelLengthLimit) diff --git a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ObjectIdValidator.java b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ObjectIdValidator.java index 30b118bed..965b60349 100644 --- a/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ObjectIdValidator.java +++ b/tracdap-libs/tracdap-lib-validation/src/main/java/org/finos/tracdap/common/validation/static_/ObjectIdValidator.java @@ -25,6 +25,9 @@ import org.finos.tracdap.common.validation.core.ValidationContext; import com.google.protobuf.Descriptors; +import java.util.List; +import java.util.stream.Collectors; + import static org.finos.tracdap.common.validation.core.ValidatorUtils.field; @@ -205,6 +208,18 @@ public static ValidationContext selectorType(TagSelector selector, ObjectType re return ctx; } + public static ValidationContext selectorType(TagSelector selector, List allowedTypes, ValidationContext ctx) { + + if (!allowedTypes.contains(selector.getObjectType())) { + var allowed = allowedTypes.stream().map(Enum::name).collect(Collectors.joining(", ")); + var err = String.format("Wrong object type in [%s] selector: allowed [%s], got [%s]", + ctx.fieldName(), allowed, selector.getObjectType()); + return ctx.error(err); + } + + return ctx; + } + public static ValidationContext selectorForLatest(TagSelector selector, ValidationContext ctx) { if (!selector.getLatestObject() || !selector.getLatestTag()) { diff --git a/tracdap-libs/tracdap-lib-validation/src/test/java/org/finos/tracdap/common/validation/static_/FlowValidatorTest.java b/tracdap-libs/tracdap-lib-validation/src/test/java/org/finos/tracdap/common/validation/static_/FlowValidatorTest.java index 358b7dc5d..c49081e62 100644 --- a/tracdap-libs/tracdap-lib-validation/src/test/java/org/finos/tracdap/common/validation/static_/FlowValidatorTest.java +++ b/tracdap-libs/tracdap-lib-validation/src/test/java/org/finos/tracdap/common/validation/static_/FlowValidatorTest.java @@ -115,6 +115,7 @@ void basicFlow_ok3() { // Flow schema .putInputs("input_1", ModelInputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(SchemaDefinition.newBuilder() .setSchemaType(SchemaType.TABLE) .setPartType(PartType.PART_ROOT) @@ -132,6 +133,7 @@ void basicFlow_ok3() { .build()) .putInputs("input_2", ModelInputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(SchemaDefinition.newBuilder() .setSchemaType(SchemaType.TABLE) .setPartType(PartType.PART_ROOT) @@ -149,6 +151,7 @@ void basicFlow_ok3() { .build()) .putOutputs("output_1", ModelOutputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(SchemaDefinition.newBuilder() .setSchemaType(SchemaType.TABLE) .setPartType(PartType.PART_ROOT) diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/context.py b/tracdap-runtime/python/src/tracdap/rt/_exec/context.py index 10c1d4c76..da2e07ddf 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/context.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/context.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import copy +import io import logging import pathlib import typing as tp @@ -82,9 +84,9 @@ def get_parameter(self, parameter_name: str) -> tp.Any: _val.validate_signature(self.get_parameter, parameter_name) - self.__val.check_param_valid_identifier(parameter_name) - self.__val.check_param_defined_in_model(parameter_name) - self.__val.check_param_available_in_context(parameter_name) + self.__val.check_item_valid_identifier(parameter_name, TracContextValidator.PARAMETER) + self.__val.check_item_defined_in_model(parameter_name, TracContextValidator.PARAMETER) + self.__val.check_item_available_in_context(parameter_name, TracContextValidator.PARAMETER) value: _meta.Value = self.__local_ctx.get(parameter_name) @@ -96,8 +98,8 @@ def has_dataset(self, dataset_name: str) -> bool: _val.validate_signature(self.has_dataset, dataset_name) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_defined_in_model(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_defined_in_model(dataset_name, TracContextValidator.DATASET) data_view: _data.DataView = self.__local_ctx.get(dataset_name) @@ -105,6 +107,7 @@ def has_dataset(self, dataset_name: str) -> bool: return False self.__val.check_context_object_type(dataset_name, data_view, _data.DataView) + self.__val.check_context_data_view_type(dataset_name, data_view, _meta.ObjectType.DATA) return not data_view.is_empty() @@ -112,9 +115,9 @@ def get_schema(self, dataset_name: str) -> _meta.SchemaDefinition: _val.validate_signature(self.get_schema, dataset_name) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_defined_in_model(dataset_name) - self.__val.check_dataset_available_in_context(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_defined_in_model(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_available_in_context(dataset_name, TracContextValidator.DATASET) static_schema = self.__get_static_schema(self.__model_def, dataset_name) data_view: _data.DataView = self.__local_ctx.get(dataset_name) @@ -123,6 +126,7 @@ def get_schema(self, dataset_name: str) -> _meta.SchemaDefinition: # This ensures errors are always reported and is consistent with get_pandas_table() self.__val.check_context_object_type(dataset_name, data_view, _data.DataView) + self.__val.check_context_data_view_type(dataset_name, data_view, _meta.ObjectType.DATA) self.__val.check_dataset_schema_defined(dataset_name, data_view) # If a static schema exists, that takes priority @@ -138,9 +142,9 @@ def get_table(self, dataset_name: str, framework: _eapi.DataFramework[_eapi.DATA _val.validate_signature(self.get_table, dataset_name, framework) _val.require_package(framework.protocol_name, framework.api_type) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_defined_in_model(dataset_name) - self.__val.check_dataset_available_in_context(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_defined_in_model(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_available_in_context(dataset_name, TracContextValidator.DATASET) self.__val.check_data_framework_args(framework, framework_args) static_schema = self.__get_static_schema(self.__model_def, dataset_name) @@ -150,6 +154,7 @@ def get_table(self, dataset_name: str, framework: _eapi.DataFramework[_eapi.DATA converter = _data.DataConverter.for_framework(framework, **framework_args) self.__val.check_context_object_type(dataset_name, data_view, _data.DataView) + self.__val.check_context_data_view_type(dataset_name, data_view, _meta.ObjectType.DATA) self.__val.check_dataset_schema_defined(dataset_name, data_view) self.__val.check_dataset_part_present(dataset_name, data_view, part_key) @@ -173,6 +178,27 @@ def get_pandas_table(self, dataset_name: str, use_temporal_objects: tp.Optional[ def get_polars_table(self, dataset_name: str) -> "_data.polars.DataFrame": return self.get_table(dataset_name, _eapi.POLARS) + + def get_file(self, file_name: str) -> bytes: + + _val.validate_signature(self.get_file, file_name) + + self.__val.check_item_valid_identifier(file_name, TracContextValidator.FILE) + self.__val.check_item_defined_in_model(file_name, TracContextValidator.FILE) + self.__val.check_item_available_in_context(file_name, TracContextValidator.FILE) + + file_view: _data.DataView = self.__local_ctx.get(file_name) + + self.__val.check_context_object_type(file_name, file_view, _data.DataView) + self.__val.check_context_data_view_type(file_name, file_view, _meta.ObjectType.FILE) + self.__val.check_file_content_present(file_name, file_view) + + return file_view.file_item.raw_bytes + + def get_file_stream(self, file_name: str) -> tp.ContextManager[tp.BinaryIO]: + + buffer = self.get_file(file_name) + return contextlib.closing(io.BytesIO(buffer)) def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition): @@ -182,7 +208,7 @@ def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition): # If field ordering is not assigned by the model, assign it here (model code will not see the numbers) schema_copy = self.__assign_field_order(copy.deepcopy(schema)) - self.__val.check_dataset_valid_identifier(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) self.__val.check_dataset_is_dynamic_output(dataset_name) self.__val.check_provided_schema_is_valid(dataset_name, schema_copy) @@ -197,6 +223,7 @@ def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition): # If there is a prior view it must contain nothing and will be replaced self.__val.check_context_object_type(dataset_name, data_view, _data.DataView) + self.__val.check_context_data_view_type(dataset_name, data_view, _meta.ObjectType.DATA) self.__val.check_dataset_schema_not_defined(dataset_name, data_view) self.__val.check_dataset_is_empty(dataset_name, data_view) @@ -216,8 +243,8 @@ def put_table( _val.require_package(framework.protocol_name, framework.api_type) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_is_model_output(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_is_model_output(dataset_name, TracContextValidator.DATASET) self.__val.check_provided_dataset_type(dataset, framework.api_type) self.__val.check_data_framework_args(framework, framework_args) @@ -234,6 +261,7 @@ def put_table( data_view = _data.DataView.create_empty() self.__val.check_context_object_type(dataset_name, data_view, _data.DataView) + self.__val.check_context_data_view_type(dataset_name, data_view, _meta.ObjectType.DATA) self.__val.check_dataset_schema_defined(dataset_name, data_view) self.__val.check_dataset_part_not_present(dataset_name, data_view, part_key) @@ -246,7 +274,7 @@ def put_table( # Data conformance is applied automatically inside the converter, if schema != None table = converter.to_internal(dataset, schema) - item = _data.DataItem(schema, table) + item = _data.DataItem(_meta.ObjectType.DATA, schema, table) updated_view = _data.DataMapping.add_item_to_view(data_view, part_key, item) @@ -259,6 +287,46 @@ def put_pandas_table(self, dataset_name: str, dataset: "_data.pandas.DataFrame") def put_polars_table(self, dataset_name: str, dataset: "_data.polars.DataFrame"): self.put_table(dataset_name, dataset, _eapi.POLARS) + + def put_file(self, file_name: str, file_content: tp.Union[bytes, bytearray]): + + _val.validate_signature(self.put_file, file_name, file_content) + + self.__val.check_item_valid_identifier(file_name, TracContextValidator.FILE) + self.__val.check_item_is_model_output(file_name, TracContextValidator.FILE) + + file_view: _data.DataView = self.__local_ctx.get(file_name) + + if file_view is None: + file_view = _data.DataView.create_empty(_meta.ObjectType.FILE) + + self.__val.check_context_object_type(file_name, file_view, _data.DataView) + self.__val.check_context_data_view_type(file_name, file_view, _meta.ObjectType.FILE) + self.__val.check_file_content_not_present(file_name, file_view) + + if isinstance(file_content, bytearray): + file_content = bytes(bytearray) + + file_item = _data.DataItem.for_file_content(file_content) + self.__local_ctx[file_name] = file_view.with_file_item(file_item) + + def put_file_stream(self, file_name: str) -> tp.ContextManager[tp.BinaryIO]: + + _val.validate_signature(self.put_file_stream, file_name) + + self.__val.check_item_valid_identifier(file_name, TracContextValidator.FILE) + self.__val.check_item_is_model_output(file_name, TracContextValidator.FILE) + + @contextlib.contextmanager + def memory_stream(stream: io.BytesIO): + try: + yield stream + buffer = stream.getbuffer().tobytes() + self.put_file(file_name, buffer) + finally: + stream.close() + + return memory_stream(io.BytesIO()) def log(self) -> logging.Logger: @@ -310,7 +378,7 @@ def __init__( self.__storage_map = storage_map self.__checkout_directory = checkout_directory - self.__val = self._TracContextImpl__val # noqa + self.__val: TracContextValidator = self._TracContextImpl__val # noqa def get_file_storage(self, storage_key: str) -> _eapi.TracFileStorage: @@ -348,9 +416,9 @@ def add_data_import(self, dataset_name: str): _val.validate_signature(self.add_data_import, dataset_name) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_not_defined_in_model(dataset_name) - self.__val.check_dataset_not_available_in_context(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_not_defined_in_model(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_not_available_in_context(dataset_name, TracContextValidator.DATASET) self.__local_ctx[dataset_name] = _data.DataView.create_empty() self.__dynamic_outputs.append(dataset_name) @@ -359,8 +427,8 @@ def set_source_metadata(self, dataset_name: str, storage_key: str, source_info: _val.validate_signature(self.set_source_metadata, dataset_name, storage_key, source_info) - self.__val.check_dataset_valid_identifier(dataset_name) - self.__val.check_dataset_available_in_context(dataset_name) + self.__val.check_item_valid_identifier(dataset_name, TracContextValidator.DATASET) + self.__val.check_item_available_in_context(dataset_name, TracContextValidator.DATASET) self.__val.check_storage_valid_identifier(storage_key) self.__val.check_storage_available(self.__storage_map, storage_key) @@ -368,11 +436,11 @@ def set_source_metadata(self, dataset_name: str, storage_key: str, source_info: if isinstance(storage, _eapi.TracFileStorage): if not isinstance(source_info, _eapi.FileStat): - self.__val.report_public_error(f"Expected storage_info to be a FileStat, [{storage_key}] refers to file storage") + self.__val.report_public_error(_ex.ERuntimeValidation(f"Expected storage_info to be a FileStat, [{storage_key}] refers to file storage")) if isinstance(storage, _eapi.TracDataStorage): if not isinstance(source_info, str): - self.__val.report_public_error(f"Expected storage_info to be a table name, [{storage_key}] refers to dadta storage") + self.__val.report_public_error(_ex.ERuntimeValidation(f"Expected storage_info to be a table name, [{storage_key}] refers to dadta storage")) pass # Not implemented yet, only required when imports are sent back to the platform @@ -684,6 +752,10 @@ def _type_name(type_: type): class TracContextValidator(TracContextErrorReporter): + PARAMETER = "Parameter" + DATASET = "Dataset" + FILE = "File" + def __init__( self, log: logging.Logger, model_def: _meta.ModelDefinition, @@ -697,49 +769,45 @@ def __init__( self.__local_ctx = local_ctx self.__dynamic_outputs = dynamic_outputs - def check_param_valid_identifier(self, param_name: str): - - if param_name is None: - self._report_error(f"Parameter name is null") - - if not self._VALID_IDENTIFIER.match(param_name): - self._report_error(f"Parameter name {param_name} is not a valid identifier") - - def check_param_defined_in_model(self, param_name: str): + def check_item_valid_identifier(self, item_name: str, item_type: str): - if param_name not in self.__model_def.parameters: - self._report_error(f"Parameter {param_name} is not defined in the model") + if item_name is None: + self._report_error(f"{item_type} name is null") - def check_param_available_in_context(self, param_name: str): + if not self._VALID_IDENTIFIER.match(item_name): + self._report_error(f"{item_type} name {item_name} is not a valid identifier") - if param_name not in self.__local_ctx: - self._report_error(f"Parameter {param_name} is not available in the current context") + def check_item_defined_in_model(self, item_name: str, item_type: str): - def check_dataset_valid_identifier(self, dataset_name: str): + if item_type == self.PARAMETER: + if item_name not in self.__model_def.parameters: + self._report_error(f"{item_type} {item_name} is not defined in the model") + else: + if item_name not in self.__model_def.inputs and item_name not in self.__model_def.outputs: + self._report_error(f"{item_type} {item_name} is not defined in the model") - if dataset_name is None: - self._report_error(f"Dataset name is null") + def check_item_not_defined_in_model(self, item_name: str, item_type: str): - if not self._VALID_IDENTIFIER.match(dataset_name): - self._report_error(f"Dataset name {dataset_name} is not a valid identifier") + if item_name in self.__model_def.inputs or item_name in self.__model_def.outputs: + self._report_error(f"{item_type} {item_name} is already defined in the model") - def check_dataset_not_defined_in_model(self, dataset_name: str): + if item_name in self.__model_def.parameters: + self._report_error(f"{item_name} name {item_name} is already in use as a model parameter") - if dataset_name in self.__model_def.inputs or dataset_name in self.__model_def.outputs: - self._report_error(f"Dataset {dataset_name} is already defined in the model") + def check_item_is_model_output(self, item_name: str, item_type: str): - if dataset_name in self.__model_def.parameters: - self._report_error(f"Dataset name {dataset_name} is already in use as a model parameter") + if item_name not in self.__model_def.outputs and item_name not in self.__dynamic_outputs: + self._report_error(f"{item_type} {item_name} is not defined as a model output") - def check_dataset_defined_in_model(self, dataset_name: str): + def check_item_available_in_context(self, item_name: str, item_type: str): - if dataset_name not in self.__model_def.inputs and dataset_name not in self.__model_def.outputs: - self._report_error(f"Dataset {dataset_name} is not defined in the model") + if item_name not in self.__local_ctx: + self._report_error(f"{item_type} {item_name} is not available in the current context") - def check_dataset_is_model_output(self, dataset_name: str): + def check_item_not_available_in_context(self, item_name: str, item_type: str): - if dataset_name not in self.__model_def.outputs and dataset_name not in self.__dynamic_outputs: - self._report_error(f"Dataset {dataset_name} is not defined as a model output") + if item_name in self.__local_ctx: + self._report_error(f"{item_type} {item_name} already exists in the current context") def check_dataset_is_dynamic_output(self, dataset_name: str): @@ -752,16 +820,6 @@ def check_dataset_is_dynamic_output(self, dataset_name: str): if model_output and not model_output.dynamic: self._report_error(f"Model output {dataset_name} is not a dynamic output") - def check_dataset_available_in_context(self, item_name: str): - - if item_name not in self.__local_ctx: - self._report_error(f"Dataset {item_name} is not available in the current context") - - def check_dataset_not_available_in_context(self, item_name: str): - - if item_name in self.__local_ctx: - self._report_error(f"Dataset {item_name} already exists in the current context") - def check_dataset_schema_defined(self, dataset_name: str, data_view: _data.DataView): schema = data_view.trac_schema if data_view is not None else None @@ -834,6 +892,14 @@ def check_context_object_type(self, item_name: str, item: tp.Any, expected_type: f"The object referenced by [{item_name}] in the current context has the wrong type" + f" (expected {expected_type_name}, got {actual_type_name})") + def check_context_data_view_type(self, item_name: str, data_vew: _data.DataView, expected_type: _meta.ObjectType): + + if data_vew.object_type != expected_type: + + self._report_error( + f"The object referenced by [{item_name}] in the current context has the wrong type" + + f" (expected {expected_type.name}, got {data_vew.object_type.name})") + def check_data_framework_args(self, framework: _eapi.DataFramework, framework_args: tp.Dict[str, tp.Any]): expected_args = _data.DataConverter.get_framework_args(framework) @@ -861,6 +927,16 @@ def check_data_framework_args(self, framework: _eapi.DataFramework, framework_ar f"Using [{framework}], argument [{arg_name}] has the wrong type" + f" (expected {expected_type_name}, got {actual_type_name})") + def check_file_content_present(self, file_name: str, file_view: _data.DataView): + + if file_view.file_item is None or not file_view.file_item.raw_bytes: + self._report_error(f"File content is missing or empty for [{file_name}] in the current context") + + def check_file_content_not_present(self, file_name: str, file_view: _data.DataView): + + if file_view.file_item is not None and file_view.file_item.raw_bytes: + self._report_error(f"File content is already present for [{file_name}] in the current context") + def check_storage_valid_identifier(self, storage_key): if storage_key is None: @@ -878,7 +954,7 @@ def check_storage_available(self, storage_map: tp.Dict, storage_key: str): def check_storage_type( self, storage_map: tp.Dict, storage_key: str, - storage_type: tp.Union[_eapi.TracFileStorage.__class__]): + storage_type: tp.Union[_eapi.TracFileStorage.__class__, _eapi.TracDataStorage.__class__]): storage_instance = storage_map.get(storage_key) diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/dev_mode.py b/tracdap-runtime/python/src/tracdap/rt/_exec/dev_mode.py index e0bb89b6b..adebebbc4 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/dev_mode.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/dev_mode.py @@ -137,11 +137,14 @@ def _resolve_storage_location(cls, bucket_key, bucket_config, config_mgr: _cfg_p raise _ex.EConfigParse(msg) - def __init__(self, sys_config: _cfg.RuntimeConfig, config_mgr: _cfg_p.ConfigManager, scratch_dir: pathlib.Path): + def __init__( + self, sys_config: _cfg.RuntimeConfig, config_mgr: _cfg_p.ConfigManager, scratch_dir: pathlib.Path = None, + model_loader: _models.ModelLoader = None, storage_manager: _storage.StorageManager = None): + self._sys_config = sys_config self._config_mgr = config_mgr - self._scratch_dir = scratch_dir - self._model_loader: tp.Optional[_models.ModelLoader] = None + self._model_loader = model_loader or _models.ModelLoader(self._sys_config, scratch_dir) + self._storage_manager = storage_manager or _storage.StorageManager(self._sys_config) def translate_job_config( self, job_config: _cfg.JobConfig, @@ -150,8 +153,6 @@ def translate_job_config( try: self._log.info(f"Applying dev mode config translation to job config") - - self._model_loader = _models.ModelLoader(self._sys_config, self._scratch_dir) self._model_loader.create_scope("DEV_MODE_TRANSLATION") job_config = copy.deepcopy(job_config) @@ -168,7 +169,6 @@ def translate_job_config( finally: self._model_loader.destroy_scope("DEV_MODE_TRANSLATION") - self._model_loader = None def translate_job_def( self, job_config: _cfg.JobConfig, job_def: _meta.JobDefinition, @@ -694,7 +694,7 @@ def _infer_output_schema( model_selector = job_def.runFlow.models.get(source.node) model_obj = _util.get_job_resource(model_selector, job_config) - model_input = model_obj.model.inputs.get(source.socket) + model_input = model_obj.model.outputs.get(source.socket) model_outputs.append(model_input) if len(model_outputs) == 0: @@ -764,7 +764,7 @@ def _process_parameters_dict( else: p_spec = param_specs[p_name] - cls._log.info(f"Encoding parameter [{p_name}] as {p_spec.paramType.basicType}") + cls._log.info(f"Encoding parameter [{p_name}] as {p_spec.paramType.basicType.name}") encoded_value = _types.MetadataCodec.convert_value(p_value, p_spec.paramType) encoded_values[p_name] = encoded_value @@ -798,38 +798,46 @@ def _process_inputs_and_outputs( if not (isinstance(input_value, str) and input_value in job_resources): model_input = required_inputs[input_key] - input_schema = model_input.schema if model_input and not model_input.dynamic else None - input_id = self._process_input_or_output( - input_key, input_value, job_resources, - new_unique_file=False, schema=input_schema) + if model_input.objectType == _meta.ObjectType.DATA: + schema = model_input.schema if model_input and not model_input.dynamic else None + input_id = self._process_data_socket(input_key, input_value, schema, job_resources, new_unique_file=False) + elif model_input.objectType == _meta.ObjectType.FILE: + file_type = model_input.fileType + input_id = self._process_file_socket(input_key, input_value, file_type, job_resources, new_unique_file=False) + else: + raise _ex.EUnexpected() job_inputs[input_key] = _util.selector_for(input_id) for output_key, output_value in job_outputs.items(): if not (isinstance(output_value, str) and output_value in job_resources): - model_output= required_outputs[output_key] - output_schema = model_output.schema if model_output and not model_output.dynamic else None + model_output = required_outputs[output_key] - output_id = self._process_input_or_output( - output_key, output_value, job_resources, - new_unique_file=True, schema=output_schema) + if model_output.objectType == _meta.ObjectType.DATA: + schema = model_output.schema if model_output and not model_output.dynamic else None + output_id = self._process_data_socket(output_key, output_value, schema, job_resources, new_unique_file=True) + elif model_output.objectType == _meta.ObjectType.FILE: + file_type = model_output.fileType + output_id = self._process_file_socket(output_key, output_value, file_type, job_resources, new_unique_file=True) + else: + raise _ex.EUnexpected() job_outputs[output_key] = _util.selector_for(output_id) return job_config, job_def - def _process_input_or_output( - self, data_key, data_value, - resources: tp.Dict[str, _meta.ObjectDefinition], - new_unique_file=False, - schema: tp.Optional[_meta.SchemaDefinition] = None) \ + def _process_data_socket( + self, data_key, data_value, schema: tp.Optional[_meta.SchemaDefinition], + resources: tp.Dict[str, _meta.ObjectDefinition], new_unique_file=False) \ -> _meta.TagHeader: data_id = _util.new_object_id(_meta.ObjectType.DATA) storage_id = _util.new_object_id(_meta.ObjectType.STORAGE) + self._log.info(f"Generating data definition for [{data_key}] with ID = [{_util.object_key(data_id)}]") + if isinstance(data_value, str): storage_path = data_value storage_key = self._sys_config.storage.defaultBucket @@ -850,43 +858,85 @@ def _process_input_or_output( else: raise _ex.EConfigParse(f"Invalid configuration for input '{data_key}'") - self._log.info(f"Generating data definition for [{data_key}] with ID = [{_util.object_key(data_id)}]") - # For unique outputs, increment the snap number to find a new unique snap # These are not incarnations, bc likely in dev mode model code and inputs are changing # Incarnations are for recreation of a dataset using the exact same code path and inputs if new_unique_file: + storage_path, snap_version = self._new_unique_file(data_key, storage_key, storage_path, snap_version) - x_storage_mgr = _storage.StorageManager(self._sys_config) - x_storage = x_storage_mgr.get_file_storage(storage_key) - x_orig_path = pathlib.PurePath(storage_path) - x_name = x_orig_path.name - - if x_storage.exists(str(x_orig_path.parent)): - listing = x_storage.ls(str(x_orig_path.parent)) - existing_files = list(map(lambda stat: stat.file_name, listing)) - else: - existing_files = [] - - while x_name in existing_files: + part_key = _meta.PartKey(opaqueKey="part-root", partType=_meta.PartType.PART_ROOT) + delta_index = 1 + incarnation_index = 1 - snap_version += 1 - x_name = f"{x_orig_path.stem}-{snap_version}" - storage_path = str(x_orig_path.parent.joinpath(x_name)) + # This is also defined in functions.DynamicDataSpecFunc, maybe centralize? + data_item = f"data/table/{data_id.objectId}/{part_key.opaqueKey}/snap-{snap_version}/delta-{delta_index}" - self._log.info(f"Output for [{data_key}] will be snap version {snap_version}") + data_obj = self._generate_data_definition( + part_key, snap_version, delta_index, data_item, + schema, storage_id) - data_obj, storage_obj = self._generate_input_definition( - data_id, storage_id, storage_key, storage_path, storage_format, - snap_index=snap_version, delta_index=1, incarnation_index=1, - schema=schema) + storage_obj = self._generate_storage_definition( + storage_id, storage_key, storage_path, storage_format, + data_item, incarnation_index) resources[_util.object_key(data_id)] = data_obj resources[_util.object_key(storage_id)] = storage_obj return data_id + def _process_file_socket( + self, file_key, file_value, file_type: _meta.FileType, + resources: tp.Dict[str, _meta.ObjectDefinition], new_unique_file=False) \ + -> _meta.TagHeader: + + file_id = _util.new_object_id(_meta.ObjectType.FILE) + storage_id = _util.new_object_id(_meta.ObjectType.STORAGE) + + self._log.info(f"Generating file definition for [{file_key}] with ID = [{_util.object_key(file_id)}]") + + if isinstance(file_value, str): + + storage_key = self._sys_config.storage.defaultBucket + storage_path = file_value + + elif isinstance(file_value, dict): + + storage_key = file_value.get("storageKey") or self._sys_config.storage.defaultBucket + storage_path = file_value.get("path") + + if not storage_path: + raise _ex.EConfigParse(f"Invalid configuration for input [{file_key}] (missing required value 'path'") + + else: + raise _ex.EConfigParse(f"Invalid configuration for input '{file_key}'") + + storage_format = "application/x-binary" + file_version = 1 + + if new_unique_file: + storage_path, file_version = self._new_unique_file(file_key, storage_key, storage_path, file_version) + file_size = 0 + else: + storage = self._storage_manager.get_file_storage(storage_key) + file_size = storage.size(storage_path) + + data_item = f"file/{file_id.objectId}/version-{file_version}" + file_name = f"{file_key}.{file_type.extension}" + + file_obj = self._generate_file_definition( + file_name, file_type, file_size, + storage_id, data_item) + + storage_obj = self._generate_storage_definition( + storage_id, storage_key, storage_path, storage_format, + data_item, incarnation_index=1) + + resources[_util.object_key(file_id)] = file_obj + resources[_util.object_key(storage_id)] = storage_obj + + return file_id + @staticmethod def infer_format(storage_path: str, storage_config: _cfg.StorageConfig): @@ -898,20 +948,33 @@ def infer_format(storage_path: str, storage_config: _cfg.StorageConfig): else: return storage_config.defaultFormat - @classmethod - def _generate_input_definition( - cls, data_id: _meta.TagHeader, storage_id: _meta.TagHeader, - storage_key: str, storage_path: str, storage_format: str, - snap_index: int, delta_index: int, incarnation_index: int, - schema: tp.Optional[_meta.SchemaDefinition] = None) \ - -> (_meta.ObjectDefinition, _meta.ObjectDefinition): + def _new_unique_file(self, socket_name, storage_key, storage_path, version): - part_key = _meta.PartKey( - opaqueKey="part-root", - partType=_meta.PartType.PART_ROOT) + x_storage = self._storage_manager.get_file_storage(storage_key) + x_orig_path = pathlib.PurePath(storage_path) + x_name = x_orig_path.name - # This is also defined in functions.DynamicDataSpecFunc, maybe centralize? - data_item = f"data/table/{data_id.objectId}/{part_key.opaqueKey}/snap-{snap_index}/delta-{delta_index}" + if x_storage.exists(str(x_orig_path.parent)): + listing = x_storage.ls(str(x_orig_path.parent)) + existing_files = list(map(lambda stat: stat.file_name, listing)) + else: + existing_files = [] + + while x_name in existing_files: + + version += 1 + x_name = f"{x_orig_path.stem}-{version}{x_orig_path.suffix}" + storage_path = str(x_orig_path.parent.joinpath(x_name)) + + self._log.info(f"Output for [{socket_name}] will be version {version}") + + return storage_path, version + + @classmethod + def _generate_data_definition( + cls, part_key: _meta.PartKey, snap_index: int, delta_index: int, data_item: str, + schema: tp.Optional[_meta.SchemaDefinition], storage_id: _meta.TagHeader) \ + -> (_meta.ObjectDefinition, _meta.ObjectDefinition): delta = _meta.DataDefinition.Delta( deltaIndex=delta_index, @@ -925,17 +988,31 @@ def _generate_input_definition( partKey=part_key, snap=snap) - data_def = _meta.DataDefinition(parts={}) + data_def = _meta.DataDefinition() data_def.parts[part_key.opaqueKey] = part + data_def.schema = schema + data_def.storageId = _util.selector_for(storage_id) - if schema is not None: - data_def.schema = schema - else: - data_def.schema = None + return _meta.ObjectDefinition(objectType=_meta.ObjectType.DATA, data=data_def) - data_def.storageId = _meta.TagSelector( - _meta.ObjectType.STORAGE, storage_id.objectId, - objectVersion=storage_id.objectVersion, latestTag=True) + @classmethod + def _generate_file_definition( + cls, file_name: str, file_type: _meta.FileType, file_size: int, + storage_id: _meta.TagHeader, data_item: str) \ + -> _meta.ObjectDefinition: + + file_def = _meta.FileDefinition( + name=file_name, extension=file_type.extension, mimeType=file_type.mimeType, + storageId=_util.selector_for(storage_id), dataItem=data_item, size=file_size) + + return _meta.ObjectDefinition(objectType=_meta.ObjectType.FILE, file=file_def) + + @classmethod + def _generate_storage_definition( + cls, storage_id: _meta.TagHeader, + storage_key: str, storage_path: str, storage_format: str, + data_item: str, incarnation_index: int) \ + -> _meta.ObjectDefinition: storage_copy = _meta.StorageCopy( storageKey=storage_key, @@ -952,16 +1029,14 @@ def _generate_input_definition( storage_item = _meta.StorageItem( incarnations=[storage_incarnation]) - storage_def = _meta.StorageDefinition(dataItems={}) - storage_def.dataItems[delta.dataItem] = storage_item + storage_def = _meta.StorageDefinition() + storage_def.dataItems[data_item] = storage_item if storage_format.lower() == "csv": storage_def.storageOptions["lenient_csv_parser"] = _types.MetadataCodec.encode_value(True) - data_obj = _meta.ObjectDefinition(objectType=_meta.ObjectType.DATA, data=data_def) - storage_obj = _meta.ObjectDefinition(objectType=_meta.ObjectType.STORAGE, storage=storage_def) + return _meta.ObjectDefinition(objectType=_meta.ObjectType.STORAGE, storage=storage_def) - return data_obj, storage_obj DevModeTranslator._log = _util.logger_for_class(DevModeTranslator) diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py b/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py index f1238069f..885fdd16d 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/engine.py @@ -170,7 +170,7 @@ def submit_job( self._log.info(f"Job submitted: [{job_key}]") - job_processor = JobProcessor(self._models, self._storage, job_key, job_config, result_spec, graph_spec=None) + job_processor = JobProcessor(self._sys_config, self._models, self._storage, job_key, job_config, result_spec, graph_spec=None) job_actor_id = self.actors().spawn(job_processor) job_monitor_success = lambda ctx, key, result: self._notify_callback(key, result, None) @@ -190,7 +190,7 @@ def submit_child_job(self, child_id: _meta.TagHeader, child_graph: _graph.Graph, child_key = _util.object_key(child_id) - child_processor = JobProcessor(self._models, self._storage, child_key, None, None, graph_spec=child_graph) # noqa + child_processor = JobProcessor(self._sys_config, self._models, self._storage, child_key, None, None, graph_spec=child_graph) # noqa child_actor_id = self.actors().spawn(child_processor) child_state = _JobState(child_id) @@ -336,7 +336,8 @@ class JobProcessor(_actors.Actor): """ def __init__( - self, models: _models.ModelLoader, storage: _storage.StorageManager, + self, sys_config: _cfg.RuntimeConfig, + models: _models.ModelLoader, storage: _storage.StorageManager, job_key: str, job_config: _cfg.JobConfig, result_spec: _graph.JobResultSpec, graph_spec: tp.Optional[_graph.Graph]): @@ -345,6 +346,7 @@ def __init__( self.job_config = job_config self.result_spec = result_spec self.graph_spec = graph_spec + self._sys_config = sys_config self._models = models self._storage = storage self._resolver = _func.FunctionResolver(models, storage) @@ -358,7 +360,7 @@ def on_start(self): if self.graph_spec is not None: self.actors().send(self.actors().id, "build_graph_succeeded", self.graph_spec) else: - self.actors().spawn(GraphBuilder(self.job_config, self.result_spec)) + self.actors().spawn(GraphBuilder(self._sys_config, self.job_config, self.result_spec)) def on_stop(self): @@ -426,8 +428,9 @@ class GraphBuilder(_actors.Actor): GraphBuilder is a worker (actor) to wrap the GraphBuilder logic from graph_builder.py """ - def __init__(self, job_config: _cfg.JobConfig, result_spec: _graph.JobResultSpec): + def __init__(self, sys_config: _cfg.RuntimeConfig, job_config: _cfg.JobConfig, result_spec: _graph.JobResultSpec): super().__init__() + self.sys_config = sys_config self.job_config = job_config self.result_spec = result_spec self._log = _util.logger_for_object(self) @@ -440,8 +443,7 @@ def build_graph(self, job_config: _cfg.JobConfig): self._log.info("Building execution graph") - # TODO: Get sys config, or find a way to pass storage settings - graph_builder = _graph.GraphBuilder(job_config, self.result_spec) + graph_builder = _graph.GraphBuilder(self.sys_config, job_config, self.result_spec) graph_spec = graph_builder.build_job(job_config.job) self.actors().reply("build_graph_succeeded", graph_spec) diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py b/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py index 011753923..dba3b6f7c 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/functions.py @@ -15,6 +15,7 @@ from __future__ import annotations +import copy import datetime import abc import random @@ -296,8 +297,13 @@ def _execute(self, ctx: NodeContext) -> _data.DataView: # Map empty item -> emtpy view (for optional inputs not supplied) if root_item.is_empty(): - return _data.DataView.create_empty() + return _data.DataView.create_empty(root_item.object_type) + # Handle file data views + if root_item.object_type == meta.ObjectType.FILE: + return _data.DataView.for_file_item(root_item) + + # Everything else is a regular data view if self.node.schema is not None and len(self.node.schema.table.fields) > 0: trac_schema = self.node.schema else: @@ -322,7 +328,11 @@ def _execute(self, ctx: NodeContext) -> _data.DataItem: # Map empty view -> emtpy item (for optional outputs not supplied) if data_view.is_empty(): - return _data.DataItem.create_empty() + return _data.DataItem.create_empty(data_view.object_type) + + # Handle file data views + if data_view.object_type == meta.ObjectType.FILE: + return data_view.file_item # TODO: Support selecting data item described by self.node @@ -342,25 +352,24 @@ def __init__(self, node: DataResultNode): def _execute(self, ctx: NodeContext) -> ObjectBundle: - data_item = _ctx_lookup(self.node.data_item_id, ctx) + data_spec = _ctx_lookup(self.node.data_save_id, ctx) - # Do not record output metadata for optional outputs that are empty - if data_item.is_empty(): - return {} + result_bundle = dict() - data_spec = _ctx_lookup(self.node.data_spec_id, ctx) + # Do not record output metadata for optional outputs that are empty + if data_spec.is_empty(): + return result_bundle - # TODO: Check result of save operation - # save_result = _ctx_lookup(self.node.data_save_id, ctx) + if self.node.data_key is not None: + result_bundle[self.node.data_key] = meta.ObjectDefinition(objectType=meta.ObjectType.DATA, data=data_spec.data_def) - data_result = meta.ObjectDefinition(objectType=meta.ObjectType.DATA, data=data_spec.data_def) - storage_result = meta.ObjectDefinition(objectType=meta.ObjectType.STORAGE, storage=data_spec.storage_def) + if self.node.file_key is not None: + result_bundle[self.node.file_key] = meta.ObjectDefinition(objectType=meta.ObjectType.FILE, file=data_spec.file_def) - bundle = { - self.node.data_key: data_result, - self.node.storage_key: storage_result} + if self.node.storage_key is not None: + result_bundle[self.node.storage_key] = meta.ObjectDefinition(objectType=meta.ObjectType.STORAGE, storage=data_spec.storage_def) - return bundle + return result_bundle class DynamicDataSpecFunc(NodeFunction[_data.DataSpec]): @@ -443,11 +452,7 @@ def _execute(self, ctx: NodeContext) -> _data.DataSpec: # Dynamic data def will always use an embedded schema (this is no ID for an external schema) - return _data.DataSpec( - data_item, - data_def, - storage_def, - schema_def=None) + return _data.DataSpec.create_data_spec(data_item, data_def, storage_def, schema_def=None) class _LoadSaveDataFunc(abc.ABC): @@ -455,6 +460,16 @@ class _LoadSaveDataFunc(abc.ABC): def __init__(self, storage: _storage.StorageManager): self.storage = storage + @classmethod + def _choose_data_spec(cls, spec_id, spec, ctx: NodeContext): + + if spec_id is not None: + return _ctx_lookup(spec_id, ctx) + elif spec is not None: + return spec + else: + raise _ex.EUnexpected() + def _choose_copy(self, data_item: str, storage_def: meta.StorageDefinition) -> meta.StorageCopy: # Metadata should be checked for consistency before a job is accepted @@ -491,9 +506,19 @@ def __init__(self, node: LoadDataNode, storage: _storage.StorageManager): def _execute(self, ctx: NodeContext) -> _data.DataItem: - data_spec = _ctx_lookup(self.node.spec_id, ctx) + data_spec = self._choose_data_spec(self.node.spec_id, self.node.spec, ctx) data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def) - data_storage = self.storage.get_data_storage(data_copy.storageKey) + + if data_spec.object_type == _api.ObjectType.DATA: + return self._load_data(data_spec, data_copy) + + elif data_spec.object_type == _api.ObjectType.FILE: + return self._load_file(data_copy) + + else: + raise _ex.EUnexpected() + + def _load_data(self, data_spec, data_copy): trac_schema = data_spec.schema_def if data_spec.schema_def else data_spec.data_def.schema arrow_schema = _data.DataMapping.trac_to_arrow_schema(trac_schema) if trac_schema else None @@ -503,36 +528,52 @@ def _execute(self, ctx: NodeContext) -> _data.DataItem: for opt_key, opt_value in data_spec.storage_def.storageOptions.items(): options[opt_key] = _types.MetadataCodec.decode_value(opt_value) - table = data_storage.read_table( + storage = self.storage.get_data_storage(data_copy.storageKey) + table = storage.read_table( data_copy.storagePath, data_copy.storageFormat, arrow_schema, storage_options=options) - return _data.DataItem(table.schema, table) + return _data.DataItem(_api.ObjectType.DATA, table.schema, table) + + def _load_file(self, data_copy): + + storage = self.storage.get_file_storage(data_copy.storageKey) + raw_bytes = storage.read_bytes(data_copy.storagePath) + return _data.DataItem(_api.ObjectType.FILE, raw_bytes=raw_bytes) -class SaveDataFunc(_LoadSaveDataFunc, NodeFunction[None]): + +class SaveDataFunc(_LoadSaveDataFunc, NodeFunction[_data.DataSpec]): def __init__(self, node: SaveDataNode, storage: _storage.StorageManager): super().__init__(storage) self.node = node - def _execute(self, ctx: NodeContext): + def _execute(self, ctx: NodeContext) -> _data.DataSpec: # Item to be saved should exist in the current context data_item = _ctx_lookup(self.node.data_item_id, ctx) + # Metadata already exists as data_spec but may not contain schema, row count, file size etc. + data_spec = self._choose_data_spec(self.node.spec_id, self.node.spec, ctx) + data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def) + # Do not save empty outputs (optional outputs that were not produced) if data_item.is_empty(): - return + return _data.DataSpec.create_empty_spec(data_item.object_type) - # This function assumes that metadata has already been generated as the data_spec - # i.e. it is already known which incarnation / copy of the data will be created + if data_item.object_type == _api.ObjectType.DATA: + return self._save_data(data_item, data_spec, data_copy) - data_spec = _ctx_lookup(self.node.spec_id, ctx) - data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def) - data_storage = self.storage.get_data_storage(data_copy.storageKey) + elif data_item.object_type == _api.ObjectType.FILE: + return self._save_file(data_item, data_spec, data_copy) + + else: + raise _ex.EUnexpected() + + def _save_data(self, data_item, data_spec, data_copy): # Current implementation will always put an Arrow table in the data item # Empty tables are allowed, so explicitly check if table is None @@ -546,11 +587,32 @@ def _execute(self, ctx: NodeContext): for opt_key, opt_value in data_spec.storage_def.storageOptions.items(): options[opt_key] = _types.MetadataCodec.decode_value(opt_value) - data_storage.write_table( + storage = self.storage.get_data_storage(data_copy.storageKey) + storage.write_table( data_copy.storagePath, data_copy.storageFormat, data_item.table, storage_options=options, overwrite=False) + data_spec = copy.deepcopy(data_spec) + # TODO: Save row count in metadata + + if data_spec.data_def.schema is None and data_spec.data_def.schemaId is None: + data_spec.data_def.schema = _data.DataMapping.arrow_to_trac_schema(data_item.table.schema) + + return data_spec + + def _save_file(self, data_item, data_spec, data_copy): + + if data_item.raw_bytes is None: + raise _ex.EUnexpected() + + storage = self.storage.get_file_storage(data_copy.storageKey) + storage.write_bytes(data_copy.storagePath, data_item.raw_bytes) + + data_spec = copy.deepcopy(data_spec) + data_spec.file_def.size = len(data_item.raw_bytes) + + return data_spec def _model_def_for_import(import_details: meta.ImportModelJob): diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py b/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py index 840135f92..ee71d320b 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/graph.py @@ -309,20 +309,18 @@ def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]: @_node_type class DataResultNode(Node[ObjectBundle]): + # TODO: Remove this node type + # Either produce metadata in SaveDataNode, or handle DataSpec outputs in result processing nodes + output_name: str - data_item_id: NodeId[_data.DataItem] - data_spec_id: NodeId[_data.DataSpec] - data_save_id: NodeId[type(None)] + data_save_id: NodeId[_data.DataSpec] - data_key: str - storage_key: str + data_key: str = None + file_key: str = None + storage_key: str = None def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]: - - return { - self.data_item_id: DependencyType.HARD, - self.data_spec_id: DependencyType.HARD, - self.data_save_id: DependencyType.HARD} + return {self.data_save_id: DependencyType.HARD} @_node_type @@ -333,24 +331,33 @@ class LoadDataNode(Node[_data.DataItem]): The latest incarnation of the item will be loaded from any available copy """ - spec_id: NodeId[_data.DataSpec] + spec_id: tp.Optional[NodeId[_data.DataSpec]] = None + spec: tp.Optional[_data.DataSpec] = None def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]: - return {self.spec_id: DependencyType.HARD} + deps = dict() + if self.spec_id is not None: + deps[self.spec_id] = DependencyType.HARD + return deps @_node_type -class SaveDataNode(Node[None]): +class SaveDataNode(Node[_data.DataSpec]): """ Save an individual data item to storage """ - spec_id: NodeId[_data.DataSpec] data_item_id: NodeId[_data.DataItem] + spec_id: tp.Optional[NodeId[_data.DataSpec]] = None + spec: tp.Optional[_data.DataSpec] = None + def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]: - return {self.spec_id: DependencyType.HARD, self.data_item_id: DependencyType.HARD} + deps = {self.data_item_id: DependencyType.HARD} + if self.spec_id is not None: + deps[self.spec_id] = DependencyType.HARD + return deps @_node_type diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py b/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py index bbc7d5081..3b30b340c 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/graph_builder.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime as _dt + import tracdap.rt.config as config import tracdap.rt.exceptions as _ex import tracdap.rt._impl.data as _data # noqa @@ -33,8 +35,9 @@ class GraphBuilder: __JOB_BUILD_FUNC = tp.Callable[[meta.JobDefinition, NodeId], GraphSection] - def __init__(self, job_config: config.JobConfig, result_spec: JobResultSpec): + def __init__(self, sys_config: config.RuntimeConfig, job_config: config.JobConfig, result_spec: JobResultSpec): + self._sys_config = sys_config self._job_config = job_config self._result_spec = result_spec @@ -45,7 +48,7 @@ def __init__(self, job_config: config.JobConfig, result_spec: JobResultSpec): def _child_builder(self, job_id: meta.TagHeader) -> "GraphBuilder": - builder = GraphBuilder(self._job_config, JobResultSpec(save_result=False)) + builder = GraphBuilder(self._sys_config, self._job_config, JobResultSpec(save_result=False)) builder._job_key = _util.object_key(job_id) builder._job_namespace = NodeNamespace(builder._job_key) @@ -355,58 +358,76 @@ def build_job_inputs( nodes = dict() outputs = set() - must_run = list() - for input_name, input_schema in required_inputs.items(): + for input_name, input_def in required_inputs.items(): + + # Backwards compatibility with pre 0.8 versions + input_type = meta.ObjectType.DATA \ + if input_def.objectType == meta.ObjectType.OBJECT_TYPE_NOT_SET \ + else input_def.objectType + + input_selector = supplied_inputs.get(input_name) - data_selector = supplied_inputs.get(input_name) + if input_selector is None: - if data_selector is None: - if input_schema.optional: + if input_def.optional: data_view_id = NodeId.of(input_name, self._job_namespace, _data.DataView) - nodes[data_view_id] = StaticValueNode(data_view_id, _data.DataView.create_empty()) + data_view = _data.DataView.create_empty(input_type) + nodes[data_view_id] = StaticValueNode(data_view_id, data_view, explicit_deps=explicit_deps) outputs.add(data_view_id) - continue else: self._error(_ex.EJobValidation(f"Missing required input: [{input_name}]")) - continue - # Build a data spec using metadata from the job config - # For now we are always loading the root part, snap 0, delta 0 - data_def = _util.get_job_resource(data_selector, self._job_config).data - storage_def = _util.get_job_resource(data_def.storageId, self._job_config).storage + elif input_type == meta.ObjectType.DATA: + self._build_data_input(input_name, input_selector, nodes, outputs, explicit_deps) + + elif input_type == meta.ObjectType.FILE: + self._build_file_input(input_name, input_selector, nodes, outputs, explicit_deps) - if data_def.schemaId: - schema_def = _util.get_job_resource(data_def.schemaId, self._job_config).schema else: - schema_def = data_def.schema + self._error(_ex.EJobValidation(f"Invalid input type [{input_type.name}] for input [{input_name}]")) - root_part_opaque_key = 'part-root' # TODO: Central part names / constants - data_item = data_def.parts[root_part_opaque_key].snap.deltas[0].dataItem - data_spec = _data.DataSpec(data_item, data_def, storage_def, schema_def) + return GraphSection(nodes, outputs=outputs) - # Data spec node is static, using the assembled data spec - data_spec_id = NodeId.of(f"{input_name}:SPEC", self._job_namespace, _data.DataSpec) - data_spec_node = StaticValueNode(data_spec_id, data_spec, explicit_deps=explicit_deps) + def _build_data_input(self, input_name, input_selector, nodes, outputs, explicit_deps): - # Physical load of data items from disk - # Currently one item per input, since inputs are single part/delta - data_load_id = NodeId.of(f"{input_name}:LOAD", self._job_namespace, _data.DataItem) - data_load_node = LoadDataNode(data_load_id, data_spec_id, explicit_deps=explicit_deps) + # Build a data spec using metadata from the job config + # For now we are always loading the root part, snap 0, delta 0 + data_def = _util.get_job_resource(input_selector, self._job_config).data + storage_def = _util.get_job_resource(data_def.storageId, self._job_config).storage - # Input views assembled by mapping one root part to each view - data_view_id = NodeId.of(input_name, self._job_namespace, _data.DataView) - data_view_node = DataViewNode(data_view_id, schema_def, data_load_id) + if data_def.schemaId: + schema_def = _util.get_job_resource(data_def.schemaId, self._job_config).schema + else: + schema_def = data_def.schema - nodes[data_spec_id] = data_spec_node - nodes[data_load_id] = data_load_node - nodes[data_view_id] = data_view_node + root_part_opaque_key = 'part-root' # TODO: Central part names / constants + data_item = data_def.parts[root_part_opaque_key].snap.deltas[0].dataItem + data_spec = _data.DataSpec.create_data_spec(data_item, data_def, storage_def, schema_def) + + # Physical load of data items from disk + # Currently one item per input, since inputs are single part/delta + data_load_id = NodeId.of(f"{input_name}:LOAD", self._job_namespace, _data.DataItem) + nodes[data_load_id] = LoadDataNode(data_load_id, spec=data_spec, explicit_deps=explicit_deps) + + # Input views assembled by mapping one root part to each view + data_view_id = NodeId.of(input_name, self._job_namespace, _data.DataView) + nodes[data_view_id] = DataViewNode(data_view_id, schema_def, data_load_id) + outputs.add(data_view_id) + + def _build_file_input(self, input_name, input_selector, nodes, outputs, explicit_deps): - # Job-level data view is an output of the load operation - outputs.add(data_view_id) - must_run.append(data_spec_id) + file_def = _util.get_job_resource(input_selector, self._job_config).file + storage_def = _util.get_job_resource(file_def.storageId, self._job_config).storage - return GraphSection(nodes, outputs=outputs, must_run=must_run) + file_spec = _data.DataSpec.create_file_spec(file_def.dataItem, file_def, storage_def) + file_load_id = NodeId.of(f"{input_name}:LOAD", self._job_namespace, _data.DataItem) + nodes[file_load_id] = LoadDataNode(file_load_id, spec=file_spec, explicit_deps=explicit_deps) + + # Input views assembled by mapping one root part to each view + file_view_id = NodeId.of(input_name, self._job_namespace, _data.DataView) + nodes[file_view_id] = DataViewNode(file_view_id, None, file_load_id) + outputs.add(file_view_id) def build_job_outputs( self, @@ -418,12 +439,21 @@ def build_job_outputs( nodes = {} inputs = set() - for output_name, output_schema in required_outputs.items(): + for output_name, output_def in required_outputs.items(): + + # Output data view must already exist in the namespace, it is an input to the save operation + data_view_id = NodeId.of(output_name, self._job_namespace, _data.DataView) + inputs.add(data_view_id) + + # Backwards compatibility with pre 0.8 versions + output_type = meta.ObjectType.DATA \ + if output_def.objectType == meta.ObjectType.OBJECT_TYPE_NOT_SET \ + else output_def.objectType - data_selector = supplied_outputs.get(output_name) + output_selector = supplied_outputs.get(output_name) - if data_selector is None: - if output_schema.optional: + if output_selector is None: + if output_def.optional: optional_info = "(configuration is required for all optional outputs, in case they are produced)" self._error(_ex.EJobValidation(f"Missing optional output: [{output_name}] {optional_info}")) continue @@ -431,75 +461,129 @@ def build_job_outputs( self._error(_ex.EJobValidation(f"Missing required output: [{output_name}]")) continue - # Output data view must already exist in the namespace - data_view_id = NodeId.of(output_name, self._job_namespace, _data.DataView) - data_spec_id = NodeId.of(f"{output_name}:SPEC", self._job_namespace, _data.DataSpec) + elif output_type == meta.ObjectType.DATA: + self._build_data_output(output_name, output_selector, data_view_id, nodes, explicit_deps) - data_obj = _util.get_job_resource(data_selector, self._job_config, optional=True) + elif output_type == meta.ObjectType.FILE: + self._build_file_output(output_name, output_def, output_selector, data_view_id, nodes, explicit_deps) - if data_obj is not None: + else: + self._error(_ex.EJobValidation(f"Invalid output type [{output_type.name}] for input [{output_name}]")) - # If data def for the output has been built in advance, use a static data spec + return GraphSection(nodes, inputs=inputs) - data_def = data_obj.data - storage_def = _util.get_job_resource(data_def.storageId, self._job_config).storage + def _build_data_output(self, output_name, output_selector, data_view_id, nodes, explicit_deps): - if data_def.schemaId: - schema_def = _util.get_job_resource(data_def.schemaId, self._job_config).schema - else: - schema_def = data_def.schema + # Map one data item from each view, since outputs are single part/delta + data_item_id = NodeId(f"{output_name}:ITEM", self._job_namespace, _data.DataItem) + nodes[data_item_id] = DataItemNode(data_item_id, data_view_id) - root_part_opaque_key = 'part-root' # TODO: Central part names / constants - data_item = data_def.parts[root_part_opaque_key].snap.deltas[0].dataItem - data_spec = _data.DataSpec(data_item, data_def, storage_def, schema_def) + data_obj = _util.get_job_resource(output_selector, self._job_config, optional=True) - data_spec_node = StaticValueNode(data_spec_id, data_spec, explicit_deps=explicit_deps) + if data_obj is not None: - output_data_key = output_name + ":DATA" - output_storage_key = output_name + ":STORAGE" + # If data def for the output has been built in advance, use a static data spec + data_def = data_obj.data + storage_def = _util.get_job_resource(data_def.storageId, self._job_config).storage + + if data_def.schemaId: + schema_def = _util.get_job_resource(data_def.schemaId, self._job_config).schema else: + schema_def = data_def.schema - # If output data def for an output was not supplied in the job, create a dynamic data spec - # Dynamic data def will always use an embedded schema (this is no ID for an external schema) + root_part_opaque_key = 'part-root' # TODO: Central part names / constants + data_item = data_def.parts[root_part_opaque_key].snap.deltas[0].dataItem + data_spec = _data.DataSpec.create_data_spec(data_item, data_def, storage_def, schema_def) + + # Create a physical save operation for the data item + data_save_id = NodeId.of(f"{output_name}:SAVE", self._job_namespace, _data.DataSpec) + nodes[data_save_id] = SaveDataNode(data_save_id, data_item_id, spec=data_spec) + + output_key = output_name + storage_key = output_name + ":STORAGE" - data_key = output_name + ":DATA" - data_id = self._job_config.resultMapping[data_key] - storage_key = output_name + ":STORAGE" - storage_id = self._job_config.resultMapping[storage_key] + else: - data_spec_node = DynamicDataSpecNode( - data_spec_id, data_view_id, - data_id, storage_id, - prior_data_spec=None, - explicit_deps=explicit_deps) + # If output data def for an output was not supplied in the job, create a dynamic data spec + # Dynamic data def will always use an embedded schema (this is no ID for an external schema) - output_data_key = _util.object_key(data_id) - output_storage_key = _util.object_key(storage_id) + mapped_output_key = output_name + mapped_storage_key = output_name + ":STORAGE" - # Map one data item from each view, since outputs are single part/delta - data_item_id = NodeId(f"{output_name}:ITEM", self._job_namespace, _data.DataItem) - data_item_node = DataItemNode(data_item_id, data_view_id) + data_id = self._job_config.resultMapping[mapped_output_key] + storage_id = self._job_config.resultMapping[mapped_storage_key] + + data_spec_id = NodeId.of(f"{output_name}:SPEC", self._job_namespace, _data.DataSpec) + nodes[data_spec_id] = DynamicDataSpecNode( + data_spec_id, data_view_id, + data_id, storage_id, + prior_data_spec=None, + explicit_deps=explicit_deps) # Create a physical save operation for the data item - data_save_id = NodeId.of(f"{output_name}:SAVE", self._job_namespace, None) - data_save_node = SaveDataNode(data_save_id, data_spec_id, data_item_id) + data_save_id = NodeId.of(f"{output_name}:SAVE", self._job_namespace, _data.DataSpec) + nodes[data_save_id] = SaveDataNode(data_save_id, data_item_id, spec_id=data_spec_id) - data_result_id = NodeId.of(f"{output_name}:RESULT", self._job_namespace, ObjectBundle) - data_result_node = DataResultNode( - data_result_id, output_name, - data_item_id, data_spec_id, data_save_id, - output_data_key, output_storage_key) + output_key = _util.object_key(data_id) + storage_key = _util.object_key(storage_id) - nodes[data_spec_id] = data_spec_node - nodes[data_item_id] = data_item_node - nodes[data_save_id] = data_save_node - nodes[data_result_id] = data_result_node + data_result_id = NodeId.of(f"{output_name}:RESULT", self._job_namespace, ObjectBundle) + nodes[data_result_id] = DataResultNode( + data_result_id, output_name, data_save_id, + data_key=output_key, + storage_key=storage_key) - # Job-level data view is an input to the save operation - inputs.add(data_view_id) + def _build_file_output(self, output_name, output_def, output_selector, file_view_id, nodes, explicit_deps): - return GraphSection(nodes, inputs=inputs) + mapped_output_key = output_name + mapped_storage_key = output_name + ":STORAGE" + + file_obj = _util.get_job_resource(output_selector, self._job_config, optional=True) + + if file_obj is not None: + + # Definitions already exist (generated by dev mode translator) + + file_def = _util.get_job_resource(output_selector, self._job_config).file + storage_def = _util.get_job_resource(file_def.storageId, self._job_config).storage + + resolved_output_key = mapped_output_key + resolved_storage_key = mapped_storage_key + + else: + + # Create new definitions (default behavior for jobs sent from the platform) + + output_id = self._job_config.resultMapping[mapped_output_key] + storage_id = self._job_config.resultMapping[mapped_storage_key] + + file_type = output_def.fileType + timestamp = _dt.datetime.fromisoformat(output_id.objectTimestamp.isoDatetime) + data_item = f"file/{output_id.objectId}/version-{output_id.objectVersion}" + storage_key = self._sys_config.storage.defaultBucket + storage_path = f"file/FILE-{output_id.objectId}/version-{output_id.objectVersion}/{output_name}.{file_type.extension}" + + file_def = self.build_file_def(output_name, file_type, storage_id, data_item) + storage_def = self.build_storage_def(data_item, storage_key, storage_path, file_type.mimeType, timestamp) + + resolved_output_key = _util.object_key(output_id) + resolved_storage_key = _util.object_key(storage_id) + + # Required object defs are available, now build the graph nodes + + file_item_id = NodeId(f"{output_name}:ITEM", self._job_namespace, _data.DataItem) + nodes[file_item_id] = DataItemNode(file_item_id, file_view_id, explicit_deps=explicit_deps) + + file_spec = _data.DataSpec.create_file_spec(file_def.dataItem, file_def, storage_def) + file_save_id = NodeId.of(f"{output_name}:SAVE", self._job_namespace, _data.DataSpec) + nodes[file_save_id] = SaveDataNode(file_save_id, file_item_id, spec=file_spec) + + data_result_id = NodeId.of(f"{output_name}:RESULT", self._job_namespace, ObjectBundle) + nodes[data_result_id] = DataResultNode( + data_result_id, output_name, file_save_id, + file_key=resolved_output_key, + storage_key=resolved_storage_key) @classmethod def build_runtime_outputs(cls, output_names: tp.List[str], job_namespace: NodeNamespace): @@ -519,9 +603,10 @@ def build_runtime_outputs(cls, output_names: tp.List[str], job_namespace: NodeNa data_view_id = NodeId.of(output_name, job_namespace, _data.DataView) data_spec_id = NodeId.of(f"{output_name}:SPEC", job_namespace, _data.DataSpec) - data_key = output_name + ":DATA" + mapped_output_key = output_name + mapped_storage_key = output_name + ":STORAGE" + data_id = _util.new_object_id(meta.ObjectType.DATA) - storage_key = output_name + ":STORAGE" storage_id = _util.new_object_id(meta.ObjectType.STORAGE) data_spec_node = DynamicDataSpecNode( @@ -529,22 +614,21 @@ def build_runtime_outputs(cls, output_names: tp.List[str], job_namespace: NodeNa data_id, storage_id, prior_data_spec=None) - output_data_key = _util.object_key(data_id) - output_storage_key = _util.object_key(storage_id) + output_key = _util.object_key(data_id) + storage_key = _util.object_key(storage_id) # Map one data item from each view, since outputs are single part/delta data_item_id = NodeId(f"{output_name}:ITEM", job_namespace, _data.DataItem) data_item_node = DataItemNode(data_item_id, data_view_id) # Create a physical save operation for the data item - data_save_id = NodeId.of(f"{output_name}:SAVE", job_namespace, None) - data_save_node = SaveDataNode(data_save_id, data_spec_id, data_item_id) + data_save_id = NodeId.of(f"{output_name}:SAVE", job_namespace, _data.DataSpec) + data_save_node = SaveDataNode(data_save_id, data_item_id, spec_id=data_spec_id) data_result_id = NodeId.of(f"{output_name}:RESULT", job_namespace, ObjectBundle) data_result_node = DataResultNode( - data_result_id, output_name, - data_item_id, data_spec_id, data_save_id, - output_data_key, output_storage_key) + data_result_id, output_name, data_save_id, + output_key, storage_key) nodes[data_spec_id] = data_spec_node nodes[data_item_id] = data_item_node @@ -563,6 +647,45 @@ def build_runtime_outputs(cls, output_names: tp.List[str], job_namespace: NodeNa return GraphSection(nodes, inputs=inputs, outputs={runtime_outputs_id}) + @classmethod + def build_file_def(cls, file_name, file_type, storage_id, data_item): + + file_def = meta.FileDefinition() + file_def.name = f"{file_name}.{file_type.extension}" + file_def.extension = file_type.extension + file_def.mimeType = file_type.mimeType + file_def.storageId = _util.selector_for_latest(storage_id) + file_def.dataItem = data_item + file_def.size = 0 + + return file_def + + @classmethod + def build_storage_def( + cls, data_item: str, + storage_key, storage_path, storage_format, + timestamp: _dt.datetime): + + first_incarnation = 0 + + storage_copy = meta.StorageCopy( + storage_key, storage_path, storage_format, + copyStatus=meta.CopyStatus.COPY_AVAILABLE, + copyTimestamp=meta.DatetimeValue(timestamp.isoformat())) + + storage_incarnation = meta.StorageIncarnation( + [storage_copy], + incarnationIndex=first_incarnation, + incarnationTimestamp=meta.DatetimeValue(timestamp.isoformat()), + incarnationStatus=meta.IncarnationStatus.INCARNATION_AVAILABLE) + + storage_item = meta.StorageItem([storage_incarnation]) + + storage_def = meta.StorageDefinition() + storage_def.dataItems[data_item] = storage_item + + return storage_def + def build_job_results( self, objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = None, diff --git a/tracdap-runtime/python/src/tracdap/rt/_exec/runtime.py b/tracdap-runtime/python/src/tracdap/rt/_exec/runtime.py index 61f91bac8..57b9a2649 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_exec/runtime.py +++ b/tracdap-runtime/python/src/tracdap/rt/_exec/runtime.py @@ -96,6 +96,7 @@ def __init__( self._scratch_dir_persist = scratch_dir_persist self._plugin_packages = plugin_packages or [] self._dev_mode = dev_mode + self._dev_mode_translator = None # Runtime control self._runtime_lock = threading.Lock() @@ -141,10 +142,6 @@ def pre_start(self): self._log.info(f"Beginning pre-start sequence...") - # Scratch dir is needed during pre-start (at least dev mode translation uses the model loader) - - self._prepare_scratch_dir() - # Plugin manager, static API and guard rails are singletons # Calling these methods multiple times is safe (e.g. for embedded or testing scenarios) # However, plugins are never un-registered for the lifetime of the processes @@ -198,9 +195,17 @@ def start(self, wait: bool = False): self._log.info("Starting the engine...") + self._prepare_scratch_dir() + self._models = _models.ModelLoader(self._sys_config, self._scratch_dir) self._storage = _storage.StorageManager(self._sys_config) + if self._dev_mode: + + self._dev_mode_translator = _dev_mode.DevModeTranslator( + self._sys_config, self._config_mgr, self._scratch_dir, + model_loader=self._models, storage_manager=self._storage) + # Enable protection after the initial setup of the runtime is complete # Storage plugins in particular are likely to tigger protected imports # Once the runtime is up, no more plugins should be loaded @@ -323,6 +328,9 @@ def load_job_config( self, job_config: tp.Union[str, pathlib.Path, _cfg.JobConfig], model_class: tp.Optional[_api.TracModel.__class__] = None): + if not self._engine or self._shutdown_requested: + raise _ex.ETracInternal("Engine is not started or shutdown has been requested") + if isinstance(job_config, _cfg.JobConfig): self._log.info("Using embedded job config") @@ -334,13 +342,15 @@ def load_job_config( config_file_name="job") if self._dev_mode: - translator = _dev_mode.DevModeTranslator(self._sys_config, self._config_mgr, self._scratch_dir) - job_config = translator.translate_job_config(job_config, model_class) + job_config = self._dev_mode_translator.translate_job_config(job_config, model_class) return job_config def submit_job(self, job_config: _cfg.JobConfig): + if not self._engine or self._shutdown_requested: + raise _ex.ETracInternal("Engine is not started or shutdown has been requested") + job_key = _util.object_key(job_config.jobId) self._jobs[job_key] = _RuntimeJobInfo() @@ -351,6 +361,9 @@ def submit_job(self, job_config: _cfg.JobConfig): def wait_for_job(self, job_id: _api.TagHeader): + if not self._engine or self._shutdown_requested: + raise _ex.ETracInternal("Engine is not started or shutdown has been requested") + job_key = _util.object_key(job_id) if job_key not in self._jobs: diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/data.py b/tracdap-runtime/python/src/tracdap/rt/_impl/data.py index 6b0e3d93a..a82d58a2a 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/data.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/data.py @@ -14,6 +14,7 @@ # limitations under the License. import abc +import copy import dataclasses as dc import typing as tp import datetime as dt @@ -42,11 +43,48 @@ @dc.dataclass(frozen=True) class DataSpec: + object_type: _meta.ObjectType data_item: str + data_def: _meta.DataDefinition + file_def: _meta.FileDefinition storage_def: _meta.StorageDefinition schema_def: tp.Optional[_meta.SchemaDefinition] + @staticmethod + def create_data_spec( + data_item: str, + data_def: _meta.DataDefinition, + storage_def: _meta.StorageDefinition, + schema_def: tp.Optional[_meta.SchemaDefinition] = None) -> "DataSpec": + + return DataSpec( + _meta.ObjectType.DATA, data_item, + data_def, + storage_def=storage_def, + schema_def=schema_def, + file_def=None) + + @staticmethod + def create_file_spec( + data_item: str, + file_def: _meta.FileDefinition, + storage_def: _meta.StorageDefinition) -> "DataSpec": + + return DataSpec( + _meta.ObjectType.FILE, data_item, + file_def=file_def, + storage_def=storage_def, + data_def=None, + schema_def=None) + + @staticmethod + def create_empty_spec(object_type: _meta.ObjectType): + return DataSpec(object_type, None, None, None, None, None) + + def is_empty(self): + return self.data_item is None or len(self.data_item) == 0 + @dc.dataclass(frozen=True) class DataPartKey: @@ -61,44 +99,79 @@ def for_root(cls) -> "DataPartKey": @dc.dataclass(frozen=True) class DataItem: - schema: pa.Schema + object_type: _meta.ObjectType + + schema: pa.Schema = None table: tp.Optional[pa.Table] = None batches: tp.Optional[tp.List[pa.RecordBatch]] = None pandas: "tp.Optional[pandas.DataFrame]" = None pyspark: tp.Any = None + raw_bytes: bytes = None + def is_empty(self) -> bool: - return self.table is None and (self.batches is None or len(self.batches) == 0) + if self.object_type == _meta.ObjectType.FILE: + return self.raw_bytes is None or len(self.raw_bytes) == 0 + else: + return self.table is None and (self.batches is None or len(self.batches) == 0) @staticmethod - def create_empty() -> "DataItem": - return DataItem(pa.schema([])) + def create_empty(object_type: _meta.ObjectType = _meta.ObjectType.DATA) -> "DataItem": + if object_type == _meta.ObjectType.DATA: + return DataItem(_meta.ObjectType.DATA, pa.schema([])) + else: + return DataItem(object_type) + + @staticmethod + def for_file_content(raw_bytes: bytes): + return DataItem(_meta.ObjectType.FILE, raw_bytes=raw_bytes) @dc.dataclass(frozen=True) class DataView: - trac_schema: _meta.SchemaDefinition - arrow_schema: pa.Schema + object_type: _meta.ObjectType - parts: tp.Dict[DataPartKey, tp.List[DataItem]] + trac_schema: _meta.SchemaDefinition = None + arrow_schema: pa.Schema = None + + parts: tp.Dict[DataPartKey, tp.List[DataItem]] = None + file_item: tp.Optional[DataItem] = None @staticmethod - def create_empty() -> "DataView": - return DataView(_meta.SchemaDefinition(), pa.schema([]), dict()) + def create_empty(object_type: _meta.ObjectType = _meta.ObjectType.DATA) -> "DataView": + if object_type == _meta.ObjectType.DATA: + return DataView(object_type, _meta.SchemaDefinition(), pa.schema([]), dict()) + else: + return DataView(object_type) @staticmethod def for_trac_schema(trac_schema: _meta.SchemaDefinition): arrow_schema = DataMapping.trac_to_arrow_schema(trac_schema) - return DataView(trac_schema, arrow_schema, dict()) + return DataView(_meta.ObjectType.DATA, trac_schema, arrow_schema, dict()) + + @staticmethod + def for_file_item(file_item: DataItem): + return DataView(file_item.object_type, file_item=file_item) def with_trac_schema(self, trac_schema: _meta.SchemaDefinition): arrow_schema = DataMapping.trac_to_arrow_schema(trac_schema) - return DataView(trac_schema, arrow_schema, self.parts) + return DataView(_meta.ObjectType.DATA, trac_schema, arrow_schema, self.parts) + + def with_part(self, part_key: DataPartKey, part: DataItem): + new_parts = copy.copy(self.parts) + new_parts[part_key] = [part] + return DataView(self.object_type, self.trac_schema, self.arrow_schema, new_parts) + + def with_file_item(self, file_item: DataItem): + return DataView(self.object_type, file_item=file_item) def is_empty(self) -> bool: - return self.parts is None or not any(self.parts.values()) + if self.object_type == _meta.ObjectType.FILE: + return self.file_item is None + else: + return self.parts is None or not any(self.parts.values()) class _DataInternal: @@ -293,7 +366,7 @@ def add_item_to_view(cls, view: DataView, part: DataPartKey, item: DataItem) -> deltas = [*prior_deltas, item] parts = {**view.parts, part: deltas} - return DataView(view.trac_schema, view.arrow_schema, parts) + return DataView(view.object_type, view.trac_schema, view.arrow_schema, parts=parts) @classmethod def view_to_arrow(cls, view: DataView, part: DataPartKey) -> pa.Table: diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/models.py b/tracdap-runtime/python/src/tracdap/rt/_impl/models.py index 86932ee93..4954db161 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/models.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/models.py @@ -226,13 +226,15 @@ def scan_model(self, model_stub: _meta.ModelDefinition, model_class: _api.TracMo self.__log.info(f"Parameter [{name}] - {param.paramType.basicType.name}") param.paramProps = self._encoded_props(param.paramProps, "parameter", name) - for name, schema in model_def.inputs.items(): - self.__log.info(f"Input [{name}] - {schema.schema.schemaType.name}") - schema.inputProps = self._encoded_props(schema.inputProps, "input", name) - - for name, schema in model_def.outputs.items(): - self.__log.info(f"Output [{name}] - {schema.schema.schemaType.name}") - schema.outputProps = self._encoded_props(schema.outputProps, "input", name) + for name, input_def in model_def.inputs.items(): + input_type = input_def.schema.schemaType.name if input_def.objectType == _meta.ObjectType.DATA else input_def.objectType.name + self.__log.info(f"Input [{name}] - {input_type}") + input_def.inputProps = self._encoded_props(input_def.inputProps, "input", name) + + for name, output_def in model_def.outputs.items(): + output_type = output_def.schema.schemaType.name if output_def.objectType == _meta.ObjectType.DATA else output_def.objectType.name + self.__log.info(f"Output [{name}] - {output_type}") + output_def.outputProps = self._encoded_props(output_def.outputProps, "input", name) return model_def diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/static_api.py b/tracdap-runtime/python/src/tracdap/rt/_impl/static_api.py index 457a792ab..a033bc6cb 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/static_api.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/static_api.py @@ -152,14 +152,18 @@ def define_field( def define_schema( self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - schema_type: _meta.SchemaType = _meta.SchemaType.TABLE) \ + schema_type: _meta.SchemaType = _meta.SchemaType.TABLE, dynamic: bool = False) \ -> _meta.SchemaDefinition: - _val.validate_signature(self.define_schema, *fields, schema_type=schema_type) + _val.validate_signature(self.define_schema, *fields, schema_type=schema_type, dynamic=dynamic) if schema_type == _meta.SchemaType.TABLE: - table_schema = self._build_table_schema(*fields) + if dynamic and not fields: + table_schema = None + else: + table_schema = self._build_table_schema(*fields) + return _meta.SchemaDefinition(_meta.SchemaType.TABLE, table=table_schema) raise _ex.ERuntimeValidation(f"Invalid schema type [{schema_type.name}]") @@ -182,51 +186,67 @@ def infer_schema(self, dataset: _api.DATA_API) -> _meta.SchemaDefinition: return converter.infer_schema(dataset) - def define_input_table( - self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - label: _tp.Optional[str] = None, optional: bool = False, dynamic: bool = False, - input_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ - -> _meta.ModelInputSchema: + def define_file_type(self, extension: str, mime_type: str) -> _meta.FileType: + + _val.validate_signature(self.define_file_type, extension, mime_type) + + return _meta.FileType(extension=extension, mimeType=mime_type) + + def define_input( + self, requirement: _tp.Union[_meta.SchemaDefinition, _meta.FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, + input_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None): _val.validate_signature( - self.define_input_table, *fields, + self.define_input, requirement, label=label, optional=optional, dynamic=dynamic, input_props=input_props) - # Do not define details for dynamic schemas + if isinstance(requirement, _meta.SchemaDefinition): - if dynamic: - schema_def = _meta.SchemaDefinition(_meta.SchemaType.TABLE) - else: - schema_def = self.define_schema(*fields, schema_type=_meta.SchemaType.TABLE) + return _meta.ModelInputSchema( + objectType=_meta.ObjectType.DATA, schema=requirement, + label=label, optional=optional, dynamic=dynamic, + inputProps=input_props) - return _meta.ModelInputSchema( - schema=schema_def, label=label, - optional=optional, dynamic=dynamic, - inputProps=input_props) + elif isinstance(requirement, _meta.FileType): - def define_output_table( - self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - label: _tp.Optional[str] = None, optional: bool = False, dynamic: bool = False, - output_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ - -> _meta.ModelOutputSchema: + return _meta.ModelInputSchema( + objectType=_meta.ObjectType.FILE, fileType=requirement, + label=label, optional=optional, dynamic=dynamic, + inputProps=input_props) + + else: + raise _ex.EUnexpected() + + def define_output( + self, requirement: _tp.Union[_meta.SchemaDefinition, _meta.FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, + output_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None): _val.validate_signature( - self.define_output_table, *fields, + self.define_output, requirement, label=label, optional=optional, dynamic=dynamic, output_props=output_props) - # Do not define details for dynamic schemas + if isinstance(requirement, _meta.SchemaDefinition): - if dynamic: - schema_def = _meta.SchemaDefinition(_meta.SchemaType.TABLE) - else: - schema_def = self.define_schema(*fields, schema_type=_meta.SchemaType.TABLE) + return _meta.ModelOutputSchema( + objectType=_meta.ObjectType.DATA, schema=requirement, + label=label, optional=optional, dynamic=dynamic, + outputProps=output_props) - return _meta.ModelOutputSchema( - schema=schema_def, label=label, - optional=optional, dynamic=dynamic, - outputProps=output_props) + elif isinstance(requirement, _meta.FileType): + + return _meta.ModelOutputSchema( + objectType=_meta.ObjectType.FILE, fileType=requirement, + label=label, optional=optional, dynamic=dynamic, + outputProps=output_props) + + else: + raise _ex.EUnexpected() @staticmethod def _build_named_dict( diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/util.py b/tracdap-runtime/python/src/tracdap/rt/_impl/util.py index 9a00a384d..c2e6b7f5f 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/util.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/util.py @@ -235,7 +235,7 @@ def get_job_resource( if optional: return None - err = f"Missing required {selector.objectType} resource [{object_key(selector)}]" + err = f"Missing required {selector.objectType.name} resource [{object_key(selector)}]" raise ex.ERuntimeValidation(err) diff --git a/tracdap-runtime/python/src/tracdap/rt/_impl/validation.py b/tracdap-runtime/python/src/tracdap/rt/_impl/validation.py index 49097702f..83d7d9679 100644 --- a/tracdap-runtime/python/src/tracdap/rt/_impl/validation.py +++ b/tracdap-runtime/python/src/tracdap/rt/_impl/validation.py @@ -306,6 +306,9 @@ class StaticValidator: __reserved_identifier_pattern = re.compile("\\A(_|trac_)", re.ASCII) __label_length_limit = 4096 + __file_extension_pattern = re.compile('\\A[a-zA-Z0-9]+\\Z') + __mime_type_pattern = re.compile('\\A\\w+/[-.\\w]+(?:\\+[-.\\w]+)?\\Z') + __PRIMITIVE_TYPES = [ meta.BasicType.BOOLEAN, meta.BasicType.INTEGER, @@ -418,49 +421,72 @@ def _check_parameters(cls, parameters): cls._valid_identifiers(param.paramProps.keys(), "entry in param props") @classmethod - def _check_inputs_or_outputs(cls, inputs_or_outputs): + def _check_inputs_or_outputs(cls, sockets): - for input_name, input_schema in inputs_or_outputs.items(): + for socket_name, socket in sockets.items(): - cls._log.info(f"Checking {input_name}") + if socket.objectType == meta.ObjectType.DATA: + cls._check_socket_schema(socket_name, socket) + elif socket.objectType == meta.ObjectType.FILE: + cls._check_socket_file_type(socket_name, socket) + else: + raise ex.EModelValidation(f"Invalid object type [{socket.objectType.name}] for [{socket_name}]") - if input_schema.dynamic: - if input_schema.schema and input_schema.schema.table: - error = "Dynamic schemas must have schema.table = None" - cls._fail(f"Invalid schema for [{input_name}]: {error}") - else: - continue + label = socket.label + cls._check_label(label, socket_name) + + if isinstance(socket, meta.ModelInputSchema): + if socket.inputProps is not None: + cls._valid_identifiers(socket.inputProps.keys(), "entry in input props") + else: + if socket.outputProps is not None: + cls._valid_identifiers(socket.outputProps.keys(), "entry in output props") - fields = input_schema.schema.table.fields - field_names = list(map(lambda f: f.fieldName, fields)) - property_type = f"field in [{input_name}]" + @classmethod + def _check_socket_schema(cls, socket_name, socket): - if len(fields) == 0: - cls._fail(f"Invalid schema for [{input_name}]: No fields defined") + if socket.schema is None: + cls._fail(f"Missing schema requirement for [{socket_name}]") + return - cls._valid_identifiers(field_names, property_type) - cls._case_insensitive_duplicates(field_names, property_type) + if socket.dynamic: + if socket.schema and socket.schema.table: + error = "Dynamic schemas must have schema.table = None" + cls._fail(f"Invalid schema for [{socket_name}]: {error}") + else: + return - for field in fields: - cls._check_single_field(field, property_type) + fields = socket.schema.table.fields + field_names = list(map(lambda f: f.fieldName, fields)) + property_type = f"field in [{socket_name}]" - label = input_schema.label - cls._check_label(label, input_name) + if len(fields) == 0: + cls._fail(f"Invalid schema for [{socket_name}]: No fields defined") - if isinstance(input_schema, meta.ModelInputSchema): - if input_schema.inputProps is not None: - cls._valid_identifiers(input_schema.inputProps.keys(), "entry in input props") - else: - if input_schema.outputProps is not None: - cls._valid_identifiers(input_schema.outputProps.keys(), "entry in output props") + cls._valid_identifiers(field_names, property_type) + cls._case_insensitive_duplicates(field_names, property_type) + + for field in fields: + cls._check_single_field(field, property_type) + + @classmethod + def _check_socket_file_type(cls, socket_name, socket): + + if socket.fileType is None: + cls._fail(f"Missing file type requirement for [{socket_name}]") + return + + if not cls.__file_extension_pattern.match(socket.fileType.extension): + cls._fail(f"Invalid extension [{socket.fileType.extension}] for [{socket_name}]") + + if not cls.__mime_type_pattern.match(socket.fileType.mimeType): + cls._fail(f"Invalid mime type [{socket.fileType.mimeType}] for [{socket_name}]") @classmethod def _check_single_field(cls, field: meta.FieldSchema, property_type): # Valid identifier and not trac reserved checked separately - cls._log.info(field.fieldName) - if field.fieldOrder < 0: cls._fail(f"Invalid {property_type}: [{field.fieldName}] fieldOrder < 0") diff --git a/tracdap-runtime/python/src/tracdap/rt/api/__init__.py b/tracdap-runtime/python/src/tracdap/rt/api/__init__.py index d53e3d95d..277a0ea0c 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/__init__.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/__init__.py @@ -17,13 +17,16 @@ TRAC model API for Python """ -from .model_api import * -from .static_api import * - # Make metadata classes available to client code when importing the API package # Remove this import when generating docs, so metadata classes are only documented once from tracdap.rt.metadata import * # noqa DOCGEN_REMOVE +# static_api overrides some metadata types for backwards compatibility with pre-0.8 versions +# Make sure it is last in the list +from .file_types import * +from .model_api import * +from .static_api import * + # Map basic types into the root of the API package BOOLEAN = BasicType.BOOLEAN diff --git a/tracdap-runtime/python/src/tracdap/rt/api/file_types.py b/tracdap-runtime/python/src/tracdap/rt/api/file_types.py new file mode 100644 index 000000000..7ac27fa7b --- /dev/null +++ b/tracdap-runtime/python/src/tracdap/rt/api/file_types.py @@ -0,0 +1,29 @@ +# Licensed to the Fintech Open Source Foundation (FINOS) under one or +# more contributor license agreements. See the NOTICE file distributed +# with this work for additional information regarding copyright ownership. +# FINOS licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with the +# License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tracdap.rt.metadata import * # DOCGEN_REMOVE + + +class CommonFileTypes: + + TXT = FileType("txt", "text/plain") + + JPG = FileType("jpg", "image/jpeg") + PNG = FileType("png", "image/png") + SVG = FileType("svg", "image/svg+xml") + + WORD = FileType("docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document") + EXCEL = FileType("xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet") + POWERPOINT = FileType("pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation") diff --git a/tracdap-runtime/python/src/tracdap/rt/api/hook.py b/tracdap-runtime/python/src/tracdap/rt/api/hook.py index 975db4b7b..601bc19f3 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/hook.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/hook.py @@ -116,7 +116,7 @@ def define_field( @_abc.abstractmethod def define_schema( self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - schema_type: _meta.SchemaType = _meta.SchemaType.TABLE) \ + schema_type: _meta.SchemaType = _meta.SchemaType.TABLE, dynamic: bool = False) \ -> _meta.SchemaDefinition: pass @@ -131,21 +131,29 @@ def load_schema( @_abc.abstractmethod def infer_schema(self, dataset: _tp.Any) -> _meta.SchemaDefinition: + pass @_abc.abstractmethod - def define_input_table( - self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - label: _tp.Optional[str] = None, optional: bool = False, dynamic: bool = False, + def define_file_type(self, extension: str, mime_type: str) -> _meta.FileType: + + pass + + @_abc.abstractmethod + def define_input( + self, requirement: _tp.Union[_meta.SchemaDefinition, _meta.FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, input_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ -> _meta.ModelInputSchema: pass @_abc.abstractmethod - def define_output_table( - self, *fields: _tp.Union[_meta.FieldSchema, _tp.List[_meta.FieldSchema]], - label: _tp.Optional[str] = None, optional: bool = False, dynamic: bool = False, + def define_output( + self, requirement: _tp.Union[_meta.SchemaDefinition, _meta.FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, output_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ -> _meta.ModelOutputSchema: diff --git a/tracdap-runtime/python/src/tracdap/rt/api/model_api.py b/tracdap-runtime/python/src/tracdap/rt/api/model_api.py index 2cc7e5d2d..903002fb1 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/model_api.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/model_api.py @@ -194,6 +194,14 @@ def get_polars_table(self, dataset_name: str) -> "polars.DataFrame": pass + def get_file(self, file_name: str) -> bytes: + + pass + + def get_file_stream(self, file_name: str) -> _tp.ContextManager[_tp.BinaryIO]: + + pass + def put_schema(self, dataset_name: str, schema: SchemaDefinition): """ @@ -283,6 +291,14 @@ def put_polars_table(self, dataset_name: str, dataset: "polars.DataFrame"): pass + def put_file(self, file_name: str, file_content: _tp.Union[bytes, bytearray]): + + pass + + def put_file_stream(self, file_name: str) -> _tp.ContextManager[_tp.BinaryIO]: + + pass + def log(self) -> _logging.Logger: """ diff --git a/tracdap-runtime/python/src/tracdap/rt/api/static_api.py b/tracdap-runtime/python/src/tracdap/rt/api/static_api.py index d0e4bc3ca..eeca75bb1 100644 --- a/tracdap-runtime/python/src/tracdap/rt/api/static_api.py +++ b/tracdap-runtime/python/src/tracdap/rt/api/static_api.py @@ -162,35 +162,6 @@ def define_parameter( return sa.define_parameter(param_name, param_type, label, default_value, param_props=param_props) -def declare_parameter( - param_name: str, - param_type: _tp.Union[BasicType, TypeDescriptor], - label: str, - default_value: _tp.Optional[_tp.Any] = None) \ - -> _Named[ModelParameter]: - - """ - .. deprecated:: 0.4.4 - Use :py:func:`define_parameter` or :py:func:`P` instead. - - This function is deprecated and will be removed in a future version. - Please use :py:func:`define_parameter() ` instead. - - :type param_name: str - :type param_type: :py:class:`BasicType ` | - :py:class:`TypeDescriptor ` - - :type label: str - :type default_value: Any | None - :rtype: _Named[:py:class:`ModelParameter `] - - :display: False - """ - - print("TRAC Warning: declare_parameter() is deprecated, please use define_parameter()", file=sys.stderr) - - return define_parameter(param_name, param_type, label, default_value) - def P( # noqa param_name: str, @@ -241,29 +212,6 @@ def define_parameters( return sa.define_parameters(*parameters) -def declare_parameters( - *params: _tp.Union[_Named[ModelParameter], _tp.List[_Named[ModelParameter]]]) \ - -> _tp.Dict[str, ModelParameter]: - - """ - .. deprecated:: 0.4.4 - Use :py:func:`define_parameters` instead - - This function is deprecated and will be removed in a future version. - Please use :py:func:`define_parameters() ` instead. - - :type params: _Named[:py:class:`ModelParameter `] | - List[_Named[:py:class:`ModelParameter `]] - :rtype: Dict[str, :py:class:`ModelParameter `] - - :display: False - """ - - print("TRAC Warning: declare_parameters() is deprecated, please use define_parameters()", file=sys.stderr) - - return define_parameters(*params) - - def define_field( field_name: str, field_type: BasicType, @@ -324,45 +272,6 @@ def define_field( format_code, field_order) -def declare_field( - field_name: str, - field_type: BasicType, - label: str, - business_key: bool = False, - categorical: bool = False, - not_null: _tp.Optional[bool] = None, - format_code: _tp.Optional[str] = None, - field_order: _tp.Optional[int] = None) \ - -> FieldSchema: - - """ - .. deprecated:: 0.4.4 - Use :py:func:`define_field` or :py:func:`F` instead. - - This function is deprecated and will be removed in a future version. - Please use :py:func:`define_field() ` instead. - - :type field_name: str - :type field_type: :py:class:`BasicType ` - :type label: str - :type business_key: bool - :type categorical: bool - :type not_null: bool | None - :type format_code: str | None - :type field_order: int | None - :rtype: :py:class:`FieldSchema ` - - :display: False - """ - - print("TRAC Warning: declare_field() is deprecated, please use define_field()", file=sys.stderr) - - return define_field( - field_name, field_type, label, - business_key, categorical, not_null, - format_code, field_order) - - def F( # noqa field_name: str, field_type: BasicType, @@ -396,7 +305,7 @@ def F( # noqa def define_schema( *fields: _tp.Union[FieldSchema, _tp.List[FieldSchema]], - schema_type: SchemaType = SchemaType.TABLE) \ + schema_type: SchemaType = SchemaType.TABLE, dynamic: bool = False) \ -> SchemaDefinition: """ @@ -416,16 +325,18 @@ def define_schema( :param fields: The list of fields to include in the schema :param schema_type: The type of schema to create (currently only TABLE schemas are supported) + :param dynamic: Define a dynamic schema (fields list should be empty) :return: A schema definition built from the supplied fields :type fields: :py:class:`FieldSchema ` | List[:py:class:`FieldSchema `] :type schema_type: :py:class:`SchemaType ` + :type dynamic: bool :rtype: :py:class:`SchemaDefinition ` """ sa = _StaticApiHook.get_instance() - return sa.define_schema(*fields, schema_type=schema_type) + return sa.define_schema(*fields, schema_type=schema_type, dynamic=dynamic) def load_schema( @@ -471,6 +382,32 @@ def load_schema( return sa.load_schema(package, schema_file, schema_type=schema_type) +def define_file_type(extension: str, mime_type: str) -> FileType: + + sa = _StaticApiHook.get_instance() + return sa.define_file_type(extension, mime_type) + + +def define_input( + requirement: _tp.Union[SchemaDefinition, FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, + input_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None): + + sa = _StaticApiHook.get_instance() + return sa.define_input(requirement, label=label, optional=optional, dynamic=dynamic, input_props=input_props) + + +def define_output( + requirement: _tp.Union[SchemaDefinition, FileType], *, + label: _tp.Optional[str] = None, + optional: bool = False, dynamic: bool = False, + output_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None): + + sa = _StaticApiHook.get_instance() + return sa.define_output(requirement, label=label, optional=optional, dynamic=dynamic, output_props=output_props) + + def define_input_table( *fields: _tp.Union[FieldSchema, _tp.List[FieldSchema]], label: _tp.Optional[str] = None, optional: bool = False, dynamic: bool = False, @@ -512,34 +449,8 @@ def define_input_table( :rtype: :py:class:`ModelInputSchema ` """ - sa = _StaticApiHook.get_instance() - - return sa.define_input_table( - *fields, label=label, optional=optional, dynamic=dynamic, - input_props=input_props) - - -def declare_input_table( - *fields: _tp.Union[FieldSchema, _tp.List[FieldSchema]]) \ - -> ModelInputSchema: - - """ - .. deprecated:: 0.4.4 - Use :py:func:`define_input_table` instead. - - This function is deprecated and will be removed in a future version. - Please use :py:func:`define_input_table() ` instead. - - :type fields: :py:class:`FieldSchema ` | - List[:py:class:`FieldSchema `] - :rtype: :py:class:`ModelInputSchema ` - - :display: False - """ - - print("TRAC Warning: declare_input_table() is deprecated, please use define_input_table()", file=sys.stderr) - - return define_input_table(*fields) + schema = define_schema(*fields, schema_type=SchemaType.TABLE, dynamic=dynamic) + return define_input(schema, label=label, optional=optional, dynamic=dynamic, input_props=input_props) def define_output_table( @@ -581,11 +492,186 @@ def define_output_table( :rtype: :py:class:`ModelOutputSchema ` """ - sa = _StaticApiHook.get_instance() + schema = define_schema(*fields, schema_type=SchemaType.TABLE, dynamic=dynamic) + return define_output(schema, label=label, optional=optional, dynamic=dynamic, output_props=output_props) + + +def define_input_file( + extension: str, mime_type: str, *, + label: _tp.Optional[str] = None, optional: bool = False, + input_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ + -> ModelInputSchema: + + file_type = define_file_type(extension, mime_type) + return define_input(file_type, label=label, optional=optional, input_props=input_props) + - return sa.define_output_table( - *fields, label=label, optional=optional, dynamic=dynamic, - output_props=output_props) +def define_output_file( + extension: str, mime_type: str, *, + label: _tp.Optional[str] = None, optional: bool = False, + output_props: _tp.Optional[_tp.Dict[str, _tp.Any]] = None) \ + -> ModelOutputSchema: + + file_type = define_file_type(extension, mime_type) + return define_output(file_type, label=label, optional=optional, output_props=output_props) + + +def ModelInputSchema( # noqa + schema: SchemaDefinition, + label: _tp.Optional[str] = None, + optional: bool = False, + dynamic: bool = False, + inputProps: _tp.Optional[_tp.Dict[str, Value]] = None): # noqa + + """ + .. deprecated:: 0.8.0 + Use :py:func:`define_input` instead. + + This function is provided for compatibility with TRAC versions before 0.8.0. + Please use :py:func:`define_input() ` instead. + + :display: False + """ + + input_props = inputProps or dict() + return define_input(schema, label=label, optional=optional, dynamic=dynamic, input_props=input_props) + + +def ModelOutputSchema( # noqa + schema: SchemaDefinition, + label: _tp.Optional[str] = None, + optional: bool = False, + dynamic: bool = False, + outputProps: _tp.Optional[_tp.Dict[str, Value]] = None): # noqa + + """ + .. deprecated:: 0.8.0 + Use :py:func:`define_output` instead. + + This function is provided for compatibility with TRAC versions before 0.8.0. + Please use :py:func:`define_output() ` instead. + + :display: False + """ + + output_props = outputProps or dict() + return define_output(schema, label=label, optional=optional, dynamic=dynamic, output_props=output_props) + + + +def declare_parameter( + param_name: str, + param_type: _tp.Union[BasicType, TypeDescriptor], + label: str, + default_value: _tp.Optional[_tp.Any] = None) \ + -> _Named[ModelParameter]: + + """ + .. deprecated:: 0.4.4 + Use :py:func:`define_parameter` or :py:func:`P` instead. + + This function is deprecated and will be removed in a future version. + Please use :py:func:`define_parameter() ` instead. + + :type param_name: str + :type param_type: :py:class:`BasicType ` | + :py:class:`TypeDescriptor ` + + :type label: str + :type default_value: Any | None + :rtype: _Named[:py:class:`ModelParameter `] + + :display: False + """ + + print("TRAC Warning: declare_parameter() is deprecated, please use define_parameter()", file=sys.stderr) + + return define_parameter(param_name, param_type, label, default_value) + + +def declare_parameters( + *params: _tp.Union[_Named[ModelParameter], _tp.List[_Named[ModelParameter]]]) \ + -> _tp.Dict[str, ModelParameter]: + + """ + .. deprecated:: 0.4.4 + Use :py:func:`define_parameters` instead + + This function is deprecated and will be removed in a future version. + Please use :py:func:`define_parameters() ` instead. + + :type params: _Named[:py:class:`ModelParameter `] | + List[_Named[:py:class:`ModelParameter `]] + :rtype: Dict[str, :py:class:`ModelParameter `] + + :display: False + """ + + print("TRAC Warning: declare_parameters() is deprecated, please use define_parameters()", file=sys.stderr) + + return define_parameters(*params) + + +def declare_field( + field_name: str, + field_type: BasicType, + label: str, + business_key: bool = False, + categorical: bool = False, + not_null: _tp.Optional[bool] = None, + format_code: _tp.Optional[str] = None, + field_order: _tp.Optional[int] = None) \ + -> FieldSchema: + + """ + .. deprecated:: 0.4.4 + Use :py:func:`define_field` or :py:func:`F` instead. + + This function is deprecated and will be removed in a future version. + Please use :py:func:`define_field() ` instead. + + :type field_name: str + :type field_type: :py:class:`BasicType ` + :type label: str + :type business_key: bool + :type categorical: bool + :type not_null: bool | None + :type format_code: str | None + :type field_order: int | None + :rtype: :py:class:`FieldSchema ` + + :display: False + """ + + print("TRAC Warning: declare_field() is deprecated, please use define_field()", file=sys.stderr) + + return define_field( + field_name, field_type, label, + business_key, categorical, not_null, + format_code, field_order) + + +def declare_input_table( + *fields: _tp.Union[FieldSchema, _tp.List[FieldSchema]]) \ + -> ModelInputSchema: + + """ + .. deprecated:: 0.4.4 + Use :py:func:`define_input_table` instead. + + This function is deprecated and will be removed in a future version. + Please use :py:func:`define_input_table() ` instead. + + :type fields: :py:class:`FieldSchema ` | + List[:py:class:`FieldSchema `] + :rtype: :py:class:`ModelInputSchema ` + + :display: False + """ + + print("TRAC Warning: declare_input_table() is deprecated, please use define_input_table()", file=sys.stderr) + + return define_input_table(*fields) def declare_output_table( diff --git a/tracdap-runtime/python/src/tracdap/rt/launch/launch.py b/tracdap-runtime/python/src/tracdap/rt/launch/launch.py index b2fa7ba2c..6ef44d407 100644 --- a/tracdap-runtime/python/src/tracdap/rt/launch/launch.py +++ b/tracdap-runtime/python/src/tracdap/rt/launch/launch.py @@ -112,9 +112,10 @@ def launch_model( runtime_instance = _runtime.TracRuntime(_sys_config, dev_mode=True, plugin_packages=plugin_packages) runtime_instance.pre_start() - job = runtime_instance.load_job_config(_job_config, model_class=model_class) - with runtime_instance as rt: + + job = rt.load_job_config(_job_config, model_class=model_class) + rt.submit_job(job) rt.wait_for_job(job.jobId) @@ -160,9 +161,10 @@ def launch_job( runtime_instance = _runtime.TracRuntime(_sys_config, dev_mode=dev_mode, plugin_packages=plugin_packages) runtime_instance.pre_start() - job = runtime_instance.load_job_config(_job_config) - with runtime_instance as rt: + + job = rt.load_job_config(_job_config) + rt.submit_job(job) rt.wait_for_job(job.jobId) @@ -184,6 +186,7 @@ def launch_cli(programmatic_args: _tp.Optional[_tp.List[str]] = None): launch_args = _cli_args() _sys_config = _resolve_config_file(launch_args.sys_config, None) + _job_config = _resolve_config_file(launch_args.job_config, None) if launch_args.job_config else None runtime_instance = _runtime.TracRuntime( _sys_config, @@ -196,15 +199,10 @@ def launch_cli(programmatic_args: _tp.Optional[_tp.List[str]] = None): runtime_instance.pre_start() - if launch_args.job_config is not None: - _job_config = _resolve_config_file(launch_args.job_config, None) - job = runtime_instance.load_job_config(_job_config) - else: - job = None - with runtime_instance as rt: - if job is not None: + if _job_config is not None: + job = rt.load_job_config(_job_config) rt.submit_job(job) if rt.is_oneshot(): diff --git a/tracdap-runtime/python/test/tracdap_test/rt/exec/test_context.py b/tracdap-runtime/python/test/tracdap_test/rt/exec/test_context.py index b36e6ffd9..6d829bc56 100644 --- a/tracdap-runtime/python/test/tracdap_test/rt/exec/test_context.py +++ b/tracdap-runtime/python/test/tracdap_test/rt/exec/test_context.py @@ -99,7 +99,7 @@ def setUp(self): customer_loans_view = _data.DataMapping.add_item_to_view( customer_loans_view, _data.DataPartKey.for_root(), - _data.DataItem(customer_loans_view.arrow_schema, customer_loans_delta0)) + _data.DataItem(_api.ObjectType.DATA, customer_loans_view.arrow_schema, customer_loans_delta0)) profit_by_region_schema = _test_model_def.outputs.get("profit_by_region").schema profit_by_region_view = _data.DataView.for_trac_schema(profit_by_region_schema) diff --git a/tracdap-runtime/python/test/tracdap_test/rt/impl/test_models.py b/tracdap-runtime/python/test/tracdap_test/rt/impl/test_models.py index 2a2b5e22c..53c5ee943 100644 --- a/tracdap-runtime/python/test/tracdap_test/rt/impl/test_models.py +++ b/tracdap-runtime/python/test/tracdap_test/rt/impl/test_models.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pathlib import tempfile import typing as tp import unittest -import pathlib -import subprocess as sp import tracdap.rt.api as api import tracdap.rt.metadata as meta @@ -64,62 +63,24 @@ class ImportModelTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: + plugins.PluginManager.register_core_plugins() api_hook.StaticApiImpl.register_impl() util.configure_logging() def setUp(self) -> None: - self.test_scope = f"{self.__class__.__name__}.{self._testMethodName}" - - repo_url_proc = sp.run(["git", "config", "--get", "remote.origin.url"], stdout=sp.PIPE) - commit_hash_proc = sp.run(["git", "rev-parse", "HEAD"], stdout=sp.PIPE) - - if repo_url_proc.returncode != 0 or commit_hash_proc.returncode != 0: - raise RuntimeError("Could not discover details of the current git repo") - - self.repo_url = repo_url_proc.stdout.decode('utf-8').strip() - self.commit_hash = commit_hash_proc.stdout.decode('utf-8').strip() + self.test_scope = f"{self.__class__.__name__}.{self._testMethodName}" self.scratch_dir = pathlib.Path(tempfile.mkdtemp()) def tearDown(self) -> None: util.try_clean_dir(self.scratch_dir, remove=True) - def test_load_integrated_ok(self): - - sys_config = config.RuntimeConfig() - sys_config.repositories["trac_integrated"] = config.PluginConfig(protocol="integrated") - - stub_model_def = meta.ModelDefinition( - language="python", - repository="trac_integrated", - entryPoint="tracdap_test.rt.impl.test_models.SampleModel" - ) - - loader = models.ModelLoader(sys_config, self.scratch_dir) - loader.create_scope(self.test_scope) + def test_load_model(self, override_scope = None): - model_class = loader.load_model_class(self.test_scope, stub_model_def) - model = model_class() - - self.assertIsInstance(model_class, api.TracModel.__class__) - self.assertIsInstance(model, model_class) - self.assertIsInstance(model, api.TracModel) - - loader.destroy_scope(self.test_scope) - - def test_load_local_ok(self): - - self._test_load_local(self.test_scope) - - def test_load_local_long_path_ok(self): - - long_path_scope = "long_" + "A" * 250 - - self._test_load_local(long_path_scope) - - def _test_load_local(self, test_scope): + if override_scope: + self.test_scope = override_scope example_repo_url = pathlib.Path(__file__) \ .parent \ @@ -141,32 +102,34 @@ def _test_load_local(self, test_scope): ) loader = models.ModelLoader(sys_config, self.scratch_dir) - loader.create_scope(test_scope) + loader.create_scope(self.test_scope) - model_class = loader.load_model_class(test_scope, stub_model_def) + model_class = loader.load_model_class(self.test_scope, stub_model_def) model = model_class() self.assertIsInstance(model_class, api.TracModel.__class__) - self.assertIsInstance(model, model_class) self.assertIsInstance(model, api.TracModel) - loader.destroy_scope(test_scope) + loader.destroy_scope(self.test_scope) - def test_load_git_ok(self): + def test_load_model_long_path(self): - example_repo_config = config.PluginConfig( - protocol="git", - properties={"repoUrl": self.repo_url}) + long_path_scope = "long_" + "A" * 250 + self.test_load_model(long_path_scope) + + def test_load_model_integrated(self): + + # Integrated repo uses a different loader mechanism so include a test here + # All other repo types copy into the loader scope, so loader behavior is the same as local + # Also, tests for remote repo types are integration tests sys_config = config.RuntimeConfig() - sys_config.repositories["example_repo"] = example_repo_config + sys_config.repositories["trac_integrated"] = config.PluginConfig(protocol="integrated") stub_model_def = meta.ModelDefinition( language="python", - repository="example_repo", - path="examples/models/python/src", - entryPoint="tutorial.hello_world.HelloWorldModel", - version=self.commit_hash + repository="trac_integrated", + entryPoint="tracdap_test.rt.impl.test_models.SampleModel" ) loader = models.ModelLoader(sys_config, self.scratch_dir) @@ -176,12 +139,11 @@ def test_load_git_ok(self): model = model_class() self.assertIsInstance(model_class, api.TracModel.__class__) - self.assertIsInstance(model, model_class) self.assertIsInstance(model, api.TracModel) loader.destroy_scope(self.test_scope) - def test_scan_model_ok(self): + def test_scan_model(self): def _td(basic_type: meta.BasicType) -> meta.TypeDescriptor: return meta.TypeDescriptor(basic_type) diff --git a/tracdap-runtime/python/test/tracdap_test/rt/impl/test_repos.py b/tracdap-runtime/python/test/tracdap_test/rt/impl/test_repos.py index 827ce9fc0..4b114a09d 100644 --- a/tracdap-runtime/python/test/tracdap_test/rt/impl/test_repos.py +++ b/tracdap-runtime/python/test/tracdap_test/rt/impl/test_repos.py @@ -120,7 +120,8 @@ def _test_checkout_git_native(self, checkout_key): sys_config.repositories["git_test"] = config.PluginConfig( protocol="git", properties={ - "repoUrl": "https://github.com/finos/tracdap"}) + "repoUrl": "https://github.com/finos/tracdap", + "nativeGit": "true"}) model_def = meta.ModelDefinition( language="python", @@ -163,8 +164,7 @@ def _test_checkout_git_python(self, checkout_key): sys_config.repositories["git_test"] = config.PluginConfig( protocol="git", properties={ - "repoUrl": "https://github.com/finos/tracdap", - "nativeGit": "false"}) + "repoUrl": "https://github.com/finos/tracdap"}) # On macOS, SSL certificates are not set up correctly by default in urllib3 # We can reconfigure them by passing Git config properties into the pure python Git client diff --git a/tracdap-runtime/python/test/tracdap_test/rt/jobs/test_core_jobs.py b/tracdap-runtime/python/test/tracdap_test/rt/jobs/test_core_jobs.py index cf42c2d92..60b76683f 100644 --- a/tracdap-runtime/python/test/tracdap_test/rt/jobs/test_core_jobs.py +++ b/tracdap-runtime/python/test/tracdap_test/rt/jobs/test_core_jobs.py @@ -183,7 +183,9 @@ def _build_run_model_job_config(self): "filter_defaults": meta.ModelParameter(paramType=types.TypeMapping.python_to_trac(bool)), }, inputs={ - "customer_loans": meta.ModelInputSchema(meta.SchemaDefinition( + "customer_loans": meta.ModelInputSchema( + objectType=meta.ObjectType.DATA, + schema=meta.SchemaDefinition( schemaType=meta.SchemaType.TABLE, table=meta.TableSchema(fields=[ meta.FieldSchema("id", fieldType=meta.BasicType.STRING, businessKey=True), @@ -194,7 +196,9 @@ def _build_run_model_job_config(self): ]))) }, outputs={ - "profit_by_region": meta.ModelOutputSchema(meta.SchemaDefinition( + "profit_by_region": meta.ModelOutputSchema( + objectType=meta.ObjectType.DATA, + schema=meta.SchemaDefinition( schemaType=meta.SchemaType.TABLE, table=meta.TableSchema(fields=[ meta.FieldSchema("region", fieldType=meta.BasicType.STRING, categorical=True), diff --git a/tracdap-runtime/python/test/tracdap_test/rt/plugins/test_plugin_pacakge.py b/tracdap-runtime/python/test/tracdap_test/rt/plugins/test_plugin_pacakge.py index 1ac401530..f37e90058 100644 --- a/tracdap-runtime/python/test/tracdap_test/rt/plugins/test_plugin_pacakge.py +++ b/tracdap-runtime/python/test/tracdap_test/rt/plugins/test_plugin_pacakge.py @@ -84,31 +84,34 @@ def test_ext_config_loader_job_ok(self): plugin_package = "tracdap_test.rt.plugins.test_ext" trac_runtime = runtime.TracRuntime(self.sys_config, plugin_packages=[plugin_package], dev_mode=True) - trac_runtime.pre_start() - # Load a config object that exists - job_config = trac_runtime.load_job_config("test-ext:job_config_A1-6") - self.assertIsInstance(job_config, cfg.JobConfig) + with trac_runtime as rt: + + # Load a config object that exists + job_config = rt.load_job_config("test-ext:job_config_A1-6") + self.assertIsInstance(job_config, cfg.JobConfig) def test_ext_config_loader_job_not_found(self): plugin_package = "tracdap_test.rt.plugins.test_ext" trac_runtime = runtime.TracRuntime(self.sys_config, plugin_packages=[plugin_package], dev_mode=True) - trac_runtime.pre_start() - # Load a config object that does not exist - self.assertRaises(ex.EConfigLoad, lambda: trac_runtime.load_job_config("test-ext:job_config_B1-9")) + with trac_runtime as rt: + + # Load a config object that does not exist + self.assertRaises(ex.EConfigLoad, lambda: rt.load_job_config("test-ext:job_config_B1-9")) def test_ext_config_loader_wrong_protocol(self): plugin_package = "tracdap_test.rt.plugins.test_ext" trac_runtime = runtime.TracRuntime(self.sys_config, plugin_packages=[plugin_package], dev_mode=True) - trac_runtime.pre_start() - # Load a config object with the wrong protocol - self.assertRaises(ex.EConfigLoad, lambda: trac_runtime.load_job_config("test-ext-2:job_config_B1-9")) + with trac_runtime as rt: + + # Load a config object with the wrong protocol + self.assertRaises(ex.EConfigLoad, lambda: rt.load_job_config("test-ext-2:job_config_B1-9")) def test_launch_model(self): diff --git a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunFlowJob.java b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunFlowJob.java index 5f92df3d1..0b96ff2a9 100644 --- a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunFlowJob.java +++ b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunFlowJob.java @@ -88,8 +88,12 @@ public Map newResultIds( var runFlow = job.getRunFlow(); + var flowKey = MetadataUtil.objectKey(runFlow.getFlow()); + var flowId = resourceMapping.get(flowKey); + var flowDef = resources.get(MetadataUtil.objectKey(flowId)).getFlow(); + var outputFlowNodes = getFlowOutputNodes(runFlow.getFlow(), resources, resourceMapping); - var outputs = getFlowOutputNames(outputFlowNodes); + var outputs = getFlowOutputs(outputFlowNodes, flowDef); return newResultIds(tenant, outputs, runFlow.getPriorOutputsMap()); } @@ -112,6 +116,12 @@ private static Set getFlowOutputNames(Map outputFlowNo return new HashSet<>(outputFlowNodes.keySet()); } + private static Map getFlowOutputs(Map outputFlowNodes, FlowDefinition flow) { + + return outputFlowNodes.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> flow.getOutputsOrThrow(e.getKey()))); + } + private static Map getFlowOutputNodes( TagSelector flowSelector, Map resources, diff --git a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelJob.java b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelJob.java index 693c79507..382ed53d4 100644 --- a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelJob.java +++ b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelJob.java @@ -74,7 +74,7 @@ public Map newResultIds( var modelDef = resources.get(MetadataUtil.objectKey(modelId)).getModel(); return newResultIds( - tenant, modelDef.getOutputsMap().keySet(), + tenant, modelDef.getOutputsMap(), runModel.getPriorOutputsMap()); } diff --git a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelOrFlow.java b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelOrFlow.java index 32401c68a..03adcc2bd 100644 --- a/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelOrFlow.java +++ b/tracdap-services/tracdap-svc-orch/src/main/java/org/finos/tracdap/svc/orch/jobs/RunModelOrFlow.java @@ -21,6 +21,7 @@ import org.finos.tracdap.api.MetadataWriteRequest; import org.finos.tracdap.api.internal.RuntimeJobResult; import org.finos.tracdap.common.exception.EConsistencyValidation; +import org.finos.tracdap.common.exception.EUnexpected; import org.finos.tracdap.config.JobConfig; import org.finos.tracdap.metadata.*; import org.finos.tracdap.common.metadata.MetadataCodec; @@ -38,36 +39,42 @@ public List requiredMetadata(Map newResou for (var obj : newResources.values()) { - if (obj.getObjectType() != ObjectType.DATA) - continue; + if (obj.getObjectType() == ObjectType.DATA) { + + var dataDef = obj.getData(); + resources.add(dataDef.getStorageId()); - var dataDef = obj.getData(); - resources.add(dataDef.getStorageId()); + if (dataDef.hasSchemaId()) + resources.add(dataDef.getSchemaId()); + } - if (dataDef.hasSchemaId()) - resources.add(dataDef.getSchemaId()); + else if (obj.getObjectType() == ObjectType.FILE) { + + var fileDef = obj.getFile(); + resources.add(fileDef.getStorageId()); + } } return resources; } public Map newResultIds( - String tenant, Set outputKeys, + String tenant, Map outputRequirements, Map priorOutputsMap) { var resultMapping = new HashMap(); - for (var outputKey : outputKeys) { + for (var output : outputRequirements.entrySet()) { - if (priorOutputsMap.containsKey(outputKey)) + if (priorOutputsMap.containsKey(output.getKey())) continue; - var dataKey = String.format("%s:%s", outputKey, ObjectType.DATA); + var outputKey = output.getKey(); var storageKey = String.format("%s:%s", outputKey, ObjectType.STORAGE); - var dataReq = MetadataWriteRequest.newBuilder() + var outputReq = MetadataWriteRequest.newBuilder() .setTenant(tenant) - .setObjectType(ObjectType.DATA) + .setObjectType(output.getValue().getObjectType()) .build(); var storageReq = MetadataWriteRequest.newBuilder() @@ -75,7 +82,7 @@ public Map newResultIds( .setObjectType(ObjectType.STORAGE) .build(); - resultMapping.put(dataKey, dataReq); + resultMapping.put(outputKey, outputReq); resultMapping.put(storageKey, storageReq); } @@ -95,17 +102,25 @@ public Map priorResultIds( if (priorOutput == null) continue; - var priorDataKey = MetadataUtil.objectKey(priorOutput); - var priorDataId = resourceMapping.get(priorDataKey); - var priorDataDef = resources.get(MetadataUtil.objectKey(priorDataId)).getData(); + var priorOutputKey = MetadataUtil.objectKey(priorOutput); + var priorOutputId = resourceMapping.get(priorOutputKey); + var priorOutputDef = resources.get(MetadataUtil.objectKey(priorOutputId)); + + TagSelector priorStorageSelector; - var priorStorageKey = MetadataUtil.objectKey(priorDataDef.getStorageId()); + if (priorOutputDef.getObjectType() == ObjectType.DATA) + priorStorageSelector = priorOutputDef.getData().getStorageId(); + else if (priorOutputDef.getObjectType() == ObjectType.FILE) + priorStorageSelector = priorOutputDef.getFile().getStorageId(); + else + throw new EUnexpected(); + + var priorStorageKey = MetadataUtil.objectKey(priorStorageSelector); var priorStorageId = resourceMapping.get(priorStorageKey); - var dataKey = String.format("%s:%s", outputKey, ObjectType.DATA); var storageKey = String.format("%s:%s", outputKey, ObjectType.STORAGE); - resultMapping.put(dataKey, priorDataId); + resultMapping.put(outputKey, priorOutputId); resultMapping.put(storageKey, priorStorageId); } @@ -120,10 +135,9 @@ public Map setResultIds( for (var outputKey : outputKeys) { - var dataKey = String.format("%s:%s", outputKey, ObjectType.DATA); - var dataId = resultMapping.get(dataKey); - var dataSelector = MetadataUtil.selectorFor(dataId); - outputSelectors.put(outputKey, dataSelector); + var outputId = resultMapping.get(outputKey); + var outputSelector = MetadataUtil.selectorFor(outputId); + outputSelectors.put(outputKey, outputSelector); } return outputSelectors; @@ -140,35 +154,34 @@ public List buildResultMetadata( for (var output: expectedOutputs.entrySet()) { var outputName = output.getKey(); - var outputSchema = output.getValue(); - - // TODO: String constants + var outputDef = output.getValue(); - var dataIdLookup = outputName + ":DATA"; - var dataId = jobConfig.getResultMappingOrThrow(dataIdLookup); - var dataKey = MetadataUtil.objectKey(dataId); + var outputId = jobConfig.getResultMappingOrThrow(outputName); + var outputKey = MetadataUtil.objectKey(outputId); - if (!jobResult.containsResults(dataKey)) { - if (outputSchema.getOptional()) + if (!jobResult.containsResults(outputKey)) { + if (outputDef.getOptional()) continue; else throw new EConsistencyValidation(String.format("Missing required output [%s]", outputName)); } - var storageIdLookup = outputName + ":STORAGE"; - var storageId = jobConfig.getResultMappingOrThrow(storageIdLookup); + // TODO: Preallocated IDs for storage outputs + + var storageMapping = outputName + ":STORAGE"; + var storageId = jobConfig.getResultMappingOrThrow(storageMapping); var storageKey = MetadataUtil.objectKey(storageId); - var dataObj = jobResult.getResultsOrThrow(dataKey); + var outputObj = jobResult.getResultsOrThrow(outputKey); var storageObj = jobResult.getResultsOrThrow(storageKey); - var priorDataSelector = priorOutputs.containsKey(outputName) + var priorOutputSelector = priorOutputs.containsKey(outputName) ? priorOutputs.get(outputName) : MetadataUtil.preallocated(outputs.get(outputName)); var priorStorageSelector = priorOutputs.containsKey(outputName) - ? priorStorageSelector(priorDataSelector, jobConfig) - : MetadataUtil.preallocated(dataObj.getData().getStorageId()); + ? priorStorageSelector(priorOutputSelector, jobConfig) + : preallocatedStorageSelector(outputObj); var controlledAttrs = List.of( TagUpdate.newBuilder() @@ -181,27 +194,27 @@ public List buildResultMetadata( nodeOutputAttrs = List.of(); } - var dataUpdate = MetadataWriteRequest.newBuilder() + var outputUpdate = MetadataWriteRequest.newBuilder() .setTenant(tenant) - .setObjectType(ObjectType.DATA) - .setPriorVersion(priorDataSelector) - .setDefinition(dataObj) + .setObjectType(outputObj.getObjectType()) + .setPriorVersion(priorOutputSelector) + .setDefinition(outputObj) .addAllTagUpdates(controlledAttrs) .addAllTagUpdates(outputAttrs) .addAllTagUpdates(nodeOutputAttrs) .build(); - updates.add(dataUpdate); + updates.add(outputUpdate); var storageAttrs = List.of( TagUpdate.newBuilder() .setAttrName(MetadataConstants.TRAC_STORAGE_OBJECT_ATTR) - .setValue(MetadataCodec.encodeValue(dataKey)) + .setValue(MetadataCodec.encodeValue(outputKey)) .build()); var storageUpdate = MetadataWriteRequest.newBuilder() .setTenant(tenant) - .setObjectType(ObjectType.STORAGE) + .setObjectType(storageObj.getObjectType()) .setPriorVersion(priorStorageSelector) .setDefinition(storageObj) .addAllTagUpdates(storageAttrs) @@ -213,27 +226,50 @@ public List buildResultMetadata( return updates; } - private TagSelector priorStorageSelector(TagSelector priorDataSelector, JobConfig jobConfig) { + private TagSelector priorStorageSelector(TagSelector priorOutputSelector, JobConfig jobConfig) { - var dataKey = MetadataUtil.objectKey(priorDataSelector); + var mappedOutputKey = MetadataUtil.objectKey(priorOutputSelector); + String outputKey; - if (jobConfig.containsResourceMapping(dataKey)) { - var dataId = jobConfig.getResourceMappingOrDefault(dataKey, null); - var dataSelector = MetadataUtil.selectorFor(dataId); - dataKey = MetadataUtil.objectKey(dataSelector); + if (jobConfig.containsResourceMapping(mappedOutputKey)) { + var outputId = jobConfig.getResourceMappingOrDefault(mappedOutputKey, null); + var outputSelector = MetadataUtil.selectorFor(outputId); + outputKey = MetadataUtil.objectKey(outputSelector); } + else + outputKey = mappedOutputKey; + + var outputObj = jobConfig.getResourcesOrThrow(outputKey); + + TagSelector storageSelector; - var dataObj = jobConfig.getResourcesOrThrow(dataKey); + if (outputObj.getObjectType() == ObjectType.DATA) + storageSelector = outputObj.getData().getStorageId(); + else if (outputObj.getObjectType() == ObjectType.FILE) + storageSelector = outputObj.getFile().getStorageId(); + else + throw new EUnexpected(); - var storageSelector = dataObj.getData().getStorageId(); var storageKey = MetadataUtil.objectKey(storageSelector); if (jobConfig.containsResourceMapping(storageKey)) { var storageId = jobConfig.getResourceMappingOrDefault(storageKey, null); - storageSelector = MetadataUtil.selectorFor(storageId); + return MetadataUtil.selectorFor(storageId); } + else + return storageSelector; + } + + private TagSelector preallocatedStorageSelector(ObjectDefinition outputObj) { + + if (outputObj.getObjectType() == ObjectType.DATA) + return MetadataUtil.preallocated(outputObj.getData().getStorageId()); + + else if (outputObj.getObjectType() == ObjectType.FILE) + return MetadataUtil.preallocated(outputObj.getFile().getStorageId()); - return storageSelector; + else + throw new EUnexpected(); } } diff --git a/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/api/JobValidationTest.java b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/api/JobValidationTest.java index 1e055f19f..cae1a124b 100644 --- a/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/api/JobValidationTest.java +++ b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/api/JobValidationTest.java @@ -192,9 +192,11 @@ private TagSelector createBasicModel(SchemaDefinition inputSchema, SchemaDefinit .setParamType(TypeSystem.descriptor(BasicType.FLOAT)) .build()) .putInputs("basic_input", ModelInputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(inputSchema) .build()) .putOutputs("enriched_output", ModelOutputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(outputSchema) .build()) .build(); @@ -436,12 +438,14 @@ private TagSelector createFlowModel( for (var input : inputSchemas.entrySet()) { modelDef.putInputs(input.getKey(), ModelInputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(input.getValue()) .build()); } for (var output : outputSchemas.entrySet()) { modelDef.putOutputs(output.getKey(), ModelOutputSchema.newBuilder() + .setObjectType(ObjectType.DATA) .setSchema(output.getValue()) .build()); } diff --git a/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/RunModelTest.java b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/RunModelTest.java index 315de3d1d..41674aa5d 100644 --- a/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/RunModelTest.java +++ b/tracdap-services/tracdap-svc-orch/src/test/java/org/finos/tracdap/svc/orch/jobs/RunModelTest.java @@ -29,6 +29,8 @@ import org.junit.jupiter.api.*; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -73,6 +75,12 @@ public class RunModelTest { static TagHeader dynamicIoModelId; static TagHeader dynamicIoOutputDataId; + static TagHeader inputFileId; + static long inputFileSize; + static TagHeader fileIoModelId; + static TagHeader outputFileId; + static TagHeader outputFileIdStream; + @Test @Order(1) void loadInputData() throws Exception { @@ -275,7 +283,7 @@ void checkOutputData() { @Test @Order(5) void optionalIO_importModel() throws Exception { - log.info("Running IMPORT_MODEL job..."); + log.info("Running IMPORT_MODEL job for optional IO..."); var modelVersion = GitHelpers.getCurrentCommit(); var modelStub = ModelDefinition.newBuilder() @@ -379,7 +387,7 @@ void optionalIO_runModel() { @Test @Order(7) void optionalIO_checkOutputData() { - log.info("Checking output data..."); + log.info("Checking output data for optional IO..."); var dataClient = platform.dataClientBlocking(); @@ -408,10 +416,7 @@ void optionalIO_checkOutputData() { @Test @Order(8) void dynamicIO_importModel() throws Exception { - log.info("Running IMPORT_MODEL job..."); - - var metaClient = platform.metaClientBlocking(); - var orchClient = platform.orchClientBlocking(); + log.info("Running IMPORT_MODEL job for dynamic IO..."); var modelVersion = GitHelpers.getCurrentCommit(); var modelStub = ModelDefinition.newBuilder() @@ -514,7 +519,7 @@ void dynamicIO_runModel() { @Test @Order(10) void dynamicIO_checkOutputData() { - log.info("Checking output data..."); + log.info("Checking output data for dynamic IO..."); var dataClient = platform.dataClientBlocking(); @@ -542,4 +547,172 @@ void dynamicIO_checkOutputData() { Assertions.assertFalse(csvText.contains("munster")); } + @Test @Order(11) + void fileIO_loadInputData() { + + log.info("Loading input file..."); + + var sampleData = "Some text in a file\r\n".getBytes(StandardCharsets.UTF_8); + var sampleDataSize = sampleData.length; + + var metaClient = platform.metaClientBlocking(); + var dataClient = platform.dataClientBlocking(); + + var writeRequest = FileWriteRequest.newBuilder() + .setTenant(TEST_TENANT) + .setName("input_file.txt") + .setMimeType("text/plain") + .setContent(ByteString.copyFrom(sampleData)) + .setSize(sampleDataSize) + .addTagUpdates(TagUpdate.newBuilder() + .setAttrName("e2e_test_file") + .setValue(MetadataCodec.encodeValue("run_model:file_io"))) + .build(); + + inputFileId = dataClient.createSmallFile(writeRequest); + inputFileSize = sampleDataSize; + + var fileSelector = MetadataUtil.selectorFor(inputFileId); + var fileRequest = MetadataReadRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSelector(fileSelector) + .build(); + + var fileTag = metaClient.readObject(fileRequest); + + var fileAttr = fileTag.getAttrsOrThrow("e2e_test_file"); + var fileDef = fileTag.getDefinition().getFile(); + + Assertions.assertEquals("run_model:file_io", MetadataCodec.decodeStringValue(fileAttr)); + Assertions.assertEquals(sampleDataSize, fileDef.getSize()); + + log.info("Input file loaded, data ID = [{}]", fileTag.getHeader().getObjectId()); + } + + @Test @Order(12) + void fileIO_importModel() throws Exception { + + log.info("Running IMPORT_MODEL job for file IO..."); + + var modelVersion = GitHelpers.getCurrentCommit(); + var modelStub = ModelDefinition.newBuilder() + .setLanguage("python") + .setRepository(useTracRepo()) + .setPath("tracdap-services/tracdap-svc-orch/src/test/resources") + .setEntryPoint("file_io.FileIOModel") + .setVersion(modelVersion) + .build(); + + var modelAttrs = List.of(TagUpdate.newBuilder() + .setAttrName("e2e_test_model") + .setValue(MetadataCodec.encodeValue("run_model:file_io")) + .build()); + + var jobAttrs = List.of(TagUpdate.newBuilder() + .setAttrName("e2e_test_job") + .setValue(MetadataCodec.encodeValue("run_model:file_io_import_model")) + .build()); + + var modelTag = ImportModelTest.doImportModel(platform, TEST_TENANT, modelStub, modelAttrs, jobAttrs); + var modelDef = modelTag.getDefinition().getModel(); + var modelAttr = modelTag.getAttrsOrThrow("e2e_test_model"); + + Assertions.assertEquals("run_model:file_io", MetadataCodec.decodeStringValue(modelAttr)); + Assertions.assertEquals("file_io.FileIOModel", modelDef.getEntryPoint()); + Assertions.assertTrue(modelDef.getInputsMap().containsKey("file_input")); + Assertions.assertTrue(modelDef.getOutputsMap().containsKey("file_output")); + + fileIoModelId = modelTag.getHeader(); + } + + @ParameterizedTest() @Order(13) + @ValueSource(booleans= {true, false}) + void fileIO_runModel(boolean useStreams) { + + var metaClient = platform.metaClientBlocking(); + var orchClient = platform.orchClientBlocking(); + + var runModel = RunModelJob.newBuilder() + .setModel(MetadataUtil.selectorFor(fileIoModelId)) + .putParameters("n_copies", MetadataCodec.encodeValue(3)) + .putParameters("use_streams", MetadataCodec.encodeValue(useStreams)) + .putInputs("file_input", MetadataUtil.selectorFor(inputFileId)) + .addOutputAttrs(TagUpdate.newBuilder() + .setAttrName("e2e_test_data") + .setValue(MetadataCodec.encodeValue("run_model:file_io_data_output"))) + .build(); + + var jobRequest = JobRequest.newBuilder() + .setTenant(TEST_TENANT) + .setJob(JobDefinition.newBuilder() + .setJobType(JobType.RUN_MODEL) + .setRunModel(runModel)) + .addJobAttrs(TagUpdate.newBuilder() + .setAttrName("e2e_test_job") + .setValue(MetadataCodec.encodeValue("run_model:file_io_run_model"))) + .build(); + + var jobStatus = runJob(orchClient, jobRequest); + var jobKey = MetadataUtil.objectKey(jobStatus.getJobId()); + + Assertions.assertEquals(JobStatusCode.SUCCEEDED, jobStatus.getStatusCode()); + + var fileSearch = MetadataSearchRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSearchParams(SearchParameters.newBuilder() + .setObjectType(ObjectType.FILE) + .setSearch(SearchExpression.newBuilder() + .setTerm(SearchTerm.newBuilder() + .setAttrName("trac_create_job") + .setAttrType(BasicType.STRING) + .setOperator(SearchOperator.EQ) + .setSearchValue(MetadataCodec.encodeValue(jobKey))))) + .build(); + + var dataSearchResult = metaClient.search(fileSearch); + + Assertions.assertEquals(1, dataSearchResult.getSearchResultCount()); + + var searchResult = dataSearchResult.getSearchResult(0); + var fileReq = MetadataReadRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSelector(MetadataUtil.selectorFor(searchResult.getHeader())) + .build(); + + var fileTag = metaClient.readObject(fileReq); + var fileDef = fileTag.getDefinition().getFile(); + var outputAttr = fileTag.getAttrsOrThrow("e2e_test_data"); + + Assertions.assertEquals("run_model:file_io_data_output", MetadataCodec.decodeStringValue(outputAttr)); + Assertions.assertEquals(inputFileSize * 3, fileDef.getSize()); + + if (useStreams) + outputFileIdStream = fileTag.getHeader(); + else + outputFileId = fileTag.getHeader(); + } + + @ParameterizedTest @Order(14) + @ValueSource(booleans= {true, false}) + void fileIO_checkOutputData(boolean useStreams) { + + log.info("Checking output data for file IO..."); + + var dataClient = platform.dataClientBlocking(); + + var fileId = useStreams ? outputFileIdStream : outputFileId; + + var readRequest = FileReadRequest.newBuilder() + .setTenant(TEST_TENANT) + .setSelector(MetadataUtil.selectorFor(fileId)) + .build(); + + var readResponse = dataClient.readSmallFile(readRequest); + + var expectedContent = "Some text in a file\r\n".repeat(3); + var fileContents = readResponse.getContent().toString(StandardCharsets.UTF_8); + + Assertions.assertEquals(expectedContent, fileContents); + } + } diff --git a/tracdap-services/tracdap-svc-orch/src/test/resources/file_io.py b/tracdap-services/tracdap-svc-orch/src/test/resources/file_io.py new file mode 100644 index 000000000..a3307e0c7 --- /dev/null +++ b/tracdap-services/tracdap-svc-orch/src/test/resources/file_io.py @@ -0,0 +1,60 @@ +# Licensed to the Fintech Open Source Foundation (FINOS) under one or +# more contributor license agreements. See the NOTICE file distributed +# with this work for additional information regarding copyright ownership. +# FINOS licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with the +# License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import tracdap.rt.api as trac + + +class FileIOModel(trac.TracModel): + + def define_parameters(self) -> tp.Dict[str, trac.ModelParameter]: + + return trac.define_parameters( + trac.P("n_copies", trac.BasicType.INTEGER, label="Number of times to copy the input data"), + trac.P("use_streams", trac.BasicType.BOOLEAN, default_value=False, label="Flag to enable streams for file IO")) + + def define_inputs(self) -> tp.Dict[str, trac.ModelInputSchema]: + + file_input = trac.define_input(trac.CommonFileTypes.TXT, label="Quarterly sales report") + + return {"file_input": file_input} + + def define_outputs(self) -> tp.Dict[str, trac.ModelOutputSchema]: + + file_output = trac.define_output(trac.CommonFileTypes.TXT, label="Quarterly sales report") + + return {"file_output": file_output} + + def run_model(self, ctx: trac.TracContext): + + n_copies = ctx.get_parameter("n_copies") + use_streams = ctx.get_parameter("use_streams") + + if use_streams: + + with ctx.get_file_stream("file_input") as in_stream: + in_data = in_stream.read() + + out_data = in_data * n_copies + + with ctx.put_file_stream("file_output") as out_stream: + out_stream.write(out_data) + + else: + + in_data = ctx.get_file("file_input") + out_data = in_data * n_copies + ctx.put_file("file_output", out_data)