diff --git a/Datasets/CMakeLists.txt b/Datasets/CMakeLists.txt index 88e2ba3b9f6..a9708ee10f1 100644 --- a/Datasets/CMakeLists.txt +++ b/Datasets/CMakeLists.txt @@ -4,7 +4,6 @@ add_library(Datasets COCO/COCO.swift COCO/COCODataset.swift COCO/COCOVariant.swift - CoLA/Batchable.swift CoLA/CoLA.swift CoLA/DataUtilities.swift ImageClassificationDataset.swift diff --git a/Datasets/CoLA/Batchable.swift b/Datasets/CoLA/Batchable.swift deleted file mode 100644 index 57faca82d36..00000000000 --- a/Datasets/CoLA/Batchable.swift +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Adapted from: https://github.com/eaplatanios/nca/blob/master/Sources/NCA/Utilities/Protocols.swift - -import TensorFlow -public protocol Batchable { - static func batch(_ values: [Self]) -> Self -} - -extension Tensor: Batchable { - public static func batch(_ values: [Tensor]) -> Tensor { - Tensor(stacking: values, alongAxis: 0) - } -} - -extension KeyPathIterable { - public static func batch(_ values: [Self]) -> Self { - var result = values[0] - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { - result[keyPath: kp] = Tensor.batch(values.map { $0[keyPath: kp] }) - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { - result[keyPath: kp] = Tensor.batch(values.map { $0[keyPath: kp] }) - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { - result[keyPath: kp] = Tensor.batch(values.map { $0[keyPath: kp] }) - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { - result[keyPath: kp] = Tensor.batch(values.map { $0[keyPath: kp] }) - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor.self) { - result[keyPath: kp] = Tensor.batch(values.map { $0[keyPath: kp] }) - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor?.self) { - let keyPathValues = values.map { $0[keyPath: kp] } - if keyPathValues[0] != nil { - result[keyPath: kp] = Tensor.batch(keyPathValues.map { $0! }) - } else { - result[keyPath: kp] = nil - } - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor?.self) { - let keyPathValues = values.map { $0[keyPath: kp] } - if keyPathValues[0] != nil { - result[keyPath: kp] = Tensor.batch(keyPathValues.map { $0! }) - } else { - result[keyPath: kp] = nil - } - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor?.self) { - let keyPathValues = values.map { $0[keyPath: kp] } - if keyPathValues[0] != nil { - result[keyPath: kp] = Tensor.batch(keyPathValues.map { $0! }) - } else { - result[keyPath: kp] = nil - } - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor?.self) { - let keyPathValues = values.map { $0[keyPath: kp] } - if keyPathValues[0] != nil { - result[keyPath: kp] = Tensor.batch(keyPathValues.map { $0! }) - } else { - result[keyPath: kp] = nil - } - } - for kp in result.recursivelyAllWritableKeyPaths(to: Tensor?.self) { - let keyPathValues = values.map { $0[keyPath: kp] } - if keyPathValues[0] != nil { - result[keyPath: kp] = Tensor.batch(keyPathValues.map { $0! }) - } else { - result[keyPath: kp] = nil - } - } - return result - } -} diff --git a/Datasets/CoLA/CoLA.swift b/Datasets/CoLA/CoLA.swift index 8f359cfced7..0f1c99cb8e9 100644 --- a/Datasets/CoLA/CoLA.swift +++ b/Datasets/CoLA/CoLA.swift @@ -12,92 +12,81 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// Adapted from: https://gist.github.com/eaplatanios/5163c8d503f9e56f11b5b058fb041d62 -// Changes: -// - Rename `Architecture` to `BERTClassifier`. -// - In `CoLA.update`: -// - Change `Architecture.classify` to `BERTClassifier.callAsFunction`. -// - Change `softmaxCrossEntropy` to `sigmoidCrossEntropy`. +// Originaly adapted from: +// https://gist.github.com/eaplatanios/5163c8d503f9e56f11b5b058fb041d62 import Foundation import ModelSupport import TensorFlow -public struct CoLA { - public let directoryURL: URL - public let trainExamples: [Example] - public let devExamples: [Example] - public let testExamples: [Example] - public let maxSequenceLength: Int - public let batchSize: Int - - public typealias ExampleIterator = IndexingIterator<[Example]> - public typealias TrainDataIterator = PrefetchIterator< - GroupedIterator> - > - public typealias DevDataIterator = GroupedIterator> - public typealias TestDataIterator = DevDataIterator - - public var trainDataIterator: TrainDataIterator - public var devDataIterator: DevDataIterator - public var testDataIterator: TestDataIterator +/// A `TextBatch` with the corresponding labels. +public typealias LabeledTextBatch = (data: TextBatch, label: Tensor) + +/// CoLA example. +public struct CoLAExample { + /// The unique identifier representing the `Example`. + public let id: String + /// The text of the `Example`. + public let sentence: String + /// The label of the `Example`. + public let isAcceptable: Bool? + + /// Creates an instance from `id`, `sentence` and `isAcceptable`. + public init(id: String, sentence: String, isAcceptable: Bool?) { + self.id = id + self.sentence = sentence + self.isAcceptable = isAcceptable + } } -//===-----------------------------------------------------------------------------------------===// -// Data -//===-----------------------------------------------------------------------------------------===// +public struct CoLA { + /// The directory where the dataset will be downloaded + public let directoryURL: URL + /// The type of the labeled samples. + public typealias Samples = LazyMapSequence<[CoLAExample], LabeledTextBatch> + /// The training texts. + public let trainingExamples: Samples + /// The validation texts. + public let validationExamples: Samples + + /// The sequence length to which every sentence will be padded. + public let maxSequenceLength: Int + /// The batch size. + public let batchSize: Int + + /// The type of the collection of batches. + public typealias Batches = Slices>> + /// The type of the training sequence of epochs. + public typealias TrainEpochs = LazyMapSequence, + LazyMapSequence> + /// The sequence of training data (epochs of batches). + public var trainingEpochs: TrainEpochs + /// The validation batches. + public var validationBatches: LazyMapSequence, LabeledTextBatch> + + /// The url from which to download the dataset. + private let url: URL = URL( + string: String( + "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/" + + "o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4"))! +} +// Data extension CoLA { - /// CoLA example. - public struct Example { - public let id: String - public let sentence: String - public let isAcceptable: Bool? - - public init(id: String, sentence: String, isAcceptable: Bool?) { - self.id = id - self.sentence = sentence - self.isAcceptable = isAcceptable - } - } - - /// CoLA data batch. - public struct DataBatch: KeyPathIterable { - public var inputs: TextBatch // TODO: !!! Mutable in order to allow for batching. - public var labels: Tensor? // TODO: !!! Mutable in order to allow for batching. - - public init(inputs: TextBatch, labels: Tensor?) { - self.inputs = inputs - self.labels = labels - } + internal static func load(fromFile fileURL: URL, isTest: Bool = false) throws -> [CoLAExample] { + let lines = try parse(tsvFileAt: fileURL) + + if isTest { + // The test data file has a header. + return lines.dropFirst().enumerated().map { (i, lineParts) in + CoLAExample(id: lineParts[0], sentence: lineParts[1], isAcceptable: nil) + } } - /// URL pointing to the downloadable ZIP file that contains the CoLA dataset. - private static let url: URL = URL( - string: String( - "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/" - + "o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4"))! - - internal enum FileType: String { - case train = "train" - case dev = "dev" - case test = "test" - } - - internal static func load(fromFile fileURL: URL, fileType: FileType) throws -> [Example] { - let lines = try parse(tsvFileAt: fileURL) - - if fileType == .test { - // The test data file has a header. - return lines.dropFirst().enumerated().map { (i, lineParts) in - Example(id: lineParts[0], sentence: lineParts[1], isAcceptable: nil) - } - } - - return lines.enumerated().map { (i, lineParts) in - Example(id: lineParts[0], sentence: lineParts[3], isAcceptable: lineParts[1] == "1") - } + return lines.enumerated().map { (i, lineParts) in + CoLAExample(id: lineParts[0], sentence: lineParts[3], isAcceptable: lineParts[1] == "1") } + } } internal func parse(tsvFileAt fileURL: URL) throws -> [[String]] { @@ -110,91 +99,97 @@ internal func parse(tsvFileAt fileURL: URL) throws -> [[String]] { } extension CoLA { - public init( - exampleMap: @escaping (Example) -> DataBatch, - taskDirectoryURL: URL, - maxSequenceLength: Int, - batchSize: Int, - dropRemainder: Bool - ) throws { - self.directoryURL = taskDirectoryURL.appendingPathComponent("CoLA") - let dataURL = directoryURL.appendingPathComponent("data") - let compressedDataURL = dataURL.appendingPathComponent("downloaded-data.zip") - - // Download the data, if necessary. - try download(from: CoLA.url, to: compressedDataURL) - - // Extract the data, if necessary. - let extractedDirectoryURL = compressedDataURL.deletingPathExtension() - if !FileManager.default.fileExists(atPath: extractedDirectoryURL.path) { - try extract(zipFileAt: compressedDataURL, to: extractedDirectoryURL) - } + /// Creates an instance in `taskDirectoryURL` with batches of size `batchSize` + /// by `maximumSequenceLength`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is determinstic and not dependent on + /// other operations. + /// - exampleMap: a transform that processes `Example` in `LabeledTextBatch`. + public init( + taskDirectoryURL: URL, + maxSequenceLength: Int, + batchSize: Int, + entropy: Entropy, + exampleMap: @escaping (CoLAExample) -> LabeledTextBatch + ) throws { + self.directoryURL = taskDirectoryURL.appendingPathComponent("CoLA") + let dataURL = directoryURL.appendingPathComponent("data") + let compressedDataURL = dataURL.appendingPathComponent("downloaded-data.zip") + + // Download the data, if necessary. + try download(from: url, to: compressedDataURL) + + // Extract the data, if necessary. + let extractedDirectoryURL = compressedDataURL.deletingPathExtension() + if !FileManager.default.fileExists(atPath: extractedDirectoryURL.path) { + try extract(zipFileAt: compressedDataURL, to: extractedDirectoryURL) + } - #if false - // FIXME: Need to generalize `DatasetUtilities.downloadResource` to accept - // arbitrary full URLs instead of constructing full URL from filename and - // file extension. - DatasetUtilities.downloadResource( - filename: "\(subDirectory)", fileExtension: "zip", - remoteRoot: url.deletingLastPathComponent(), - localStorageDirectory: directory) - #endif - - // Load the data files into arrays of examples. - let dataFilesURL = extractedDirectoryURL.appendingPathComponent("CoLA") - self.trainExamples = try CoLA.load( - fromFile: dataFilesURL.appendingPathComponent("train.tsv"), - fileType: .train) - self.devExamples = try CoLA.load( - fromFile: dataFilesURL.appendingPathComponent("dev.tsv"), - fileType: .dev) - self.testExamples = try CoLA.load( - fromFile: dataFilesURL.appendingPathComponent("test.tsv"), - fileType: .test) - - self.maxSequenceLength = maxSequenceLength - self.batchSize = batchSize - - // Create the data iterators used for training and evaluating. - self.trainDataIterator = trainExamples.shuffled().makeIterator() // TODO: [RNG] Seed support. - .map(exampleMap) - .grouped( - keyFn: { _ in 0 }, - sizeFn: { _ in batchSize / maxSequenceLength }, - reduceFn: { - DataBatch( - inputs: $0.map { $0.inputs }.paddedAndCollated( - to: maxSequenceLength), - labels: Tensor.batch($0.map { $0.labels! })) - }, - dropRemainder: dropRemainder - ) - .prefetched(count: 2) - self.devDataIterator = devExamples.makeIterator() - .map(exampleMap) - .grouped( - keyFn: { _ in 0 }, - sizeFn: { _ in batchSize / maxSequenceLength }, - reduceFn: { - DataBatch( - inputs: $0.map { $0.inputs }.paddedAndCollated( - to: maxSequenceLength), - labels: Tensor.batch($0.map { $0.labels! })) - }, - dropRemainder: dropRemainder - ) - self.testDataIterator = testExamples.makeIterator() - .map(exampleMap) - .grouped( - keyFn: { _ in 0 }, - sizeFn: { _ in batchSize / maxSequenceLength }, - reduceFn: { - DataBatch( - inputs: $0.map { $0.inputs }.paddedAndCollated( - to: maxSequenceLength), - labels: nil) - }, - dropRemainder: dropRemainder - ) + #if false + // FIXME: Need to generalize `DatasetUtilities.downloadResource` to accept + // arbitrary full URLs instead of constructing full URL from filename and + // file extension. + DatasetUtilities.downloadResource( + filename: "\(subDirectory)", fileExtension: "zip", + remoteRoot: url.deletingLastPathComponent(), + localStorageDirectory: directory) + #endif + + // Load the data files. + let dataFilesURL = extractedDirectoryURL.appendingPathComponent("CoLA") + trainingExamples = try CoLA.load( + fromFile: dataFilesURL.appendingPathComponent("train.tsv") + ).lazy.map(exampleMap) + + validationExamples = try CoLA.load( + fromFile: dataFilesURL.appendingPathComponent("dev.tsv") + ).lazy.map(exampleMap) + + self.maxSequenceLength = maxSequenceLength + self.batchSize = batchSize + + // Create the training sequence of epochs. + trainingEpochs = TrainingEpochs( + samples: trainingExamples, batchSize: batchSize / maxSequenceLength, entropy: entropy + ).lazy.map { (batches: Batches) -> LazyMapSequence in + batches.lazy.map{ + ( + data: $0.map(\.data).paddedAndCollated(to: maxSequenceLength), + label: Tensor($0.map(\.label)) + ) + } } + + // Create the validation collection of batches. + validationBatches = validationExamples.inBatches(of: batchSize / maxSequenceLength).lazy.map{ + ( + data: $0.map(\.data).paddedAndCollated(to: maxSequenceLength), + label: Tensor($0.map(\.label)) + ) + } + } } + +extension CoLA where Entropy == SystemRandomNumberGenerator { + /// Creates an instance in `taskDirectoryURL` with batches of size `batchSize` + /// by `maximumSequenceLength`. + /// + /// - Parameter exampleMap: a transform that processes `Example` in `LabeledTextBatch`. + public init( + taskDirectoryURL: URL, + maxSequenceLength: Int, + batchSize: Int, + exampleMap: @escaping (CoLAExample) -> LabeledTextBatch + ) throws { + try self.init( + taskDirectoryURL: taskDirectoryURL, + maxSequenceLength: maxSequenceLength, + batchSize: batchSize, + entropy: SystemRandomNumberGenerator(), + exampleMap: exampleMap + ) + } +} \ No newline at end of file diff --git a/Datasets/CoLA/DataUtilities.swift b/Datasets/CoLA/DataUtilities.swift index f8ec50bc469..909c897ccc4 100644 --- a/Datasets/CoLA/DataUtilities.swift +++ b/Datasets/CoLA/DataUtilities.swift @@ -11,8 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -// -// Adapted from: https://gist.github.com/eaplatanios/5004c5857ec3140651ccef6766123ac2 import Foundation @@ -20,311 +18,6 @@ import Foundation import FoundationNetworking #endif -extension IteratorProtocol { - /// Returns an iterator that maps elements of this iterator using the provided function. - /// - /// - Parameters: - /// - mapFn: Function used to map the iterator elements. - public func map( - _ mapFn: @escaping (Element) -> MappedElement - ) -> MapIterator { - MapIterator(self, mapFn: mapFn) - } - - /// Returns an iterator that repeats this iterator indefinitely. - public func repeated() -> RepeatIterator { - RepeatIterator(self) - } - - /// Returns an iterator that shuffles this iterator using a temporary buffer. - /// - /// - Parameters: - /// - bufferSize: Size of the shuffle buffer. - public func shuffled(bufferSize: Int) -> ShuffleIterator { - ShuffleIterator(self, bufferSize: bufferSize) - } - - // TODO: [DOC] Add documentation string. - public func grouped( - keyFn: @escaping (Element) -> Int, - sizeFn: @escaping (Int) -> Int, - reduceFn: @escaping ([Element]) -> Element, - dropRemainder: Bool = false - ) -> GroupedIterator { - GroupedIterator(self, keyFn: keyFn, sizeFn: sizeFn, reduceFn: reduceFn, dropRemainder: dropRemainder) - } - - // TODO: [DOC] Add documentation string. - public func prefetched(count: Int) -> PrefetchIterator { - PrefetchIterator(self, prefetchCount: count) - } -} - -extension IteratorProtocol where Element: KeyPathIterable { - /// Returns an iterator that batches elements of this iterator. - /// - /// - Parameters: - /// - batchSize: Batch size. - public func batched(batchSize: Int) -> BatchIterator { - BatchIterator(self, batchSize: batchSize) - } -} - -/// Iterator that maps elements of another iterator using the provided function. -public struct MapIterator: IteratorProtocol { - private var iterator: Base - private let mapFn: (Base.Element) -> MappedElement - - public init(_ iterator: Base, mapFn: @escaping (Base.Element) -> MappedElement) { - self.iterator = iterator - self.mapFn = mapFn - } - - public mutating func next() -> MappedElement? { - if let element = iterator.next() { return mapFn(element) } - return nil - } -} - -/// Iterator that repeats another iterator indefinitely. -public struct RepeatIterator: IteratorProtocol { - private let originalIterator: Base - private var currentIterator: Base - - public init(_ iterator: Base) { - self.originalIterator = iterator - self.currentIterator = iterator - } - - public mutating func next() -> Base.Element? { - if let element = currentIterator.next() { - return element - } - currentIterator = originalIterator - return currentIterator.next() - } -} - -/// Iterator that shuffles another iterator using a temporary buffer. -public struct ShuffleIterator: IteratorProtocol { - private let bufferSize: Int - private var iterator: Base - private var buffer: [Base.Element] - private var bufferIndex: Int - - public init(_ iterator: Base, bufferSize: Int) { - self.bufferSize = bufferSize - self.iterator = iterator - self.buffer = [] - self.bufferIndex = 0 - } - - public mutating func next() -> Base.Element? { - if buffer.isEmpty || (bufferIndex >= bufferSize && bufferSize != -1) { fillBuffer() } - if buffer.isEmpty { return nil } - bufferIndex += 1 - return buffer[bufferIndex - 1] - } - - private mutating func fillBuffer() { - buffer = [] - bufferIndex = 0 - while let element = iterator.next(), bufferIndex < bufferSize || bufferSize == -1 { - buffer.append(element) - bufferIndex += 1 - } - bufferIndex = 0 - } -} - -/// Iterator that batches elements from another iterator. -public struct BatchIterator: IteratorProtocol -where Base.Element: KeyPathIterable { - private let batchSize: Int - private var iterator: Base - private var buffer: [Base.Element] - - public init(_ iterator: Base, batchSize: Int) { - self.batchSize = batchSize - self.iterator = iterator - self.buffer = [] - self.buffer.reserveCapacity(batchSize) - } - - public mutating func next() -> Base.Element? { - while buffer.count < batchSize { - if let element = iterator.next() { - buffer.append(element) - } else { - break - } - } - if buffer.isEmpty { return nil } - let batch = Base.Element.batch(buffer) - buffer = [] - buffer.reserveCapacity(batchSize) - return batch - } -} - -/// Iterator that groups elements from another iterator. -public struct GroupedIterator: IteratorProtocol { - private let keyFn: (Base.Element) -> Int - private let sizeFn: (Int) -> Int - private let reduceFn: ([Base.Element]) -> Base.Element - private let dropRemainder: Bool - private var iterator: Base - private var groups: [Int: [Base.Element]] - private var currentGroup: Dictionary.Index? = nil - - public init( - _ iterator: Base, - keyFn: @escaping (Base.Element) -> Int, - sizeFn: @escaping (Int) -> Int, - reduceFn: @escaping ([Base.Element]) -> Base.Element, - dropRemainder: Bool = false - ) { - self.keyFn = keyFn - self.sizeFn = sizeFn - self.reduceFn = reduceFn - self.dropRemainder = dropRemainder - self.iterator = iterator - self.groups = [Int: [Base.Element]]() - } - - public mutating func next() -> Base.Element? { - var elements: [Base.Element]? = nil - while elements == nil { - if let element = iterator.next() { - let key = keyFn(element) - if !groups.keys.contains(key) { - groups[key] = [element] - } else { - groups[key]!.append(element) - } - if groups[key]!.count >= sizeFn(key) { - elements = groups.removeValue(forKey: key)! - } - } else { - break - } - } - guard let elementsToReduce = elements else { - if dropRemainder { return nil } - if currentGroup == nil { currentGroup = groups.values.startIndex } - if currentGroup! >= groups.values.endIndex { return nil } - while groups.values[currentGroup!].isEmpty { - currentGroup = groups.values.index(after: currentGroup!) - } - let elementsToReduce = groups.values[currentGroup!] - currentGroup = groups.values.index(after: currentGroup!) - return reduceFn(elementsToReduce) - } - return reduceFn(elementsToReduce) - } -} - -/// Iterator that prefetches elements from another iterator asynchronously. -public struct PrefetchIterator: IteratorProtocol { - private let iterator: Base - private let prefetchCount: Int - - private var queue: BlockingQueue - - public init(_ iterator: Base, prefetchCount: Int) { - self.iterator = iterator - self.prefetchCount = prefetchCount - self.queue = BlockingQueue(count: prefetchCount, iterator: iterator) - } - - public mutating func next() -> Base.Element? { - queue.read() - } - - // TODO: !!! This is needed because `BlockingQueue` is a class. Figure out a better solution. - public func copy() -> PrefetchIterator { - PrefetchIterator(iterator, prefetchCount: prefetchCount) - } -} - -extension PrefetchIterator { - internal class BlockingQueue { - private let prefetchingDispatchQueue: DispatchQueue = DispatchQueue(label: "PrefetchIterator") - private let writeSemaphore: DispatchSemaphore - private let readSemaphore: DispatchSemaphore - private let deletedSemaphore: DispatchSemaphore - private let dispatchQueue: DispatchQueue - private var array: [Element?] - private var readIndex: Int - private var writeIndex: Int - private var depleted: Bool - private var deleted: Bool - - internal init( - count: Int, - iterator: Base - ) where Base.Element == Element { - self.writeSemaphore = DispatchSemaphore(value: count) - self.readSemaphore = DispatchSemaphore(value: 0) - self.deletedSemaphore = DispatchSemaphore(value: 0) - self.dispatchQueue = DispatchQueue(label: "BlockingQueue") - self.array = [Element?](repeating: nil, count: count) - self.readIndex = 0 - self.writeIndex = 0 - self.depleted = false - self.deleted = false - var iterator = iterator - prefetchingDispatchQueue.async { [unowned self] () in - while !self.deleted { - if let element = iterator.next() { - self.write(element) - } else { - self.depleted = true - self.readSemaphore.signal() - self.deletedSemaphore.signal() - break - } - } - self.readSemaphore.signal() - self.deletedSemaphore.signal() - } - } - - deinit { - self.deleted = true - - // Signal the write semaphore to make sure it's not in use anymore. It's final value must be - // greater or equal to its initial value. - for _ in 0...array.count { writeSemaphore.signal() } - - // Wait for the delete semaphore to make sure the prefetching thread is done. - deletedSemaphore.wait() - } - - private func write(_ element: Element) { - writeSemaphore.wait() - dispatchQueue.sync { - array[writeIndex % array.count] = element - writeIndex += 1 - } - readSemaphore.signal() - } - - internal func read() -> Element? { - if self.depleted { return nil } - readSemaphore.wait() - let element = dispatchQueue.sync { () -> Element? in - let element = array[readIndex % array.count] - array[readIndex % array.count] = nil - readIndex += 1 - return element - } - writeSemaphore.signal() - return element - } - } -} - public func extract(zipFileAt source: URL, to destination: URL) throws { print("Extracting file at '\(source.path)'.") let process = Process() @@ -346,4 +39,4 @@ public func extract(tarGZippedFileAt source: URL, to destination: URL) throws { process.arguments = ["-c", "tar -C \(destination.path) -xzf \(source.path)"] try process.run() process.waitUntilExit() -} +} \ No newline at end of file diff --git a/Examples/BERT-CoLA/main.swift b/Examples/BERT-CoLA/main.swift index 17c8961e681..fbcd407a2ec 100644 --- a/Examples/BERT-CoLA/main.swift +++ b/Examples/BERT-CoLA/main.swift @@ -41,21 +41,17 @@ var bertClassifier = BERTClassifier(bert: bert, classCount: 1) let maxSequenceLength = 128 let batchSize = 1024 -// Create a function that converts examples to data batches. -let exampleMapFn: (CoLA.Example) -> CoLA.DataBatch = { example -> CoLA.DataBatch in - let textBatch = bertClassifier.bert.preprocess( - sequences: [example.sentence], - maxSequenceLength: maxSequenceLength) - return CoLA.DataBatch( - inputs: textBatch, labels: example.isAcceptable.map { Tensor($0 ? 1 : 0) }) -} - var cola = try CoLA( - exampleMap: exampleMapFn, - taskDirectoryURL: workspaceURL, - maxSequenceLength: maxSequenceLength, - batchSize: batchSize, - dropRemainder: true) + taskDirectoryURL: workspaceURL, + maxSequenceLength: maxSequenceLength, + batchSize: batchSize, + entropy: SystemRandomNumberGenerator() +) { (example: CoLAExample) -> LabeledTextBatch in + let textBatch = bertClassifier.bert.preprocess( + sequences: [example.sentence], + maxSequenceLength: maxSequenceLength) + return (data: textBatch, label: Tensor(example.isAcceptable! ? 1 : 0)) +} print("Dataset acquired.") @@ -72,15 +68,14 @@ var optimizer = WeightDecayedAdam( maxGradientGlobalNorm: 1) print("Training BERT for the CoLA task!") -for epoch in 1...3 { - print("[Epoch \(epoch)]") +for (epoch, epochBatches) in cola.trainingEpochs.prefix(3).enumerated() { + print("[Epoch \(epoch + 1)]") Context.local.learningPhase = .training var trainingLossSum: Float = 0 var trainingBatchCount = 0 - var trainingDataIterator = cola.trainDataIterator - while let batch = withDevice(.cpu, perform: { trainingDataIterator.next() }) { - let (documents, labels) = (batch.inputs, Tensor(batch.labels!)) + for batch in epochBatches { + let (documents, labels) = (batch.data, Tensor(batch.label)) let (loss, gradients) = valueWithGradient(at: bertClassifier) { model -> Tensor in let logits = model(documents) return sigmoidCrossEntropy( @@ -103,11 +98,10 @@ for epoch in 1...3 { Context.local.learningPhase = .inference var devLossSum: Float = 0 var devBatchCount = 0 - var devDataIterator = cola.devDataIterator var devPredictedLabels = [Bool]() var devGroundTruth = [Bool]() - while let batch = withDevice(.cpu, perform: { devDataIterator.next() }) { - let (documents, labels) = (batch.inputs, batch.labels!) + for batch in cola.validationBatches { + let (documents, labels) = (batch.data, Tensor(batch.label)) let logits = bertClassifier(documents) let loss = sigmoidCrossEntropy( logits: logits.squeezingShape(at: -1),