diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b899f3f..2e21449e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [2.7.0](https://github.com/privacysandbox/aggregation-service/compare/v2.6.0...v2.7.0) (2024-08-01) + +- Added support for aggregating reports belonging to multiple reporting origins under the same + reporting site in a single aggregation job. +- [GCP Only] Updated coordinator endpoints to new Google/Third-Party coordinator pair. + ## [2.6.0](https://github.com/privacysandbox/aggregation-service/compare/v2.5.0...v2.6.0) (2024-07-19) - Enabled support for diff --git a/VERSION b/VERSION index e70b4523..24ba9a38 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.6.0 +2.7.0 diff --git a/build-scripts/DEBIAN_CONTAINER_DIGEST b/build-scripts/DEBIAN_CONTAINER_DIGEST index 9a04a244..81df05b0 100644 --- a/build-scripts/DEBIAN_CONTAINER_DIGEST +++ b/build-scripts/DEBIAN_CONTAINER_DIGEST @@ -1 +1 @@ -sha256:39868a6f452462b70cf720a8daff250c63e7342970e749059c105bf7c1e8eeaf +sha256:16112ae93b810eb1ec6d1db6e01835d2444c8ca99aa678e03dd104ea3ec68408 diff --git a/build-scripts/gcp/cloudbuild.yaml b/build-scripts/gcp/cloudbuild.yaml index dc8ba284..af59f7b4 100644 --- a/build-scripts/gcp/cloudbuild.yaml +++ b/build-scripts/gcp/cloudbuild.yaml @@ -15,7 +15,7 @@ steps: - name: '$_BUILD_IMAGE_REPO_PATH/bazel-build-container:$_VERSION' script: | - bazel run worker/gcp:worker_mp_gcp_prod -- -dst "$_IMAGE_REPO_PATH/$_IMAGE_NAME:$_IMAGE_TAG" + bazel run worker/gcp:worker_mp_gcp_g3p_prod -- -dst "$_IMAGE_REPO_PATH/$_IMAGE_NAME:$_IMAGE_TAG" bazel run //terraform/gcp:frontend_service_http_cloud_function_release \ --//terraform/gcp:bucket_flag=$_JARS_PUBLISH_BUCKET --//terraform/gcp:bucket_path_flag=$_JARS_PUBLISH_BUCKET_PATH \ -- --version=$_VERSION diff --git a/build_defs/container_dependencies.bzl b/build_defs/container_dependencies.bzl index 7c270a6f..153cb435 100644 --- a/build_defs/container_dependencies.bzl +++ b/build_defs/container_dependencies.bzl @@ -24,11 +24,11 @@ # - java_base: Distroless image for running Java. ################################################################################ -# Updated as of: 2024-07-19 +# Updated as of: 2024-07-26 CONTAINER_DEPS = { "amazonlinux_2": { - "digest": "sha256:7081389e0a1d55d5c05a6bab72fb8a82b37c72a724365c6104c7fbc5bcdb2e09", + "digest": "sha256:b2ed30084a71c34c0f41a5add7dd623a2e623f2c3b50117c720bbc02d7653fa1", "registry": "index.docker.io", "repository": "amazonlinux", }, diff --git a/docs/api.md b/docs/api.md index ad8dd809..c46ab1c7 100644 --- a/docs/api.md +++ b/docs/api.md @@ -68,6 +68,14 @@ POST // This should be same as the reporting_origin present in the reports' shared_info. "attribution_report_to": , + // [Optional] Reporting Site. + // This should be the reporting site that is onboared to aggregation service. + // Note: All reports in the request should have reporting origins which + // belong to the reporting site mentioned in this parameter. This parameter + // and the "attribution_report_to" parameter are mutually exclusive, exactly + // one of the two parameters should be provided in the request. + "reporting_site": "" + // [Optional] Differential privacy epsilon value to be used // for this job. 0.0 < debug_privacy_epsilon <= 64.0. The // value can be varied so that tests with different epsilon @@ -156,6 +164,10 @@ These are the validations that are done before the aggregation begins. ATTRIBUTION_REPORT_TO_MISMATCH error counter. Aggregatable report validations and error counters can be found in the [Input Aggregatable Report Validations](#input-aggregatable-report-validations) below +4. Job request's `job_parameters` should contain exactly one of `attribution_report_to` and + `reporting_site`. +5. If `job_parameters.reporting_site` is provided, `shared_info.reporting_origin` of all + aggregatable reports should belong to this reporting site. Return code: [INVALID_JOB](java/com/google/aggregate/adtech/worker/AggregationWorkerReturnCode.java#L38) @@ -227,6 +239,8 @@ Not found: 404 Not Found "output_domain_bucket_name": , // Reporting URL "attribution_report_to" : , + // [Optional] Reporting site value from the CreateJob request, if provided. + "reporting_site": // [Optional] differential privacy epsilon value to be used // for this job. 0.0 < debug_privacy_epsilon <= 64.0. The // value can be varied so that tests with different epsilon diff --git a/docs/gcp-aggregation-service.md b/docs/gcp-aggregation-service.md index 94f8d495..f8652c21 100644 --- a/docs/gcp-aggregation-service.md +++ b/docs/gcp-aggregation-service.md @@ -342,3 +342,12 @@ _Note: If you use self-built artifacts described in [build-scripts/gcp](/build-scripts/gcp/README.md), run `bash fetch_terraform.sh` instead of `bash download_prebuilt_dependencies.sh` and make sure you updated your dependencies in the `jars` folder._ + +_Note: When migrating to new coordinator pair from version 2.[4|5|6].z to 2.7.z or later, ensure the +file `/terraform/gcp/environments/shared/release_params.auto.tfvars` was updated with the following +values:_ + +```sh +coordinator_a_impersonate_service_account = "a-opallowedusr@ps-msmt-coord-prd-g3p-svcacc.iam.gserviceaccount.com" +coordinator_b_impersonate_service_account = "b-opallowedusr@ps-prod-msmt-type2-e541.iam.gserviceaccount.com" +``` diff --git a/java/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLogger.java b/java/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLogger.java index 6e9c134e..7d01f2d1 100644 --- a/java/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLogger.java +++ b/java/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLogger.java @@ -17,7 +17,6 @@ package com.google.aggregate.adtech.worker; import static com.google.aggregate.adtech.worker.util.DebugSupportHelper.getDebugFilePrefix; -import static com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient.BlobStorageClientException; import static com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient.getDataLocation; import static com.google.scp.operator.shared.model.BackendModelUtil.toJobKeyString; import static java.lang.annotation.ElementType.FIELD; @@ -31,10 +30,8 @@ import com.google.aggregate.adtech.worker.Annotations.ResultWriter; import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.util.OutputShardFileHelper; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; -import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter.FileWriteException; import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -45,7 +42,6 @@ import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient; import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation; import com.google.scp.operator.cpio.jobclient.model.Job; -import java.io.IOException; import java.lang.annotation.Retention; import java.lang.annotation.Target; import java.nio.file.Files; @@ -69,7 +65,6 @@ public final class LocalFileToCloudStorageLogger implements ResultLogger { private final BlobStorageClient blobStorageClient; private final Path workingDirectory; private final ListeningExecutorService blockingThreadPool; - private static final String reencryptedReportFileNamePrefix = "reencrypted-"; @Inject LocalFileToCloudStorageLogger( @@ -165,26 +160,6 @@ public void logResults(ImmutableList results, Job ctx, boolean i } } - @Override - public void logReports(ImmutableList reports, Job ctx, String shardNumber) - throws ResultLogException { - String localFileName = - toJobKeyString(ctx.jobKey()) - + "-" - + reencryptedReportFileNamePrefix - + shardNumber - + ".avro"; - Path localReportsFilePath = - workingDirectory - .getFileSystem() - .getPath(Paths.get(workingDirectory.toString(), localFileName).toString()); - try { - writeReportsToCloud(reports.stream(), ctx, localReportsFilePath, localResultFileWriter); - } catch (FileWriteException | BlobStorageClientException | IOException e) { - throw new ResultLogException(e); - } - } - @SuppressWarnings("UnstableApiUsage") private ListenableFuture writeFile( Stream aggregatedFacts, @@ -220,21 +195,6 @@ private ListenableFuture writeFile( blockingThreadPool); } - private void writeReportsToCloud( - Stream reports, Job ctx, Path localFilepath, LocalResultFileWriter writer) - throws IOException, FileWriteException, BlobStorageClientException { - Files.createDirectories(workingDirectory); - writer.writeLocalReportFile(reports, localFilepath); - - String outputDataBlobBucket = ctx.requestInfo().getOutputDataBucketName(); - String outputDataBlobPrefix = localFilepath.getFileName().toString(); - - DataLocation resultLocation = getDataLocation(outputDataBlobBucket, outputDataBlobPrefix); - - blobStorageClient.putBlob(resultLocation, localFilepath); - Files.deleteIfExists(localFilepath); - } - /** * The local file name has a random UUID in it to prevent cases where an item is processed twice * by the same worker and clobbers other files being written. diff --git a/java/com/google/aggregate/adtech/worker/LocalResultLogger.java b/java/com/google/aggregate/adtech/worker/LocalResultLogger.java index 9d304b0a..60ee2995 100644 --- a/java/com/google/aggregate/adtech/worker/LocalResultLogger.java +++ b/java/com/google/aggregate/adtech/worker/LocalResultLogger.java @@ -23,7 +23,6 @@ import com.google.aggregate.adtech.worker.LibraryAnnotations.LocalOutputDirectory; import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter.FileWriteException; import com.google.common.collect.ImmutableList; @@ -68,23 +67,6 @@ public void logResults(ImmutableList results, Job ctx, boolean i isDebugRun ? localDebugResultFileWriter : localResultFileWriter); } - // TODO(b/315199032): Add local runner test - @Override - public void logReports(ImmutableList reports, Job ctx, String shardNumber) - throws ResultLogException { - String localFileName = "reencrypted_report.avro"; - Path localReportsFilePath = - workingDirectory - .getFileSystem() - .getPath(Paths.get(workingDirectory.toString(), localFileName).toString()); - try { - Files.createDirectories(workingDirectory); - localResultFileWriter.writeLocalReportFile(reports.stream(), localReportsFilePath); - } catch (IOException | FileWriteException e) { - throw new ResultLogException(e); - } - } - private DataLocation writeFile( Stream results, Job ctx, Path filePath, LocalResultFileWriter writer) throws ResultLogException { diff --git a/java/com/google/aggregate/adtech/worker/ResultLogger.java b/java/com/google/aggregate/adtech/worker/ResultLogger.java index 0a1662bc..d631dfa4 100644 --- a/java/com/google/aggregate/adtech/worker/ResultLogger.java +++ b/java/com/google/aggregate/adtech/worker/ResultLogger.java @@ -18,7 +18,6 @@ import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.common.collect.ImmutableList; import com.google.scp.operator.cpio.jobclient.model.Job; @@ -28,8 +27,4 @@ public interface ResultLogger { /** Takes the aggregation results and logs them to results. */ void logResults(ImmutableList results, Job ctx, boolean isDebugRun) throws ResultLogException; - - /** Logs encrypted aggregatable reports. */ - void logReports(ImmutableList results, Job ctx, String shardNumber) - throws ResultLogException; } diff --git a/java/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessor.java b/java/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessor.java index 7c3a1644..41e101a4 100644 --- a/java/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessor.java +++ b/java/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessor.java @@ -33,7 +33,6 @@ import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_OUTPUT_DOMAIN_BUCKET_NAME; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_REPORT_ERROR_THRESHOLD_PERCENTAGE; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.util.concurrent.Futures.immediateFuture; import static com.google.scp.operator.shared.model.BackendModelUtil.toJobKeyString; import com.google.aggregate.adtech.worker.AggregationWorkerReturnCode; @@ -76,8 +75,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.primitives.UnsignedLong; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.errorprone.annotations.Var; import com.google.privacysandbox.otel.OTelConfiguration; @@ -94,7 +91,6 @@ import io.reactivex.rxjava3.schedulers.Schedulers; import java.io.IOException; import java.io.InputStream; -import java.math.BigInteger; import java.security.AccessControlException; import java.util.List; import java.util.Map; @@ -115,6 +111,8 @@ public final class ConcurrentAggregationProcessor implements JobProcessor { public static final String JOB_PARAM_ATTRIBUTION_REPORT_TO = "attribution_report_to"; // Key to indicate whether this is a debug job public static final String JOB_PARAM_DEBUG_RUN = "debug_run"; + // Key for user provided reporting site value in the job params of the job request. + public static final String JOB_PARAM_REPORTING_SITE = "reporting_site"; private static final int NUM_CPUS = Runtime.getRuntime().availableProcessors(); // In aggregation service, reading is much faster than decryption, and most of the time, it waits @@ -153,6 +151,8 @@ public final class ConcurrentAggregationProcessor implements JobProcessor { private final boolean enablePrivacyBudgetKeyFiltering; private final OTelConfiguration oTelConfiguration; private final double defaultReportErrorThresholdPercentage; + + // TODO(b/338219415): Reuse this flag to enable full streaming approach. private final Boolean streamingOutputDomainProcessing; @Inject @@ -223,7 +223,7 @@ public JobResult process(Job job) if (jobParams.containsKey(JOB_PARAM_OUTPUT_DOMAIN_BUCKET_NAME) && jobParams.containsKey(JOB_PARAM_OUTPUT_DOMAIN_BLOB_PREFIX) && (!jobParams.get(JOB_PARAM_OUTPUT_DOMAIN_BUCKET_NAME).isEmpty() - || !jobParams.get(JOB_PARAM_OUTPUT_DOMAIN_BLOB_PREFIX).isEmpty())) { + || !jobParams.get(JOB_PARAM_OUTPUT_DOMAIN_BLOB_PREFIX).isEmpty())) { outputDomainLocation = Optional.of( BlobStorageClient.getDataLocation( @@ -297,7 +297,6 @@ public JobResult process(Job job) NoisedAggregatedResultSet noisedResultSet; try { - if (streamingOutputDomainProcessing) { noisedResultSet = conflateWithDomainAndAddNoiseStreaming( outputDomainLocation, @@ -305,26 +304,9 @@ public JobResult process(Job job) aggregationEngine, debugPrivacyEpsilon, debugRun); - - } else { - noisedResultSet = - conflateWithDomainAndAddNoise( - outputDomainLocation, - outputDomainShards, - aggregationEngine, - debugPrivacyEpsilon, - debugRun); - } } catch (DomainReadException e) { throw new AggregationJobProcessException( INPUT_DATA_READ_FAILED, "Exception while reading domain input data.", e.getCause()); - } catch (ExecutionException e) { - if (e.getCause() instanceof DomainReadException) { - throw new AggregationJobProcessException( - INPUT_DATA_READ_FAILED, "Exception while reading domain input data.", e.getCause()); - } - throw new AggregationJobProcessException( - INTERNAL_ERROR, "Exception in processing domain.", e); } processingStopwatch.stop(); @@ -399,34 +381,6 @@ private NoisedAggregatedResultSet conflateWithDomainAndAddNoiseStreaming( debugRun); } - private NoisedAggregatedResultSet conflateWithDomainAndAddNoise( - Optional outputDomainLocation, - ImmutableList outputDomainShards, - AggregationEngine engine, - Optional debugPrivacyEpsilon, - Boolean debugRun) - throws DomainReadException, ExecutionException, InterruptedException { - @Var - ListenableFuture> outputDomainFuture = - outputDomainLocation - .map(loc -> outputDomainProcessor.readAndDedupeDomain(loc, outputDomainShards)) - .orElse(immediateFuture(ImmutableSet.of())); - - ListenableFuture aggregationFinalFuture = - Futures.transform( - outputDomainFuture, - outputDomain -> - outputDomainProcessor.adjustAggregationWithDomainAndNoise( - noisedAggregationRunner, - outputDomain, - engine.makeAggregation(), - debugPrivacyEpsilon, - debugRun), - nonBlockingThreadPool); - - return aggregationFinalFuture.get(); - } - private double getReportErrorThresholdPercentage(Map jobParams) { String jobParamsReportErrorThresholdPercentage = jobParams.getOrDefault(JOB_PARAM_REPORT_ERROR_THRESHOLD_PERCENTAGE, null); @@ -449,21 +403,34 @@ private void consumePrivacyBudgetUnits(ImmutableList budgetsT return; } + String claimedIdentity; + // Validations ensure that at least one of the parameters will always exist. + if (job.requestInfo().getJobParametersMap().containsKey(JOB_PARAM_REPORTING_SITE)) { + claimedIdentity = job.requestInfo().getJobParametersMap().get(JOB_PARAM_REPORTING_SITE); + } else { + try { + claimedIdentity = + ReportingOriginUtils.convertReportingOriginToSite( + job.requestInfo().getJobParametersMap().get(JOB_PARAM_ATTRIBUTION_REPORT_TO)); + } catch (InvalidReportingOriginException e) { + // This should never happen due to validations ensuring that the reporting origin is always + // valid. + throw new IllegalStateException( + "Invalid reporting origin found while consuming budget, this should not happen as job" + + " validations ensure the reporting origin is always valid.", + e); + } + } + ImmutableList missingPrivacyBudgetUnits; try { try (Timer t = oTelConfiguration.createDebugTimerStarted("pbs_latency", toJobKeyString(job.jobKey()))) { final String reportingOrigin = job.requestInfo().getJobParametersMap().get(JOB_PARAM_ATTRIBUTION_REPORT_TO); - final String claimedIdentity = - ReportingOriginUtils.convertReportingOriginToSite(reportingOrigin); missingPrivacyBudgetUnits = - privacyBudgetingServiceBridge.consumePrivacyBudget(budgetsToConsume, claimedIdentity); - } catch (InvalidReportingOriginException e) { - throw new AggregationJobProcessException( - INVALID_JOB, - "The attribution_report_to parameter specified in the CreateJob request is not under a" - + " known public suffix."); + privacyBudgetingServiceBridge.consumePrivacyBudget( + budgetsToConsume, claimedIdentity); } } catch (PrivacyBudgetingServiceBridgeException e) { if (e.getStatusCode() != null) { diff --git a/java/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessor.java b/java/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessor.java index f917573f..a46de36e 100644 --- a/java/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessor.java +++ b/java/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessor.java @@ -16,27 +16,19 @@ package com.google.aggregate.adtech.worker.aggregation.domain; -import static com.google.common.collect.ImmutableList.toImmutableList; - import com.google.aggregate.adtech.worker.Annotations.BlockingThreadPool; import com.google.aggregate.adtech.worker.Annotations.DomainOptional; import com.google.aggregate.adtech.worker.Annotations.EnableThresholding; import com.google.aggregate.adtech.worker.Annotations.NonBlockingThreadPool; import com.google.aggregate.adtech.worker.exceptions.DomainReadException; import com.google.aggregate.perf.StopwatchRegistry; -import com.google.aggregate.protocol.avro.AvroOutputDomainReader; import com.google.aggregate.protocol.avro.AvroOutputDomainReaderFactory; import com.google.aggregate.protocol.avro.AvroOutputDomainRecord; -import com.google.common.base.Stopwatch; -import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient; -import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient.BlobStorageClientException; -import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; -import java.util.UUID; import java.util.stream.Stream; import javax.inject.Inject; import org.apache.avro.AvroRuntimeException; @@ -69,32 +61,6 @@ public AvroOutputDomainProcessor( this.stopwatches = stopwatches; } - @Override - protected ImmutableList readShard(DataLocation outputDomainLocation) { - Stopwatch stopwatch = - stopwatches.createStopwatch(String.format("domain-shard-read-%s", UUID.randomUUID())); - stopwatch.start(); - try { - if (blobStorageClient.getBlobSize(outputDomainLocation) <= 0) { - stopwatch.stop(); - return ImmutableList.of(); - } - try (InputStream domainStream = blobStorageClient.getBlob(outputDomainLocation)) { - AvroOutputDomainReader outputDomainReader = avroReaderFactory.create(domainStream); - ImmutableList shard = - outputDomainReader - .streamRecords() - .map(AvroOutputDomainRecord::bucket) - .collect(toImmutableList()); - stopwatch.stop(); - return shard; - } - } catch (IOException | BlobStorageClientException | AvroRuntimeException e) { - stopwatch.stop(); // stop the stopwatch if an exception occurs - throw new DomainReadException(e); - } - } - @Override public Stream readInputStream(InputStream shardInputStream) { try { diff --git a/java/com/google/aggregate/adtech/worker/aggregation/domain/OutputDomainProcessor.java b/java/com/google/aggregate/adtech/worker/aggregation/domain/OutputDomainProcessor.java index b434c380..0fc70173 100644 --- a/java/com/google/aggregate/adtech/worker/aggregation/domain/OutputDomainProcessor.java +++ b/java/com/google/aggregate/adtech/worker/aggregation/domain/OutputDomainProcessor.java @@ -26,17 +26,8 @@ import com.google.aggregate.privacy.noise.NoisedAggregationRunner; import com.google.aggregate.privacy.noise.model.NoisedAggregatedResultSet; import com.google.aggregate.privacy.noise.model.NoisedAggregationResult; -import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.MapDifference; -import com.google.common.collect.MapDifference.ValueDifference; -import com.google.common.collect.Maps; import com.google.common.collect.Sets; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient; import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient.BlobStorageClientException; @@ -48,15 +39,12 @@ import java.io.InputStream; import java.math.BigInteger; import java.util.ArrayList; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.UUID; import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; import org.slf4j.Logger; @@ -96,54 +84,6 @@ public abstract class OutputDomainProcessor { this.enableThresholding = enableThresholding; } - /** - * Asynchronously reads output domain from {@link DataLocation} shards and returns a deduped set - * of buckets in output domain as {@link BigInteger}. The input data location can contain many - * shards. - * - *

Shards are read asynchronously. If there is an error reading the shards the future will - * complete with an exception. - * - * @return ListenableFuture containing the output domain buckets in a set - * @throws DomainReadException (unchecked) if there is an error listing the shards or the location - * provided has no shards present. - */ - public ListenableFuture> readAndDedupeDomain( - DataLocation outputDomainLocation, ImmutableList shards) { - ImmutableList>> futureShardReads = - shards.stream() - .map(shard -> blockingThreadPool.submit(() -> readShard(shard))) - .collect(ImmutableList.toImmutableList()); - - ListenableFuture>> allFutureShards = - Futures.allAsList(futureShardReads); - - return Futures.transform( - allFutureShards, - readShards -> { - Stopwatch stopwatch = - stopwatches.createStopwatch( - String.format("domain-combine-shards-%s", UUID.randomUUID())); - stopwatch.start(); - ImmutableSet domain = - readShards.stream() - .flatMap(Collection::stream) - .collect(ImmutableSet.toImmutableSet()); - stopwatch.stop(); - if (domain.isEmpty()) { - throw new DomainReadException( - new IllegalArgumentException( - String.format( - "No output domain provided in the location. : %s. Please refer to the API" - + " documentation for output domain parameters at" - + " https://github.com/privacysandbox/aggregation-service/blob/main/docs/api.md", - outputDomainLocation))); - } - return domain; - }, - nonBlockingThreadPool); - } - /** * Read all shards at {@link DataLocation} on the cloud storage provider. * @@ -335,98 +275,5 @@ private Flowable readShardData(DataLocation shard) { InputStream::close); } - /** - * Conflate aggregated facts with the output domain and noise results using the Maps.Difference - * API. - * - * @return NoisedAggregatedResultSet containing the combined and noised Aggregatable reports and - * output domain buckets. - */ - public NoisedAggregatedResultSet adjustAggregationWithDomainAndNoise( - NoisedAggregationRunner noisedAggregationRunner, - ImmutableSet outputDomain, - ImmutableMap reportsAggregatedMap, - Optional debugPrivacyEpsilon, - Boolean debugRun) { - // This pseudo-aggregation has all zeroes for the output domain. If a key is present in the - // output domain, but not in the aggregation itself, a zero is inserted which will later be - // noised to some value. - ImmutableMap outputDomainPseudoAggregation = - outputDomain.stream() - .collect( - ImmutableMap.toImmutableMap( - Function.identity(), key -> AggregatedFact.create(key, /* metric= */ 0))); - - // Difference by key is computed so that the output can be adjusted for the output domain. - // Keys that are in the aggregation data, but not in the output domain, are subject to both - // noising and thresholding. - // Otherwise, the data is subject to noising only. - MapDifference pseudoDiff = - Maps.difference(reportsAggregatedMap, outputDomainPseudoAggregation); - - // The values for common keys should in theory be differing, since the pseudo aggregation will - // have all zeroes, while the 'real' aggregation will have non-zeroes, but just in case to - // cover overlapping zeroes, matching keys are also processed. - // `overlappingZeroes` includes all the keys present in both domain and reports but - // the values are 0. - Iterable overlappingZeroes = pseudoDiff.entriesInCommon().values(); - // `overlappingNonZeroes` includes all the keys present in both domain and reports, and the - // value is non-zero in reports. - Iterable overlappingNonZeroes = - Maps.transformValues(pseudoDiff.entriesDiffering(), ValueDifference::leftValue).values(); - // `domainOutputOnlyZeroes` only includes keys in domain. - Iterable domainOutputOnlyZeroes = pseudoDiff.entriesOnlyOnRight().values(); - - NoisedAggregationResult noisedOverlappingNoThreshold = - noisedAggregationRunner.noise( - Iterables.concat(overlappingZeroes, overlappingNonZeroes), debugPrivacyEpsilon); - - NoisedAggregationResult noisedDomainOnlyNoThreshold = - noisedAggregationRunner.noise(domainOutputOnlyZeroes, debugPrivacyEpsilon); - - NoisedAggregationResult noisedDomainNoThreshold = - NoisedAggregationResult.merge(noisedOverlappingNoThreshold, noisedDomainOnlyNoThreshold); - - NoisedAggregationResult noisedReportsOnlyNoThreshold = null; - if (debugRun || domainOptional) { - noisedReportsOnlyNoThreshold = - noisedAggregationRunner.noise( - pseudoDiff.entriesOnlyOnLeft().values(), debugPrivacyEpsilon); - } - - NoisedAggregatedResultSet.Builder noisedResultSetBuilder = NoisedAggregatedResultSet.builder(); - - if (debugRun) { - noisedResultSetBuilder.setNoisedDebugResult( - getAnnotatedDebugResults( - noisedReportsOnlyNoThreshold, - noisedDomainOnlyNoThreshold, - noisedOverlappingNoThreshold)); - } - - if (domainOptional) { - NoisedAggregationResult noisedReportsDomainOptional = - enableThresholding - ? noisedAggregationRunner.threshold( - noisedReportsOnlyNoThreshold.noisedAggregatedFacts(), debugPrivacyEpsilon) - : noisedReportsOnlyNoThreshold; - - return noisedResultSetBuilder - .setNoisedResult( - NoisedAggregationResult.merge(noisedDomainNoThreshold, noisedReportsDomainOptional)) - .build(); - } else { - return noisedResultSetBuilder.setNoisedResult(noisedDomainNoThreshold).build(); - } - } - - /** - * Reads a given shard of the output domain - * - * @param shardLocation the location of the file to read - * @return the contents of the shard as a {@link ImmutableList} - */ - protected abstract ImmutableList readShard(DataLocation shardLocation); - public abstract Stream readInputStream(InputStream shardInputStream); } diff --git a/java/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessor.java b/java/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessor.java index 32538fe4..ac9a7df4 100644 --- a/java/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessor.java +++ b/java/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessor.java @@ -25,17 +25,13 @@ import com.google.aggregate.adtech.worker.exceptions.DomainReadException; import com.google.aggregate.adtech.worker.util.NumericConversions; import com.google.aggregate.perf.StopwatchRegistry; -import com.google.common.base.Stopwatch; import com.google.common.collect.ImmutableList; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient; -import com.google.scp.operator.cpio.blobstorageclient.BlobStorageClient.BlobStorageClientException; -import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; -import java.util.UUID; import java.util.stream.Stream; import javax.inject.Inject; @@ -64,30 +60,6 @@ public TextOutputDomainProcessor( this.stopwatches = stopwatches; } - public ImmutableList readShard(DataLocation outputDomainLocation) { - Stopwatch stopwatch = - stopwatches.createStopwatch(String.format("domain-shard-read-%s", UUID.randomUUID())); - stopwatch.start(); - try { - if (blobStorageClient.getBlobSize(outputDomainLocation) <= 0) { - return ImmutableList.of(); - } - try (InputStream domainStream = blobStorageClient.getBlob(outputDomainLocation)) { - byte[] bytes = ByteStreams.toByteArray(domainStream); - try (Stream fileLines = - NumericConversions.createStringFromByteArray(bytes).lines()) { - ImmutableList shard = - fileLines.map(NumericConversions::createBucketFromString).collect(toImmutableList()); - return shard; - } - } - } catch (IOException | BlobStorageClientException | IllegalArgumentException e) { - throw new DomainReadException(e); - } finally { - stopwatch.stop(); - } - } - public Stream readInputStream(InputStream shardInputStream) { try { byte[] bytes = ByteStreams.toByteArray(shardInputStream); diff --git a/java/com/google/aggregate/adtech/worker/encryption/NoopRecordEncrypter.java b/java/com/google/aggregate/adtech/worker/encryption/NoopRecordEncrypter.java index bf1d724d..1aa47d6a 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/NoopRecordEncrypter.java +++ b/java/com/google/aggregate/adtech/worker/encryption/NoopRecordEncrypter.java @@ -17,7 +17,6 @@ package com.google.aggregate.adtech.worker.encryption; import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.model.Report; import com.google.common.io.ByteSource; public final class NoopRecordEncrypter implements RecordEncrypter { @@ -27,10 +26,4 @@ public EncryptedReport encryptSingleReport( ByteSource report, String sharedInfo, String reportVersion) throws EncryptionException { return null; } - - @Override - public EncryptedReport encryptReport(Report report, String publicKeyUri) - throws EncryptionException { - return null; - } } diff --git a/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypter.java b/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypter.java index 6636c9a6..19fa6c66 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypter.java +++ b/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypter.java @@ -17,7 +17,6 @@ package com.google.aggregate.adtech.worker.encryption; import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.model.Report; import com.google.common.io.ByteSource; /** Interface that does encryption for any provided encryption algorithm. */ @@ -29,16 +28,6 @@ public interface RecordEncrypter { EncryptedReport encryptSingleReport(ByteSource report, String sharedInfo, String reportVersion) throws EncryptionException; - /** - * Encrypts a deserialized Report with keys provided by the publicKeyUri. - * - * @param report - * @param publicKeyUri - * @return EncryptedReport - * @throws EncryptionException - */ - EncryptedReport encryptReport(Report report, String publicKeyUri) throws EncryptionException; - final class EncryptionException extends Exception { public EncryptionException(Throwable cause) { diff --git a/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImpl.java b/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImpl.java index 362f357b..4b591d71 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImpl.java +++ b/java/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImpl.java @@ -21,37 +21,21 @@ import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKey; import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKeyService; import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKeyService.KeyFetchException; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService.ReencryptionKeyFetchException; import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.model.Report; -import com.google.aggregate.adtech.worker.model.serdes.PayloadSerdes; -import com.google.aggregate.adtech.worker.model.serdes.SharedInfoSerdes; import com.google.common.io.ByteSource; import com.google.inject.Inject; -import java.util.Optional; /** {@link RecordEncrypter} implementation. */ public final class RecordEncrypterImpl implements RecordEncrypter { private final EncryptionCipherFactory encryptionCipherFactory; private final EncryptionKeyService encryptionKeyService; - private final ReEncryptionKeyService reEncryptionKeyService; - private final PayloadSerdes payloadSerdes; - private final SharedInfoSerdes sharedInfoSerdes; @Inject public RecordEncrypterImpl( - EncryptionCipherFactory encryptionCipherFactory, - EncryptionKeyService encryptionKeyService, - ReEncryptionKeyService reEncryptionKeyService, - PayloadSerdes payloadSerdes, - SharedInfoSerdes sharedInfoSerdes) { + EncryptionCipherFactory encryptionCipherFactory, EncryptionKeyService encryptionKeyService) { this.encryptionCipherFactory = encryptionCipherFactory; this.encryptionKeyService = encryptionKeyService; - this.reEncryptionKeyService = reEncryptionKeyService; - this.payloadSerdes = payloadSerdes; - this.sharedInfoSerdes = sharedInfoSerdes; } @Override @@ -71,30 +55,4 @@ public EncryptedReport encryptSingleReport( throw new EncryptionException(e); } } - - @Override - public EncryptedReport encryptReport(Report report, String cloudEncryptionKeyVendingUri) - throws EncryptionException { - try { - EncryptionKey encryptionKey = - reEncryptionKeyService.getEncryptionPublicKey(cloudEncryptionKeyVendingUri); - EncryptionCipher encryptionCipher = - encryptionCipherFactory.encryptionCipherFor(encryptionKey.key()); - String sharedInfoString = - sharedInfoSerdes.reverse().convert(Optional.of(report.sharedInfo())); - return EncryptedReport.builder() - .setPayload( - encryptionCipher.encryptReport( - payloadSerdes.reverse().convert(Optional.of(report.payload())), - sharedInfoString, - report.sharedInfo().version())) - .setKeyId(encryptionKey.id()) - .setSharedInfo(sharedInfoString) - .build(); - } catch (CipherCreationException | ReencryptionKeyFetchException e) { - throw new EncryptionException(e); - } catch (PayloadEncryptionException e) { - throw new EncryptionException("Encountered PayloadEncryptionException."); - } - } } diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/BUILD b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/BUILD index 88138c69..f3102177 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/BUILD +++ b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/BUILD @@ -21,7 +21,6 @@ java_library( srcs = [ "EncryptionKey.java", "EncryptionKeyService.java", - "ReEncryptionKeyService.java", ], javacopts = ["-Xep:Var"], deps = [ diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/ReEncryptionKeyService.java b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/ReEncryptionKeyService.java deleted file mode 100644 index 59a726a2..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/ReEncryptionKeyService.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.hybrid.key; - -/** Interface for retrieving public encryption keys from provided public key hosting URI */ -public interface ReEncryptionKeyService { - - /** Retrieve a key from the aggregate service KMS */ - EncryptionKey getEncryptionPublicKey(String keyVendingUri) throws ReencryptionKeyFetchException; - - final class ReencryptionKeyFetchException extends Exception { - public ReencryptionKeyFetchException(Throwable cause) { - super(cause); - } - - public ReencryptionKeyFetchException(String message) { - super(message); - } - } -} diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD index 47f3053f..d32ad8d1 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD +++ b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD @@ -21,12 +21,10 @@ java_library( srcs = [ "CloudEncryptionKeyModule.java", "CloudEncryptionKeyService.java", - "CloudReEncryptionKeyService.java", ], javacopts = ["-Xep:Var"], deps = [ "//java/com/google/aggregate/adtech/worker/encryption/hybrid/key", - "//java/com/google/aggregate/adtech/worker/encryption/publickeyuri:encryption_key_config", "//java/com/google/aggregate/shared/mapper", "//java/external:apache_httpclient", "//java/external:apache_httpcore", diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyService.java b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyService.java deleted file mode 100644 index d9379b52..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyService.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.hybrid.key.cloud; - -import static com.google.aggregate.adtech.worker.encryption.publickeyuri.CloudEncryptionKeyConfig.NUM_ENCRYPTION_KEYS; - -import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKey; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKeyService; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.CacheLoader; -import com.google.common.cache.LoadingCache; -import com.google.common.collect.ImmutableList; -import com.google.common.primitives.Ints; -import com.google.inject.Inject; -import com.google.protobuf.util.JsonFormat; -import com.google.scp.coordinator.protos.keymanagement.keyhosting.api.v1.EncodedPublicKeyProto.EncodedPublicKey; -import com.google.scp.coordinator.protos.keymanagement.keyhosting.api.v1.GetActivePublicKeysResponseProto.GetActivePublicKeysResponse; -import com.google.scp.shared.api.util.HttpClientResponse; -import com.google.scp.shared.api.util.HttpClientWrapper; -import com.google.scp.shared.util.PublicKeyConversionUtil; -import java.net.URI; -import java.security.GeneralSecurityException; -import java.time.Duration; -import java.util.Random; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import org.apache.http.client.config.RequestConfig; -import org.apache.http.client.methods.HttpGet; - -/** - * TODO(b/321088264): Merge CloudReEncryptionKeyService and CloudEncryptionKeyService {@link - * EncryptionKeyService} implementation to get a random encryption key from public key endpoint for - * reencryption. - */ -public final class CloudReEncryptionKeyService implements ReEncryptionKeyService { - - private static final int REQUEST_TIMEOUT_DURATION = - Ints.checkedCast(Duration.ofMinutes(1).toMillis()); - private static final RequestConfig REQUEST_CONFIG = - RequestConfig.custom() - .setConnectionRequestTimeout(REQUEST_TIMEOUT_DURATION) - .setConnectTimeout(REQUEST_TIMEOUT_DURATION) - .setSocketTimeout(REQUEST_TIMEOUT_DURATION) - .build(); - private static final int MAX_CACHE_SIZE = 5; - private static final long CACHE_ENTRY_TTL_SEC = 3600; - private static final Random RANDOM = new Random(); - private final HttpClientWrapper httpClient; - private final LoadingCache> encryptionKeysCache = - CacheBuilder.newBuilder() - .maximumSize(MAX_CACHE_SIZE) - .expireAfterWrite(CACHE_ENTRY_TTL_SEC, TimeUnit.SECONDS) - .concurrencyLevel(Runtime.getRuntime().availableProcessors()) - .build( - new CacheLoader<>() { - @Override - public ImmutableList load(String uri) - throws ReencryptionKeyFetchException { - return getPublicKeysFromService(uri); - } - }); - - @Inject - public CloudReEncryptionKeyService(HttpClientWrapper httpClient) { - this.httpClient = httpClient; - } - - /** Throws ReencryptionKeyFetchException. */ - @Override - public EncryptionKey getEncryptionPublicKey(String keyVendingUri) - throws ReencryptionKeyFetchException { - try { - ImmutableList publicKeys = encryptionKeysCache.get(keyVendingUri); - EncodedPublicKey publicKey = publicKeys.get(randomIndex()); - return EncryptionKey.builder() - .setKey(PublicKeyConversionUtil.getKeysetHandle(publicKey.getKey())) - .setId(publicKey.getId()) - .build(); - } catch (GeneralSecurityException | ExecutionException e) { - throw new ReencryptionKeyFetchException(e); - } - } - - private ImmutableList getPublicKeysFromService(String publicKeyServiceUri) - throws ReencryptionKeyFetchException { - try { - HttpGet request = new HttpGet(URI.create(publicKeyServiceUri)); - request.setConfig(REQUEST_CONFIG); - HttpClientResponse response = httpClient.execute(request); - if (response.statusCode() != 200) { - throw new ReencryptionKeyFetchException(response.responseBody()); - } - GetActivePublicKeysResponse.Builder builder = GetActivePublicKeysResponse.newBuilder(); - JsonFormat.parser().merge(response.responseBody(), builder); - GetActivePublicKeysResponse keys = builder.build(); - return ImmutableList.copyOf(keys.getKeysList()); - } catch (Exception e) { - throw new ReencryptionKeyFetchException(e); - } - } - - private int randomIndex() { - return RANDOM.nextInt(NUM_ENCRYPTION_KEYS); - } -} diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/BUILD b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/BUILD index 6a7c1c6e..01f8199d 100644 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/BUILD +++ b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/BUILD @@ -20,7 +20,6 @@ java_library( name = "testing", srcs = [ "FakeEncryptionKeyService.java", - "FakeReEncryptionKeyService.java", ], javacopts = ["-Xep:Var"], deps = [ diff --git a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/FakeReEncryptionKeyService.java b/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/FakeReEncryptionKeyService.java deleted file mode 100644 index b0294356..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/hybrid/key/testing/FakeReEncryptionKeyService.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.hybrid.key.testing; - -import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKey; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService; -import com.google.crypto.tink.KeysetHandle; -import com.google.inject.Inject; -import java.security.GeneralSecurityException; - -/** Fake implementation of {@link ReEncryptionKeyService} for testing. */ -public final class FakeReEncryptionKeyService implements ReEncryptionKeyService { - - private final KeysetHandle keysetHandle; - - private static final String ENCRYPTION_KEY_ID = "00000000-0000-0000-0000-000000000000"; - - @Inject - FakeReEncryptionKeyService(KeysetHandle keysetHandle) { - this.keysetHandle = keysetHandle; - } - - @Override - public EncryptionKey getEncryptionPublicKey(String keyVendingUri) - throws ReencryptionKeyFetchException { - try { - return EncryptionKey.builder() - .setKey(keysetHandle.getPublicKeysetHandle()) - .setId(ENCRYPTION_KEY_ID) - .build(); - } catch (GeneralSecurityException e) { - throw new ReencryptionKeyFetchException(e); - } - } -} diff --git a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD b/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD deleted file mode 100644 index 11ac4bb3..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@rules_java//java:defs.bzl", "java_library") - -package(default_visibility = ["//visibility:public"]) - -java_library( - name = "encryption_key_config", - srcs = [ - "CloudEncryptionKeyConfig.java", - "EncryptionKeyConfigFactory.java", - ], - javacopts = ["-Xep:Var"], - deps = [ - "//java/external:autovalue", - "//java/external:autovalue_annotations", - "//protocol/proto:encryption_key_config_java_proto", - ], -) diff --git a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/CloudEncryptionKeyConfig.java b/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/CloudEncryptionKeyConfig.java deleted file mode 100644 index fca6b682..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/CloudEncryptionKeyConfig.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.publickeyuri; - -import com.google.auto.value.AutoValue; - -/** Class for configuring encryption key config. */ -@AutoValue -public abstract class CloudEncryptionKeyConfig { - public static Builder builder() { - return new AutoValue_CloudEncryptionKeyConfig.Builder(); - } - - public static final int NUM_ENCRYPTION_KEYS = 5; - - public abstract String keyVendingServiceUri(); - - @AutoValue.Builder - public abstract static class Builder { - public abstract Builder setKeyVendingServiceUri(String value); - - public abstract CloudEncryptionKeyConfig build(); - } -} diff --git a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactory.java b/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactory.java deleted file mode 100644 index e242ec8f..00000000 --- a/java/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactory.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.publickeyuri; - -import java.util.HashMap; -import java.util.Map; - -/** Factory class to get EncryptionKeyConfig based on public key uri cloud provider type. */ -public final class EncryptionKeyConfigFactory { - public static String CLOUD_PROVIDER_NAME_GCP = "GCP"; - private static final Map cloudEncryptionKeyConfigMap = - new HashMap<>(); - - /** - * Returns EncryptionKeyConfigType for the given cloud provider. - * - * @throws IllegalArgumentException when invalid cloud provider name is provided. - */ - public static CloudEncryptionKeyConfig getCloudEncryptionKeyConfig(String cloudProviderName) { - if (cloudProviderName.isEmpty()) { - throw new IllegalArgumentException("Cloud provider name not set."); - } else if (cloudProviderName.equals(CLOUD_PROVIDER_NAME_GCP)) { - cloudEncryptionKeyConfigMap.putIfAbsent( - CLOUD_PROVIDER_NAME_GCP, - CloudEncryptionKeyConfig.builder() - .setKeyVendingServiceUri( - "https://publickeyservice-a.postsb-a.test.aggregationhelper.com/.well-known/aggregation-service/v1/public-keys") - .build()); - return cloudEncryptionKeyConfigMap.get(CLOUD_PROVIDER_NAME_GCP); - } else { - throw new IllegalArgumentException("Invalid cloud provider."); - } - } -} diff --git a/java/com/google/aggregate/adtech/worker/model/ErrorCounter.java b/java/com/google/aggregate/adtech/worker/model/ErrorCounter.java index 37044c8c..5ae0bfe0 100644 --- a/java/com/google/aggregate/adtech/worker/model/ErrorCounter.java +++ b/java/com/google/aggregate/adtech/worker/model/ErrorCounter.java @@ -51,6 +51,11 @@ public enum ErrorCounter { + " days.", SharedInfo.MAX_REPORT_AGE.toDays())), INTERNAL_ERROR("Internal error occurred during operation."), + REPORTING_SITE_MISMATCH( + "Report's shared_info.reporting_origin value does not belong to the reporting_site value set" + + " in the Aggregation job parameters. Aggregation request job parameters must have" + + " reporting_site set to the site which corresponds to the shared_info.reporting_origin" + + " value."), UNSUPPORTED_OPERATION( String.format( "Report's operation is unsupported. Supported operations are %s.", diff --git a/java/com/google/aggregate/adtech/worker/testing/AvroReportsFileReader.java b/java/com/google/aggregate/adtech/worker/testing/AvroReportsFileReader.java deleted file mode 100644 index 6bec0cbb..00000000 --- a/java/com/google/aggregate/adtech/worker/testing/AvroReportsFileReader.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.testing; - -import static com.google.common.collect.ImmutableList.toImmutableList; - -import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.protocol.avro.AvroReportsSchemaSupplier; -import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Optional; -import java.util.stream.Stream; -import javax.inject.Inject; -import org.apache.avro.file.DataFileStream; -import org.apache.avro.generic.GenericDatumReader; -import org.apache.avro.generic.GenericRecord; -import org.apache.avro.io.DatumReader; - -/** Simple utility to read an Avro reports file, used for testing. */ -public final class AvroReportsFileReader { - - private final AvroReportsSchemaSupplier avroReportsSchemaSupplier; - - @Inject - AvroReportsFileReader(AvroReportsSchemaSupplier avroReportsSchemaSupplier) { - this.avroReportsSchemaSupplier = avroReportsSchemaSupplier; - } - - /** Reads the Avro results file at the path given to a list. */ - public ImmutableList readAvroReportsFile(Path path) throws IOException { - DatumReader datumReader = - new GenericDatumReader<>(avroReportsSchemaSupplier.get()); - DataFileStream streamReader = - new DataFileStream<>(Files.newInputStream(path), datumReader); - - return Stream.generate(() -> readRecordToEncryptedReport(streamReader)) - .takeWhile(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()); - } - - private static Optional readRecordToEncryptedReport( - DataFileStream streamReader) { - if (streamReader.hasNext()) { - GenericRecord genericRecord = streamReader.next(); - ByteSource payload = ByteSource.wrap(((ByteBuffer) genericRecord.get("payload")).array()); - String keyId = genericRecord.get("key_id").toString(); - String sharedInfo = genericRecord.get("shared_info").toString(); - return Optional.of( - EncryptedReport.builder() - .setPayload(payload) - .setKeyId(keyId) - .setSharedInfo(sharedInfo) - .build()); - } - - return Optional.empty(); - } -} diff --git a/java/com/google/aggregate/adtech/worker/testing/BUILD b/java/com/google/aggregate/adtech/worker/testing/BUILD index 5d6765f2..17a3d6cc 100644 --- a/java/com/google/aggregate/adtech/worker/testing/BUILD +++ b/java/com/google/aggregate/adtech/worker/testing/BUILD @@ -151,20 +151,6 @@ java_library( ], ) -java_library( - name = "avro_reports_file_reader", - srcs = ["AvroReportsFileReader.java"], - javacopts = ["-Xep:Var"], - deps = [ - "//java/com/google/aggregate/adtech/worker/model", - "//java/com/google/aggregate/adtech/worker/util", - "//java/com/google/aggregate/protocol/avro:avro_reports_schema_supplier", - "//java/external:avro", - "//java/external:guava", - "//java/external:javax_inject", - ], -) - java_library( name = "local_aggregation_worker_runner", srcs = ["LocalAggregationWorkerRunner.java"], diff --git a/java/com/google/aggregate/adtech/worker/testing/InMemoryResultLogger.java b/java/com/google/aggregate/adtech/worker/testing/InMemoryResultLogger.java index 96cc8987..d71569dd 100644 --- a/java/com/google/aggregate/adtech/worker/testing/InMemoryResultLogger.java +++ b/java/com/google/aggregate/adtech/worker/testing/InMemoryResultLogger.java @@ -19,10 +19,8 @@ import com.google.aggregate.adtech.worker.ResultLogger; import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.common.collect.ImmutableList; import com.google.scp.operator.cpio.jobclient.model.Job; -import java.util.Optional; /** * {@link ResultLogger} implementation to materialized and store aggregation results in memory for @@ -32,7 +30,6 @@ public final class InMemoryResultLogger implements ResultLogger { private MaterializedAggregationResults materializedAggregations; private MaterializedAggregationResults materializedDebugAggregations; - private Optional> materializedEncryptedReports = Optional.empty(); private boolean shouldThrow; private volatile boolean hasLogged; @@ -64,17 +61,6 @@ public void logResults(ImmutableList results, Job unused, boolea } } - @Override - public void logReports(ImmutableList reports, Job unused, String shardNumber) - throws ResultLogException { - if (shouldThrow) { - throw new ResultLogException( - new IllegalStateException("Was set to throw while logging reports.")); - } - materializedEncryptedReports = Optional.of(reports); - System.out.println("In memory encrypted reports:" + reports); - } - /** * Gets materialized aggregation results as an ImmutableList of {@link AggregatedFact} * @@ -106,21 +92,6 @@ public MaterializedAggregationResults getMaterializedDebugAggregationResults() return materializedDebugAggregations; } - /** - * Gets materialized encrypted reports as an ImmutableList of {@link EncryptedReport} - * - * @throws ResultLogException if results were not logged prior to calling this method. - */ - public ImmutableList getMaterializedEncryptedReports() - throws ResultLogException { - if (materializedEncryptedReports.isEmpty()) { - throw new ResultLogException( - new IllegalStateException( - "MaterializedEncryptionReports is null. Maybe results did not get logged.")); - } - return materializedEncryptedReports.get(); - } - public void setShouldThrow(boolean shouldThrow) { this.shouldThrow = shouldThrow; } diff --git a/java/com/google/aggregate/adtech/worker/util/JobUtils.java b/java/com/google/aggregate/adtech/worker/util/JobUtils.java index 861ed641..5d123f3b 100644 --- a/java/com/google/aggregate/adtech/worker/util/JobUtils.java +++ b/java/com/google/aggregate/adtech/worker/util/JobUtils.java @@ -32,5 +32,9 @@ public final class JobUtils { public static final String JOB_PARAM_FILTERING_IDS_DELIMITER = ","; + public static final String JOB_PARAM_ATTRIBUTION_REPORT_TO = "attribution_report_to"; + + public static final String JOB_PARAM_REPORTING_SITE = "reporting_site"; + private JobUtils() {} } diff --git a/java/com/google/aggregate/adtech/worker/validation/JobValidator.java b/java/com/google/aggregate/adtech/worker/validation/JobValidator.java index 2759284f..6f188e3a 100644 --- a/java/com/google/aggregate/adtech/worker/validation/JobValidator.java +++ b/java/com/google/aggregate/adtech/worker/validation/JobValidator.java @@ -16,11 +16,13 @@ package com.google.aggregate.adtech.worker.validation; +import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_ATTRIBUTION_REPORT_TO; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_FILTERING_IDS; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_FILTERING_IDS_DELIMITER; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_INPUT_REPORT_COUNT; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_OUTPUT_DOMAIN_BLOB_PREFIX; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_OUTPUT_DOMAIN_BUCKET_NAME; +import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_REPORTING_SITE; import static com.google.aggregate.adtech.worker.util.JobUtils.JOB_PARAM_REPORT_ERROR_THRESHOLD_PERCENTAGE; import static com.google.common.base.Preconditions.checkArgument; import static com.google.scp.operator.shared.model.BackendModelUtil.toJobKeyString; @@ -43,16 +45,7 @@ public final class JobValidator { public static void validate(Optional job, boolean domainOptional) { checkArgument(job.isPresent(), "Job metadata not found."); String jobKey = toJobKeyString(job.get().jobKey()); - checkArgument( - job.get().requestInfo().getJobParametersMap().containsKey("attribution_report_to") - && !job.get() - .requestInfo() - .getJobParametersMap() - .get("attribution_report_to") - .trim() - .isEmpty(), - String.format( - "Job parameters does not have an attribution_report_to field for the Job %s.", jobKey)); + validateReportingOriginAndSite(job.get()); Map jobParams = job.get().requestInfo().getJobParametersMap(); checkArgument( domainOptional @@ -90,6 +83,49 @@ public static void validate(Optional job, boolean domainOptional) { jobKey)); } + /** + * Validates that exactly one of the two fields 'JOB_PARAM_ATTRIBUTION_REPORT_TO' and + * 'reporting_site' is specified and the specified field is non-empty + */ + private static void validateReportingOriginAndSite(Job job) { + Map jobParams = job.requestInfo().getJobParametersMap(); + String jobKey = toJobKeyString(job.jobKey()); + boolean bothSiteAndOriginSpecified = + jobParams.containsKey(JOB_PARAM_ATTRIBUTION_REPORT_TO) + && jobParams.containsKey(JOB_PARAM_REPORTING_SITE); + boolean neitherSiteOrOriginSpecified = + !jobParams.containsKey(JOB_PARAM_ATTRIBUTION_REPORT_TO) + && !jobParams.containsKey(JOB_PARAM_REPORTING_SITE); + if (bothSiteAndOriginSpecified || neitherSiteOrOriginSpecified) { + throw new IllegalArgumentException( + String.format( + "Exactly one of 'attribution_report_to' and 'reporting_site' fields should be" + + " specified for the Job %s. It is recommended to use 'reporting_site'" + + " parameter. Parameter 'attribution_report_to' will be deprecated in the next" + + " major version upgrade of the API", + jobKey)); + } + // Verify that either the field 'JOB_PARAM_ATTRIBUTION_REPORT_TO' is not specified or is + // non-empty. + boolean emptyAttributionReportToSpecified = + jobParams.containsKey(JOB_PARAM_ATTRIBUTION_REPORT_TO) + && jobParams.get(JOB_PARAM_ATTRIBUTION_REPORT_TO).trim().isEmpty(); + checkArgument( + !emptyAttributionReportToSpecified, + String.format( + "The 'attribution_report_to' field in the Job parameters is empty for" + " the Job %s.", + jobKey)); + // Verify that either the field 'reporting_site' is not specified or is non-empty. + boolean emptyReportingSiteSpecified = + jobParams.containsKey(JOB_PARAM_REPORTING_SITE) + && jobParams.get(JOB_PARAM_REPORTING_SITE).trim().isEmpty(); + checkArgument( + !emptyReportingSiteSpecified, + String.format( + "The 'reporting_site' field in the Job parameters is empty for the Job" + " %s.", + jobKey)); + } + /** Checks if the string represents a non-negative number or is empty. */ private static boolean isAValidCount(String countInString) { return countInString == null diff --git a/java/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidator.java b/java/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidator.java index a408c2b3..8c8a8aa1 100644 --- a/java/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidator.java +++ b/java/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidator.java @@ -18,11 +18,20 @@ import static com.google.aggregate.adtech.worker.model.ErrorCounter.ATTRIBUTION_REPORT_TO_MISMATCH; import static com.google.aggregate.adtech.worker.validation.ValidatorHelper.createErrorMessage; +import static com.google.aggregate.adtech.worker.model.ErrorCounter.ATTRIBUTION_REPORT_TO_MALFORMED; +import static com.google.aggregate.adtech.worker.model.ErrorCounter.REPORTING_SITE_MISMATCH; import com.google.aggregate.adtech.worker.model.ErrorMessage; import com.google.aggregate.adtech.worker.model.Report; import com.google.scp.operator.cpio.jobclient.model.Job; import java.util.Optional; +import com.google.aggregate.adtech.worker.util.ReportingOriginUtils; +import com.google.aggregate.adtech.worker.util.ReportingOriginUtils.InvalidReportingOriginException; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; /** * Validates that the report's reportingOrigin is the same as the attributionReportTo provided in @@ -30,14 +39,45 @@ */ public final class ReportingOriginMatchesRequestValidator implements ReportValidator { + private static final int MAX_CACHE_SIZE = 100; + private static final long CACHE_ENTRY_TTL_SEC = 3600; + private final LoadingCache originToSiteMap = + CacheBuilder.newBuilder() + .maximumSize(MAX_CACHE_SIZE) + .expireAfterWrite(CACHE_ENTRY_TTL_SEC, TimeUnit.SECONDS) + .concurrencyLevel(Runtime.getRuntime().availableProcessors()) + .build( + new CacheLoader<>() { + @Override + public String load(final String reportingOrigin) + throws InvalidReportingOriginException { + return ReportingOriginUtils.convertReportingOriginToSite(reportingOrigin); + } + }); + @Override public Optional validate(Report report, Job ctx) { - String attributionReportTo = - ctx.requestInfo().getJobParametersMap().get("attribution_report_to"); - if (report.sharedInfo().reportingOrigin().equals(attributionReportTo)) { - return Optional.empty(); - } + Optional optionalSiteValue = + Optional.ofNullable(ctx.requestInfo().getJobParametersMap().get("reporting_site")); + if (optionalSiteValue.isPresent()) { + try { + String reportingSiteParameterValue = optionalSiteValue.get(); + String siteForReportingOrigin = originToSiteMap.get(report.sharedInfo().reportingOrigin()); + if (!reportingSiteParameterValue.equals(siteForReportingOrigin)) { + return createErrorMessage(REPORTING_SITE_MISMATCH); + } + return Optional.empty(); + } catch (ExecutionException e) { + return createErrorMessage(ATTRIBUTION_REPORT_TO_MALFORMED); + } + } else { + String attributionReportTo = + ctx.requestInfo().getJobParametersMap().get("attribution_report_to"); + if (report.sharedInfo().reportingOrigin().equals(attributionReportTo)) { + return Optional.empty(); + } - return createErrorMessage(ATTRIBUTION_REPORT_TO_MISMATCH); + return createErrorMessage(ATTRIBUTION_REPORT_TO_MISMATCH); + } } } diff --git a/java/com/google/aggregate/adtech/worker/writer/LocalResultFileWriter.java b/java/com/google/aggregate/adtech/worker/writer/LocalResultFileWriter.java index df49580a..3f60fac3 100644 --- a/java/com/google/aggregate/adtech/worker/writer/LocalResultFileWriter.java +++ b/java/com/google/aggregate/adtech/worker/writer/LocalResultFileWriter.java @@ -17,7 +17,6 @@ package com.google.aggregate.adtech.worker.writer; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import java.nio.file.Path; import java.util.stream.Stream; @@ -27,10 +26,6 @@ public interface LocalResultFileWriter { /** Write the file to the local filesystem */ void writeLocalFile(Stream results, Path resultFile) throws FileWriteException; - /** Writes list of encrypted reports to a local file. */ - void writeLocalReportFile(Stream reports, Path resultFilePath) - throws FileWriteException; - /** Returns the file extension for the file type written */ String getFileExtension(); diff --git a/java/com/google/aggregate/adtech/worker/writer/avro/BUILD b/java/com/google/aggregate/adtech/worker/writer/avro/BUILD index 7b2b9095..61119a70 100644 --- a/java/com/google/aggregate/adtech/worker/writer/avro/BUILD +++ b/java/com/google/aggregate/adtech/worker/writer/avro/BUILD @@ -29,8 +29,6 @@ java_library( "//java/com/google/aggregate/adtech/worker/writer", "//java/com/google/aggregate/protocol/avro:avro_debug_results", "//java/com/google/aggregate/protocol/avro:avro_record_writer", - "//java/com/google/aggregate/protocol/avro:avro_report", - "//java/com/google/aggregate/protocol/avro:avro_reports_schema_supplier", "//java/com/google/aggregate/protocol/avro:avro_results_schema_supplier", "//java/external:avro", "//java/external:guava", diff --git a/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroDebugResultFileWriter.java b/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroDebugResultFileWriter.java index bb9e4265..e383606f 100644 --- a/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroDebugResultFileWriter.java +++ b/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroDebugResultFileWriter.java @@ -20,7 +20,6 @@ import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; import com.google.aggregate.protocol.avro.AvroDebugResultsRecord; import com.google.aggregate.protocol.avro.AvroDebugResultsWriter; @@ -66,13 +65,6 @@ public void writeLocalFile(Stream results, Path resultFilePath) } } - @Override - public void writeLocalReportFile(Stream reports, Path resultFilePath) - throws UnsupportedOperationException { - throw new UnsupportedOperationException( - "LocalAvroDebugResultFileWriter cannot write Avro report file."); - } - @Override public String getFileExtension() { return ".avro"; diff --git a/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriter.java b/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriter.java index 9970f78e..0b289e3a 100644 --- a/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriter.java +++ b/java/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriter.java @@ -16,22 +16,14 @@ package com.google.aggregate.adtech.worker.writer.avro; -import static com.google.common.collect.ImmutableList.toImmutableList; import static java.nio.file.StandardOpenOption.CREATE; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.util.NumericConversions; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; -import com.google.aggregate.protocol.avro.AvroRecordWriter.MetadataElement; -import com.google.aggregate.protocol.avro.AvroReportRecord; -import com.google.aggregate.protocol.avro.AvroReportWriter; -import com.google.aggregate.protocol.avro.AvroReportWriterFactory; import com.google.aggregate.protocol.avro.AvroResultsSchemaSupplier; -import com.google.common.collect.ImmutableList; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; @@ -49,13 +41,10 @@ public final class LocalAvroResultFileWriter implements LocalResultFileWriter { private AvroResultsSchemaSupplier schemaSupplier; - private final AvroReportWriterFactory reportsWriterFactory; @Inject - LocalAvroResultFileWriter( - AvroResultsSchemaSupplier schemaSupplier, AvroReportWriterFactory reportsWriterFactory) { + LocalAvroResultFileWriter(AvroResultsSchemaSupplier schemaSupplier) { this.schemaSupplier = schemaSupplier; - this.reportsWriterFactory = reportsWriterFactory; } /** @@ -94,23 +83,6 @@ public void writeLocalFile(Stream results, Path resultFilePath) } } - @Override - public void writeLocalReportFile(Stream reports, Path resultFilePath) - throws FileWriteException { - try (OutputStream outputAvroStream = - Files.newOutputStream(resultFilePath, CREATE, TRUNCATE_EXISTING); - AvroReportWriter avroReportWriter = reportsWriterFactory.create(outputAvroStream)) { - ImmutableList metaData = ImmutableList.of(); - Stream reportsRecords = - reports.map( - (report -> - AvroReportRecord.create(report.payload(), report.keyId(), report.sharedInfo()))); - avroReportWriter.writeRecords(metaData, reportsRecords.collect(toImmutableList())); - } catch (IOException e) { - throw new FileWriteException("Failed to write local Avro report file.", e); - } - } - @Override public String getFileExtension() { return ".avro"; diff --git a/java/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriter.java b/java/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriter.java index 2273f26e..68450d1e 100644 --- a/java/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriter.java +++ b/java/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriter.java @@ -19,13 +19,9 @@ import static java.nio.file.StandardOpenOption.CREATE; import static java.nio.file.StandardOpenOption.TRUNCATE_EXISTING; -import com.fasterxml.jackson.core.JsonGenerator; import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.module.SimpleModule; -import com.fasterxml.jackson.databind.ser.std.StdSerializer; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.util.NumericConversions; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; import com.google.aggregate.protocol.avro.AvroResultsSchemaSupplier; @@ -35,10 +31,7 @@ import java.nio.ByteBuffer; import java.nio.file.Files; import java.nio.file.Path; -import java.nio.file.StandardOpenOption; import java.util.Iterator; -import java.util.List; -import java.util.stream.Collectors; import java.util.stream.Stream; import javax.inject.Inject; import org.apache.avro.Schema; @@ -62,7 +55,6 @@ public final class LocalJsonResultFileWriter implements LocalResultFileWriter { @Inject LocalJsonResultFileWriter(AvroResultsSchemaSupplier schemaSupplier) { this.schemaSupplier = schemaSupplier; - module.addSerializer(EncryptedReport.class, new EncryptedReportSerializer()); mapper.registerModule(module); } @@ -99,19 +91,6 @@ public void writeLocalFile(Stream results, Path resultFilePath) } } - @Override - public void writeLocalReportFile(Stream reports, Path resultFilePath) - throws FileWriteException { - try { - List encryptedReportsList = reports.collect(Collectors.toList()); - String prettyJson = - mapper.writerWithDefaultPrettyPrinter().writeValueAsString(encryptedReportsList); - Files.writeString(resultFilePath, prettyJson, StandardOpenOption.CREATE); - } catch (Exception e) { - throw new FileWriteException("Failed to write reports to local Json file", e); - } - } - @Override public String getFileExtension() { return ".json"; @@ -125,26 +104,4 @@ private GenericRecord aggregatedFactToGenericRecord(AggregatedFact aggregatedFac genericRecord.put("metric", aggregatedFact.getMetric()); return genericRecord; } - - private static class EncryptedReportSerializer extends StdSerializer { - - EncryptedReportSerializer() { - super(EncryptedReport.class); - } - - EncryptedReportSerializer(Class t) { - super(t); - } - - @Override - public void serialize( - EncryptedReport encryptedReport, JsonGenerator jgen, SerializerProvider serializerProvider) - throws IOException { - jgen.writeStartObject(); - jgen.writeStringField("key_id", encryptedReport.keyId()); - jgen.writeBinaryField("payload", encryptedReport.payload().read()); - jgen.writeStringField("shared_info", encryptedReport.sharedInfo()); - jgen.writeEndObject(); - } - } } diff --git a/javatests/com/google/aggregate/adtech/worker/AwsOTelTest.java b/javatests/com/google/aggregate/adtech/worker/AwsOTelTest.java index 3676e34b..a6092857 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsOTelTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsOTelTest.java @@ -116,7 +116,7 @@ public void createJobE2ETest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( testDataBucket, inputKey, testDataBucket, diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerAutoScalingTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerAutoScalingTest.java index 46980b2e..231d91bd 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerAutoScalingTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerAutoScalingTest.java @@ -17,7 +17,6 @@ package com.google.aggregate.adtech.worker; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.KOKORO_BUILD_ID; -import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.createJobRequest; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.submitJob; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.waitForJobCompletions; @@ -110,7 +109,7 @@ private CreateJobRequest createE2EJob(Integer jobCount) throws Exception { String.format( "%s/test-outputs/%s/%s", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID, outputFile); CreateJobRequest createJobRequest = - createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), INPUT_DATA_PATH, getTestDataBucket(), diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousDiffTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousDiffTest.java index 8203339a..c2bec4f3 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousDiffTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousDiffTest.java @@ -18,7 +18,6 @@ import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.AWS_S3_BUCKET_REGION; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.KOKORO_BUILD_ID; -import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.createJobRequest; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.getOutputFileName; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.readResultsFromS3; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.submitJobAndWaitForResult; @@ -111,7 +110,7 @@ public void e2eDiffTest() throws Exception { String goldenLocation = "testdata/golden/2022_10_18/10k_diff_test.avro.golden"; CreateJobRequest createJobRequest = - createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -165,9 +164,7 @@ protected void configure() { .httpClient(UrlConnectionHttpClient.builder().build()) .build()); bind(S3AsyncClient.class) - .toInstance( - S3AsyncClient.builder() - .region(AWS_S3_BUCKET_REGION).build()); + .toInstance(S3AsyncClient.builder().region(AWS_S3_BUCKET_REGION).build()); bind(Boolean.class).annotatedWith(S3UsePartialRequests.class).toInstance(false); bind(Integer.class).annotatedWith(PartialRequestBufferSize.class).toInstance(20); } diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousInvalidCredentialsTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousInvalidCredentialsTest.java index c6396df6..b1744d02 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousInvalidCredentialsTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousInvalidCredentialsTest.java @@ -18,7 +18,6 @@ import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.AWS_S3_BUCKET_REGION; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.KOKORO_BUILD_ID; -import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.createJobRequest; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.submitJob; import static com.google.common.truth.Truth.assertThat; import static com.google.scp.operator.protos.frontend.api.v1.ReturnCodeProto.ReturnCode.RETRIES_EXHAUSTED; @@ -82,7 +81,7 @@ public void e2ePerfTest() throws Exception { // TODO(b/228085828): Modify e2e tests to use output domain CreateJobRequest createJobRequest = - createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( TESTING_BUCKET, INPUT_DATA_PATH, TESTING_BUCKET, diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousOutOfMemoryTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousOutOfMemoryTest.java index 639f495b..1cea1170 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousOutOfMemoryTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousOutOfMemoryTest.java @@ -95,7 +95,7 @@ public void createJobE2ETest() throws Exception { "%s/%s/test-outputs/OOM_test_output.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest1 = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -119,7 +119,7 @@ public void createJobE2ETest() throws Exception { "%s/%s/test-outputs/OOM_test_output.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest2 = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -156,9 +156,7 @@ protected void configure() { .httpClient(UrlConnectionHttpClient.builder().build()) .build()); bind(S3AsyncClient.class) - .toInstance( - S3AsyncClient.builder() - .region(AWS_S3_BUCKET_REGION).build()); + .toInstance(S3AsyncClient.builder().region(AWS_S3_BUCKET_REGION).build()); bind(Boolean.class).annotatedWith(S3UsePartialRequests.class).toInstance(false); bind(Integer.class).annotatedWith(PartialRequestBufferSize.class).toInstance(20); } diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousPerfTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousPerfTest.java index b8dc6d3b..49d92ac7 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousPerfTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousPerfTest.java @@ -18,7 +18,6 @@ import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.AWS_S3_BUCKET_REGION; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.KOKORO_BUILD_ID; -import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.createJobRequest; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.getAndWriteStopwatchesFromS3; import static com.google.aggregate.adtech.worker.AwsWorkerContinuousTestHelper.submitJobAndWaitForResult; import static com.google.common.truth.Truth.assertThat; @@ -48,10 +47,8 @@ @RunWith(JUnit4.class) public class AwsWorkerContinuousPerfTest { - @Rule - public final Acai acai = new Acai(TestEnv.class); - @Rule - public final TestName name = new TestName(); + @Rule public final Acai acai = new Acai(TestEnv.class); + @Rule public final TestName name = new TestName(); private static final Duration completionTimeout = Duration.of(60, ChronoUnit.MINUTES); private static final String TESTING_BUCKET = "aggregation-service-testing"; @@ -88,8 +85,7 @@ public class AwsWorkerContinuousPerfTest { private static final String OUTPUT_DOMAIN_PREFIX = "testdata/1m_staging_2022_08_08_sharded_domain/shard"; - @Inject - S3BlobStorageClient s3BlobStorageClient; + @Inject S3BlobStorageClient s3BlobStorageClient; @Test public void e2ePerfTest() throws Exception { @@ -104,7 +100,7 @@ public void e2ePerfTest() throws Exception { "e2e_test_outputs/%s/%s", KOKORO_BUILD_ID, "createJobE2EperfTest-reports1m.avro"); CreateJobRequest createJobRequest = - createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( TESTING_BUCKET, INPUT_REPORTS_PREFIX, TESTING_BUCKET, @@ -138,9 +134,7 @@ protected void configure() { .httpClient(UrlConnectionHttpClient.builder().build()) .build()); bind(S3AsyncClient.class) - .toInstance( - S3AsyncClient.builder() - .region(AWS_S3_BUCKET_REGION).build()); + .toInstance(S3AsyncClient.builder().region(AWS_S3_BUCKET_REGION).build()); bind(Boolean.class).annotatedWith(S3UsePartialRequests.class).toInstance(false); bind(Integer.class).annotatedWith(PartialRequestBufferSize.class).toInstance(20); } diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTest.java index dbd5c46b..ef114c29 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTest.java @@ -130,7 +130,7 @@ public void createJobE2ETest() throws Exception { "%s/%s/test-outputs/10k_test_output.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -154,34 +154,39 @@ public void createJobE2ETest() throws Exception { } /** - * This test includes sending a non-debug job and aggregatable reports with debug mode enabled. + * This test includes sending a job with reporting site only. Verifies that jobs with only + * reporting site are successful. */ @Test - public void createNotDebugJobE2EReportDebugEnabledTest() throws Exception { + public void createJobE2ETestWithReportingSite() throws Exception { var inputKey = String.format( - "%s/%s/test-inputs/10k_attribution_report_test_input_debug.avro", + "%s/%s/test-inputs/10k_test_input_reporting_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var domainKey = String.format( - "%s/%s/test-inputs/10k_attribution_report_test_domain_debug.avro", + "%s/%s/test-inputs/10k_test_domain_reporting_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var outputKey = String.format( - "%s/%s/test-outputs/10k_attribution_report_test_output_notDebugJob_debugEnabled.avro", + "%s/%s/test-outputs/10k_test_output_reporting_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithReportingSite( getTestDataBucket(), inputKey, getTestDataBucket(), outputKey, - /* debugRun= */ false, /* jobId= */ getClass().getSimpleName() + "::" + name.getMethodName(), /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), /* outputDomainPrefix= */ Optional.of(domainKey)); - assertResponseForCode(createJobRequest, AggregationWorkerReturnCode.SUCCESS); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + + assertThat(result.get("result_info").get("return_code").asText()) + .isEqualTo(AggregationWorkerReturnCode.SUCCESS.name()); + assertThat(result.get("result_info").get("error_summary").get("error_counts").isEmpty()) + .isTrue(); // Read output avro from s3. ImmutableList aggregatedFacts = @@ -191,44 +196,41 @@ public void createNotDebugJobE2EReportDebugEnabledTest() throws Exception { getTestDataBucket(), getOutputFileName(outputKey)); - // If the domainOptional is true, the aggregatedFact keys would be more than domain keys - // Otherwise, aggregatedFact keys would be equal to domain keys - // The "isAtLeast" assert is set here to accommodate both conditions. - assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); - // The debug file shouldn't exist because it's not debug run - assertThat( - AwsWorkerContinuousTestHelper.checkS3FileExists( - s3BlobStorageClient, getTestDataBucket(), getDebugFilePrefix(outputKey))) - .isFalse(); + assertThat(aggregatedFacts.size()).isGreaterThan(10); } - /** This test includes sending a debug job and aggregatable reports with debug mode enabled. */ + /** + * This test includes sending a job with reports from multiple reporting origins belonging to the + * same reporting site. Verifies that all the reports are processed successfully. + */ @Test - public void createDebugJobE2EReportDebugModeEnabledTest() throws Exception { + public void createJobE2ETestWithMultipleReportingOrigins() throws Exception { var inputKey = - String.format( - "%s/%s/test-inputs/10k_attribution_report_test_input_debug_enabled_nondebug_run.avro", - TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + String.format("%s/%s/test-inputs/same-site/", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var domainKey = String.format( - "%s/%s/test-inputs/10k_attribution_report_test_domain_debug_enabled_nondebug_run.avro", + "%s/%s/test-inputs/10k_test_domain_multiple_origins_same_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var outputKey = String.format( - "%s/%s/test-outputs/10k_attribution_report_test_output_DebugJob_debugEnabled.avro", + "%s/%s/test-outputs/10k_test_output_multiple_origins_same_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithReportingSite( getTestDataBucket(), inputKey, getTestDataBucket(), outputKey, - /* debugRun= */ true, /* jobId= */ getClass().getSimpleName() + "::" + name.getMethodName(), /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), /* outputDomainPrefix= */ Optional.of(domainKey)); - assertResponseForCode(createJobRequest, AggregationWorkerReturnCode.SUCCESS); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + + assertThat(result.get("result_info").get("return_code").asText()) + .isEqualTo(AggregationWorkerReturnCode.SUCCESS.name()); + assertThat(result.get("result_info").get("error_summary").get("error_counts").isEmpty()) + .isTrue(); // Read output avro from s3. ImmutableList aggregatedFacts = @@ -238,46 +240,34 @@ public void createDebugJobE2EReportDebugModeEnabledTest() throws Exception { getTestDataBucket(), getOutputFileName(outputKey)); - // The "isAtLeast" assert is set here to accommodate domainOptional(True/False) conditions. - assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); - - // Read debug results avro from s3. - ImmutableList aggregatedDebugFacts = - readDebugResultsFromS3( - s3BlobStorageClient, - readerFactory, - getTestDataBucket(), - getOutputFileName(getDebugFilePrefix(outputKey))); - - // Debug facts count should be greater than or equal to the summary facts count because some - // keys are filtered out due to thresholding or not in domain. - assertThat(aggregatedDebugFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); + assertThat(aggregatedFacts.size()).isGreaterThan(10); } /** - * This test includes sending a debug job and aggregatable reports with debug mode disabled. Uses - * the same data as the normal e2e test. + * This test includes sending a job with reports from multiple reporting origins belonging to + * different reporting sites. It is expected that the 5k reports with a different reporting site + * will fail and come up in the error counts. */ @Test - public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { + public void createJobE2ETestWithSomeReportsHavingDifferentReportingOrigins() throws Exception { var inputKey = String.format( - "%s/%s/test-inputs/10k_test_input_2.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + "%s/%s/test-inputs/different-site/", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var domainKey = String.format( - "%s/%s/test-inputs/10k_test_domain_2.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + "%s/%s/test-inputs/10k_test_domain_multiple_origins_different_site.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); var outputKey = String.format( - "%s/%s/test-outputs/10k_test_output_DebugJob_debugDisabled.avro", + "%s/%s/test-outputs/10k_test_output_multiple_origins_different_site.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithReportingSite( getTestDataBucket(), inputKey, getTestDataBucket(), outputKey, - /* debugRun= */ true, /* jobId= */ getClass().getSimpleName() + "::" + name.getMethodName(), /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), /* outputDomainPrefix= */ Optional.of(domainKey)); @@ -293,7 +283,7 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { .get(0) .get("count") .asInt()) - .isEqualTo(10000); + .isEqualTo(5000); assertThat( result .get("result_info") @@ -302,16 +292,95 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { .get(0) .get("category") .asText()) - .isEqualTo(ErrorCounter.DEBUG_NOT_ENABLED.name()); + .isEqualTo(ErrorCounter.REPORTING_SITE_MISMATCH.name()); + + // Read output avro from s3. + ImmutableList aggregatedFacts = + readResultsFromS3( + s3BlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + getOutputFileName(outputKey)); + + assertThat(aggregatedFacts.size()).isGreaterThan(10); + } + + /** + * This test includes sending a non-debug job and aggregatable reports with debug mode enabled. + */ + @Test + public void createNotDebugJobE2EReportDebugEnabledTest() throws Exception { + var inputKey = + String.format( + "%s/%s/test-inputs/10k_attribution_report_test_input_debug.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + var domainKey = + String.format( + "%s/%s/test-inputs/10k_attribution_report_test_domain_debug.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + var outputKey = + String.format( + "%s/%s/test-outputs/10k_attribution_report_test_output_notDebugJob_debugEnabled.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + + CreateJobRequest createJobRequest = + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( + getTestDataBucket(), + inputKey, + getTestDataBucket(), + outputKey, + /* debugRun= */ false, + /* jobId= */ getClass().getSimpleName() + "::" + name.getMethodName(), + /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), + /* outputDomainPrefix= */ Optional.of(domainKey)); + assertResponseForCode(createJobRequest, AggregationWorkerReturnCode.SUCCESS); + + // Read output avro from s3. + ImmutableList aggregatedFacts = + readResultsFromS3( + s3BlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + getOutputFileName(outputKey)); + + // If the domainOptional is true, the aggregatedFact keys would be more than domain keys + // Otherwise, aggregatedFact keys would be equal to domain keys + // The "isAtLeast" assert is set here to accommodate both conditions. + assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); + // The debug file shouldn't exist because it's not debug run assertThat( - result - .get("result_info") - .get("error_summary") - .get("error_counts") - .get(0) - .get("description") - .asText()) - .isEqualTo(ErrorCounter.DEBUG_NOT_ENABLED.getDescription()); + AwsWorkerContinuousTestHelper.checkS3FileExists( + s3BlobStorageClient, getTestDataBucket(), getDebugFilePrefix(outputKey))) + .isFalse(); + } + + /** This test includes sending a debug job and aggregatable reports with debug mode enabled. */ + @Test + public void createDebugJobE2EReportDebugModeEnabledTest() throws Exception { + var inputKey = + String.format( + "%s/%s/test-inputs/10k_attribution_report_test_input_debug_enabled_nondebug_run.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + var domainKey = + String.format( + "%s/%s/test-inputs/10k_attribution_report_test_domain_debug_enabled_nondebug_run.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + var outputKey = + String.format( + "%s/%s/test-outputs/10k_attribution_report_test_output_DebugJob_debugEnabled.avro", + TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); + + CreateJobRequest createJobRequest = + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( + getTestDataBucket(), + inputKey, + getTestDataBucket(), + outputKey, + /* debugRun= */ true, + /* jobId= */ getClass().getSimpleName() + "::" + name.getMethodName(), + /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), + /* outputDomainPrefix= */ Optional.of(domainKey)); + assertResponseForCode(createJobRequest, AggregationWorkerReturnCode.SUCCESS); // Read output avro from s3. ImmutableList aggregatedFacts = @@ -321,9 +390,10 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { getTestDataBucket(), getOutputFileName(outputKey)); - assertThat(aggregatedFacts.size()).isEqualTo(DEBUG_DOMAIN_KEY_SIZE); + // The "isAtLeast" assert is set here to accommodate domainOptional(True/False) conditions. + assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); - // Read debug result from s3. + // Read debug results avro from s3. ImmutableList aggregatedDebugFacts = readDebugResultsFromS3( s3BlobStorageClient, @@ -331,12 +401,9 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { getTestDataBucket(), getOutputFileName(getDebugFilePrefix(outputKey))); - // Only contains keys in domain because all reports are filtered out. - assertThat(aggregatedDebugFacts.size()).isEqualTo(DEBUG_DOMAIN_KEY_SIZE); - // The unnoisedMetric of aggregatedDebugFacts should be 0 for all keys because - // all reports are filtered out. - // Noised metric in both debug reports and summary reports should be noise value instead of 0. - aggregatedDebugFacts.forEach(fact -> assertThat(fact.getUnnoisedMetric().get()).isEqualTo(0)); + // Debug facts count should be greater than or equal to the summary facts count because some + // keys are filtered out due to thresholding or not in domain. + assertThat(aggregatedDebugFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); } @Test @@ -354,7 +421,7 @@ public void aggregate_withDebugReportsInNonDebugMode_errorsExceedsThreshold_quit TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -426,7 +493,7 @@ public void createJobE2EAggregateReportingDebugTest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -469,7 +536,7 @@ public void createJobE2ETestPrivacyBudgetExhausted() throws Exception { "%s/%s/test-outputs/10k_test_output.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest1 = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -520,7 +587,7 @@ public void createJobE2EFledgeTest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -569,7 +636,7 @@ public void createJobE2ESharedStorageTest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -607,7 +674,7 @@ public void createDebugJobE2ETestPrivacyBudgetExhausted() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest1 = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -656,7 +723,7 @@ public void createJobE2ETestWithMultiOutputShard() throws Exception { "%s/%s/test-outputs/30k_test_output.avro", TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -704,7 +771,7 @@ public void createJobE2ETestWithInvalidReports() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -716,9 +783,8 @@ public void createJobE2ETestWithInvalidReports() throws Exception { // The job should be completed before the completion timeout. assertThat(result.get("job_status").asText()).isEqualTo("FINISHED"); - // The threshold is 100%, so we get SUCCESS_WITH_ERRORS. assertThat(result.get("result_info").get("return_code").asText()) - .isEqualTo(AggregationWorkerReturnCode.SUCCESS_WITH_ERRORS.name()); + .isEqualTo(AggregationWorkerReturnCode.REPORTS_WITH_ERRORS_EXCEEDED_THRESHOLD.name()); } @Test @@ -747,7 +813,7 @@ public void createJob_withFilteringId() throws Exception { @Var Set filteringIds = ImmutableSet.of(); @Var CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -776,7 +842,7 @@ public void createJob_withFilteringId() throws Exception { filteringIds = ImmutableSet.of(UnsignedLong.valueOf("18446744073709551615"), UnsignedLong.valueOf(65536)); createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -804,7 +870,7 @@ public void createJob_withFilteringId() throws Exception { filteringIds = ImmutableSet.of(UnsignedLong.valueOf(5), UnsignedLong.ZERO); createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTestChromeReports.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTestChromeReports.java index 9cbdaf28..35e7c764 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTestChromeReports.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousSmokeTestChromeReports.java @@ -100,7 +100,7 @@ public void createJobE2ETest() throws Exception { // Create the job and wait for the result CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( inputBucket, inputKey, outputBucket, @@ -148,7 +148,7 @@ public void createDebugJobE2ETest() throws Exception { // Create the job and wait for the result CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( inputBucket, inputKey, outputBucket, @@ -212,9 +212,7 @@ protected void configure() { .httpClient(UrlConnectionHttpClient.builder().build()) .build()); bind(S3AsyncClient.class) - .toInstance( - S3AsyncClient.builder() - .region(AWS_S3_BUCKET_REGION).build()); + .toInstance(S3AsyncClient.builder().region(AWS_S3_BUCKET_REGION).build()); bind(Boolean.class).annotatedWith(S3UsePartialRequests.class).toInstance(false); bind(Integer.class).annotatedWith(PartialRequestBufferSize.class).toInstance(20); } diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousTestHelper.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousTestHelper.java index e03b1c1f..44b8a004 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousTestHelper.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerContinuousTestHelper.java @@ -77,13 +77,15 @@ public class AwsWorkerContinuousTestHelper { public static final Region AWS_API_GATEWAY_REGION = Region.US_EAST_1; public static final Region AWS_S3_BUCKET_REGION = Region.US_EAST_1; + + public static final String DEFAULT_ATTRIBUTION_REPORT_TO = "https://subdomain.fakeurl.com"; + public static final String DEFAULT_REPORTING_SITE = "https://fakeurl.com"; + + public static final String ENV_ATTRIBUTION_REPORT_TO = System.getenv("ATTRIBUTION_REPORT_TO"); + public static final String ENV_REPORTING_SITE = System.getenv("REPORTING_SITE"); public static final String FRONTEND_API = System.getenv("FRONTEND_API"); public static final String KOKORO_BUILD_ID = System.getenv("KOKORO_BUILD_ID"); - // The attribution_report_to in job params should be configurable because this needs to match - // allowed_principals_map in coordinator setting which would be different in different test - // environments. - public static final String ENV_ATTRIBUTION_REPORT_TO = System.getenv("ATTRIBUTION_REPORT_TO"); - public static final String DEFAULT_ATTRIBUTION_REPORT_TO = "https://foo.com"; + public static final String CREATE_JOB_URI_PATTERN = "https://%s.execute-api.us-east-1.amazonaws.com/%s/%s/createJob"; public static final String GET_JOB_URI_PATTERN = @@ -114,6 +116,13 @@ private static String getAttributionReportTo() { return DEFAULT_ATTRIBUTION_REPORT_TO; } + private static String getReportingSite() { + if (ENV_REPORTING_SITE != null) { + return ENV_REPORTING_SITE; + } + return DEFAULT_REPORTING_SITE; + } + /** Helper for extracting a bucket name from an S3 URI. */ public static String getS3Bucket(String s3Uri) { return parseS3Uri(s3Uri).group("bucket"); @@ -137,7 +146,7 @@ public static String getOutputFileName(String outputKey, int shardId, int numSha : outputKey + outputSuffix; } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, @@ -152,17 +161,42 @@ public static CreateJobRequest createJobRequest( outputDataBlobPrefix, jobId) .putAllJobParameters( - getJobParams( + getJobParamsWithAttributionReportTo( false, outputDomainBucketName, outputDomainPrefix, - /* reportErrorThresholdPercentage= */ 100, + /* reportErrorThresholdPercentage= */ 0, /* inputReportCount= */ Optional.empty(), /* filteringIds= */ Optional.empty())) .build(); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithReportingSite( + String inputDataBlobBucket, + String inputDataBlobPrefix, + String outputDataBlobBucket, + String outputDataBlobPrefix, + String jobId, + Optional outputDomainBucketName, + Optional outputDomainPrefix) { + ImmutableMap jobParams = + getJobParamsWithReportingSite( + false, + outputDomainBucketName, + outputDomainPrefix, + /* reportErrorThresholdPercentage= */ 100, + /* inputReportCount= */ Optional.empty()); + return createDefaultJobRequestBuilder( + inputDataBlobBucket, + inputDataBlobPrefix, + outputDataBlobBucket, + outputDataBlobPrefix, + jobId) + .putAllJobParameters(jobParams) + .build(); + } + + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, @@ -178,17 +212,17 @@ public static CreateJobRequest createJobRequest( outputDataBlobPrefix, jobId) .putAllJobParameters( - getJobParams( + getJobParamsWithAttributionReportTo( debugRun, outputDomainBucketName, outputDomainPrefix, - /* reportErrorThresholdPercentage= */ 100, + /* reportErrorThresholdPercentage= */ 0, /* inputReportCount= */ Optional.empty(), /* filteringIds= */ Optional.empty())) .build(); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, @@ -207,7 +241,7 @@ public static CreateJobRequest createJobRequest( outputDataBlobPrefix, jobId) .putAllJobParameters( - getJobParams( + getJobParamsWithAttributionReportTo( debugRun, outputDomainBucketName, outputDomainPrefix, @@ -217,7 +251,7 @@ public static CreateJobRequest createJobRequest( .build(); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, @@ -235,7 +269,7 @@ public static CreateJobRequest createJobRequest( outputDataBlobPrefix, jobId) .putAllJobParameters( - getJobParams( + getJobParamsWithAttributionReportTo( debugRun, outputDomainBucketName, outputDomainPrefix, @@ -268,7 +302,7 @@ private static CreateJobRequest.Builder createDefaultJobRequestBuilder( .putAllJobParameters(ImmutableMap.of()); } - private static ImmutableMap getJobParams( + private static ImmutableMap getJobParamsWithAttributionReportTo( Boolean debugRun, Optional outputDomainBucketName, Optional outputDomainPrefix, @@ -308,6 +342,34 @@ private static ImmutableMap getJobParams( return jobParams.build(); } + private static ImmutableMap getJobParamsWithReportingSite( + Boolean debugRun, + Optional outputDomainBucketName, + Optional outputDomainPrefix, + int reportErrorThresholdPercentage, + Optional inputReportCountOptional) { + ImmutableMap.Builder jobParams = ImmutableMap.builder(); + jobParams.put("reporting_site", getReportingSite()); + if (debugRun) { + jobParams.put("debug_run", "true"); + } + inputReportCountOptional.ifPresent( + inputReportCount -> + jobParams.put(JobUtils.JOB_PARAM_INPUT_REPORT_COUNT, String.valueOf(inputReportCount))); + jobParams.put( + "report_error_threshold_percentage", String.valueOf(reportErrorThresholdPercentage)); + if (outputDomainPrefix.isPresent() && outputDomainBucketName.isPresent()) { + jobParams.put("output_domain_blob_prefix", outputDomainPrefix.get()); + jobParams.put("output_domain_bucket_name", outputDomainBucketName.get()); + return jobParams.build(); + } else if (outputDomainPrefix.isEmpty() && outputDomainBucketName.isEmpty()) { + return jobParams.build(); + } else { + throw new IllegalStateException( + "outputDomainPrefix and outputDomainBucketName must both be provided or both be empty."); + } + } + public static JsonNode submitJobAndWaitForResult( CreateJobRequest createJobRequest, Duration timeout) throws IOException, InterruptedException { diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerPerformanceRegressionTest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerPerformanceRegressionTest.java index 6b997b75..6b477aec 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerPerformanceRegressionTest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerPerformanceRegressionTest.java @@ -98,7 +98,7 @@ public void aggregateARA500kTransient() throws Exception { "test-data/%s/test-outputs/500k_report_%s_500k_domain_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -141,7 +141,7 @@ public void aggregateARA500kReports500kDomainWarmup() throws Exception { "test-data/%s/test-outputs/500k_report_%s_500k_domain_warmup_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -194,7 +194,7 @@ public void aggregateARA500kReports500kDomainTransient() throws Exception { "test-data/%s/test-outputs/500k_report_%s_500k_domain_transient_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), diff --git a/javatests/com/google/aggregate/adtech/worker/AwsWorkerPrivateAggregationAPITest.java b/javatests/com/google/aggregate/adtech/worker/AwsWorkerPrivateAggregationAPITest.java index 9169e9a7..8aa6fdff 100644 --- a/javatests/com/google/aggregate/adtech/worker/AwsWorkerPrivateAggregationAPITest.java +++ b/javatests/com/google/aggregate/adtech/worker/AwsWorkerPrivateAggregationAPITest.java @@ -53,10 +53,8 @@ @RunWith(JUnit4.class) public class AwsWorkerPrivateAggregationAPITest { - @Rule - public final Acai acai = new Acai(TestEnv.class); - @Rule - public final TestName name = new TestName(); + @Rule public final Acai acai = new Acai(TestEnv.class); + @Rule public final TestName name = new TestName(); private static final Duration COMPLETION_TIMEOUT = Duration.of(10, ChronoUnit.MINUTES); @@ -64,10 +62,8 @@ public class AwsWorkerPrivateAggregationAPITest { private static final String TEST_DATA_S3_KEY_PREFIX = "generated-test-data"; - @Inject - S3BlobStorageClient s3BlobStorageClient; - @Inject - AvroResultsFileReader avroResultsFileReader; + @Inject S3BlobStorageClient s3BlobStorageClient; + @Inject AvroResultsFileReader avroResultsFileReader; private static String getTestDataBucket() { if (System.getenv("TEST_DATA_BUCKET") != null) { @@ -111,7 +107,7 @@ public void createJobE2EProtectedAudienceTest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -123,7 +119,7 @@ public void createJobE2EProtectedAudienceTest() throws Exception { /* Debug job */ CreateJobRequest createDebugJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKeyDebug, getTestDataBucket(), @@ -204,7 +200,7 @@ public void createJobE2ESharedStorageTest() throws Exception { TEST_DATA_S3_KEY_PREFIX, KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -216,7 +212,7 @@ public void createJobE2ESharedStorageTest() throws Exception { /* Debug job */ CreateJobRequest createDebugJobRequest = - AwsWorkerContinuousTestHelper.createJobRequest( + AwsWorkerContinuousTestHelper.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKeyDebug, getTestDataBucket(), @@ -273,9 +269,7 @@ protected void configure() { .httpClient(UrlConnectionHttpClient.builder().build()) .build()); bind(S3AsyncClient.class) - .toInstance( - S3AsyncClient.builder() - .region(AWS_S3_BUCKET_REGION).build()); + .toInstance(S3AsyncClient.builder().region(AWS_S3_BUCKET_REGION).build()); bind(Boolean.class).annotatedWith(S3UsePartialRequests.class).toInstance(false); bind(Integer.class).annotatedWith(PartialRequestBufferSize.class).toInstance(20); } diff --git a/javatests/com/google/aggregate/adtech/worker/BUILD b/javatests/com/google/aggregate/adtech/worker/BUILD index d866af3e..1898e2d4 100644 --- a/javatests/com/google/aggregate/adtech/worker/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/BUILD @@ -448,14 +448,12 @@ java_test( "//java/com/google/aggregate/adtech/worker", "//java/com/google/aggregate/adtech/worker/exceptions", "//java/com/google/aggregate/adtech/worker/model", - "//java/com/google/aggregate/adtech/worker/testing:avro_reports_file_reader", "//java/com/google/aggregate/adtech/worker/testing:avro_results_file_reader", "//java/com/google/aggregate/adtech/worker/util", "//java/com/google/aggregate/adtech/worker/writer", "//java/com/google/aggregate/adtech/worker/writer/avro", "//java/com/google/aggregate/protocol/avro:avro_debug_results", "//java/com/google/aggregate/protocol/avro:avro_debug_results_schema_supplier", - "//java/com/google/aggregate/protocol/avro:avro_report", "//java/com/google/aggregate/protocol/avro:avro_results_schema_supplier", "//java/external:acai", "//java/external:clients_blobstorageclient_aws", @@ -570,6 +568,7 @@ java_library( deps = [ "//java/com/google/aggregate/adtech/worker/model", "//java/com/google/aggregate/adtech/worker/testing:avro_results_file_reader", + "//java/com/google/aggregate/adtech/worker/util", "//java/com/google/aggregate/protocol/avro:avro_debug_results", "//java/com/google/aggregate/protocol/avro:avro_debug_results_schema_supplier", "//java/com/google/aggregate/protocol/avro:avro_results_schema_supplier", diff --git a/javatests/com/google/aggregate/adtech/worker/GcpOTelTest.java b/javatests/com/google/aggregate/adtech/worker/GcpOTelTest.java index 70246117..d3082a47 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpOTelTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpOTelTest.java @@ -106,7 +106,7 @@ public void createJobE2ETest() throws Exception { String.format("%s/test-outputs/otel_test.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerAutoScalingTest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerAutoScalingTest.java index 902eb0ac..b7d1eac7 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerAutoScalingTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerAutoScalingTest.java @@ -49,18 +49,16 @@ @RunWith(JUnit4.class) public class GcpWorkerAutoScalingTest { - @Rule - public final Acai acai = new Acai(TestEnv.class); + @Rule public final Acai acai = new Acai(TestEnv.class); private static final Duration SUBMIT_JOB_TIMEOUT = Duration.of(1, ChronoUnit.SECONDS); - private static final Duration SCALE_ACTION_COMPLETION_TIMEOUT = Duration.of(20, - ChronoUnit.MINUTES); + private static final Duration SCALE_ACTION_COMPLETION_TIMEOUT = + Duration.of(20, ChronoUnit.MINUTES); private static final Duration COMPLETION_TIMEOUT = Duration.of(15, ChronoUnit.MINUTES); private static final Integer MIN_INSTANCES = 1; public static final int CONCURRENT_JOBS = 5; - @Inject - InstancesClient gcpInstancesClient; + @Inject InstancesClient gcpInstancesClient; @Test public void autoscalingE2ETest() throws Exception { @@ -74,13 +72,14 @@ public void autoscalingE2ETest() throws Exception { String outputFile = String.format("100k_auto_scale_job_%d.avro.test", jobNum); String outputDataPrefix = String.format("%s/test-outputs/%s", KOKORO_BUILD_ID, outputFile); - CreateJobRequest jobRequest = SmokeTestBase.createJobRequest( - getTestDataBucket(), - inputDataPrefix, - getTestDataBucket(), - outputDataPrefix, - Optional.of(getTestDataBucket()), - Optional.of(domainDataPrefix)); + CreateJobRequest jobRequest = + SmokeTestBase.createJobRequestWithAttributionReportTo( + getTestDataBucket(), + inputDataPrefix, + getTestDataBucket(), + outputDataPrefix, + Optional.of(getTestDataBucket()), + Optional.of(domainDataPrefix)); SmokeTestBase.submitJob(jobRequest, SUBMIT_JOB_TIMEOUT, false); @@ -104,7 +103,8 @@ private void waitForInstanceScaleAction(boolean isScaleOut) throws InterruptedEx while (!scaleSuccessful && Instant.now().isBefore(waitMax)) { instanceCount = getInstanceCount(); System.out.println( - "Verifying instance count. Is scale out: " + isScaleOut + "Verifying instance count. Is scale out: " + + isScaleOut + ". Current instance count: " + instanceCount); if ((!isScaleOut && instanceCount == MIN_INSTANCES) @@ -130,13 +130,15 @@ private void waitForInstanceScaleAction(boolean isScaleOut) throws InterruptedEx } private int getInstanceCount() { - AggregatedListPagedResponse pagedResponse = gcpInstancesClient.aggregatedList( - getTestProjectId()); + AggregatedListPagedResponse pagedResponse = + gcpInstancesClient.aggregatedList(getTestProjectId()); int instancesCount = 0; for (Entry entry : pagedResponse.iterateAll()) { - instancesCount += entry.getValue().getInstancesList().stream() - .filter(i -> i.getName().contains(getEnvironmentName())).count(); + instancesCount += + entry.getValue().getInstancesList().stream() + .filter(i -> i.getName().contains(getEnvironmentName())) + .count(); } return instancesCount; } @@ -165,8 +167,7 @@ protected void configure() { .getService()); try { - bind(InstancesClient.class) - .toInstance(InstancesClient.create()); + bind(InstancesClient.class).toInstance(InstancesClient.create()); } catch (IOException e) { throw new RuntimeException("Unable to instantiate GCP Instances client: ", e); } diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousDiffTest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousDiffTest.java index 9ab1422a..d3db481b 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousDiffTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousDiffTest.java @@ -55,15 +55,12 @@ @RunWith(JUnit4.class) public class GcpWorkerContinuousDiffTest { - @Rule - public final Acai acai = new Acai(TestEnv.class); + @Rule public final Acai acai = new Acai(TestEnv.class); private static final Duration COMPLETION_TIMEOUT = Duration.of(10, ChronoUnit.MINUTES); - @Inject - GcsBlobStorageClient gcsBlobStorageClient; - @Inject - AvroResultsFileReader avroResultsFileReader; + @Inject GcsBlobStorageClient gcsBlobStorageClient; + @Inject AvroResultsFileReader avroResultsFileReader; @Before public void checkBuildEnv() { @@ -86,7 +83,7 @@ public void e2eDiffTest() throws Exception { "%s/test-outputs/10k_diff_test_output.avro.result", SmokeTestBase.KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -114,12 +111,12 @@ public void e2eDiffTest() throws Exception { MapDifference diffs = ResultDiffer.diffResults(aggregatedFacts.stream(), goldenAggregatedFacts.stream()); assertWithMessage( - String.format( - "Found (%s) diffs between left(test) and right(golden). Found (%s) entries only on" - + " left(test) and (%s) entries only on right(golden).", - diffs.entriesDiffering().size(), - diffs.entriesOnlyOnLeft().size(), - diffs.entriesOnlyOnRight().size())) + String.format( + "Found (%s) diffs between left(test) and right(golden). Found (%s) entries only on" + + " left(test) and (%s) entries only on right(golden).", + diffs.entriesDiffering().size(), + diffs.entriesOnlyOnLeft().size(), + diffs.entriesOnlyOnRight().size())) .that(diffs.areEqual()) .isTrue(); diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousOutOfMemoryTest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousOutOfMemoryTest.java index 85adcb3f..823436cb 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousOutOfMemoryTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousOutOfMemoryTest.java @@ -77,7 +77,7 @@ public void createJobE2EOOMTest() throws Exception { "%s/test-outputs/OOM_test_output_1.avro.result", SmokeTestBase.KOKORO_BUILD_ID); CreateJobRequest createJobRequest1 = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), @@ -101,7 +101,7 @@ public void createJobE2EOOMTest() throws Exception { "%s/test-outputs/OOM_test_output_2.avro.result", SmokeTestBase.KOKORO_BUILD_ID); CreateJobRequest createJobRequest2 = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputKey, getTestDataBucket(), diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousSmokeTest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousSmokeTest.java index b2fe2630..8df9b654 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousSmokeTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerContinuousSmokeTest.java @@ -90,7 +90,7 @@ public void createJobE2ETest() throws Exception { String.format("%s/test-outputs/10k_test_domain_1.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), @@ -122,7 +122,7 @@ public void createNotDebugJobE2EReportDebugEnabledTest() throws Exception { String.format("%s/test-outputs/10k_test_input_non_debug.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), @@ -130,7 +130,7 @@ public void createNotDebugJobE2EReportDebugEnabledTest() throws Exception { false, Optional.of(getTestDataBucket()), Optional.of(domainDataPrefix)); - JsonNode result = SmokeTestBase.submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); checkJobExecutionResult(result, SUCCESS.name(), 0); // Read output avro from GCS. @@ -169,7 +169,7 @@ public void createDebugJobE2EReportDebugModeEnabledTest() throws Exception { "%s/test-outputs/10k_test_input_debug_for_debug_disabled.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), @@ -177,7 +177,7 @@ public void createDebugJobE2EReportDebugModeEnabledTest() throws Exception { true, Optional.of(getTestDataBucket()), Optional.of(domainDataPrefix)); - JsonNode result = SmokeTestBase.submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); checkJobExecutionResult(result, SUCCESS.name(), 0); // Read output avro from GCS. @@ -215,27 +215,33 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { String.format("%s/test-outputs/10k_test_input_2.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), outputDataPrefix, true, Optional.of(getTestDataBucket()), - Optional.of(domainDataPrefix)); + Optional.of(domainDataPrefix), + /* totalReportsCount= */ 10000, + /* reportErrorThreshold= */ 10); JsonNode result = SmokeTestBase.submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); assertThat(result.get("result_info").get("return_code").asText()) - .isEqualTo(AggregationWorkerReturnCode.SUCCESS_WITH_ERRORS.name()); - assertThat( - result - .get("result_info") - .get("error_summary") - .get("error_counts") - .get(0) - .get("count") - .asInt()) - .isEqualTo(10000); + .isEqualTo(AggregationWorkerReturnCode.REPORTS_WITH_ERRORS_EXCEEDED_THRESHOLD.name()); + // Due to parallel aggregation, the processing may stop a little over the threshold. + // So, asserting below that the processing stopped somewhere above the threshold but before all + // the 10K reports are processed. + int erroringReportCount = + result + .get("result_info") + .get("error_summary") + .get("error_counts") + .get(0) + .get("count") + .asInt(); + assertThat(erroringReportCount).isAtLeast(1000); + assertThat(erroringReportCount).isLessThan(10000); assertThat( result .get("result_info") @@ -254,31 +260,6 @@ public void createDebugJobE2EReportDebugModeDisabledTest() throws Exception { .get("description") .asText()) .isEqualTo(ErrorCounter.DEBUG_NOT_ENABLED.getDescription()); - - // Read output avro from s3. - ImmutableList aggregatedFacts = - readResultsFromCloud( - gcsBlobStorageClient, - avroResultsFileReader, - getTestDataBucket(), - outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); - - assertThat(aggregatedFacts.size()).isEqualTo(DEBUG_DOMAIN_KEY_SIZE); - - // Read debug result from s3. - ImmutableList aggregatedDebugFacts = - readDebugResultsFromCloud( - gcsBlobStorageClient, - readerFactory, - getTestDataBucket(), - getDebugFilePrefix(outputDataPrefix + OUTPUT_DATA_PREFIX_NAME)); - - // Only contains keys in domain because all reports are filtered out. - assertThat(aggregatedDebugFacts.size()).isEqualTo(DEBUG_DOMAIN_KEY_SIZE); - // The unnoisedMetric of aggregatedDebugFacts should be 0 for all keys because - // all reports are filtered out. - // Noised metric in both debug reports and summary reports should be noise value instead of 0. - aggregatedDebugFacts.forEach(fact -> assertThat(fact.getUnnoisedMetric().get()).isEqualTo(0)); } /** @@ -297,7 +278,7 @@ public void createJobE2EAggregateReportingDebugTest() throws Exception { "%s/test-outputs/10k_test_output_attribution_debug.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), @@ -308,14 +289,141 @@ public void createJobE2EAggregateReportingDebugTest() throws Exception { checkJobExecutionResult(result, SUCCESS.name(), 0); ImmutableList aggregatedFacts = - readResultsFromCloud( - gcsBlobStorageClient, - avroResultsFileReader, - getTestDataBucket(), - outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); + readResultsFromCloud( + gcsBlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); + assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); + } + + /** + * This test includes sending a job with reporting site only. Verifies that jobs with only + * reporting site are successful. + */ + @Test + public void createJobE2ETestWithReportingSite() throws Exception { + var inputDataPrefix = + String.format("%s/test-inputs/10k_test_input_reporting_site.avro", KOKORO_BUILD_ID); + var domainDataPrefix = + String.format("%s/test-inputs/10k_test_domain_reporting_site.avro", KOKORO_BUILD_ID); + var outputDataPrefix = + String.format( + "%s/test-outputs/10k_test_output_reporting_site.avro.result", KOKORO_BUILD_ID); + + CreateJobRequest createJobRequest = + SmokeTestBase.createJobRequestWithReportingSite( + getTestDataBucket(), + inputDataPrefix, + getTestDataBucket(), + outputDataPrefix, + /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), + /* outputDomainPrefix= */ Optional.of(domainDataPrefix)); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + + checkJobExecutionResult(result, SUCCESS.name(), 0); + + ImmutableList aggregatedFacts = + readResultsFromCloud( + gcsBlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); + assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); + } + + /** + * This test includes sending a job with reports from multiple reporting origins belonging to the + * same reporting site. Verifies that all the reports are processed successfully. + */ + @Test + public void createJobE2ETestWithMultipleReportingOrigins() throws Exception { + var inputDataPrefix = String.format("%s/test-inputs/same-site/", KOKORO_BUILD_ID); + var domainDataPrefix = + String.format( + "%s/test-inputs/10k_test_domain_multiple_origins_same_site.avro", KOKORO_BUILD_ID); + var outputDataPrefix = + String.format( + "%s/test-outputs/10k_test_output_multiple_origins_same_site.avro.result", + KOKORO_BUILD_ID); + + CreateJobRequest createJobRequest = + SmokeTestBase.createJobRequestWithReportingSite( + getTestDataBucket(), + inputDataPrefix, + getTestDataBucket(), + outputDataPrefix, + /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), + /* outputDomainPrefix= */ Optional.of(domainDataPrefix)); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + + checkJobExecutionResult(result, SUCCESS.name(), 0); + + ImmutableList aggregatedFacts = + readResultsFromCloud( + gcsBlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); assertThat(aggregatedFacts.size()).isAtLeast(DEBUG_DOMAIN_KEY_SIZE); } + /** + * This test includes sending a job with reports from multiple reporting origins belonging to + * different reporting sites. It is expected that the 5k reports with a different reporting site + * will fail and come up in the error counts. + */ + @Test + public void createJobE2ETestWithSomeReportsHavingDifferentReportingOrigins() throws Exception { + var inputDataPrefix = String.format("%s/test-inputs/different-site/", KOKORO_BUILD_ID); + var domainDataPrefix = + String.format( + "%s/test-inputs/10k_test_domain_multiple_origins_different_site.avro", KOKORO_BUILD_ID); + var outputDataPrefix = + String.format( + "%s/test-outputs/10k_test_output_multiple_origins_different_site.avro.result", + KOKORO_BUILD_ID); + + CreateJobRequest createJobRequest = + SmokeTestBase.createJobRequestWithReportingSite( + getTestDataBucket(), + inputDataPrefix, + getTestDataBucket(), + outputDataPrefix, + /* outputDomainBucketName= */ Optional.of(getTestDataBucket()), + /* outputDomainPrefix= */ Optional.of(domainDataPrefix)); + JsonNode result = submitJobAndWaitForResult(createJobRequest, COMPLETION_TIMEOUT); + + assertThat(result.get("result_info").get("return_code").asText()) + .isEqualTo(AggregationWorkerReturnCode.SUCCESS_WITH_ERRORS.name()); + assertThat( + result + .get("result_info") + .get("error_summary") + .get("error_counts") + .get(0) + .get("count") + .asInt()) + .isEqualTo(5000); + assertThat( + result + .get("result_info") + .get("error_summary") + .get("error_counts") + .get(0) + .get("category") + .asText()) + .isEqualTo(ErrorCounter.REPORTING_SITE_MISMATCH.name()); + + ImmutableList aggregatedFacts = + readResultsFromCloud( + gcsBlobStorageClient, + avroResultsFileReader, + getTestDataBucket(), + outputDataPrefix + OUTPUT_DATA_PREFIX_NAME); + assertThat(aggregatedFacts.size()).isAtLeast(5000); + } + /* Creates a job and waits for successful completion. Then creates another job for the same data and verifies the privacy budget is exhausted. @@ -329,15 +437,14 @@ public void createJobE2ETestPrivacyBudgetExhausted() throws Exception { String.format("%s/test-outputs/10k_test_input_3.avro.result", KOKORO_BUILD_ID); CreateJobRequest createJobRequest1 = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(), inputDataPrefix, getTestDataBucket(), outputDataPrefix, Optional.of(getTestDataBucket()), Optional.of(domainDataPrefix)); - JsonNode result = - SmokeTestBase.submitJobAndWaitForResult(createJobRequest1, COMPLETION_TIMEOUT); + JsonNode result = submitJobAndWaitForResult(createJobRequest1, COMPLETION_TIMEOUT); assertThat(result.get("result_info").get("return_code").asText()) .isEqualTo(AggregationWorkerReturnCode.SUCCESS.name()); assertThat(result.get("result_info").get("error_summary").get("error_counts").isEmpty()) diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerKhsLoadtest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerKhsLoadtest.java index ea1f42d3..328b515c 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerKhsLoadtest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerKhsLoadtest.java @@ -47,19 +47,14 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * GCP KHS loadtest implementation - */ +/** GCP KHS loadtest implementation */ @RunWith(JUnit4.class) public final class GcpWorkerKhsLoadtest { - @Rule - public final Acai acai = new Acai(TestEnv.class); - @Rule - public TestName name = new TestName(); + @Rule public final Acai acai = new Acai(TestEnv.class); + @Rule public TestName name = new TestName(); - private static final String KHS_LOADTEST_DATA_BUCKET = - "loadtest_data"; + private static final String KHS_LOADTEST_DATA_BUCKET = "loadtest_data"; private static final int NUM_RUNS = 5; private static final Duration COMPLETION_TIMEOUT = Duration.of(30, ChronoUnit.MINUTES); @@ -71,31 +66,29 @@ public void checkBuildEnv() { } } - /** - * Run Aggregation job for KHS loadtest. - */ + /** Run Aggregation job for KHS loadtest. */ @Test public void aggregateKhsLoadTest() throws Exception { ArrayList jobRequests = new ArrayList<>(NUM_RUNS); ArrayList jobRequestsDeepCopy = new ArrayList<>(NUM_RUNS); for (int i = 1; i <= NUM_RUNS; i++) { - var inputKey = String.format("test-data/%s/test-inputs/loadtest_report.avro", KOKORO_BUILD_ID); - var domainKey = String.format("test-data/%s/test-inputs/loadtest_domain.avro", KOKORO_BUILD_ID); + var inputKey = + String.format("test-data/%s/test-inputs/loadtest_report.avro", KOKORO_BUILD_ID); + var domainKey = + String.format("test-data/%s/test-inputs/loadtest_domain.avro", KOKORO_BUILD_ID); var outputKey = - String.format( - "test-data/%s/test-outputs/loadtest_%s_output.avro", - KOKORO_BUILD_ID, i); + String.format("test-data/%s/test-outputs/loadtest_%s_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(KHS_LOADTEST_DATA_BUCKET), inputKey, getTestDataBucket(KHS_LOADTEST_DATA_BUCKET), outputKey, /* debugRun= */ true, /* jobId= */ UUID.randomUUID().toString(), - /* outputDomainBucketName= */ - Optional.of(getTestDataBucket(KHS_LOADTEST_DATA_BUCKET)), + /* outputDomainBucketName= */ Optional.of( + getTestDataBucket(KHS_LOADTEST_DATA_BUCKET)), /* outputDomainPrefix= */ Optional.of(domainKey)); createJob(createJobRequest); diff --git a/javatests/com/google/aggregate/adtech/worker/GcpWorkerPerformanceRegressionTest.java b/javatests/com/google/aggregate/adtech/worker/GcpWorkerPerformanceRegressionTest.java index 494e6c35..b442a6e2 100644 --- a/javatests/com/google/aggregate/adtech/worker/GcpWorkerPerformanceRegressionTest.java +++ b/javatests/com/google/aggregate/adtech/worker/GcpWorkerPerformanceRegressionTest.java @@ -50,23 +50,16 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * GCP performance regression test implementation - */ +/** GCP performance regression test implementation */ @RunWith(JUnit4.class) public final class GcpWorkerPerformanceRegressionTest { - @Rule - public final Acai acai = new Acai(TestEnv.class); - @Rule - public TestName name = new TestName(); - - @Inject - GcsBlobStorageClient gcsBlobStorageClient; - @Inject - AvroResultsFileReader avroResultsFileReader; - @Inject - private AvroDebugResultsReaderFactory readerFactory; + @Rule public final Acai acai = new Acai(TestEnv.class); + @Rule public TestName name = new TestName(); + + @Inject GcsBlobStorageClient gcsBlobStorageClient; + @Inject AvroResultsFileReader avroResultsFileReader; + @Inject private AvroDebugResultsReaderFactory readerFactory; private static final String PERFORMANCE_REGRESSION_DATA_BUCKET = "gcp_performance_regression_test_data"; private static final int NUM_WARMUP_RUNS = 5; @@ -99,7 +92,7 @@ public void aggregateARA500kReports500kDomainWarmup() throws Exception { "test-data/%s/test-outputs/500k_report_%s_500k_domain_warmup_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET), inputKey, getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET), @@ -110,8 +103,8 @@ public void aggregateARA500kReports500kDomainWarmup() throws Exception { + name.getMethodName() + "_warmup-" + i, - /* outputDomainBucketName= */ - Optional.of(getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET)), + /* outputDomainBucketName= */ Optional.of( + getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET)), /* outputDomainPrefix= */ Optional.of(domainKey)); createJob(createJobRequest); @@ -119,7 +112,7 @@ public void aggregateARA500kReports500kDomainWarmup() throws Exception { warmUpJobRequestsDeepCopy.add(createJobRequest); } - waitForJobCompletions(warmUpJobRequestsDeepCopy, COMPLETION_TIMEOUT); + waitForJobCompletions(warmUpJobRequestsDeepCopy, COMPLETION_TIMEOUT, true); for (int i = 1; i <= NUM_WARMUP_RUNS; i++) { String outputKey = @@ -153,7 +146,7 @@ public void aggregateARA500kReports500kDomainTransient() throws Exception { "test-data/%s/test-outputs/500k_report_%s_500k_domain_transient_output.avro", KOKORO_BUILD_ID, i); CreateJobRequest createJobRequest = - SmokeTestBase.createJobRequest( + SmokeTestBase.createJobRequestWithAttributionReportTo( getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET), inputKey, getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET), @@ -164,8 +157,8 @@ public void aggregateARA500kReports500kDomainTransient() throws Exception { + name.getMethodName() + "_transient-" + i, - /* outputDomainBucketName= */ - Optional.of(getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET)), + /* outputDomainBucketName= */ Optional.of( + getTestDataBucket(PERFORMANCE_REGRESSION_DATA_BUCKET)), /* outputDomainPrefix= */ Optional.of(domainKey)); createJob(createJobRequest); transientJobRequests.add(createJobRequest); @@ -173,7 +166,7 @@ public void aggregateARA500kReports500kDomainTransient() throws Exception { } waitForJobCompletions( - transientJobRequestsDeepCopy, COMPLETION_TIMEOUT, false); + transientJobRequestsDeepCopy, COMPLETION_TIMEOUT, true); for (int i = 1; i <= NUM_TRANSIENT_RUNS; i++) { var outputKey = diff --git a/javatests/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLoggerTest.java b/javatests/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLoggerTest.java index 93dd25ec..bd99ac8e 100644 --- a/javatests/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLoggerTest.java +++ b/javatests/com/google/aggregate/adtech/worker/LocalFileToCloudStorageLoggerTest.java @@ -18,7 +18,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.reset; import static org.mockito.Mockito.spy; @@ -32,7 +31,6 @@ import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; import com.google.aggregate.adtech.worker.model.DebugBucketAnnotation; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.testing.AvroResultsFileReader; import com.google.aggregate.adtech.worker.util.OutputShardFileHelper; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter; @@ -41,11 +39,7 @@ import com.google.aggregate.protocol.avro.AvroDebugResultsReader; import com.google.aggregate.protocol.avro.AvroDebugResultsReaderFactory; import com.google.aggregate.protocol.avro.AvroDebugResultsRecord; -import com.google.aggregate.protocol.avro.AvroReportRecord; -import com.google.aggregate.protocol.avro.AvroReportsReader; -import com.google.aggregate.protocol.avro.AvroReportsReaderFactory; import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import com.google.common.util.concurrent.ListeningExecutorService; @@ -66,7 +60,6 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; -import java.util.UUID; import java.util.concurrent.Executors; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -119,37 +112,12 @@ public class LocalFileToCloudStorageLoggerTest { AggregatedFact.create(BigInteger.valueOf(3789), 90L, 80L, annotationReportOnly), AggregatedFact.create(BigInteger.valueOf(4123), 100L, 70L, annotationReportOnly)); - private static final ByteSource encryptedReport1Payload = - ByteSource.wrap(new byte[] {0x00, 0x01}); - private static final ByteSource encryptedReport2Payload = - ByteSource.wrap(new byte[] {0x01, 0x02}); - private static final String encryptedReport1KeyId = UUID.randomUUID().toString(); - private static final String encryptedReport2KeyId = UUID.randomUUID().toString(); - private static final String encryptedReport1SharedInfo = "foo"; - private static final String encryptedReport2SharedInfo = "bar"; - private static final EncryptedReport report1 = - EncryptedReport.builder() - .setPayload(encryptedReport1Payload) - .setKeyId(encryptedReport1KeyId) - .setSharedInfo(encryptedReport1SharedInfo) - .build(); - - private static final EncryptedReport report2 = - EncryptedReport.builder() - .setPayload(encryptedReport2Payload) - .setKeyId(encryptedReport2KeyId) - .setSharedInfo(encryptedReport2SharedInfo) - .build(); - - private static final ImmutableList reportsList = - ImmutableList.of(report1, report2); // Under test @Inject private Provider localFileToCloudStorageLogger; @Inject private FSBlobStorageClient blobStorageClient; @Inject private AvroResultsFileReader avroResultsFileReader; @Inject private AvroDebugResultsReaderFactory readerFactory; - @Inject private AvroReportsReaderFactory reportReaderFactory; @Inject private ParallelUploadFlagHelper uploadFlagHelper; @Inject private FileSystem testFS; @Inject @ResultWorkingDirectory private Path workingDirectory; @@ -178,43 +146,6 @@ public void logResultsTest_singleThreaded() throws Exception { logResultsTest(); } - @Test - public void logReports_writesReports() throws Exception { - localFileToCloudStorageLogger.get().logReports(reportsList, ctx, "1"); - - Path reportsFilePath = blobStorageClient.getLastWrittenFile(); - Stream writtenFile; - try (AvroReportsReader reader = getReportsReader(reportsFilePath)) { - writtenFile = reader.streamRecords(); - } - Stream writtenFileEncryptedReports = - writtenFile.map( - report -> - EncryptedReport.builder() - .setKeyId(report.keyId()) - .setPayload(report.payload()) - .setSharedInfo(report.sharedInfo()) - .build()); - - List encryptedReportsList = - writtenFileEncryptedReports.collect(toImmutableList()); - - // check reencrypted reports file name - assertThat(reportsFilePath.toString()).isEqualTo("/bucket/abc123-reencrypted-1.avro"); - // Check the output reports - assertThat(encryptedReportsList.get(0).keyId()).isEqualTo(encryptedReport1KeyId); - assertTrue(encryptedReportsList.get(0).payload().contentEquals(encryptedReport1Payload)); - assertThat(encryptedReportsList.get(0).sharedInfo()).isEqualTo(encryptedReport1SharedInfo); - - assertThat(encryptedReportsList.get(1).keyId()).isEqualTo(encryptedReport2KeyId); - assertTrue(encryptedReportsList.get(1).payload().contentEquals(encryptedReport2Payload)); - assertThat(encryptedReportsList.get(1).sharedInfo()).isEqualTo(encryptedReport2SharedInfo); - } - - private AvroReportsReader getReportsReader(Path avroFile) throws Exception { - return reportReaderFactory.create(Files.newInputStream(avroFile)); - } - private void logResultsTest() throws Exception { OutputShardFileHelper.setOutputShardFileSizeBytes(100_000_000L); diff --git a/javatests/com/google/aggregate/adtech/worker/SmokeTestBase.java b/javatests/com/google/aggregate/adtech/worker/SmokeTestBase.java index ce6eee37..bea3248a 100644 --- a/javatests/com/google/aggregate/adtech/worker/SmokeTestBase.java +++ b/javatests/com/google/aggregate/adtech/worker/SmokeTestBase.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.aggregate.adtech.worker.model.AggregatedFact; import com.google.aggregate.adtech.worker.testing.AvroResultsFileReader; +import com.google.aggregate.adtech.worker.util.JobUtils; import com.google.aggregate.protocol.avro.AvroDebugResultsReader; import com.google.aggregate.protocol.avro.AvroDebugResultsReaderFactory; import com.google.aggregate.protocol.avro.AvroDebugResultsRecord; @@ -58,13 +59,14 @@ import org.apache.http.impl.client.HttpClients; import org.apache.http.util.EntityUtils; -/** - * Smoke test base class - */ +/** Smoke test base class */ public abstract class SmokeTestBase { public static final String ENV_ATTRIBUTION_REPORT_TO = System.getenv("ATTRIBUTION_REPORT_TO"); - public static final String DEFAULT_ATTRIBUTION_REPORT_TO = "foo.com"; + public static final String ENV_REPORTING_SITE = System.getenv("REPORTING_SITE"); + public static final String DEFAULT_ATTRIBUTION_REPORT_TO = "https://subdomain.fakeurl.com"; + public static final String DEFAULT_REPORTING_SITE = "https://fakeurl.com"; + public static final String FRONTEND_CLOUDFUNCTION_URL = System.getenv("FRONTEND_CLOUDFUNCTION_URL"); public static final String KOKORO_BUILD_ID = System.getenv("KOKORO_BUILD_ID"); @@ -75,56 +77,104 @@ public abstract class SmokeTestBase { public static final String FRONTEND_API = System.getenv("FRONTEND_API"); public static final String API_GATEWAY_STAGE = "stage"; public static final String GCP_ACCESS_TOKEN = System.getenv("GCP_ACCESS_TOKEN"); - public static final String DEFAULT_DEPLOY_SA = "deploy-sa@ps-msmt-aggserv-test.iam.gserviceaccount.com"; + public static final String DEFAULT_DEPLOY_SA = + "deploy-sa@ps-msmt-aggserv-test.iam.gserviceaccount.com"; public static final String DEFAULT_TEST_DATA_BUCKET = "test_reports_data"; public static final String DEFAULT_PROJECT_ID = "ps-msmt-aggserv-test"; public static final String DEFAULT_ENVIRONMENT_NAME = "continuous_mp"; protected CreateJobRequest createJobRequest; - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, String outputDataBlobPrefix, String jobId, Optional outputDomainBucketName, + Optional outputDomainPrefix, + long totalReportsCount, + int reportErrorThreshold) { + return createDefaultJobRequestBuilder( + inputDataBlobBucket, + inputDataBlobPrefix, + outputDataBlobBucket, + outputDataBlobPrefix, + jobId) + .putAllJobParameters( + getJobParamsWithAttributionReportTo( + false, outputDomainBucketName, outputDomainPrefix, Optional.of(totalReportsCount), reportErrorThreshold)) + .build(); + } + + public static CreateJobRequest createJobRequestWithAttributionReportTo( + String inputDataBlobBucket, + String inputDataBlobPrefix, + String outputDataBlobBucket, + String outputDataBlobPrefix, + Boolean debugRun, + Optional outputDomainBucketName, Optional outputDomainPrefix) { return createDefaultJobRequestBuilder( - inputDataBlobBucket, - inputDataBlobPrefix, - outputDataBlobBucket, - outputDataBlobPrefix, - jobId) - .putAllJobParameters(getJobParams(false, outputDomainBucketName, outputDomainPrefix, 100)) + inputDataBlobBucket, inputDataBlobPrefix, outputDataBlobBucket, outputDataBlobPrefix) + .putAllJobParameters( + getJobParamsWithAttributionReportTo( + debugRun, + outputDomainBucketName, + outputDomainPrefix, + /* inputReportCount= */ Optional.empty(), + /* reportErrorThreshold= */ 0)) .build(); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, String outputDataBlobPrefix, Boolean debugRun, + String jobId, Optional outputDomainBucketName, Optional outputDomainPrefix) { + return createDefaultJobRequestBuilder( + inputDataBlobBucket, + inputDataBlobPrefix, + outputDataBlobBucket, + outputDataBlobPrefix, + jobId) + .putAllJobParameters( + getJobParamsWithAttributionReportTo( + debugRun, outputDomainBucketName, outputDomainPrefix, Optional.empty(), 0)) + .build(); + } + + public static CreateJobRequest createJobRequestWithAttributionReportTo( + String inputDataBlobBucket, + String inputDataBlobPrefix, + String outputDataBlobBucket, + String outputDataBlobPrefix, + Boolean debugRun, + Optional outputDomainBucketName, + Optional outputDomainPrefix, + long totalReportsCount, + int reportErrorThreshold) { return createDefaultJobRequestBuilder( inputDataBlobBucket, inputDataBlobPrefix, outputDataBlobBucket, outputDataBlobPrefix) .putAllJobParameters( - getJobParams( + getJobParamsWithAttributionReportTo( debugRun, outputDomainBucketName, outputDomainPrefix, - /* reportErrorThreshold= */ 100)) + Optional.of(totalReportsCount), + reportErrorThreshold)) .build(); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, String outputDataBlobPrefix, - Boolean debugRun, String jobId, Optional outputDomainBucketName, Optional outputDomainPrefix) { @@ -135,7 +185,12 @@ public static CreateJobRequest createJobRequest( outputDataBlobPrefix, jobId) .putAllJobParameters( - getJobParams(debugRun, outputDomainBucketName, outputDomainPrefix, 100)) + getJobParamsWithAttributionReportTo( + false, + outputDomainBucketName, + outputDomainPrefix, + /* inputReportCount= */ Optional.empty(), + /* reportErrorThreshold= */ 0)) .build(); } @@ -159,7 +214,7 @@ private static CreateJobRequest.Builder createDefaultJobRequestBuilder( .putAllJobParameters(ImmutableMap.of()); } - public static CreateJobRequest createJobRequest( + public static CreateJobRequest createJobRequestWithAttributionReportTo( String inputDataBlobBucket, String inputDataBlobPrefix, String outputDataBlobBucket, @@ -167,9 +222,24 @@ public static CreateJobRequest createJobRequest( Optional outputDomainBucketName, Optional outputDomainPrefix) { return createDefaultJobRequestBuilder( - inputDataBlobBucket, inputDataBlobPrefix, outputDataBlobBucket, outputDataBlobPrefix) + inputDataBlobBucket, inputDataBlobPrefix, outputDataBlobBucket, outputDataBlobPrefix) + .putAllJobParameters( + getJobParamsWithAttributionReportTo( + false, outputDomainBucketName, outputDomainPrefix, /* reportErrorThreshold= */ Optional.empty(), 0)) + .build(); + } + + public static CreateJobRequest createJobRequestWithReportingSite( + String inputDataBlobBucket, + String inputDataBlobPrefix, + String outputDataBlobBucket, + String outputDataBlobPrefix, + Optional outputDomainBucketName, + Optional outputDomainPrefix) { + return createDefaultJobRequestBuilder( + inputDataBlobBucket, inputDataBlobPrefix, outputDataBlobBucket, outputDataBlobPrefix) .putAllJobParameters( - getJobParams( + getJobParamsWithReportingSite( false, outputDomainBucketName, outputDomainPrefix, /* reportErrorThreshold= */ 100)) .build(); } @@ -404,9 +474,9 @@ protected static ImmutableList rea throws Exception { Path tempResultFile = Files.createTempFile(/* prefix= */ "results", /* suffix= */ "avro"); try (InputStream resultStream = - blobStorageClient.getBlob( - DataLocation.ofBlobStoreDataLocation( - BlobStoreDataLocation.create(outputBucket, outputPrefix))); + blobStorageClient.getBlob( + DataLocation.ofBlobStoreDataLocation( + BlobStoreDataLocation.create(outputBucket, outputPrefix))); OutputStream outputStream = Files.newOutputStream(tempResultFile)) { ByteStreams.copy(resultStream, outputStream); outputStream.flush(); @@ -419,18 +489,18 @@ protected static ImmutableList rea } protected static - ImmutableList readDebugResultsFromCloud( - T blobStorageClient, - AvroDebugResultsReaderFactory readerFactory, - String outputBucket, - String outputPrefix) - throws Exception { + ImmutableList readDebugResultsFromCloud( + T blobStorageClient, + AvroDebugResultsReaderFactory readerFactory, + String outputBucket, + String outputPrefix) + throws Exception { Stream writtenResults; Path tempResultFile = Files.createTempFile(/* prefix= */ "debug_results", /* suffix= */ "avro"); try (InputStream resultStream = - blobStorageClient.getBlob( - DataLocation.ofBlobStoreDataLocation( - BlobStoreDataLocation.create(outputBucket, outputPrefix))); + blobStorageClient.getBlob( + DataLocation.ofBlobStoreDataLocation( + BlobStoreDataLocation.create(outputBucket, outputPrefix))); OutputStream outputStream = Files.newOutputStream(tempResultFile)) { ByteStreams.copy(resultStream, outputStream); outputStream.flush(); @@ -465,17 +535,45 @@ protected static boolean checkFileExists( } } - private static ImmutableMap getJobParams( + private static ImmutableMap getJobParamsWithReportingSite( Boolean debugRun, Optional outputDomainBucketName, Optional outputDomainPrefix, int reportErrorThresholdPercentage) { ImmutableMap.Builder jobParams = ImmutableMap.builder(); + jobParams.put("reporting_site", getReportingSite()); + if (debugRun) { + jobParams.put("debug_run", "true"); + } + jobParams.put( + "report_error_threshold_percentage", String.valueOf(reportErrorThresholdPercentage)); + if (outputDomainPrefix.isPresent() && outputDomainBucketName.isPresent()) { + jobParams.put("output_domain_blob_prefix", outputDomainPrefix.get()); + jobParams.put("output_domain_bucket_name", outputDomainBucketName.get()); + return jobParams.build(); + } else if (outputDomainPrefix.isEmpty() && outputDomainBucketName.isEmpty()) { + return jobParams.build(); + } else { + throw new IllegalStateException( + "outputDomainPrefix and outputDomainBucketName must both be provided or both be empty."); + } + } + + private static ImmutableMap getJobParamsWithAttributionReportTo( + Boolean debugRun, + Optional outputDomainBucketName, + Optional outputDomainPrefix, + Optional inputReportCountOptional, + int reportErrorThresholdPercentage) { + ImmutableMap.Builder jobParams = ImmutableMap.builder(); jobParams.put("attribution_report_to", getAttributionReportTo()); if (debugRun) { jobParams.put("debug_run", "true"); } + inputReportCountOptional.ifPresent( + inputReportCount -> + jobParams.put(JobUtils.JOB_PARAM_INPUT_REPORT_COUNT, String.valueOf(inputReportCount))); jobParams.put( "report_error_threshold_percentage", String.valueOf(reportErrorThresholdPercentage)); if (outputDomainPrefix.isPresent() && outputDomainBucketName.isPresent()) { @@ -497,6 +595,13 @@ private static String getAttributionReportTo() { return DEFAULT_ATTRIBUTION_REPORT_TO; } + private static String getReportingSite() { + if (ENV_REPORTING_SITE != null) { + return ENV_REPORTING_SITE; + } + return DEFAULT_REPORTING_SITE; + } + protected static void checkJobExecutionResult( JsonNode result, String returnCode, int errorCount) { assertThat(result.get("result_info").get("return_code").asText()).isEqualTo(returnCode); @@ -505,13 +610,13 @@ protected static void checkJobExecutionResult( .isTrue(); } else { assertThat( - result - .get("result_info") - .get("error_summary") - .get("error_counts") - .get(0) - .get("count") - .asInt()) + result + .get("result_info") + .get("error_summary") + .get("error_counts") + .get(0) + .get("count") + .asInt()) .isEqualTo(errorCount); } } diff --git a/javatests/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessorTest.java b/javatests/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessorTest.java index 47b5c51d..4c73929f 100644 --- a/javatests/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessorTest.java +++ b/javatests/com/google/aggregate/adtech/worker/aggregation/concurrent/ConcurrentAggregationProcessorTest.java @@ -28,6 +28,7 @@ import static com.google.aggregate.adtech.worker.aggregation.concurrent.ConcurrentAggregationProcessor.JOB_PARAM_ATTRIBUTION_REPORT_TO; import static com.google.aggregate.adtech.worker.aggregation.concurrent.ConcurrentAggregationProcessor.JOB_PARAM_DEBUG_PRIVACY_EPSILON; import static com.google.aggregate.adtech.worker.aggregation.concurrent.ConcurrentAggregationProcessor.JOB_PARAM_DEBUG_RUN; +import static com.google.aggregate.adtech.worker.aggregation.concurrent.ConcurrentAggregationProcessor.JOB_PARAM_REPORTING_SITE; import static com.google.aggregate.adtech.worker.model.ErrorCounter.NUM_REPORTS_WITH_ERRORS; import static com.google.aggregate.adtech.worker.model.SharedInfo.LATEST_VERSION; import static com.google.aggregate.adtech.worker.model.SharedInfo.VERSION_0_1; @@ -99,6 +100,7 @@ import com.google.aggregate.adtech.worker.testing.FakeValidator; import com.google.aggregate.adtech.worker.testing.InMemoryResultLogger; import com.google.aggregate.adtech.worker.util.NumericConversions; +import com.google.aggregate.adtech.worker.util.ReportingOriginUtils; import com.google.aggregate.adtech.worker.validation.ReportValidator; import com.google.aggregate.adtech.worker.validation.ReportVersionValidator; import com.google.aggregate.perf.StopwatchExporter; @@ -153,9 +155,10 @@ import com.google.scp.operator.cpio.distributedprivacybudgetclient.StatusCode; import com.google.scp.operator.cpio.jobclient.model.Job; import com.google.scp.operator.cpio.jobclient.model.JobResult; -import com.google.scp.operator.cpio.jobclient.testing.FakeJobGenerator; import com.google.scp.operator.protos.shared.backend.ErrorCountProto.ErrorCount; import com.google.scp.operator.protos.shared.backend.ErrorSummaryProto.ErrorSummary; +import com.google.scp.operator.protos.shared.backend.JobKeyProto.JobKey; +import com.google.scp.operator.protos.shared.backend.JobStatusProto.JobStatus; import com.google.scp.operator.protos.shared.backend.RequestInfoProto.RequestInfo; import com.google.scp.operator.protos.shared.backend.ResultInfoProto.ResultInfo; import com.google.scp.shared.proto.ProtoUtil; @@ -169,6 +172,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.time.Clock; +import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.util.Arrays; @@ -191,6 +195,10 @@ public class ConcurrentAggregationProcessorTest { private static final Instant FIXED_TIME = Instant.parse("2021-01-01T00:00:00Z"); + private static final Instant REQUEST_RECEIVED_AT = Instant.parse("2019-10-01T08:25:24.00Z"); + private static final Instant REQUEST_PROCESSING_STARTED_AT = + Instant.parse("2019-10-01T08:29:24.00Z"); + private static final Instant REQUEST_UPDATED_AT = Instant.parse("2019-10-01T08:29:24.00Z"); @Rule public final Acai acai = new Acai(TestEnv.class); @Rule public final TemporaryFolder testWorkingDir = new TemporaryFolder(); @@ -289,7 +297,7 @@ public void setUpInputData() throws Exception { Files.createDirectory(reportsDirectory); Files.createDirectory(invalidReportsDirectory); - ctx = FakeJobGenerator.generateBuilder("foo").build(); + ctx = generateJob("foo", Optional.of("https://example.foo.com"), Optional.empty()); ctx = ctx.toBuilder() .setRequestInfo( @@ -304,7 +312,7 @@ public void setUpInputData() throws Exception { expectedJobResult = makeExpectedJobResult(); // Job context for job with invalid version input report. - ctxInvalidReport = FakeJobGenerator.generateBuilder("bar").build(); + ctxInvalidReport = generateJob("bar", Optional.of("https://example.foo.com"), Optional.empty()); ctxInvalidReport = ctxInvalidReport.toBuilder() .setRequestInfo( @@ -389,6 +397,23 @@ public void aggregate_noOutputDomain_thresholding() throws Exception { AggregatedFact.create(/* bucket= */ createBucketFromInt(2), /* metric= */ 5, 8L)); } + @Test + public void aggregate_reportingSiteProvided() throws Exception { + ctx = generateJob("foo", Optional.empty(), Optional.of("https://foo.com")); + ctx = + ctx.toBuilder() + .setRequestInfo( + getRequestInfoWithInputDataBucketName(ctx.requestInfo(), reportsDirectory)) + .build(); + JobResult jobResultProcessor = processor.get().process(ctx); + + assertThat(jobResultProcessor).isEqualTo(expectedJobResult); + assertThat(resultLogger.getMaterializedAggregationResults().getMaterializedAggregations()) + .containsExactly( + AggregatedFact.create(/* bucket= */ createBucketFromInt(1), /* metric= */ 2, 2L), + AggregatedFact.create(/* bucket= */ createBucketFromInt(2), /* metric= */ 8, 8L)); + } + @Test public void aggregate_withOutputDomain_overlappingDomainKeysInResults() throws Exception { outputDomainProcessorHelper.setDomainOptional(false); @@ -1640,7 +1665,7 @@ public void processingWithWrongSharedInfo() throws Exception { String keyId = UUID.randomUUID().toString(); Report report = FakeReportGenerator.generateWithParam( - 1, /* reportVersion */ LATEST_VERSION, "https://foo.com"); + 1, /* reportVersion */ LATEST_VERSION, "https://example.foo.com"); // Encrypt with a different sharedInfo than what is provided with the report so that decryption // fails String sharedInfoForEncryption = "foobarbaz"; @@ -1719,8 +1744,11 @@ public void aggregate_withPrivacyBudgeting() throws Exception { AggregatedFact.create( /* bucket= */ createBucketFromInt(2), /* metric= */ 8, /* unnoisedMetric= */ 8L)); // Check that the right attributionReportTo and debugPrivacyBudgetLimit were sent to the bridge + String claimedIdentity = + ReportingOriginUtils.convertReportingOriginToSite( + ctx.requestInfo().getJobParametersMap().get(JOB_PARAM_ATTRIBUTION_REPORT_TO)); assertThat(fakePrivacyBudgetingServiceBridge.getLastAttributionReportToSent()) - .hasValue(ctx.requestInfo().getJobParametersMap().get(JOB_PARAM_ATTRIBUTION_REPORT_TO)); + .hasValue(claimedIdentity); } @Test @@ -1773,13 +1801,12 @@ public void aggregate_withPrivacyBudgeting_invalidReportingOriginException_failJ privacyBudgetingServiceBridge.setPrivacyBudgetingServiceBridgeImpl( fakePrivacyBudgetingServiceBridge); - AggregationJobProcessException ex = - assertThrows(AggregationJobProcessException.class, () -> processor.get().process(ctx)); - assertThat(ex.getCode()).isEqualTo(INVALID_JOB); + IllegalStateException ex = + assertThrows(IllegalStateException.class, () -> processor.get().process(ctx)); assertThat(ex.getMessage()) .isEqualTo( - "The attribution_report_to parameter specified in the CreateJob request is not under a" - + " known public suffix."); + "Invalid reporting origin found while consuming budget, this should not happen as job" + + " validations ensure the reporting origin is always valid."); } @Test @@ -1787,7 +1814,7 @@ public void aggregate_withPrivacyBudgeting_oneBudgetMissing() { FakePrivacyBudgetingServiceBridge fakePrivacyBudgetingServiceBridge = new FakePrivacyBudgetingServiceBridge(); fakePrivacyBudgetingServiceBridge.setPrivacyBudget( - PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "foo.com"), 1); + PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "https://example.foo.com"), 1); // Missing budget for the second report. privacyBudgetingServiceBridge.setPrivacyBudgetingServiceBridgeImpl( fakePrivacyBudgetingServiceBridge); @@ -1858,7 +1885,7 @@ public void aggregate_withDebugRunAndPrivacyBudgetFailure_succeedsWithErrorCode( // Privacy Budget failure via thrown exception fakePrivacyBudgetingServiceBridge.setShouldThrow(); fakePrivacyBudgetingServiceBridge.setPrivacyBudget( - PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "foo.com"), 1); + PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "https://example.foo.com"), 1); // Missing budget for the second report. privacyBudgetingServiceBridge.setPrivacyBudgetingServiceBridgeImpl( fakePrivacyBudgetingServiceBridge); @@ -1888,7 +1915,7 @@ public void aggregateDebug_withPrivacyBudgetExhausted() throws Exception { FakePrivacyBudgetingServiceBridge fakePrivacyBudgetingServiceBridge = new FakePrivacyBudgetingServiceBridge(); fakePrivacyBudgetingServiceBridge.setPrivacyBudget( - PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "foo.com"), 1); + PrivacyBudgetUnit.create("1", Instant.ofEpochMilli(0), "https://example.foo.com"), 1); // Missing budget for the second report. privacyBudgetingServiceBridge.setPrivacyBudgetingServiceBridgeImpl( fakePrivacyBudgetingServiceBridge); @@ -1978,7 +2005,11 @@ private PrivacyBudgetUnit getPrivacyBudgetUnit( privacyBudgetKeyGeneratorFactory.getPrivacyBudgetKeyGenerator(privacyBudgetKeyInput).get(); PrivacyBudgetUnit privacyBudgetUnit = PrivacyBudgetUnit.create( - privacyBudgetKeyGenerator.generatePrivacyBudgetKey(privacyBudgetKeyInput), + privacyBudgetKeyGenerator.generatePrivacyBudgetKey( + PrivacyBudgetKeyGenerator.PrivacyBudgetKeyInput.builder() + .setFilteringId(filteringId) + .setSharedInfo(sharedInfo) + .build()), Instant.ofEpochMilli(0), sharedInfo.reportingOrigin()); return privacyBudgetUnit; @@ -2096,6 +2127,45 @@ void setEnablePrivacyBudgetKeyFiltering(boolean enablePrivacyBudgetKeyFiltering) } } + public static Job generateJob( + String id, Optional attributionReportTo, Optional reportingSite) { + if (attributionReportTo.isEmpty() && reportingSite.isEmpty()) { + throw new RuntimeException( + "At least one of attributionReportTo and reportingSite should be provided"); + } + RequestInfo.Builder requestInfoBuilder = + RequestInfo.newBuilder() + .setJobRequestId(id) + .setInputDataBlobPrefix("dataHandle") + .setInputDataBucketName("bucket") + .setOutputDataBlobPrefix("dataHandle") + .setOutputDataBucketName("bucket") + .setPostbackUrl("http://postback.com"); + RequestInfo requestInfo; + if (attributionReportTo.isPresent()) { + requestInfo = + requestInfoBuilder + .putAllJobParameters( + ImmutableMap.of(JOB_PARAM_ATTRIBUTION_REPORT_TO, attributionReportTo.get())) + .build(); + } else { + requestInfo = + requestInfoBuilder + .putAllJobParameters(ImmutableMap.of(JOB_PARAM_REPORTING_SITE, reportingSite.get())) + .build(); + } + return Job.builder() + .setJobKey(JobKey.newBuilder().setJobRequestId(id).build()) + .setJobProcessingTimeout(Duration.ofSeconds(3600)) + .setRequestInfo(requestInfo) + .setCreateTime(REQUEST_RECEIVED_AT) + .setUpdateTime(REQUEST_UPDATED_AT) + .setProcessingStartTime(Optional.of(REQUEST_PROCESSING_STARTED_AT)) + .setJobStatus(JobStatus.IN_PROGRESS) + .setNumAttempts(0) + .build(); + } + private static final class TestEnv extends AbstractModule { OutputDomainProcessorHelper helper = new OutputDomainProcessorHelper(); @@ -2186,13 +2256,13 @@ OutputDomainProcessor provideDomainProcess( @DomainOptional Boolean domainOptional) { return helper.isAvroOutputDomainProcessor() ? new AvroOutputDomainProcessor( - blockingThreadPool, - nonBlockingThreadPool, - blobStorageClient, - avroOutputDomainReaderFactory, - stopwatchRegistry, - domainOptional, - enableThresholding) + blockingThreadPool, + nonBlockingThreadPool, + blobStorageClient, + avroOutputDomainReaderFactory, + stopwatchRegistry, + domainOptional, + enableThresholding) : new TextOutputDomainProcessor( blockingThreadPool, nonBlockingThreadPool, diff --git a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessorTest.java b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessorTest.java index 1b5748ad..7f44dbc7 100644 --- a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessorTest.java +++ b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/AvroOutputDomainProcessorTest.java @@ -25,11 +25,32 @@ import static org.junit.Assert.assertThrows; import com.google.acai.Acai; +import com.google.acai.TestScoped; import com.google.aggregate.adtech.worker.Annotations.BlockingThreadPool; +import com.google.aggregate.adtech.worker.Annotations.CustomForkJoinThreadPool; import com.google.aggregate.adtech.worker.Annotations.DomainOptional; import com.google.aggregate.adtech.worker.Annotations.EnableThresholding; import com.google.aggregate.adtech.worker.Annotations.NonBlockingThreadPool; +import com.google.aggregate.adtech.worker.Annotations.ParallelAggregatedFactNoising; +import com.google.aggregate.adtech.worker.aggregation.engine.AggregationEngine; +import com.google.aggregate.adtech.worker.aggregation.engine.AggregationEngineFactory; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingDelta; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingDistribution; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingEpsilon; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingL1Sensitivity; import com.google.aggregate.adtech.worker.exceptions.DomainReadException; +import com.google.aggregate.adtech.worker.model.AggregatedFact; +import com.google.aggregate.privacy.budgeting.budgetkeygenerator.PrivacyBudgetKeyGeneratorModule; +import com.google.aggregate.privacy.noise.Annotations.Threshold; +import com.google.aggregate.privacy.noise.NoiseApplier; +import com.google.aggregate.privacy.noise.NoisedAggregationRunner; +import com.google.aggregate.privacy.noise.NoisedAggregationRunnerImpl; +import com.google.aggregate.privacy.noise.model.NoisedAggregatedResultSet; +import com.google.aggregate.privacy.noise.proto.Params.NoiseParameters.Distribution; +import com.google.aggregate.privacy.noise.proto.Params.PrivacyParameters; +import com.google.aggregate.privacy.noise.testing.ConstantNoiseModule.ConstantNoiseApplier; +import com.google.aggregate.privacy.noise.testing.FakeNoiseApplierSupplier; import com.google.aggregate.protocol.avro.AvroOutputDomainRecord; import com.google.aggregate.protocol.avro.AvroOutputDomainWriter; import com.google.aggregate.protocol.avro.AvroOutputDomainWriterFactory; @@ -39,6 +60,7 @@ import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.AbstractModule; import com.google.inject.Provides; +import com.google.inject.Singleton; import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation; import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation.BlobStoreDataLocation; import com.google.scp.operator.cpio.blobstorageclient.testing.FSBlobStorageClientModule; @@ -50,7 +72,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.Optional; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import javax.inject.Inject; @@ -69,8 +92,12 @@ public class AvroOutputDomainProcessorTest { @Inject AvroOutputDomainWriterFactory avroOutputDomainWriterFactory; // Under test @Inject AvroOutputDomainProcessor outputDomainProcessor; + @Inject AggregationEngineFactory aggregationEngineFactory; + @Inject FakeNoiseApplierSupplier fakeNoiseApplierSupplier; + @Inject NoisedAggregationRunnerImpl noisedAggregationRunner; private Path outputDomainDirectory; private DataLocation outputDomainLocation; + private AggregationEngine aggregationEngine; @Before public void setUp() throws Exception { @@ -80,6 +107,8 @@ public void setUp() throws Exception { DataLocation.ofBlobStoreDataLocation( BlobStoreDataLocation.create( /* bucket= */ outputDomainDirectory.toAbsolutePath().toString(), /* key= */ "")); + aggregationEngine = aggregationEngineFactory.create(ImmutableSet.of()); + fakeNoiseApplierSupplier.setFakeNoiseApplier(new ConstantNoiseApplier(0)); } @Test @@ -152,9 +181,7 @@ public void skipsZeroByteDomains() throws Exception { @Test public void ioProblem() { // No file written, path pointing to a non-existing file, this should be an IO exception. - ExecutionException error = assertThrows(ExecutionException.class, this::readOutputDomain); - - assertThat(error).hasCauseThat().isInstanceOf(DomainReadException.class); + assertThrows(DomainReadException.class, this::readOutputDomain); } @Test @@ -162,10 +189,9 @@ public void readOutputDomain_emptyOutputDomain_throwsException() throws Exceptio writeOutputDomain(outputDomainDirectory.resolve("domain_1.avro"), Stream.of()); writeOutputDomain(outputDomainDirectory.resolve("domain_2.avro"), Stream.of()); - ExecutionException error = assertThrows(ExecutionException.class, this::readOutputDomain); + DomainReadException error = assertThrows(DomainReadException.class, this::readOutputDomain); - assertThat(error).hasCauseThat().isInstanceOf(DomainReadException.class); - assertThat(error.getCause()).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + assertThat(error).hasCauseThat().isInstanceOf(IllegalArgumentException.class); assertThat(error.getCause()) .hasMessageThat() .containsMatch("No output domain provided in the location.*"); @@ -175,17 +201,22 @@ public void readOutputDomain_emptyOutputDomain_throwsException() throws Exceptio public void readOutputDomain_notReadableOutputDomain_throwsException() throws Exception { writeOutputDomainTextFile(outputDomainDirectory.resolve("domain_1.avro"), "bad domain"); - ExecutionException error = assertThrows(ExecutionException.class, this::readOutputDomain); - - assertThat(error).hasCauseThat().isInstanceOf(DomainReadException.class); + assertThrows(DomainReadException.class, this::readOutputDomain); } - private ImmutableSet readOutputDomain() - throws ExecutionException, InterruptedException { - return outputDomainProcessor - .readAndDedupeDomain( - outputDomainLocation, outputDomainProcessor.listShards(outputDomainLocation)) - .get(); + private ImmutableSet readOutputDomain() { + NoisedAggregatedResultSet noisedResultset = + outputDomainProcessor.adjustAggregationWithDomainAndNoiseStreaming( + aggregationEngine, + Optional.of(outputDomainLocation), + outputDomainProcessor.listShards(outputDomainLocation), + noisedAggregationRunner, + Optional.empty(), + false); + + return noisedResultset.noisedResult().noisedAggregatedFacts().stream() + .map(AggregatedFact::getBucket) + .collect(ImmutableSet.toImmutableSet()); } private void writeOutputDomain(Path path, Stream keys) throws IOException { @@ -207,10 +238,22 @@ private static final class TestEnv extends AbstractModule { @Override protected void configure() { install(new FSBlobStorageClientModule()); + install(new PrivacyBudgetKeyGeneratorModule()); + bind(FileSystem.class).toInstance(FileSystems.getDefault()); bind(OutputDomainProcessor.class).to(AvroOutputDomainProcessor.class); bind(Boolean.class).annotatedWith(DomainOptional.class).toInstance(true); bind(Boolean.class).annotatedWith(EnableThresholding.class).toInstance(true); + + bind(FakeNoiseApplierSupplier.class).in(TestScoped.class); + bind(NoisedAggregationRunner.class).to(NoisedAggregationRunnerImpl.class); + bind(boolean.class).annotatedWith(ParallelAggregatedFactNoising.class).toInstance(true); + bind(Distribution.class) + .annotatedWith(NoisingDistribution.class) + .toInstance(Distribution.LAPLACE); + bind(double.class).annotatedWith(NoisingEpsilon.class).toInstance(0.1); + bind(long.class).annotatedWith(NoisingL1Sensitivity.class).toInstance(4L); + bind(double.class).annotatedWith(NoisingDelta.class).toInstance(5.00); } @Provides @@ -225,6 +268,30 @@ ListeningExecutorService provideBlockingThreadPool() { return newDirectExecutorService(); } + @Provides + @Threshold + Supplier provideThreshold() { + return () -> 0.0; + } + + @Provides + Supplier provideNoiseApplierSupplier( + FakeNoiseApplierSupplier fakeNoiseApplierSupplier) { + return fakeNoiseApplierSupplier; + } + + @Provides + @Singleton + @CustomForkJoinThreadPool + ListeningExecutorService provideCustomForkJoinThreadPool() { + return newDirectExecutorService(); + } + + @Provides + Supplier providePrivacyParamConfig(PrivacyParametersSupplier supplier) { + return () -> supplier.get().toBuilder().setDelta(1e-5).build(); + } + @Provides Ticker provideTimingTicker() { return Ticker.systemTicker(); diff --git a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/BUILD b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/BUILD index 516a9ce8..029c5fce 100644 --- a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/BUILD @@ -23,9 +23,16 @@ java_test( "//java/com/google/aggregate/adtech/worker", "//java/com/google/aggregate/adtech/worker/aggregation/domain", "//java/com/google/aggregate/adtech/worker/aggregation/domain:text_domain", + "//java/com/google/aggregate/adtech/worker/aggregation/engine", + "//java/com/google/aggregate/adtech/worker/configs", "//java/com/google/aggregate/adtech/worker/exceptions", "//java/com/google/aggregate/adtech/worker/model", "//java/com/google/aggregate/adtech/worker/util", + "//java/com/google/aggregate/privacy/budgeting/budgetkeygenerator:privacy_budget_key_generator", + "//java/com/google/aggregate/privacy/noise", + "//java/com/google/aggregate/privacy/noise/model", + "//java/com/google/aggregate/privacy/noise/proto:privacy_parameters_java_proto", + "//java/com/google/aggregate/privacy/noise/testing", "//java/external:acai", "//java/external:clients_blobstorageclient_model", "//java/external:google_truth", @@ -43,8 +50,15 @@ java_test( "//java/com/google/aggregate/adtech/worker", "//java/com/google/aggregate/adtech/worker/aggregation/domain", "//java/com/google/aggregate/adtech/worker/aggregation/domain:avro_domain", + "//java/com/google/aggregate/adtech/worker/aggregation/engine", + "//java/com/google/aggregate/adtech/worker/configs", "//java/com/google/aggregate/adtech/worker/exceptions", "//java/com/google/aggregate/adtech/worker/model", + "//java/com/google/aggregate/privacy/budgeting/budgetkeygenerator:privacy_budget_key_generator", + "//java/com/google/aggregate/privacy/noise", + "//java/com/google/aggregate/privacy/noise/model", + "//java/com/google/aggregate/privacy/noise/proto:privacy_parameters_java_proto", + "//java/com/google/aggregate/privacy/noise/testing", "//java/com/google/aggregate/protocol/avro:avro_output_domain", "//java/external:acai", "//java/external:clients_blobstorageclient_model", diff --git a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessorTest.java b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessorTest.java index c301b501..711dd92d 100644 --- a/javatests/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessorTest.java +++ b/javatests/com/google/aggregate/adtech/worker/aggregation/domain/TextOutputDomainProcessorTest.java @@ -26,17 +26,39 @@ import static org.junit.Assert.assertThrows; import com.google.acai.Acai; +import com.google.acai.TestScoped; import com.google.aggregate.adtech.worker.Annotations.BlockingThreadPool; +import com.google.aggregate.adtech.worker.Annotations.CustomForkJoinThreadPool; import com.google.aggregate.adtech.worker.Annotations.DomainOptional; import com.google.aggregate.adtech.worker.Annotations.EnableThresholding; import com.google.aggregate.adtech.worker.Annotations.NonBlockingThreadPool; +import com.google.aggregate.adtech.worker.Annotations.ParallelAggregatedFactNoising; +import com.google.aggregate.adtech.worker.aggregation.engine.AggregationEngine; +import com.google.aggregate.adtech.worker.aggregation.engine.AggregationEngineFactory; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingDelta; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingDistribution; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingEpsilon; +import com.google.aggregate.adtech.worker.configs.PrivacyParametersSupplier.NoisingL1Sensitivity; import com.google.aggregate.adtech.worker.exceptions.DomainReadException; +import com.google.aggregate.adtech.worker.model.AggregatedFact; +import com.google.aggregate.privacy.budgeting.budgetkeygenerator.PrivacyBudgetKeyGeneratorModule; +import com.google.aggregate.privacy.noise.Annotations.Threshold; +import com.google.aggregate.privacy.noise.NoiseApplier; +import com.google.aggregate.privacy.noise.NoisedAggregationRunner; +import com.google.aggregate.privacy.noise.NoisedAggregationRunnerImpl; +import com.google.aggregate.privacy.noise.model.NoisedAggregatedResultSet; +import com.google.aggregate.privacy.noise.proto.Params.NoiseParameters.Distribution; +import com.google.aggregate.privacy.noise.proto.Params.PrivacyParameters; +import com.google.aggregate.privacy.noise.testing.ConstantNoiseModule.ConstantNoiseApplier; +import com.google.aggregate.privacy.noise.testing.FakeNoiseApplierSupplier; import com.google.common.base.Ticker; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.ListeningExecutorService; import com.google.inject.AbstractModule; import com.google.inject.Provides; +import com.google.inject.Singleton; import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation; import com.google.scp.operator.cpio.blobstorageclient.model.DataLocation.BlobStoreDataLocation; import com.google.scp.operator.cpio.blobstorageclient.testing.FSBlobStorageClientModule; @@ -48,7 +70,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.List; -import java.util.concurrent.ExecutionException; +import java.util.Optional; +import java.util.function.Supplier; import java.util.stream.Collectors; import javax.inject.Inject; import org.junit.Before; @@ -65,6 +88,10 @@ public class TextOutputDomainProcessorTest { @Rule public final Acai acai = new Acai(TestEnv.class); // Under test @Inject TextOutputDomainProcessor outputDomainProcessor; + @Inject AggregationEngineFactory aggregationEngineFactory; + @Inject FakeNoiseApplierSupplier fakeNoiseApplierSupplier; + @Inject NoisedAggregationRunnerImpl noisedAggregationRunner; + private AggregationEngine aggregationEngine; private Path outputDomainDirectory; private DataLocation outputDomainLocation; @@ -76,6 +103,8 @@ public void setUp() throws Exception { DataLocation.ofBlobStoreDataLocation( BlobStoreDataLocation.create( /* bucket= */ outputDomainDirectory.toAbsolutePath().toString(), /* key= */ "")); + aggregationEngine = aggregationEngineFactory.create(ImmutableSet.of()); + fakeNoiseApplierSupplier.setFakeNoiseApplier(new ConstantNoiseApplier(0)); } @Test @@ -164,9 +193,7 @@ public void skipZeroByteDomains() throws Exception { @Test public void ioProblem() throws Exception { // No file written, path pointing to a non-existing file, this should be an IO exception. - ExecutionException error = assertThrows(ExecutionException.class, () -> readOutputDomain()); - - assertThat(error).hasCauseThat().isInstanceOf(DomainReadException.class); + assertThrows(DomainReadException.class, this::readOutputDomain); } @Test @@ -174,17 +201,22 @@ public void readDomain_notReadableTextFile() throws Exception { String badString = "abcdabcdabcdabcdabcdabcdabcdabcd"; writeOutputDomain(outputDomainDirectory.resolve("domain_1.txt"), badString); - ExecutionException error = assertThrows(ExecutionException.class, () -> readOutputDomain()); - - assertThat(error).hasCauseThat().isInstanceOf(DomainReadException.class); + assertThrows(DomainReadException.class, this::readOutputDomain); } - private ImmutableSet readOutputDomain() - throws ExecutionException, InterruptedException { - return outputDomainProcessor - .readAndDedupeDomain( - outputDomainLocation, outputDomainProcessor.listShards(outputDomainLocation)) - .get(); + private ImmutableSet readOutputDomain() { + NoisedAggregatedResultSet noisedResultset = + outputDomainProcessor.adjustAggregationWithDomainAndNoiseStreaming( + aggregationEngine, + Optional.of(outputDomainLocation), + outputDomainProcessor.listShards(outputDomainLocation), + noisedAggregationRunner, + Optional.empty(), + false); + + return noisedResultset.noisedResult().noisedAggregatedFacts().stream() + .map(AggregatedFact::getBucket) + .collect(ImmutableSet.toImmutableSet()); } private void writeOutputDomain(Path path, String... keys) throws IOException { @@ -196,10 +228,46 @@ private static final class TestEnv extends AbstractModule { @Override protected void configure() { install(new FSBlobStorageClientModule()); + install(new PrivacyBudgetKeyGeneratorModule()); + bind(FileSystem.class).toInstance(FileSystems.getDefault()); bind(OutputDomainProcessor.class).to(TextOutputDomainProcessor.class); bind(Boolean.class).annotatedWith(DomainOptional.class).toInstance(true); bind(Boolean.class).annotatedWith(EnableThresholding.class).toInstance(true); + + bind(FakeNoiseApplierSupplier.class).in(TestScoped.class); + bind(NoisedAggregationRunner.class).to(NoisedAggregationRunnerImpl.class); + bind(boolean.class).annotatedWith(ParallelAggregatedFactNoising.class).toInstance(true); + bind(Distribution.class) + .annotatedWith(NoisingDistribution.class) + .toInstance(Distribution.LAPLACE); + bind(double.class).annotatedWith(NoisingEpsilon.class).toInstance(0.1); + bind(long.class).annotatedWith(NoisingL1Sensitivity.class).toInstance(4L); + bind(double.class).annotatedWith(NoisingDelta.class).toInstance(5.00); + } + + @Provides + @Threshold + Supplier provideThreshold() { + return () -> 0.0; + } + + @Provides + Supplier provideNoiseApplierSupplier( + FakeNoiseApplierSupplier fakeNoiseApplierSupplier) { + return fakeNoiseApplierSupplier; + } + + @Provides + @Singleton + @CustomForkJoinThreadPool + ListeningExecutorService provideCustomForkJoinThreadPool() { + return newDirectExecutorService(); + } + + @Provides + Supplier providePrivacyParamConfig(PrivacyParametersSupplier supplier) { + return () -> supplier.get().toBuilder().setDelta(1e-5).build(); } @Provides diff --git a/javatests/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImplTest.java b/javatests/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImplTest.java index a6345261..d0cba832 100644 --- a/javatests/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImplTest.java +++ b/javatests/com/google/aggregate/adtech/worker/encryption/RecordEncrypterImplTest.java @@ -16,10 +16,8 @@ package com.google.aggregate.adtech.worker.encryption; -import static com.google.aggregate.adtech.worker.model.SharedInfo.ATTRIBUTION_REPORTING_API; import static com.google.aggregate.adtech.worker.model.SharedInfo.LATEST_VERSION; import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertTrue; import com.google.acai.Acai; import com.google.aggregate.adtech.worker.decryption.DecryptionCipher.PayloadDecryptionException; @@ -27,18 +25,8 @@ import com.google.aggregate.adtech.worker.encryption.RecordEncrypter.EncryptionException; import com.google.aggregate.adtech.worker.encryption.hybrid.HybridCipherModule; import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKeyService; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService; import com.google.aggregate.adtech.worker.encryption.hybrid.key.testing.FakeEncryptionKeyService; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.testing.FakeReEncryptionKeyService; import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.model.Fact; -import com.google.aggregate.adtech.worker.model.Payload; -import com.google.aggregate.adtech.worker.model.Report; -import com.google.aggregate.adtech.worker.model.SharedInfo; -import com.google.aggregate.adtech.worker.model.serdes.PayloadSerdes; -import com.google.aggregate.adtech.worker.model.serdes.SharedInfoSerdes; -import com.google.aggregate.adtech.worker.model.serdes.cbor.CborPayloadSerdes; -import com.google.common.collect.ImmutableList; import com.google.common.io.ByteSource; import com.google.crypto.tink.HybridDecrypt; import com.google.crypto.tink.KeysetHandle; @@ -49,12 +37,7 @@ import com.google.inject.Provides; import com.google.inject.Singleton; import com.google.scp.operator.shared.testing.StringToByteSourceConverter; -import java.io.IOException; -import java.math.BigInteger; import java.security.GeneralSecurityException; -import java.time.Instant; -import java.util.Optional; -import java.util.UUID; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -65,14 +48,8 @@ public class RecordEncrypterImplTest { @Rule public final Acai acai = new Acai(TestEnv.class); @Inject RecordEncrypter recordEncrypter; @Inject KeysetHandle keysetHandle; - @Inject PayloadSerdes payloadSerdes; - @Inject SharedInfoSerdes sharedInfoSerdes; private StringToByteSourceConverter converter; - private static final String DESTINATION = "dest.com"; - private static final UUID FIXED_UUID = UUID.randomUUID(); - private static final String REPORTING_ORIGIN = "foo.com"; - @Before public void setUp() throws GeneralSecurityException { converter = new StringToByteSourceConverter(); @@ -92,39 +69,6 @@ public void encryptSingleReport() assertThat(encryptedReport.sharedInfo()).isEqualTo(sharedInfo); } - @Test - public void encryptSerializedReport_succeeds() - throws EncryptionException, - PayloadDecryptionException, - GeneralSecurityException, - IOException { - ImmutableList factList = - ImmutableList.of(Fact.builder().setBucket(BigInteger.valueOf(123)).setValue(5).build()); - Payload payload = Payload.builder().addAllFact(factList).build(); - SharedInfo sharedInfo = - SharedInfo.builder() - .setSourceRegistrationTime(Instant.now()) - .setDestination(DESTINATION) - .setScheduledReportTime(Instant.now()) - .setReportId(FIXED_UUID.toString()) - .setVersion(LATEST_VERSION) - .setApi(ATTRIBUTION_REPORTING_API) - .setReportingOrigin(REPORTING_ORIGIN) - .build(); - Report deserializedReport = - Report.builder().setPayload(payload).setSharedInfo(sharedInfo).build(); - ByteSource serializedPayload = payloadSerdes.reverse().convert(Optional.of(payload)); - String serializedSharedInfo = sharedInfoSerdes.reverse().convert(Optional.of(sharedInfo)); - - EncryptedReport generatedReport = - recordEncrypter.encryptReport(deserializedReport, "fakeuri.com"); - - assertThat(serializedPayload).isNotNull(); - assertTrue(decryptReport(generatedReport).contentEquals(serializedPayload)); - assertThat(generatedReport.keyId()).isEqualTo(ENCRYPTION_KEY_ID); - assertTrue(generatedReport.sharedInfo().contentEquals(serializedSharedInfo)); - } - private ByteSource decryptReport(EncryptedReport encryptedReport) throws PayloadDecryptionException, GeneralSecurityException { return HybridDecryptionCipher.of(keysetHandle.getPrimitive(HybridDecrypt.class)) @@ -138,8 +82,6 @@ protected void configure() { install(new HybridCipherModule()); bind(EncryptionKeyService.class).to(FakeEncryptionKeyService.class); bind(RecordEncrypter.class).to(RecordEncrypterImpl.class); - bind(ReEncryptionKeyService.class).to(FakeReEncryptionKeyService.class); - bind(PayloadSerdes.class).to(CborPayloadSerdes.class); } @Provides diff --git a/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD b/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD index b738ac66..179e4848 100644 --- a/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/BUILD @@ -35,24 +35,3 @@ java_test( "//protocol/proto:encryption_key_config_java_proto", ], ) - -java_test( - name = "CloudReEncryptionKeyServiceTest", - srcs = ["CloudReEncryptionKeyServiceTest.java"], - deps = [ - "//java/com/google/aggregate/adtech/worker/encryption/hybrid/key", - "//java/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud", - "//java/com/google/aggregate/adtech/worker/encryption/publickeyuri:encryption_key_config", - "//java/external:acai", - "//java/external:apache_httpclient", - "//java/external:apache_httpcore", - "//java/external:api_shared_util", - "//java/external:aws_dynamodb", - "//java/external:google_truth", - "//java/external:guava", - "//java/external:guice", - "//java/external:mockito", - "//java/external:tink", - "//protocol/proto:encryption_key_config_java_proto", - ], -) diff --git a/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyServiceTest.java b/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyServiceTest.java deleted file mode 100644 index 9721f327..00000000 --- a/javatests/com/google/aggregate/adtech/worker/encryption/hybrid/key/cloud/CloudReEncryptionKeyServiceTest.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.hybrid.key.cloud; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.when; - -import com.google.aggregate.adtech.worker.encryption.hybrid.key.EncryptionKey; -import com.google.aggregate.adtech.worker.encryption.hybrid.key.ReEncryptionKeyService.ReencryptionKeyFetchException; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.scp.shared.api.util.HttpClientResponse; -import com.google.scp.shared.api.util.HttpClientWrapper; -import org.apache.http.client.methods.HttpRequestBase; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -@RunWith(JUnit4.class) -public class CloudReEncryptionKeyServiceTest { - - private static final String KEY_ID_1 = "00000000-0000-0000-0000-000000000000"; - private static final String KEY_ID_2 = "00000000-0000-0000-0000-111111111111"; - private static final String KEY_ID_3 = "00000000-0000-0000-0000-222222222222"; - private static final String KEY_ID_4 = "00000000-0000-0000-0000-333333333333"; - private static final String KEY_ID_5 = "00000000-0000-0000-0000-444444444444"; - private static final ImmutableList keySet = - ImmutableList.of(KEY_ID_1, KEY_ID_2, KEY_ID_3, KEY_ID_4, KEY_ID_5); - - @Rule public final MockitoRule mockito = MockitoJUnit.rule(); - - @Mock private HttpClientWrapper httpClient; - CloudReEncryptionKeyService cloudReEncryptionKeyService; - String keyVendingResponse; - String publicKey; - String keyVendingServiceUri = - "https://publickeyservice.aggregationhelper.com/.well-known/aggregation-service/v1/public-keys"; - - @Before - public void setup() { - cloudReEncryptionKeyService = new CloudReEncryptionKeyService(httpClient); - publicKey = - "EkQKBAgCEAMSOhI4CjB0eXBlLmdvb2dsZWFwaXMuY29tL2dvb2dsZS5jcnlwdG8udGluay5BZXNH" - + "Y21LZXkSAhAQGAEYARohAJryfZtZSsWNdh86h3sOuxRurI4q/Qg2ECaABVGfgOu6IiEAjAYDniN7v5mb" - + "bMhPbXVSkPhEZFx84sB7MKB/AiN6KBI="; - keyVendingResponse = - String.format( - "{\"keys\":[{\"id\":\"%s\",\"key\":\"%s\"}," - + "{\"id\":\"%s\",\"key\":\"%s\"}," - + "{\"id\":\"%s\",\"key\":\"%s\"}," - + "{\"id\":\"%s\",\"key\":\"%s\"}," - + "{\"id\":\"%s\",\"key\":\"%s\"}]}", - KEY_ID_1, publicKey, KEY_ID_2, publicKey, KEY_ID_3, publicKey, KEY_ID_4, publicKey, - KEY_ID_5, publicKey); - } - - @Test - public void getCloudProviderKey_succeeds() throws Exception { - HttpClientResponse response = buildFakeResponse(200, keyVendingResponse); - when(httpClient.execute(any(HttpRequestBase.class))).thenReturn(response); - - EncryptionKey key = cloudReEncryptionKeyService.getEncryptionPublicKey(keyVendingServiceUri); - - assertThat(keySet).contains(key.id()); - } - - @Test - public void getCloudProviderKey_fails() throws Exception { - HttpClientResponse response = buildFakeResponse(500, keyVendingResponse); - when(httpClient.execute(any(HttpRequestBase.class))).thenReturn(response); - - assertThrows( - ReencryptionKeyFetchException.class, - () -> cloudReEncryptionKeyService.getEncryptionPublicKey(keyVendingServiceUri)); - } - - private HttpClientResponse buildFakeResponse(int statusCode, String body) { - HttpClientResponse response = HttpClientResponse.create(statusCode, body, ImmutableMap.of()); - return response; - } -} diff --git a/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD b/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD deleted file mode 100644 index 69d96c14..00000000 --- a/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@rules_java//java:defs.bzl", "java_test") - -package(default_visibility = ["//visibility:public"]) - -java_test( - name = "EncryptionKeyConfigFactoryTest", - srcs = ["EncryptionKeyConfigFactoryTest.java"], - deps = [ - "//java/com/google/aggregate/adtech/worker/encryption/publickeyuri:encryption_key_config", - "//java/external:google_truth", - ], -) diff --git a/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactoryTest.java b/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactoryTest.java deleted file mode 100644 index ebc68a6e..00000000 --- a/javatests/com/google/aggregate/adtech/worker/encryption/publickeyuri/EncryptionKeyConfigFactoryTest.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.encryption.publickeyuri; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; - -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class EncryptionKeyConfigFactoryTest { - @Test - public void getValidEncryptionKeyConfig_succeeds() { - CloudEncryptionKeyConfig gcpCloudEncryptionKeyConfig = - CloudEncryptionKeyConfig.builder() - .setKeyVendingServiceUri( - "https://publickeyservice-a.postsb-a.test.aggregationhelper.com/.well-known/aggregation-service/v1/public-keys") - .build(); - CloudEncryptionKeyConfig cloudEncryptionKeyConfig = - EncryptionKeyConfigFactory.getCloudEncryptionKeyConfig("GCP"); - assertThat(cloudEncryptionKeyConfig.keyVendingServiceUri()) - .isEqualTo(gcpCloudEncryptionKeyConfig.keyVendingServiceUri()); - } - - @Test - public void getInvalidEncryptionKeyConfig_fails() { - assertThrows( - IllegalArgumentException.class, - () -> EncryptionKeyConfigFactory.getCloudEncryptionKeyConfig("invalid-cloud")); - } -} diff --git a/javatests/com/google/aggregate/adtech/worker/testing/AvroReportsFileReaderTest.java b/javatests/com/google/aggregate/adtech/worker/testing/AvroReportsFileReaderTest.java deleted file mode 100644 index c9ccd1b6..00000000 --- a/javatests/com/google/aggregate/adtech/worker/testing/AvroReportsFileReaderTest.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2022 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.aggregate.adtech.worker.testing; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; - -import com.google.acai.Acai; -import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.writer.avro.LocalAvroResultFileWriter; -import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; -import com.google.common.jimfs.Configuration; -import com.google.common.jimfs.Jimfs; -import com.google.inject.AbstractModule; -import java.io.IOException; -import java.nio.file.FileSystem; -import java.nio.file.Path; -import javax.inject.Inject; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -@RunWith(JUnit4.class) -public class AvroReportsFileReaderTest { - - @Rule public final Acai acai = new Acai(TestEnv.class); - - @Inject LocalAvroResultFileWriter localAvroResultFileWriter; - - // Under test - @Inject AvroReportsFileReader avroReportsFileReader; - - private FileSystem filesystem; - private Path avroFile; - private ImmutableList reports; - - private final ByteSource encryptedReport1Payload = ByteSource.wrap(new byte[] {0x00, 0x01}); - private final ByteSource encryptedReport2Payload = ByteSource.wrap(new byte[] {0x01, 0x02}); - private final EncryptedReport encryptedReport1 = - EncryptedReport.builder() - .setPayload(encryptedReport1Payload) - .setKeyId("key1") - .setSharedInfo("foo") - .build(); - - private final EncryptedReport encryptedReport2 = - EncryptedReport.builder() - .setPayload(encryptedReport2Payload) - .setKeyId("key1") - .setSharedInfo("foo") - .build(); - - @Before - public void setUp() { - filesystem = - Jimfs.newFileSystem(Configuration.unix().toBuilder().setWorkingDirectory("/").build()); - avroFile = filesystem.getPath("reports.avro"); - reports = ImmutableList.of(encryptedReport1, encryptedReport2); - } - - /** Writes reports and reads to confirm data is read correctly. */ - @Test - public void testLocalReportFile_writesSuccessfully() throws Exception { - localAvroResultFileWriter.writeLocalReportFile(reports.stream(), avroFile); - - ImmutableList writtenReports = - avroReportsFileReader.readAvroReportsFile(avroFile); - - assertThat(writtenReports.get(0).sharedInfo()).isEqualTo(encryptedReport1.sharedInfo()); - assertTrue(writtenReports.get(0).payload().contentEquals(encryptedReport1.payload())); - assertThat(writtenReports.get(0).keyId()).isEqualTo(encryptedReport1.keyId()); - - assertThat(writtenReports.get(1).sharedInfo()).isEqualTo(encryptedReport2.sharedInfo()); - assertTrue(writtenReports.get(1).payload().contentEquals(encryptedReport2.payload())); - assertThat(writtenReports.get(1).keyId()).isEqualTo(encryptedReport2.keyId()); - } - - @Test - public void readMissingFile_throwsException() throws Exception { - Path missingAvroFile = filesystem.getPath("filedoesnotexist.avro"); - - localAvroResultFileWriter.writeLocalReportFile(reports.stream(), missingAvroFile); - - assertThrows(IOException.class, () -> avroReportsFileReader.readAvroReportsFile(avroFile)); - } - - public static final class TestEnv extends AbstractModule {} -} diff --git a/javatests/com/google/aggregate/adtech/worker/testing/BUILD b/javatests/com/google/aggregate/adtech/worker/testing/BUILD index df41b667..8670cfaf 100644 --- a/javatests/com/google/aggregate/adtech/worker/testing/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/testing/BUILD @@ -149,18 +149,3 @@ java_test( "//java/external:jimfs", ], ) - -java_test( - name = "AvroReportsFileReaderTest", - srcs = ["AvroReportsFileReaderTest.java"], - deps = [ - "//java/com/google/aggregate/adtech/worker/model", - "//java/com/google/aggregate/adtech/worker/testing:avro_reports_file_reader", - "//java/com/google/aggregate/adtech/worker/writer/avro", - "//java/external:acai", - "//java/external:google_truth", - "//java/external:guava", - "//java/external:guice", - "//java/external:jimfs", - ], -) diff --git a/javatests/com/google/aggregate/adtech/worker/testing/InMemoryResultLoggerTest.java b/javatests/com/google/aggregate/adtech/worker/testing/InMemoryResultLoggerTest.java index d73590f0..81e1e7dd 100644 --- a/javatests/com/google/aggregate/adtech/worker/testing/InMemoryResultLoggerTest.java +++ b/javatests/com/google/aggregate/adtech/worker/testing/InMemoryResultLoggerTest.java @@ -21,9 +21,7 @@ import com.google.aggregate.adtech.worker.exceptions.ResultLogException; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; import com.google.scp.operator.cpio.jobclient.model.Job; import com.google.scp.operator.cpio.jobclient.testing.FakeJobGenerator; import java.math.BigInteger; @@ -99,46 +97,10 @@ public void getDebugAggregationWithoutLogging() { .contains("MaterializedAggregations is null. Maybe results did not get logged."); } - @Test - public void logInMemoryReports_logSucceeds() throws ResultLogException { - EncryptedReport encryptedReport1 = - EncryptedReport.builder() - .setPayload(ByteSource.wrap(new byte[] {0x00, 0x01})) - .setKeyId("key1") - .setSharedInfo("foo") - .build(); - EncryptedReport encryptedReport2 = - EncryptedReport.builder() - .setPayload(ByteSource.wrap(new byte[] {0x01, 0x02})) - .setKeyId("key2") - .setSharedInfo("foo") - .build(); - ImmutableList encryptedReports = - ImmutableList.of(encryptedReport1, encryptedReport2); - - inMemoryResultLogger.logReports(encryptedReports, FakeJobGenerator.generate("foo"), "1"); - - assertThat(inMemoryResultLogger.getMaterializedEncryptedReports()) - .containsExactly(encryptedReport1, encryptedReport2); - } - - @Test - public void logNullReports_throwsException() { - ResultLogException exception = - assertThrows( - ResultLogException.class, () -> inMemoryResultLogger.getMaterializedEncryptedReports()); - - assertThat(exception).hasCauseThat().isInstanceOf(IllegalStateException.class); - assertThat(exception) - .hasMessageThat() - .contains("MaterializedEncryptionReports is null. Maybe results did not get logged."); - } - @Test public void throwsWhenSetTo() { inMemoryResultLogger.setShouldThrow(true); ImmutableList aggregatedFacts = ImmutableList.of(); - ImmutableList encryptedReports = ImmutableList.of(); Job Job = FakeJobGenerator.generate("foo"); assertThrows( @@ -147,8 +109,5 @@ public void throwsWhenSetTo() { assertThrows( ResultLogException.class, () -> inMemoryResultLogger.logResults(aggregatedFacts, Job, /* isDebugRun= */ true)); - assertThrows( - ResultLogException.class, - () -> inMemoryResultLogger.logReports(encryptedReports, Job, "1")); } } diff --git a/javatests/com/google/aggregate/adtech/worker/util/BUILD b/javatests/com/google/aggregate/adtech/worker/util/BUILD index c8be7343..3a5a0494 100644 --- a/javatests/com/google/aggregate/adtech/worker/util/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/util/BUILD @@ -71,3 +71,13 @@ java_test( "//java/external:google_truth", ], ) + +java_test( + name = "ReportingOriginUtilsTest", + srcs = ["ReportingOriginUtilsTest.java"], + deps = [ + "//java/com/google/aggregate/adtech/worker/util", + "//java/external:google_truth", + "//java/external:guava", + ], +) diff --git a/javatests/com/google/aggregate/adtech/worker/util/ReportingOriginUtilsTest.java b/javatests/com/google/aggregate/adtech/worker/util/ReportingOriginUtilsTest.java index 6fa44766..46731445 100644 --- a/javatests/com/google/aggregate/adtech/worker/util/ReportingOriginUtilsTest.java +++ b/javatests/com/google/aggregate/adtech/worker/util/ReportingOriginUtilsTest.java @@ -56,11 +56,10 @@ public void convertToSite_whenUrlWithTrailingSlashProvided() } @Test - public void convertToSite_whenUrlWithPortProvided() throws InvalidReportingOriginException { - assertThat( - ReportingOriginUtils.convertReportingOriginToSite( - "http://about.foo.blogspot.com:8443/bar")) - .isEqualTo("https://foo.blogspot.com"); + public void convertToSite_whenUrlWithPortProvided() + throws InvalidReportingOriginException { + assertThat(ReportingOriginUtils.convertReportingOriginToSite("http://about.foo.blogspot.com:8443/bar")) + .isEqualTo("https://foo.blogspot.com"); } @Test diff --git a/javatests/com/google/aggregate/adtech/worker/validation/JobValidatorTest.java b/javatests/com/google/aggregate/adtech/worker/validation/JobValidatorTest.java index e605db2d..5caf5e9b 100644 --- a/javatests/com/google/aggregate/adtech/worker/validation/JobValidatorTest.java +++ b/javatests/com/google/aggregate/adtech/worker/validation/JobValidatorTest.java @@ -56,7 +56,9 @@ public void validate_noAttributionReportToKeyInParams_fails() { assertThat(exception) .hasMessageThat() - .containsMatch("Job parameters does not have an attribution_report_to field for the Job"); + .containsMatch( + "Exactly one of 'attribution_report_to' and 'reporting_site' fields should be specified" + + " for the Job"); } @Test @@ -71,7 +73,8 @@ public void validate_noAttributionReportTo_fails() { assertThat(exception) .hasMessageThat() - .containsMatch("Job parameters does not have an attribution_report_to field for the Job"); + .containsMatch( + "The 'attribution_report_to' field in the Job parameters is empty for the Job"); } @Test @@ -344,6 +347,39 @@ public void validate_invalidFilteringIds_throws() { () -> JobValidator.validate(Optional.of(jobWithNonNumberIds), /* domainOptional= */ true)); } + @Test + public void validate_noReportingSite_fails() { + ImmutableMap jobParams = ImmutableMap.of("reporting_site", ""); + Job job = buildJob(jobParams).build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> JobValidator.validate(Optional.of(job), /* domainOptional= */ false)); + + assertThat(exception) + .hasMessageThat() + .containsMatch("The 'reporting_site' field in the Job parameters is empty for the Job"); + } + + @Test + public void validate_attributionReportToAndReportingSiteBothPresent_fails() { + ImmutableMap jobParams = + ImmutableMap.of("attribution_report_to", "someOrigin", "reporting_site", "someSite"); + Job job = buildJob(jobParams).build(); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> JobValidator.validate(Optional.of(job), /* domainOptional= */ false)); + + assertThat(exception) + .hasMessageThat() + .containsMatch( + "Exactly one of 'attribution_report_to' and 'reporting_site' fields should be specified" + + " for the Job"); + } + private Job.Builder buildJob(ImmutableMap jobParams) { return jobBuilder.setRequestInfo(requestInfoBuilder.putAllJobParameters(jobParams).build()); } diff --git a/javatests/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidatorTest.java b/javatests/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidatorTest.java index fad99524..4e991a43 100644 --- a/javatests/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidatorTest.java +++ b/javatests/com/google/aggregate/adtech/worker/validation/ReportingOriginMatchesRequestValidatorTest.java @@ -20,6 +20,8 @@ import static com.google.aggregate.adtech.worker.model.SharedInfo.LATEST_VERSION; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth8.assertThat; +import static com.google.aggregate.adtech.worker.model.ErrorCounter.ATTRIBUTION_REPORT_TO_MALFORMED; +import static com.google.aggregate.adtech.worker.model.ErrorCounter.REPORTING_SITE_MISMATCH; import com.google.aggregate.adtech.worker.model.ErrorMessage; import com.google.aggregate.adtech.worker.model.Payload; @@ -59,6 +61,16 @@ public void setUp() { ctx = FakeJobGenerator.generateBuilder("").build(); } + private Job createTestJob(ImmutableMap jobParameters) { + return ctx.toBuilder() + .setRequestInfo( + ctx.requestInfo().toBuilder() + .clearJobParameters() + .putAllJobParameters(jobParameters) + .build()) + .build(); + } + /** * Test that the validation passed when the report and the aggregation request ({@code Job}) have * matching attributionReportTo values. @@ -69,14 +81,7 @@ public void testMatchingPasses() { reportBuilder .setSharedInfo(sharedInfoBuilder.setReportingOrigin("foo.com").build()) .build(); - Job testCtx = - ctx.toBuilder() - .setRequestInfo( - ctx.requestInfo().toBuilder() - .clearJobParameters() - .putAllJobParameters(ImmutableMap.of("attribution_report_to", "foo.com")) - .build()) - .build(); + Job testCtx = createTestJob(ImmutableMap.of("attribution_report_to", "foo.com")); Optional validationError = validator.validate(report, testCtx); @@ -93,18 +98,67 @@ public void testMismatchingFails() { reportBuilder .setSharedInfo(sharedInfoBuilder.setReportingOrigin("foo.com").build()) .build(); - Job testCtx = - ctx.toBuilder() - .setRequestInfo( - ctx.requestInfo().toBuilder() - .clearJobParameters() - .putAllJobParameters(ImmutableMap.of("attribution_report_to", "bar.com")) - .build()) - .build(); + Job testCtx = createTestJob(ImmutableMap.of("attribution_report_to", "bar.com")); Optional validationError = validator.validate(report, testCtx); assertThat(validationError).isPresent(); assertThat(validationError.get().category()).isEqualTo(ATTRIBUTION_REPORT_TO_MISMATCH); } + + /** + * Test that the validation passed when the report's reporting origin belongs to the site provided + * in the aggregation request ({@code Job}). + */ + @Test + public void siteProvided_reportOriginBelongsToSite_success() { + Report report1 = + reportBuilder + .setSharedInfo(sharedInfoBuilder.setReportingOrigin("https://origin1.foo.com").build()) + .build(); + Report report2 = + reportBuilder + .setSharedInfo(sharedInfoBuilder.setReportingOrigin("https://origin2.foo.com").build()) + .build(); + Job testCtx = createTestJob(ImmutableMap.of("reporting_site", "https://foo.com")); + + Optional validationError1 = validator.validate(report1, testCtx); + Optional validationError2 = validator.validate(report2, testCtx); + + assertThat(validationError1).isEmpty(); + assertThat(validationError2).isEmpty(); + } + + /** + * Test that the validation fails when the report's reporting origin belongs to a different site + * than the one provided in the aggregation request ({@code Job}). + */ + @Test + public void siteProvided_reportOriginDoesNotBelongsToSite_failure() { + Report report = + reportBuilder + .setSharedInfo(sharedInfoBuilder.setReportingOrigin("https://origin.foo.com").build()) + .build(); + Job testCtx = createTestJob(ImmutableMap.of("reporting_site", "https://foo1.com")); + + Optional validationError = validator.validate(report, testCtx); + + assertThat(validationError).isPresent(); + assertThat(validationError.get().category()).isEqualTo(REPORTING_SITE_MISMATCH); + } + + /** Tests validation failure when the report's reporting origin is malformed. */ + @Test + public void siteProvided_reportOriginInvalid_failure() { + Report report = + reportBuilder + .setSharedInfo(sharedInfoBuilder.setReportingOrigin("origin.foo.com").build()) + .build(); + Job testCtx = createTestJob(ImmutableMap.of("reporting_site", "https://foo1.com")); + + Optional validationError = validator.validate(report, testCtx); + + assertThat(validationError).isPresent(); + assertThat(validationError.get().category()).isEqualTo(ATTRIBUTION_REPORT_TO_MALFORMED); + } } diff --git a/javatests/com/google/aggregate/adtech/worker/writer/avro/BUILD b/javatests/com/google/aggregate/adtech/worker/writer/avro/BUILD index b1ff2595..7d4e9038 100644 --- a/javatests/com/google/aggregate/adtech/worker/writer/avro/BUILD +++ b/javatests/com/google/aggregate/adtech/worker/writer/avro/BUILD @@ -21,7 +21,6 @@ java_test( srcs = ["LocalAvroResultFileWriterTest.java"], deps = [ "//java/com/google/aggregate/adtech/worker/model", - "//java/com/google/aggregate/adtech/worker/testing:avro_reports_file_reader", "//java/com/google/aggregate/adtech/worker/testing:avro_results_file_reader", "//java/com/google/aggregate/adtech/worker/writer", "//java/com/google/aggregate/adtech/worker/writer/avro", diff --git a/javatests/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriterTest.java b/javatests/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriterTest.java index c6db000f..1f63942e 100644 --- a/javatests/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriterTest.java +++ b/javatests/com/google/aggregate/adtech/worker/writer/avro/LocalAvroResultFileWriterTest.java @@ -18,16 +18,12 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertThrows; -import static org.junit.Assert.assertTrue; import com.google.acai.Acai; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; -import com.google.aggregate.adtech.worker.testing.AvroReportsFileReader; import com.google.aggregate.adtech.worker.testing.AvroResultsFileReader; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter.FileWriteException; import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import com.google.inject.AbstractModule; @@ -48,29 +44,10 @@ public class LocalAvroResultFileWriterTest { // Under test @Inject LocalAvroResultFileWriter localAvroResultFileWriter; - @Inject AvroResultsFileReader avroResultsFileReader; - @Inject AvroReportsFileReader avroReportsFileReader; - private FileSystem filesystem; private Path avroFile; ImmutableList results; - private ImmutableList reports; - private final ByteSource encryptedReport1Payload = ByteSource.wrap(new byte[] {0x00, 0x01}); - private final ByteSource encryptedReport2Payload = ByteSource.wrap(new byte[] {0x01, 0x02}); - private final EncryptedReport encryptedReport1 = - EncryptedReport.builder() - .setPayload(encryptedReport1Payload) - .setKeyId("key1") - .setSharedInfo("foo") - .build(); - - private final EncryptedReport encryptedReport2 = - EncryptedReport.builder() - .setPayload(encryptedReport2Payload) - .setKeyId("key2") - .setSharedInfo("bar") - .build(); @Before public void setUp() throws Exception { @@ -83,7 +60,6 @@ public void setUp() throws Exception { AggregatedFact.create(BigInteger.valueOf(123), 50L), AggregatedFact.create(BigInteger.valueOf(456), 30L), AggregatedFact.create(BigInteger.valueOf(789), 40L)); - reports = ImmutableList.of(encryptedReport1, encryptedReport2); } /** @@ -109,32 +85,6 @@ public void testExceptionOnFailedWrite() throws Exception { () -> localAvroResultFileWriter.writeLocalFile(results.stream(), nonExistentDirectory)); } - @Test - public void localReportWrite_succeeds() throws Exception { - localAvroResultFileWriter.writeLocalReportFile(reports.stream(), avroFile); - - ImmutableList writtenReports = - avroReportsFileReader.readAvroReportsFile(avroFile); - assertThat(writtenReports.get(0).sharedInfo()).isEqualTo(encryptedReport1.sharedInfo()); - assertTrue(writtenReports.get(0).payload().contentEquals(encryptedReport1.payload())); - assertThat(writtenReports.get(0).keyId()).isEqualTo(encryptedReport1.keyId()); - - assertThat(writtenReports.get(1).sharedInfo()).isEqualTo(encryptedReport2.sharedInfo()); - assertTrue(writtenReports.get(1).payload().contentEquals(encryptedReport2.payload())); - assertThat(writtenReports.get(1).keyId()).isEqualTo(encryptedReport2.keyId()); - } - - @Test - public void localReportWrite_invalidWritePath_fails() throws Exception { - Path nonExistentDirectory = - avroFile.getFileSystem().getPath("/doesnotexist", avroFile.toString()); - - assertThrows( - FileWriteException.class, - () -> - localAvroResultFileWriter.writeLocalReportFile(reports.stream(), nonExistentDirectory)); - } - @Test public void testFileExtension() { assertThat(localAvroResultFileWriter.getFileExtension()).isEqualTo(".avro"); diff --git a/javatests/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriterTest.java b/javatests/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriterTest.java index da4ee1b0..d7b3a70b 100644 --- a/javatests/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriterTest.java +++ b/javatests/com/google/aggregate/adtech/worker/writer/json/LocalJsonResultFileWriterTest.java @@ -23,11 +23,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.acai.Acai; import com.google.aggregate.adtech.worker.model.AggregatedFact; -import com.google.aggregate.adtech.worker.model.EncryptedReport; import com.google.aggregate.adtech.worker.util.NumericConversions; import com.google.aggregate.adtech.worker.writer.LocalResultFileWriter.FileWriteException; import com.google.common.collect.ImmutableList; -import com.google.common.io.ByteSource; import com.google.common.jimfs.Configuration; import com.google.common.jimfs.Jimfs; import com.google.inject.AbstractModule; @@ -56,24 +54,6 @@ public class LocalJsonResultFileWriterTest { private FileSystem filesystem; private Path jsonFile; - private ImmutableList reports; - - // Not testing for payload, since encrypted payload in json is not useful. - private final ByteSource encryptedReportPayload = ByteSource.wrap(new byte[] {0x00, 0x01}); - private final EncryptedReport encryptedReport1 = - EncryptedReport.builder() - .setPayload(encryptedReportPayload) - .setKeyId("key1") - .setSharedInfo("foo") - .build(); - - private final EncryptedReport encryptedReport2 = - EncryptedReport.builder() - .setPayload(encryptedReportPayload) - .setKeyId("key2") - .setSharedInfo("bar") - .build(); - @Before public void setUp() throws Exception { filesystem = @@ -86,7 +66,6 @@ public void setUp() throws Exception { AggregatedFact.create(NumericConversions.createBucketFromInt(123), 50L), AggregatedFact.create(NumericConversions.createBucketFromInt(456), 30L), AggregatedFact.create(NumericConversions.createBucketFromInt(789), 40L)); - reports = ImmutableList.of(encryptedReport1, encryptedReport2); } /** @@ -121,41 +100,6 @@ public void testExceptionOnFailedWrite() throws Exception { () -> localJsonResultFileWriter.writeLocalFile(results.stream(), nonExistentDirectory)); } - @Test - public void writeLocalJsonReport_succeeds() throws Exception { - localJsonResultFileWriter.writeLocalReportFile(reports.stream(), jsonFile); - ObjectMapper mapper = new ObjectMapper(); - JsonNode jsonNode = mapper.readTree(Files.newInputStream(jsonFile)); - List writtenReports = new ArrayList<>(); - jsonNode - .iterator() - .forEachRemaining( - entry -> { - writtenReports.add( - EncryptedReport.builder() - .setSharedInfo((entry.get("shared_info").asText())) - .setKeyId(entry.get("key_id").asText()) - .setPayload(encryptedReportPayload) - .build()); - }); - assertThat(writtenReports.get(0).sharedInfo()).isEqualTo(encryptedReport1.sharedInfo()); - assertThat(writtenReports.get(0).keyId()).isEqualTo(encryptedReport1.keyId()); - - assertThat(writtenReports.get(1).sharedInfo()).isEqualTo(encryptedReport2.sharedInfo()); - assertThat(writtenReports.get(1).keyId()).isEqualTo(encryptedReport2.keyId()); - } - - @Test - public void writeLocalJsonReport_invalidPath_fails() throws Exception { - Path nonExistentDirectory = - jsonFile.getFileSystem().getPath("/doesnotexist", jsonFile.toString()); - - assertThrows( - FileWriteException.class, - () -> - localJsonResultFileWriter.writeLocalReportFile(reports.stream(), nonExistentDirectory)); - } - @Test public void testFileExtension() { assertThat(localJsonResultFileWriter.getFileExtension()).isEqualTo(".json"); diff --git a/terraform/gcp/fetch_terraform.sh b/terraform/gcp/fetch_terraform.sh index 01ffc7a0..52b0a7b3 100644 --- a/terraform/gcp/fetch_terraform.sh +++ b/terraform/gcp/fetch_terraform.sh @@ -80,8 +80,8 @@ frontend_service_jar = "../../jars/FrontendServiceHttpCloudFunction_${VERSION}.j worker_scale_in_jar = "../../jars/WorkerScaleInCloudFunction_${VERSION}.jar" # Coordinator service accounts to impersonate for authorization and authentication -coordinator_a_impersonate_service_account = "a-opallowedusr@ps-msmt-coord-prd-gg-svcacc.iam.gserviceaccount.com" -coordinator_b_impersonate_service_account = "b-opallowedusr@ps-msmt-coord-prd-gg-svcacc.iam.gserviceaccount.com" +coordinator_a_impersonate_service_account = "a-opallowedusr@ps-msmt-coord-prd-g3p-svcacc.iam.gserviceaccount.com" +coordinator_b_impersonate_service_account = "b-opallowedusr@ps-prod-msmt-type2-e541.iam.gserviceaccount.com" EOT diff --git a/worker/gcp/BUILD b/worker/gcp/BUILD index 5f1abd4a..a850533d 100644 --- a/worker/gcp/BUILD +++ b/worker/gcp/BUILD @@ -16,6 +16,74 @@ load("//build_defs/worker/gcp:deploy.bzl", "worker_gcp_deployment") package(default_visibility = ["//visibility:public"]) +worker_gcp_deployment( + name = "worker_mp_gcp_g3p_prod", + cmd = [ + "WorkerRunner_prod_deploy.jar", + "--client_config_env", + "GCP", + "--job_client", + "GCP", + "--blob_storage_client", + "GCP_CS_CLIENT", + "--decryption_key_service", + "GCP_KMS_MULTI_PARTY_DECRYPTION_KEY_SERVICE", + "--primary_encryption_key_service_base_url", + "https://privatekeyservice-a.msmt-3.gcp.privacysandboxservices.com/v1alpha", + "--secondary_encryption_key_service_base_url", + "https://privatekeyservice-b.msmt-4.gcp.privacysandboxservices.com/v1alpha", + "--primary_encryption_key_service_cloudfunction_url", + "https://a-us-central1-encryption-key-service-cloudfunctio-zihnau4cbq-uc.a.run.app", + "--secondary_encryption_key_service_cloudfunction_url", + "https://b-us-central1-encryption-key-service-cloudfunctio-mnlu5dzbga-uc.a.run.app", + "--coordinator_a_kms_key", + "gcp-kms://projects/ps-msmt-a-coord-prd-g3p/locations/us/keyRings/a_key_encryption_ring/cryptoKeys/a_key_encryption_key", + "--coordinator_b_kms_key", + "gcp-kms://projects/ps-prod-msmt-type2-e541/locations/us/keyRings/b_key_encryption_ring/cryptoKeys/b_key_encryption_key", + "--coordinator_a_wip_provider", + "projects/306633382134/locations/global/workloadIdentityPools/a-opwip/providers/a-opwip-pvdr", + "--coordinator_a_sa", + "a-opverifiedusr@ps-msmt-coord-prd-g3p-wif.iam.gserviceaccount.com", + "--coordinator_b_wip_provider", + "projects/364328752810/locations/global/workloadIdentityPools/b-opwip/providers/b-opwip-pvdr", + "--coordinator_b_sa", + "b-opverifiedusr@ps-prod-msmt-type2-e541.iam.gserviceaccount.com", + "--coordinator_a_privacy_budgeting_service_base_url", + "https://mp-pbs-a.msmt-3.gcp.privacysandboxservices.com/v1", + "--coordinator_a_privacy_budgeting_service_auth_endpoint", + "https://a-us-central1-pbs-auth-cloudfunction-zihnau4cbq-uc.a.run.app", + "--coordinator_b_privacy_budgeting_service_base_url", + "https://mp-pbs-b.msmt-4.gcp.privacysandboxservices.com/v1", + "--coordinator_b_privacy_budgeting_service_auth_endpoint", + "https://b-us-central1-pbs-auth-cloudfunction-mnlu5dzbga-uc.a.run.app", + "--privacy_budgeting", + "HTTP", + "--param_client", + "GCP", + "--metric_client", + "GCP", + "--lifecycle_client", + "GCP", + "--pbs_client", + "GCP", + "--noising", + "DP_NOISING", + "--return_stack_trace", + "--parallel_summary_upload_enabled", + "--streaming_output_domain_processing_enabled", + "--parallel_fact_noising_enabled", + "--labeled_privacy_budget_keys_enabled", + ], + entrypoint = [ + "/usr/bin/java", + "-XX:+ExitOnOutOfMemoryError", + "-XX:MaxRAMPercentage=75.0", + "-jar", + ], + files = ["//java/com/google/aggregate/adtech/worker/gcp:WorkerRunner_prod_deploy.jar"], + labels = {"tee.launch_policy.allow_cmd_override": "false"}, +) + worker_gcp_deployment( name = "worker_mp_gcp_prod", cmd = [