Skip to content

Commit

Permalink
Parallel PIRProcessDatabase
Browse files Browse the repository at this point in the history
  • Loading branch information
gleb032 committed Oct 15, 2024
1 parent 90f4b99 commit 81b0ece
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 65 deletions.
163 changes: 102 additions & 61 deletions Sources/PIRProcessDatabase/ProcessDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ struct ResolvedArguments: CustomStringConvertible, Encodable {
}

@main
struct ProcessDatabase: ParsableCommand {
struct ProcessDatabase: AsyncParsableCommand {
static let configuration: CommandConfiguration = .init(
commandName: "PIRProcessDatabase")

Expand All @@ -311,13 +311,18 @@ struct ProcessDatabase: ParsableCommand {
""")
var configFile: String

@Flag(name: .customLong("parallel"),
inversion: .prefixedNo,
help: "Enables parallel processing.")
var parallel = false

/// Performs the processing on the given database.
/// - Parameters:
/// - config: The configuration for the PIR processing.
/// - scheme: The HE scheme.
/// - Throws: Error upon processing the database.
@inlinable
mutating func process<Scheme: HeScheme>(config: Arguments, scheme: Scheme.Type) throws {
mutating func process<Scheme: HeScheme>(config: Arguments, scheme: Scheme.Type) async throws {
let database: [KeywordValuePair] =
try Apple_SwiftHomomorphicEncryption_Pir_V1_KeywordDatabase(from: config.inputDatabase).native()

Expand All @@ -339,69 +344,39 @@ struct ProcessDatabase: ParsableCommand {
keyCompression: config.keyCompression,
trialsPerShard: config.trialsPerShard)

var evaluationKeyConfig = EvaluationKeyConfig()
let context = try Context(encryptionParameters: processArgs.encryptionParameters)
let keywordDatabase = try KeywordDatabase(rows: database, sharding: processArgs.databaseConfig.sharding)
ProcessDatabase.logger
.info("Sharded database into \(keywordDatabase.shards.count) shards")
for (shardID, shard) in keywordDatabase.shards
.sorted(by: { $0.0.localizedStandardCompare($1.0) == .orderedAscending })
{
func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws {
switch event {
case let .cuckooTableEvent(.createdTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Created cuckoo table \(summary)")
case let .cuckooTableEvent(.expandingTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.finishedExpandingTable(table)):
let summary = try table.summarize()
ProcessDatabase.logger.info("Finished expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)):
let reportingPercentage = 10
let shardFraction = shard.rows.count / reportingPercentage
if (index + 1).isMultiple(of: shardFraction) {
let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction)
ProcessDatabase.logger
.info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%")
ProcessDatabase.logger.info("Sharded database into \(keywordDatabase.shards.count) shards")

let shards = keywordDatabase.shards.sorted { $0.0.localizedStandardCompare($1.0) == .orderedAscending }

var evaluationKeyConfig = EvaluationKeyConfig()
if parallel {
try await withThrowingTaskGroup(of: EvaluationKeyConfig.self) { group in
for (shardID, shard) in shards {
group.addTask { @Sendable [self] in
try await processShard(
shardID: shardID,
shard: shard,
config: config,
context: context,
processArgs: processArgs)
}
}
}

ProcessDatabase.logger.info("Processing shard \(shardID) with \(shard.rows.count) rows")
let processed = try ProcessKeywordDatabase.processShard(
shard: shard,
with: processArgs,
onEvent: logEvent)
if config.trialsPerShard > 0 {
guard let row = shard.rows.first else {
throw PirError.emptyDatabase
for try await processedEvaluationKeyConfig in group {
evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union()
}
ProcessDatabase.logger.info("Validating shard \(shardID)")
let validationResults = try ProcessKeywordDatabase
.validateShard(shard: processed,
row: KeywordValuePair(keyword: row.key, value: row.value),
trials: config.trialsPerShard, context: context)
let description = try validationResults.description()
ProcessDatabase.logger.info("ValidationResults \(description)")
}

let outputDatabaseFilename = config.outputDatabase.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try processed.database.save(to: outputDatabaseFilename)
ProcessDatabase.logger.info("Saved shard \(shardID) to \(outputDatabaseFilename)")

let shardEvaluationKeyConfig = processed.evaluationKeyConfig
evaluationKeyConfig = [evaluationKeyConfig, shardEvaluationKeyConfig].union()

let shardPirParameters = try processed.proto(context: context)
let outputParametersFilename = config.outputPirParameters.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try shardPirParameters.save(to: outputParametersFilename)
ProcessDatabase.logger.info("Saved shard \(shardID) PIR parameters to \(outputParametersFilename)")
} else {
for (shardID, shard) in shards {
let processedEvaluationKeyConfig = try await processShard(
shardID: shardID,
shard: shard, config:
config, context: context,
processArgs: processArgs)
evaluationKeyConfig = [evaluationKeyConfig, processedEvaluationKeyConfig].union()
}
}

if let evaluationKeyConfigFile = config.outputEvaluationKeyConfig {
Expand All @@ -411,14 +386,80 @@ struct ProcessDatabase: ParsableCommand {
}
}

mutating func run() throws {
private func processShard<Scheme: HeScheme>(
shardID: String,
shard: KeywordDatabaseShard,
config: ResolvedArguments,
context: Context<Scheme>,
processArgs: ProcessKeywordDatabase.Arguments<Scheme>) async throws -> EvaluationKeyConfig
{
var logger = ProcessDatabase.logger
logger[metadataKey: "shardID"] = .string(shardID)

func logEvent(event: ProcessKeywordDatabase.ProcessShardEvent) throws {
switch event {
case let .cuckooTableEvent(.createdTable(table)):
let summary = try table.summarize()
logger.info("Created cuckoo table \(summary)")
case let .cuckooTableEvent(.expandingTable(table)):
let summary = try table.summarize()
logger.info("Expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.finishedExpandingTable(table)):
let summary = try table.summarize()
logger.info("Finished expanding cuckoo table \(summary)")
case let .cuckooTableEvent(.insertedKeywordValuePair(index, _)):
let reportingPercentage = 10
let shardFraction = shard.rows.count / reportingPercentage
if (index + 1).isMultiple(of: shardFraction) {
let percentage = Float(reportingPercentage * (index + 1)) / Float(shardFraction)
logger.info("Inserted \(index + 1) / \(shard.rows.count) keywords \(percentage)%")
}
}
}

logger.info("Processing shard with \(shard.rows.count) rows")
let processed = try ProcessKeywordDatabase.processShard(
shard: shard,
with: processArgs,
onEvent: logEvent)

if config.trialsPerShard > 0 {
guard let row = shard.rows.first else {
throw PirError.emptyDatabase
}
logger.info("Validating shard")
let validationResults = try ProcessKeywordDatabase
.validateShard(shard: processed,
row: KeywordValuePair(keyword: row.key, value: row.value),
trials: config.trialsPerShard, context: context)
let description = try validationResults.description()
logger.info("ValidationResults \(description)")
}

let outputDatabaseFilename = config.outputDatabase.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try processed.database.save(to: outputDatabaseFilename)
logger.info("Saved shard to \(outputDatabaseFilename)")

let shardPirParameters = try processed.proto(context: context)
let outputParametersFilename = config.outputPirParameters.replacingOccurrences(
of: "SHARD_ID",
with: String(shardID))
try shardPirParameters.save(to: outputParametersFilename)
logger.info("Saved shard PIR parameters to \(outputParametersFilename)")

return processed.evaluationKeyConfig
}

mutating func run() async throws {
let configURL = URL(fileURLWithPath: configFile)
let configData = try Data(contentsOf: configURL)
let config = try JSONDecoder().decode(Arguments.self, from: configData)
if config.rlweParameters.supportsScalar(UInt32.self) {
try process(config: config, scheme: Bfv<UInt32>.self)
try await process(config: config, scheme: Bfv<UInt32>.self)
} else {
try process(config: config, scheme: Bfv<UInt64>.self)
try await process(config: config, scheme: Bfv<UInt64>.self)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions Sources/PrivateInformationRetrieval/KeywordDatabase.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ extension Sharding {
}

/// A shard of a ``KeywordDatabase``.
public struct KeywordDatabaseShard: Hashable, Codable {
public struct KeywordDatabaseShard: Hashable, Codable, Sendable {
/// Identifier for the shard.
public let shardID: String
/// Rows in the database.
Expand Down Expand Up @@ -204,7 +204,7 @@ extension KeywordDatabaseShard: Collection {
}

/// Configuration for a ``KeywordDatabase``.
public struct KeywordDatabaseConfig: Hashable, Codable {
public struct KeywordDatabaseConfig: Hashable, Codable, Sendable {
public let sharding: Sharding
public let keywordPirConfig: KeywordPirConfig

Expand Down Expand Up @@ -264,7 +264,7 @@ public struct KeywordDatabase {
/// Utilities for processing a ``KeywordDatabase``.
public enum ProcessKeywordDatabase {
/// Arguments for processing a keyword database.
public struct Arguments<Scheme: HeScheme>: Codable {
public struct Arguments<Scheme: HeScheme>: Codable, Sendable {
/// Database configuration.
public let databaseConfig: KeywordDatabaseConfig
/// Encryption parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Foundation
import HomomorphicEncryption

/// Configuration for a ``KeywordDatabase``.
public struct KeywordPirConfig: Hashable, Codable {
public struct KeywordPirConfig: Hashable, Codable, Sendable {
/// Number of dimensions in the database.
@usableFromInline let dimensionCount: Int

Expand Down

0 comments on commit 81b0ece

Please sign in to comment.