diff --git a/Benchmarks/Models/WordSeg.swift b/Benchmarks/Models/WordSeg.swift index ff56c5a002c..b6340179e93 100644 --- a/Benchmarks/Models/WordSeg.swift +++ b/Benchmarks/Models/WordSeg.swift @@ -106,14 +106,14 @@ struct WordSegBenchmark: Benchmark { from: [sentence], alphabet: dataset.alphabet, maxLength: maximumSequenceLength, - minFreq: 10 + minFrequency: 10 ) let modelParameters = SNLM.Parameters( - ndim: 512, - dropoutProb: 0.5, - chrVocab: dataset.alphabet, - strVocab: lexicon, + hiddenSize: 512, + dropoutProbability: 0.5, + alphabet: dataset.alphabet, + lexicon: lexicon, order: 5 ) diff --git a/Datasets/CMakeLists.txt b/Datasets/CMakeLists.txt index 3f1b519cb7d..b2344d69e4c 100644 --- a/Datasets/CMakeLists.txt +++ b/Datasets/CMakeLists.txt @@ -21,7 +21,7 @@ add_library(Datasets TensorPair.swift TextUnsupervised/TextUnsupervised.swift WordSeg/WordSegDataset.swift - WordSeg/WordSegRecord.swift + WordSeg/Phrase.swift ImageSegmentationDataset.swift OxfordIIITPets/OxfordIIITPets.swift) target_link_libraries(Datasets PUBLIC diff --git a/Datasets/WordSeg/WordSegRecord.swift b/Datasets/WordSeg/Phrase.swift similarity index 73% rename from Datasets/WordSeg/WordSegRecord.swift rename to Datasets/WordSeg/Phrase.swift index d0049a2a2c2..e2ccb02fad1 100644 --- a/Datasets/WordSeg/WordSegRecord.swift +++ b/Datasets/WordSeg/Phrase.swift @@ -14,10 +14,17 @@ import ModelSupport -public struct WordSegRecord { +/// A sequence of text for use in word segmentation. +public struct Phrase { + + /// A raw, unprocessed sequence of text. public let plainText: String + + /// A sequence of text in numeric form, derived from `plainText`. public let numericalizedText: CharacterSequence + /// Creates an instance containing both raw (`plainText`) and processed + /// (`numericalizedText`) forms of a sequence of text. public init(plainText: String, numericalizedText: CharacterSequence) { self.plainText = plainText self.numericalizedText = numericalizedText diff --git a/Datasets/WordSeg/WordSegDataset.swift b/Datasets/WordSeg/WordSegDataset.swift index 4b918898de6..8bb166d7aa6 100644 --- a/Datasets/WordSeg/WordSegDataset.swift +++ b/Datasets/WordSeg/WordSegDataset.swift @@ -15,110 +15,129 @@ import Foundation import ModelSupport +/// A dataset targeted at the problem of word segmentation. +/// +/// The reference archive was published in the paper "Learning to Discover, +/// Ground, and Use Words with Segmental Neural Language Models" by Kazuya +/// Kawakami, Chris Dyer, and Phil Blunsom: +/// https://www.aclweb.org/anthology/P19-1645.pdf. public struct WordSegDataset { - public let training: [WordSegRecord] - public private(set) var testing: [WordSegRecord]? - public private(set) var validation: [WordSegRecord]? + + /// The training data. + public let trainingPhrases: [Phrase] + + /// The test data. + public private(set) var testingPhrases: [Phrase] + + /// The validation data. + public private(set) var validationPhrases: [Phrase] + + /// A mapping between characters used in the dataset and densely-packed integers public let alphabet: Alphabet - private struct DownloadDetails { - var archiveLocation = URL(string: "https://s3.eu-west-2.amazonaws.com/k-kawakami")! - var archiveFileName = "seg" - var archiveExtension = "zip" - var testingFilePath = "br/br-text/te.txt" - var trainingFilePath = "br/br-text/tr.txt" - var validationFilePath = "br/br-text/va.txt" - } + /// A pointer to source data. + private struct DownloadableArchive { - private static func load(data: Data) throws -> [String] { - guard let contents: String = String(data: data, encoding: .utf8) else { - throw CharacterErrors.nonUtf8Data - } - return load(contents: contents) - } + /// A [web resource](https://en.wikipedia.org/wiki/Web_resource) that can be unpacked + /// into data files described by other properties of `self`. + let location = URL(string: "https://s3.eu-west-2.amazonaws.com/k-kawakami/seg.zip")! - private static func load(contents: String) -> [String] { - var strings = [String]() + /// The path to the test data within the unpacked archive. + let testingFilePath = "br/br-text/te.txt" - for line in contents.components(separatedBy: .newlines) { - let trimmed = line.trimmingCharacters(in: .whitespaces) - if trimmed.isEmpty { continue } - strings.append(trimmed) - } - return strings + /// The path to the training data within the unpacked archive. + let trainingFilePath = "br/br-text/tr.txt" + + /// The path to the validation data within the unpacked archive. + let validationFilePath = "br/br-text/va.txt" + } + + /// Returns phrases parsed from `data` in UTF8, separated by newlines. + private static func load(data: Data) -> [Substring] { + let contents = String(decoding: data, as: Unicode.UTF8.self) + let splitContents = contents.split(separator: "\n", omittingEmptySubsequences: true) + return splitContents } + /// Returns the union of all characters in `phrases`. + /// + /// - Parameter eos: the end of sequence marker. + /// - Parameter eow:the end of word marker. + /// - Parameter pad: the padding marker. private static func makeAlphabet( - datasets training: [String], - _ otherSequences: [String]?..., + phrases: [Substring], eos: String = "", eow: String = "", pad: String = "" ) -> Alphabet { - var letters: Set = [] - - for dataset in otherSequences + [training] { - guard let dataset = dataset else { continue } - for sentence in dataset { - for character in sentence { - if !character.isWhitespace { letters.insert(character) } - } - } - } + let letters = Set(phrases.joined().lazy.filter { !$0.isWhitespace }) // Sort the letters to make it easier to interpret ints vs letters. - var sorted = Array(letters) - sorted.sort() + let sorted = Array(letters).sorted() return Alphabet(sorted, eos: eos, eow: eow, pad: pad) } - private static func convertDataset(_ dataset: [String], alphabet: Alphabet) throws - -> [WordSegRecord] - { - return try dataset.map { - let trimmed = $0.components(separatedBy: .whitespaces).joined() - return try WordSegRecord( - plainText: $0, - numericalizedText: CharacterSequence( - alphabet: alphabet, appendingEoSTo: trimmed)) - } - } - private static func convertDataset(_ dataset: [String]?, alphabet: Alphabet) throws - -> [WordSegRecord]? + /// Numericalizes `dataset` with the mapping in `alphabet`, to be used with the + /// WordSeg model. + /// + /// - Note: Omits any phrase that cannot be converted to `CharacterSequence`. + private static func numericalizeDataset(_ dataset: [Substring], alphabet: Alphabet) + -> [Phrase] { - if let ds = dataset { - let tmp: [WordSegRecord] = try convertDataset(ds, alphabet: alphabet) // Use tmp to disambiguate function - return tmp + var phrases = [Phrase]() + + for data in dataset { + let trimmed = data.split(separator: " ", omittingEmptySubsequences: true).joined() + guard + let numericalizedText = try? CharacterSequence( + alphabet: alphabet, appendingEoSTo: trimmed) + else { continue } + let phrase = Phrase( + plainText: String(data), + numericalizedText: numericalizedText) + phrases.append(phrase) } - return nil + + return phrases } + /// Creates an instance containing phrases from the reference archive. + /// + /// - Throws: an error in the Cocoa domain, if the default training file + /// cannot be read. public init() throws { - let downloadDetails = DownloadDetails() + let source = DownloadableArchive() let localStorageDirectory: URL = DatasetUtilities.defaultDirectory .appendingPathComponent("WordSeg", isDirectory: true) - WordSegDataset.downloadIfNotPresent(to: localStorageDirectory, downloadDetails: downloadDetails) + Self.downloadIfNotPresent( + to: localStorageDirectory, source: source) + let archiveFileName = source.location.deletingPathExtension().lastPathComponent let archiveDirectory = localStorageDirectory - .appendingPathComponent(downloadDetails.archiveFileName) + .appendingPathComponent(archiveFileName) let trainingFilePath = archiveDirectory - .appendingPathComponent(downloadDetails.trainingFilePath).path + .appendingPathComponent(source.trainingFilePath).path let validationFilePath = archiveDirectory - .appendingPathComponent(downloadDetails.validationFilePath).path + .appendingPathComponent(source.validationFilePath).path let testingFilePath = archiveDirectory - .appendingPathComponent(downloadDetails.testingFilePath).path + .appendingPathComponent(source.testingFilePath).path try self.init( training: trainingFilePath, validation: validationFilePath, testing: testingFilePath) } + /// Creates an instance containing phrases from `trainingFile`, and + /// optionally `validationFile` and `testingFile`. + /// + /// - Throws: an error in the Cocoa domain, if `trainingFile` cannot be + /// read. public init( training trainingFile: String, validation validationFile: String? = nil, @@ -127,53 +146,38 @@ public struct WordSegDataset { let trainingData = try Data( contentsOf: URL(fileURLWithPath: trainingFile), options: .alwaysMapped) - let training = try Self.load(data: trainingData) - var validation: [String]? = nil - var testing: [String]? = nil + let validationData = try Data( + contentsOf: URL(fileURLWithPath: validationFile ?? "/dev/null"), + options: .alwaysMapped) - if let validationFile = validationFile { - let data = try Data( - contentsOf: URL(fileURLWithPath: validationFile), - options: .alwaysMapped) - validation = try Self.load(data: data) - } + let testingData = try Data( + contentsOf: URL(fileURLWithPath: testingFile ?? "/dev/null"), + options: .alwaysMapped) - if let testingFile = testingFile { - let data: Data = try Data( - contentsOf: URL(fileURLWithPath: testingFile), - options: .alwaysMapped) - testing = try Self.load(data: data) - } - self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) - self.training = try Self.convertDataset(training, alphabet: self.alphabet) - self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) - self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) + self.init( + training: trainingData, validation: validationData, testing: testingData) } + /// Creates an instance containing phrases from `trainingData`, and + /// optionally `validationData` and `testingData`. public init( training trainingData: Data, validation validationData: Data?, testing testingData: Data? - ) - throws - { - let training = try Self.load(data: trainingData) - var validation: [String]? = nil - var testing: [String]? = nil - if let validationData = validationData { - validation = try Self.load(data: validationData) - } - if let testingData = testingData { - testing = try Self.load(data: testingData) - } - - self.alphabet = Self.makeAlphabet(datasets: training, validation, testing) - self.training = try Self.convertDataset(training, alphabet: self.alphabet) - self.validation = try Self.convertDataset(validation, alphabet: self.alphabet) - self.testing = try Self.convertDataset(testing, alphabet: self.alphabet) + ) { + let training = Self.load(data: trainingData) + let validation = Self.load(data: validationData ?? Data()) + let testing = Self.load(data: testingData ?? Data()) + + self.alphabet = Self.makeAlphabet(phrases: training + validation + testing) + self.trainingPhrases = Self.numericalizeDataset(training, alphabet: self.alphabet) + self.validationPhrases = Self.numericalizeDataset(validation, alphabet: self.alphabet) + self.testingPhrases = Self.numericalizeDataset(testing, alphabet: self.alphabet) } + /// Downloads and unpacks `source` to `directory` if it does not + /// exist locally. private static func downloadIfNotPresent( - to directory: URL, downloadDetails: DownloadDetails + to directory: URL, source: DownloadableArchive ) { let downloadPath = directory.path let directoryExists = FileManager.default.fileExists(atPath: downloadPath) @@ -182,11 +186,15 @@ public struct WordSegDataset { guard !directoryExists || directoryEmpty else { return } + let remoteRoot = source.location.deletingLastPathComponent() + let filename = source.location.deletingPathExtension().lastPathComponent + let fileExtension = source.location.pathExtension + // Downloads and extracts dataset files. let _ = DatasetUtilities.downloadResource( - filename: downloadDetails.archiveFileName, - fileExtension: downloadDetails.archiveExtension, - remoteRoot: downloadDetails.archiveLocation, + filename: filename, + fileExtension: fileExtension, + remoteRoot: remoteRoot, localStorageDirectory: directory, extract: true) } } diff --git a/Examples/WordSeg/main.swift b/Examples/WordSeg/main.swift index 8a68e9efcf5..ae89d0342e6 100644 --- a/Examples/WordSeg/main.swift +++ b/Examples/WordSeg/main.swift @@ -18,9 +18,9 @@ import TensorFlow import TextModels // Model flags -let ndim = 512 // Hidden unit size. +let hiddenSize = 512 // Hidden unit size. // Training flags -let dropoutProb = 0.5 // Dropout rate. +let dropoutProbability = 0.5 // Dropout rate. let order = 5 // Power of length penalty. let maxEpochs = 1000 // Maximum number of training epochs. var trainingLossHistory = [Float]() // Keep track of loss. @@ -30,7 +30,7 @@ let learningRate: Float = 1e-3 // Initial learning rate. let lambd: Float = 0.00075 // Weight of length penalty. // Lexicon flags. let maxLength = 10 // Maximum length of a string. -let minFreq = 10 // Minimum frequency of a string. +let minFrequency = 10 // Minimum frequency of a string. // Load user-provided data files. let dataset: WordSegDataset @@ -50,23 +50,23 @@ default: usage() } -let sequences = dataset.training.map { $0.numericalizedText } +let sequences = dataset.trainingPhrases.map { $0.numericalizedText } let lexicon = Lexicon( from: sequences, alphabet: dataset.alphabet, maxLength: maxLength, - minFreq: minFreq + minFrequency: minFrequency ) let modelParameters = SNLM.Parameters( - ndim: ndim, - dropoutProb: dropoutProb, - chrVocab: dataset.alphabet, - strVocab: lexicon, + hiddenSize: hiddenSize, + dropoutProbability: dropoutProbability, + alphabet: dataset.alphabet, + lexicon: lexicon, order: order ) -let device = Device.defaultXLA +let device = Device.defaultTFEager var model = SNLM(parameters: modelParameters) model.move(to: device) @@ -80,8 +80,8 @@ for epoch in 1...maxEpochs { Context.local.learningPhase = .training var trainingLossSum: Float = 0 var trainingBatchCount = 0 - for record in dataset.training { - let sentence = record.numericalizedText + for phrase in dataset.trainingPhrases { + let sentence = phrase.numericalizedText let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor in let lattice = model.buildLattice(sentence, maxLen: maxLength, device: device) let score = lattice[sentence.count].semiringScore @@ -107,11 +107,11 @@ for epoch in 1...maxEpochs { trainingLossHistory.append(trainingLoss) reduceLROnPlateau(lossHistory: trainingLossHistory, optimizer: optimizer) - guard let validationDataset = dataset.validation else { + if dataset.validationPhrases.count < 1 { print( """ [Epoch \(epoch)] \ - Training loss: \(trainingLoss)) + Training loss: \(trainingLoss) """ ) @@ -131,8 +131,8 @@ for epoch in 1...maxEpochs { var validationBatchCount = 0 var validationCharacterCount = 0 var validationPlainText: String = "" - for record in validationDataset { - let sentence = record.numericalizedText + for phrase in dataset.validationPhrases { + let sentence = phrase.numericalizedText var lattice = model.buildLattice(sentence, maxLen: maxLength, device: device) let score = lattice[sentence.count].semiringScore @@ -141,8 +141,8 @@ for epoch in 1...maxEpochs { validationCharacterCount += sentence.count // View a sample segmentation once per epoch. - if validationBatchCount == validationDataset.count { - let bestPath = lattice.viterbi(sentence: record.numericalizedText) + if validationBatchCount == dataset.validationPhrases.count { + let bestPath = lattice.viterbi(sentence: phrase.numericalizedText) validationPlainText = Lattice.pathToPlainText(path: bestPath, alphabet: dataset.alphabet) } } diff --git a/Models/Text/WordSeg/Lattice.swift b/Models/Text/WordSeg/Lattice.swift index 0a0cfb1cd81..bd1c7ef9d72 100644 --- a/Models/Text/WordSeg/Lattice.swift +++ b/Models/Text/WordSeg/Lattice.swift @@ -23,23 +23,41 @@ import TensorFlow import Glibc #endif -/// Lattice +/// A structure used for scoring all possible segmentations of a character +/// sequence. /// -/// Represents the lattice used by the WordSeg algorithm. +/// The path with the best score provides the most likely segmentation. public struct Lattice: Differentiable { - /// Edge + + /// Represents a word. /// - /// Represents an Edge + /// At each character position, an edge is constructed for every possible + /// segmentation of the preceding portion of the sequence. public struct Edge: Differentiable { + + /// The node position immediately preceding this edge. @noDerivative public var start: Int + + /// The node position immediately following this edge. @noDerivative public var end: Int + + /// The characters composing a word. @noDerivative public var string: CharacterSequence + + /// The log likelihood of this segmentation. public var logp: Float - // expectation + /// The expected score for this segmentation. public var score: SemiRing + + /// The expected total score for this segmentation. public var totalScore: SemiRing + /// Creates an edge for `sentence` between `start` and `end`. + /// + /// Uses the log probability `logp` and the power of the length penalty + /// `order` to calculate the regularization factor and form the current + /// score. Sums this score with `previous` to determine the total score. @differentiable init( start: Int, end: Int, sentence: CharacterSequence, logp: Float, @@ -58,6 +76,8 @@ public struct Lattice: Differentiable { self.totalScore = self.score * previous } + /// Creates an edge for `string` between `start` and `end` and sets the + /// log probability `logp`, `score`, and `totalScore`. @differentiable public init( start: Int, end: Int, string: CharacterSequence, logp: Float, @@ -72,17 +92,32 @@ public struct Lattice: Differentiable { } } - /// Node + /// Represents a word boundary. /// - /// Represents a node in the lattice + /// When a lattice is built, a start node is created, followed by one for + /// every character in the sequence, representing every potential boundary. + /// + /// - Note: Scores are only meaningful in relation to incoming edges and the + /// start node has no incoming edges. public struct Node: Differentiable { + + /// The incoming edge with the highest score. @noDerivative public var bestEdge: Edge? + + /// The score of the best incoming edge. public var bestScore: Float = 0.0 + + /// All incoming edges. public var edges = [Edge]() + + /// A composite score of all incoming edges. public var semiringScore: SemiRing = SemiRing.one + /// Creates an empty instance. init() {} + /// Creates a node preceded by `bestEdge`, sets incoming edges to + /// `edges`, and stores `bestScore` and `semiringScore`. @differentiable public init( bestEdge: Edge?, bestScore: Float, edges: [Edge], @@ -94,20 +129,24 @@ public struct Lattice: Differentiable { self.semiringScore = semiringScore } + /// Returns a sum of the total score of all incoming edges. @differentiable func computeSemiringScore() -> SemiRing { // TODO: Reduceinto and += edges.differentiableMap { $0.totalScore }.sum() } + /// Calculates and sets the current semiring score. @differentiable mutating func recomputeSemiringScore() { semiringScore = computeSemiringScore() } } + /// Represents the position of word boundaries. var positions: [Node] + /// Accesses the node at the `index`th position. @differentiable public subscript(index: Int) -> Node { get { return positions[index] } @@ -121,16 +160,20 @@ public struct Lattice: Differentiable { // _modify { yield &positions[index] } } + /// Creates an empty instance with a start node, followed by `count` nodes. init(count: Int) { positions = Array(repeating: Node(), count: count + 1) } + /// Creates an instance with the nodes in `positions`. public init(positions: [Node]) { self.positions = positions } + /// Returns the path representing the best segmentation of `sentence`. public mutating func viterbi(sentence: CharacterSequence) -> [Edge] { - // Forwards pass + // Forward pass + // Starts at 1 since the 0 node has no incoming edges. for position in 1...sentence.count { var bestScore = -Float.infinity var bestEdge: Edge! @@ -145,7 +188,7 @@ public struct Lattice: Differentiable { self[position].bestEdge = bestEdge } - // Backwards + // Backward pass var bestPath: [Edge] = [] var nextEdge = self[sentence.count].bestEdge! while nextEdge.start != 0 { @@ -157,6 +200,9 @@ public struct Lattice: Differentiable { return bestPath.reversed() } + /// Returns the plain text encoded in `path`, using `alphabet`. + /// + /// This represents the segmentation of the full character sequence. public static func pathToPlainText(path: [Edge], alphabet: Alphabet) -> String { var plainText = [String]() for edge in path { @@ -171,6 +217,8 @@ public struct Lattice: Differentiable { } extension Lattice: CustomStringConvertible { + + /// The plain text description of this instance that describes all nodes. public var description: String { """ [ @@ -181,6 +229,9 @@ extension Lattice: CustomStringConvertible { } extension Lattice.Node: CustomStringConvertible { + + /// The plain text description of this instance that describes all incoming + /// edges. public var description: String { var edgesStr: String if edges.isEmpty { @@ -196,13 +247,19 @@ extension Lattice.Node: CustomStringConvertible { } extension Lattice.Edge: CustomStringConvertible { + + /// The plain text description of this instance with all edge details. public var description: String { "[\(start)->\(end)] logp: \(logp), score: \(score.shortDescription), total score: \(totalScore.shortDescription), sentence: \(string)" } } -/// SE-0259-esque equality with tolerance extension Lattice { + + /// Returns true when all nodes in `self` are within `tolerance` of all + /// nodes in `other`. + /// + /// - Note: This behavior is modeled after SE-0259. public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { guard self.positions.count == other.positions.count else { print("positions count mismatch: \(self.positions.count) != \(other.positions.count)") @@ -221,6 +278,11 @@ extension Lattice { } extension Lattice.Node { + + /// Returns true when all properties and edges in `self` are within + /// `tolerance` of all properties and edges in `other`. + /// + /// - Note: This behavior is modeled after SE-0259. public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { guard self.edges.count == other.edges.count else { return false } @@ -243,6 +305,11 @@ extension Lattice.Node { } extension Lattice.Edge { + + /// Returns true when the log likelihood and scores in `self` are within + /// `tolerance` of the log likelihood and scores in `other`. + /// + /// - Note: This behavior is modeled after SE-0259. public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { let diffP = abs(self.logp - other.logp) return self.start == other.start && self.end == other.end diff --git a/Models/Text/WordSeg/Model.swift b/Models/Text/WordSeg/Model.swift index ca065171e6f..660a5cadf54 100644 --- a/Models/Text/WordSeg/Model.swift +++ b/Models/Text/WordSeg/Model.swift @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + // Original Paper: // "Learning to Discover, Ground, and Use Words with Segmental Neural Language // Models" @@ -18,102 +19,137 @@ // https://www.aclweb.org/anthology/P19-1645.pdf // This implementation is not affiliated with DeepMind and has not been // verified by the authors. + import ModelSupport import TensorFlow -/// SNLM -/// -/// A representation of the Segmental Neural Language Model. -/// -/// \ref https://www.aclweb.org/anthology/P19-1645.pdf +/// A Segmental Neural Language Model for word segmentation, as described in +/// the above paper. public struct SNLM: EuclideanDifferentiable, KeyPathIterable { + + /// A set of configuration parameters that define model behavior. public struct Parameters { - public var ndim: Int - public var dropoutProb: Double - public var chrVocab: Alphabet - public var strVocab: Lexicon + + /// The hidden unit size. + public var hiddenSize: Int + + /// The dropout rate. + public var dropoutProbability: Double + + /// The union of characters used in this model. + public var alphabet: Alphabet + + /// Contiguous sequences of characters encountered in the training data. + public var lexicon: Lexicon + + /// The power of the length penalty. public var order: Int + /// Creates an instance with `hiddenSize` units, `dropoutProbability` + /// rate, `alphabet`, `lexicon`, and `order` power of length penalty. public init( - ndim: Int, - dropoutProb: Double, - chrVocab: Alphabet, - strVocab: Lexicon, + hiddenSize: Int, + dropoutProbability: Double, + alphabet: Alphabet, + lexicon: Lexicon, order: Int ) { - self.ndim = ndim - self.dropoutProb = dropoutProb - self.chrVocab = chrVocab - self.strVocab = strVocab + self.hiddenSize = hiddenSize + self.dropoutProbability = dropoutProbability + self.alphabet = alphabet + self.lexicon = lexicon self.order = order } } + /// The configuration parameters that define model behavior. @noDerivative public var parameters: Parameters // MARK: - Encoder + + /// The embedding layer for the encoder. public var encoderEmbedding: Embedding + + /// The LSTM layer for the encoder. public var encoderLSTM: LSTM // MARK: - Interpolation weight + + /// The interpolation weight, which determines the proportion of + /// contributions from the lexical memory and character generation. public var mlpInterpolation: MLP // MARK: - Lexical memory + + /// The lexical memory. public var mlpMemory: MLP // MARK: - Character-level decoder + + /// The embedding layer for the decoder. public var decoderEmbedding: Embedding + + /// The LSTM layer for the decoder. public var decoderLSTM: LSTM + + /// The dense layer for the decoder. public var decoderDense: Dense // MARK: - Other layers + + /// The dropout layer for both the encoder and decoder. public var dropout: Dropout // MARK: - Initializer + + /// Creates an instance with the configuration defined by `parameters`. public init(parameters: Parameters) { self.parameters = parameters // Encoder self.encoderEmbedding = Embedding( - vocabularySize: parameters.chrVocab.count, - embeddingSize: parameters.ndim) + vocabularySize: parameters.alphabet.count, + embeddingSize: parameters.hiddenSize) self.encoderLSTM = LSTM( LSTMCell( - inputSize: parameters.ndim, + inputSize: parameters.hiddenSize, hiddenSize: - parameters.ndim)) + parameters.hiddenSize)) // Interpolation weight self.mlpInterpolation = MLP( - nIn: parameters.ndim, - nHidden: parameters.ndim, - nOut: 2, - dropoutProbability: parameters.dropoutProb) + inputSize: parameters.hiddenSize, + hiddenSize: parameters.hiddenSize, + outputSize: 2, + dropoutProbability: parameters.dropoutProbability) // Lexical memory self.mlpMemory = MLP( - nIn: parameters.ndim, - nHidden: parameters.ndim, - nOut: parameters.strVocab.count, - dropoutProbability: parameters.dropoutProb) + inputSize: parameters.hiddenSize, + hiddenSize: parameters.hiddenSize, + outputSize: parameters.lexicon.count, + dropoutProbability: parameters.dropoutProbability) // Character-level decoder self.decoderEmbedding = Embedding( - vocabularySize: parameters.chrVocab.count, - embeddingSize: parameters.ndim) + vocabularySize: parameters.alphabet.count, + embeddingSize: parameters.hiddenSize) self.decoderLSTM = LSTM( LSTMCell( - inputSize: parameters.ndim, + inputSize: parameters.hiddenSize, hiddenSize: - parameters.ndim)) - self.decoderDense = Dense(inputSize: parameters.ndim, outputSize: parameters.chrVocab.count) + parameters.hiddenSize)) + self.decoderDense = Dense( + inputSize: parameters.hiddenSize, outputSize: parameters.alphabet.count) // Other layers - self.dropout = Dropout(probability: parameters.dropoutProb) + self.dropout = Dropout(probability: parameters.dropoutProbability) } // MARK: - Encode - /// Returns the hidden states of the encoder LSTM applied to the given sentence. + + /// Returns the hidden states of the encoder LSTM applied to `x`, using + /// `device`. public func encode(_ x: CharacterSequence, device: Device) -> [Tensor] { var embedded = encoderEmbedding(x.tensor(device: device)) embedded = dropout(embedded) @@ -125,7 +161,9 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { } // MARK: - Decode - /// Returns log probabilities for each of the candidates. + + /// Returns the log probabilities for each sequence in `candidates`, given + /// hidden `state` from the encoder LSTM, using `device`. public func decode(_ candidates: [CharacterSequence], _ state: Tensor, device: Device) -> Tensor { @@ -135,16 +173,16 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { var xBatch: [Int32] = [] var yBatch: [Int32] = [] for candidate in candidates { - let padding = Array(repeating: parameters.chrVocab.pad, count: maxLen - candidate.count - 1) + let padding = Array(repeating: parameters.alphabet.pad, count: maxLen - candidate.count - 1) // x is {sentence}{padding} - xBatch.append(parameters.chrVocab.eow) + xBatch.append(parameters.alphabet.eow) xBatch.append(contentsOf: candidate.characters) xBatch.append(contentsOf: padding) // y is {sentence}{padding} yBatch.append(contentsOf: candidate.characters) - yBatch.append(parameters.chrVocab.eow) + yBatch.append(parameters.alphabet.eow) yBatch.append(contentsOf: padding) } @@ -157,26 +195,26 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { shape: [candidates.count, maxLen], scalars: yBatch, on: device ).transposed() - // [time x batch x ndim] + // [time x batch x hiddenSize] var embeddedX = decoderEmbedding(x) embeddedX = dropout(embeddedX) - // [batch x ndim] + // [batch x hiddenSize] let stateBatch = state.rankLifted().tiled(multiples: [candidates.count, 1]) - // [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x ndim] + // [time] array of LSTM states whose `hidden` and `cell` fields have shape [batch x hiddenSize] let decoderStates = decoderLSTM( embeddedX.unstacked(), initialState: LSTMCell.State( cell: Tensor(zeros: stateBatch.shape, on: device), hidden: stateBatch)) - // [time x batch x ndim] + // [time x batch x hiddenSize] var decoderResult = Tensor( stacking: decoderStates.differentiableMap { $0.hidden }) decoderResult = dropout(decoderResult) - // [time x batch x chrVocab.count] + // [time x batch x alphabet.count] let logits = decoderDense(decoderResult) // [time x batch] @@ -189,7 +227,7 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { ).reshaped(to: y.shape) // [time x batch] - let padScalars = [Int32](repeating: parameters.chrVocab.pad, count: candidates.count * maxLen) + let padScalars = [Int32](repeating: parameters.alphabet.pad, count: candidates.count * maxLen) let noPad = Tensor( y .!= Tensor(shape: y.shape, scalars: padScalars, on: device)) let noPadFloat = Tensor(noPad) @@ -202,13 +240,17 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { } // MARK: - buildLattice + + /// Returns the log probability for `candidate` from the lexical memory + /// `logp_lex`. func get_logp_lex(_ logp_lex: [Float], _ candidate: CharacterSequence) -> Float { - guard let index = parameters.strVocab.dictionary[candidate] else { + guard let index = parameters.lexicon.dictionary[candidate] else { return -Float.infinity } return logp_lex[Int(index)] } + /// Returns a lattice for `sentence` with `maxLen` maximum sequence length. @differentiable public func buildLattice(_ sentence: CharacterSequence, maxLen: Int, device: Device) -> Lattice { var lattice = Lattice(count: sentence.count) @@ -221,12 +263,12 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { // TODO: avoid copies? let candidate = CharacterSequence( - alphabet: parameters.chrVocab, + alphabet: parameters.alphabet, characters: sentence[pos.."] continue } @@ -276,15 +318,19 @@ public struct SNLM: EuclideanDifferentiable, KeyPathIterable { } extension Array { - // NOTE(TF-1277): this mutating method exists as a workaround for `Array.subscript._modify` not - // being differentiable. - // - // Semantically, it behaves like `Array.subscript.set`. + + /// Sets the `index`th element of `self` to `value`. + /// + /// Semantically, this function behaves like `Array.subscript.set`. + /// + /// - Note: this mutating method exists as a workaround for + /// `Array.subscript._modify` not being differentiable (TF-1277). @inlinable mutating func update(at index: Int, to value: Element) { self[index] = value } + /// Returns the value and pullback of `self.update`. @usableFromInline @derivative(of: update) mutating func vjpUpdate(at index: Int, to value: Element) -> ( @@ -301,17 +347,27 @@ extension Array { } } +/// A multilayer perceptron with three layers. public struct MLP: Layer { + + /// The first dense layer. public var dense1: Dense + + /// The dropout layer. public var dropout: Dropout + + /// The second dense layer. public var dense2: Dense - public init(nIn: Int, nHidden: Int, nOut: Int, dropoutProbability: Double) { - dense1 = Dense(inputSize: nIn, outputSize: nHidden, activation: tanh) + /// Creates an instance with `inputSize`, `hiddenSize`, + /// `dropoutProbability`, and `outputSize`. + public init(inputSize: Int, hiddenSize: Int, outputSize: Int, dropoutProbability: Double) { + dense1 = Dense(inputSize: inputSize, outputSize: hiddenSize, activation: tanh) dropout = Dropout(probability: dropoutProbability) - dense2 = Dense(inputSize: nHidden, outputSize: nOut, activation: logSoftmax) + dense2 = Dense(inputSize: hiddenSize, outputSize: outputSize, activation: logSoftmax) } + /// Returns the result of applying all three layers in sequence to `input`. @differentiable public func callAsFunction(_ input: Tensor) -> Tensor { return dense2(dropout(dense1(input))) @@ -319,10 +375,12 @@ public struct MLP: Layer { } extension Tensor { - // NOTE(TF-1008): this is a duplicate of `Tensor.scalars` that is needed for differentiation - // correctness. It exists as a workaround for TF-1008: per-instance zero tangent vectors. - // - // Remove this when differentiation uses per-instance zeros + + /// Returns `self`. + /// + /// - Note: this is a workaround for TF-1008 that is needed for + /// differentiation correctness. + // TODO: Remove this when differentiation uses per-instance zeros // (`Differentiable.zeroTangentVectorInitializer`) instead of static zeros // (`AdditiveArithmetic.zero`). @differentiable(where Scalar: TensorFlowFloatingPoint) @@ -330,6 +388,7 @@ extension Tensor { scalars } + /// Returns the value and pullback of `self.scalarsADHack`. @derivative(of: scalarsADHack) func vjpScalarsADHack(device: Device) -> ( value: [Scalar], pullback: (Array.TangentVector) -> Tensor diff --git a/Models/Text/WordSeg/SemiRing.swift b/Models/Text/WordSeg/SemiRing.swift index d985099dcc4..9e70cd897ee 100644 --- a/Models/Text/WordSeg/SemiRing.swift +++ b/Models/Text/WordSeg/SemiRing.swift @@ -20,9 +20,10 @@ import Glibc #endif -/// logSumExp(_:) +/// Returns a single tensor containing the log of the sum of the exponentials +/// in `x`. /// -/// logSumExp (see https://en.wikipedia.org/wiki/LogSumExp) +/// Used for numerical stability when dealing with very small values. @differentiable public func logSumExp(_ x: [Float]) -> Float { if x.count == 0 { return -Float.infinity} @@ -46,31 +47,42 @@ public func vjpLogSumExp(_ x: [Float]) -> ( return (logSumExp(x), pb) } -/// logSumExp(_:_:) +/// Returns a single tensor containing the log of the sum of the exponentials +/// in `lhs` and `rhs`. /// -/// Specialized logSumExp for 2 floats. +/// Used for numerical stability when dealing with very small values. @differentiable public func logSumExp(_ lhs: Float, _ rhs: Float) -> Float { return logSumExp([lhs, rhs]) } -/// SemiRing -/// -/// Represents a SemiRing +/// A storage mechanism for scoring inside a lattice. public struct SemiRing: Differentiable { + + /// The log likelihood. public var logp: Float + + /// The regularization factor. public var logr: Float + /// Creates an instance with log likelihood `logp` and regularization + /// factor `logr`. @differentiable public init(logp: Float, logr: Float) { self.logp = logp self.logr = logr } + /// The baseline score of zero. static var zero: SemiRing { SemiRing(logp: -Float.infinity, logr: -Float.infinity) } + + /// The baseline score of one. static var one: SemiRing { SemiRing(logp: 0.0, logr: -Float.infinity) } } +/// Multiplies `lhs` by `rhs`. +/// +/// Since scores are on a logarithmic scale, products become sums. @differentiable func * (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { return SemiRing( @@ -78,6 +90,7 @@ func * (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { logr: logSumExp(lhs.logp + rhs.logr, rhs.logp + lhs.logr)) } +/// Sums `lhs` by `rhs`. @differentiable func + (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { return SemiRing( @@ -86,6 +99,8 @@ func + (_ lhs: SemiRing, _ rhs: SemiRing) -> SemiRing { } extension Array where Element == SemiRing { + + /// Returns a sum of all scores in the collection. @differentiable func sum() -> SemiRing { return SemiRing( @@ -95,13 +110,18 @@ extension Array where Element == SemiRing { } extension SemiRing { + + /// The plain text description of this instance with score details. var shortDescription: String { "(\(logp), \(logr))" } } -/// SE-0259-esque equality with tolerance extension SemiRing { + + /// Returns true when `self` is within `tolerance` of `other`. + /// + /// - Note: This behavior is modeled after SE-0259. // TODO(abdulras) see if we can use ulp as a default tolerance @inlinable public func isAlmostEqual(to other: Self, tolerance: Float) -> Bool { diff --git a/Support/Text/WordSeg/Alphabet.swift b/Support/Text/WordSeg/Alphabet.swift index 52d6672ef7e..96a5356914b 100644 --- a/Support/Text/WordSeg/Alphabet.swift +++ b/Support/Text/WordSeg/Alphabet.swift @@ -14,20 +14,32 @@ import TensorFlow -/// Alphabet maps from characters in a string to Int32 representations. +/// A mapping between individual characters and their integer representation. /// -/// Note: we map from String in order to support multi-character metadata sequences such as . +/// - Note: We map from String in order to support multi-character metadata +/// sequences such as ``. /// -/// In Python implementations, this is sometimes called the character vocabulary. +/// - Note: In Python implementations, this is sometimes called the character +/// vocabulary. public struct Alphabet { + + /// A type whose instances represent a character. public typealias Element = String + /// A one-to-one mapping between a set of characters and a unique integer. public var dictionary: BijectiveDictionary + /// A marker denoting the end of a sequence. public let eos: Int32 + + /// A marker denoting the end of a word. public let eow: Int32 + + /// A marker used for padding inside a sequence. public let pad: Int32 + /// Creates an instance containing a mapping from `letters` to unique + /// integers, including markers `eos`, `eow`, and `pad`. public init(_ letters: C, eos: String, eow: String, pad: String) where C.Element == Character { self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) @@ -42,6 +54,8 @@ public struct Alphabet { self.dictionary[pad] = self.pad } + /// Creates an instance containing a mapping from `letters` to unique + /// integers, including markers `eos`, `eow`, and `pad`. public init(_ letters: C, eos: String, eow: String, pad: String) where C.Element == Element { self.dictionary = .init(zip(letters.lazy.map { String($0) }, 0...)) @@ -56,8 +70,10 @@ public struct Alphabet { self.dictionary[pad] = self.pad } + /// A count of unique characters, including markers. public var count: Int { return dictionary.count } + /// Accesses the `key`th element, returning `nil` if it does not exist. public subscript(key: String) -> Int32? { return dictionary[key] } diff --git a/Support/Text/WordSeg/CharacterSequence.swift b/Support/Text/WordSeg/CharacterSequence.swift index 27d2f7748a4..60721561dc3 100644 --- a/Support/Text/WordSeg/CharacterSequence.swift +++ b/Support/Text/WordSeg/CharacterSequence.swift @@ -14,16 +14,26 @@ import TensorFlow -/// An Int32-based representation of a string to be used with the WordSeg model. +/// A sequence of characters represented by integers. public struct CharacterSequence: Hashable { + + /// Represents an ordered sequence of characters. public let characters: [Int32] + + /// A marker denoting the end of the sequence. private let eos: Int32 + /// Creates an empty instance without meaningful contents. public init(_debug: Int) { self.characters = [] self.eos = -1 } + /// Creates a sequence from `string`, using `alphabet`, appended with the + /// end marker. + /// + /// - Throws: `CharacterErrors.unknownCharacter` if `string` contains a + /// character that does not exist in `alphabet`. public init(alphabet: Alphabet, appendingEoSTo string: String) throws { var characters = [Int32]() characters.reserveCapacity(string.count + 1) @@ -37,34 +47,70 @@ public struct CharacterSequence: Hashable { self.init(alphabet: alphabet, characters: characters) } + /// Creates a sequence from `characters` and sets the end marker from + /// `alphabet`. + /// + /// - Note: Assumes `characters` contains an end marker. private init(alphabet: Alphabet, characters: [Int32]) { self.characters = characters self.eos = alphabet.eos } + /// Creates a sequence from `characters` and sets the end marker from + /// `alphabet`. + /// + /// - Note: Assumes `characters` contains an end marker. public init(alphabet: Alphabet, characters: ArraySlice) { self.characters = [Int32](characters) self.eos = alphabet.eos } + /// Accesses the `index`th character. public subscript(index: Int32) -> Int32 { return characters[Int(index)] } + /// Accesses characters within `range`. public subscript(range: Range) -> ArraySlice { return characters[range] } + /// Count of characters in the sequence, including the end marker. + public var count: Int { return characters.count } + + /// The last character in the sequence, if `characters` is not empty. + /// + /// - Note: This is usually the end marker. + public var last: Int32? { return characters.last } + + /// Representation for character generation on `device`, with the end marker + /// moved to the beginning. public func tensor(device: Device) -> Tensor { Tensor([self.eos] + characters[0.. + /// A count of unique logical words in the lexicon. public var count: Int { return dictionary.count } + /// Creates an instance containing `sequences`. public init(_ sequences: C) where C.Element == Element { self.dictionary = .init(zip(sequences, 0...)) } + /// Creates an instance containing `sequences` using `alphabet`, truncating + /// elements at `maxLength` and including only those appearing at least + /// `minFrequency` times. public init( from sequences: [CharacterSequence], alphabet: Alphabet, maxLength: Int, - minFreq: Int + minFrequency: Int ) { var histogram: [ArraySlice: Int] = [:] @@ -51,7 +60,7 @@ public struct Lexicon { } } - let frequentWordCandidates = histogram.filter { $0.1 >= minFreq } + let frequentWordCandidates = histogram.filter { $0.1 >= minFrequency } let vocab = frequentWordCandidates.map { CharacterSequence(alphabet: alphabet, characters: $0.0) } @@ -59,20 +68,3 @@ public struct Lexicon { self.init(vocab) } } - -public enum CharacterErrors: Error { - case unknownCharacter(character: Character, index: Int, sentence: String) - case nonUtf8Data -} - -extension CharacterErrors: CustomStringConvertible { - public var description: String { - switch self { - case let .unknownCharacter(character, index, sentence): - return - "Unknown character '\(character)' encountered at index \(index) while converting sentence \"\(sentence)\" to a character sequence." - case .nonUtf8Data: - return "Non-UTF8 data encountered." - } - } -} diff --git a/Tests/DatasetsTests/WordSeg/WordSegDatasetTests.swift b/Tests/DatasetsTests/WordSeg/WordSegDatasetTests.swift index e3ea2da3d37..09ca58e6545 100644 --- a/Tests/DatasetsTests/WordSeg/WordSegDatasetTests.swift +++ b/Tests/DatasetsTests/WordSeg/WordSegDatasetTests.swift @@ -17,29 +17,50 @@ import ModelSupport import XCTest class WordSegDatasetTests: XCTestCase { - func testCreateWordSegDataset() { + func testCreateWordSegDatasetReference() { do { let dataset = try WordSegDataset() - XCTAssertEqual(dataset.training.count, 7832) - XCTAssertEqual(dataset.validation!.count, 979) - XCTAssertEqual(dataset.testing!.count, 979) + XCTAssertEqual(dataset.trainingPhrases.count, 7832) + XCTAssertEqual(dataset.validationPhrases.count, 979) + XCTAssertEqual(dataset.testingPhrases.count, 979) // Check the first example in each set. let trainingExample: [Int32] = [ 26, 16, 22, 24, 2, 15, 21, 21, 16, 20, 6, 6, 21, 9, 6, 3, 16, 16, 12, 28, ] - XCTAssertEqual(dataset.training[0].numericalizedText.characters, trainingExample) + XCTAssertEqual(dataset.trainingPhrases[0].numericalizedText.characters, trainingExample) let validationExample: [Int32] = [9, 6, 13, 13, 16, 14, 10, 14, 10, 28] - XCTAssertEqual(dataset.validation![0].numericalizedText.characters, validationExample) + XCTAssertEqual(dataset.validationPhrases[0].numericalizedText.characters, validationExample) let testingExample: [Int32] = [ 13, 6, 21, 14, 6, 20, 6, 6, 10, 7, 10, 4, 2, 15, 20, 6, 6, 2, 15, 26, 3, 16, 5, 26, 10, 15, 21, 9, 2, 21, 14, 10, 19, 19, 16, 19, 28, ] - XCTAssertEqual(dataset.testing![0].numericalizedText.characters, testingExample) + XCTAssertEqual(dataset.testingPhrases[0].numericalizedText.characters, testingExample) + } catch { + XCTFail(error.localizedDescription) + } + } + + func testCreateWordSegDatasetTrainingOnly() { + do { + let localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("WordSeg", isDirectory: true) + let trainingFile = localStorageDirectory.appendingPathComponent("/seg/br/br-text/tr.txt") + let dataset = try WordSegDataset(training: trainingFile.path) + XCTAssertEqual(dataset.trainingPhrases.count, 7832) + XCTAssertEqual(dataset.validationPhrases.count, 0) + XCTAssertEqual(dataset.testingPhrases.count, 0) + + // Check the first example in each set. + let trainingExample: [Int32] = [ + 26, 16, 22, 24, 2, 15, 21, 21, 16, 20, 6, 6, 21, + 9, 6, 3, 16, 16, 12, 28, + ] + XCTAssertEqual(dataset.trainingPhrases[0].numericalizedText.characters, trainingExample) } catch { XCTFail(error.localizedDescription) } @@ -57,16 +78,17 @@ class WordSegDatasetTests: XCTestCase { Data( bytesNoCopy: UnsafeMutableRawPointer(mutating: address), count: pointer.count, deallocator: .none) - dataset = try? WordSegDataset(training: training, validation: nil, testing: nil) + dataset = WordSegDataset(training: training, validation: nil, testing: nil) } // 'a', 'h', 'l', 'p', '', '', '' XCTAssertEqual(dataset?.alphabet.count, 7) - XCTAssertEqual(dataset?.training.count, 1) + XCTAssertEqual(dataset?.trainingPhrases.count, 1) } static var allTests = [ - ("testCreateWordSegDataset", testCreateWordSegDataset), + ("testCreateWordSegDatasetReference", testCreateWordSegDatasetReference), + ("testCreateWordSegDatasetTrainingOnly", testCreateWordSegDatasetTrainingOnly), ("testWordSegDatasetLoad", testWordSegDatasetLoad), ] } diff --git a/Tests/SupportTests/WordSegSupportTests.swift b/Tests/SupportTests/WordSegSupportTests.swift index 78406e6a6f2..c97feda32da 100644 --- a/Tests/SupportTests/WordSegSupportTests.swift +++ b/Tests/SupportTests/WordSegSupportTests.swift @@ -72,7 +72,7 @@ class WordSegSupportTests: XCTestCase { try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "alpha"), try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "beta"), try! CharacterSequence(alphabet: alphabet, appendingEoSTo: "gamma"), - ], alphabet: alphabet, maxLength: 5, minFreq: 4) + ], alphabet: alphabet, maxLength: 5, minFrequency: 4) XCTAssertEqual(lexicon.count, 3) } diff --git a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift index b5610df265c..06f179697e2 100644 --- a/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift +++ b/Tests/TextTests/WordSegmentationTests/ProbeLayers.swift @@ -123,35 +123,35 @@ func almostEqual( class WordSegProbeLayerTests: XCTestCase { func testProbeEncoder() { - // chrVocab is: + // alphabet is: // 0 - a // 1 - b // 2 - // 3 - // 4 - - let chrVocab: Alphabet = Alphabet( + let alphabet: Alphabet = Alphabet( [ "a", "b", ], eos: "", eow: "", pad: "") - // strVocab is: + // lexicon is: // 0 - aaaa // 1 - bbbb // 2 - abab - let strVocab: Lexicon = Lexicon([ - CharacterSequence(alphabet: chrVocab, characters: [0, 0]), // "aa" - CharacterSequence(alphabet: chrVocab, characters: [1, 1]), // "bb" - CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" - CharacterSequence(alphabet: chrVocab, characters: [1, 0]), // "ba" + let lexicon: Lexicon = Lexicon([ + CharacterSequence(alphabet: alphabet, characters: [0, 0]), // "aa" + CharacterSequence(alphabet: alphabet, characters: [1, 1]), // "bb" + CharacterSequence(alphabet: alphabet, characters: [0, 1]), // "ab" + CharacterSequence(alphabet: alphabet, characters: [1, 0]), // "ba" ]) var model = SNLM( parameters: SNLM.Parameters( - ndim: 2, - dropoutProb: 0, - chrVocab: chrVocab, - strVocab: strVocab, + hiddenSize: 2, + dropoutProbability: 0, + alphabet: alphabet, + lexicon: lexicon, order: 5)) model.setParameters(Example1.parameters) @@ -159,7 +159,7 @@ class WordSegProbeLayerTests: XCTestCase { print("Encoding") let encoderStates = model.encode( - CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1]), device: device) // "abab" + CharacterSequence(alphabet: alphabet, characters: [0, 1, 0, 1]), device: device) // "abab" let encoderStatesTensor = Tensor(stacking: encoderStates) print("Expected: \(Example1.expectedEncoding)") print("Actual: \(encoderStatesTensor)") @@ -185,8 +185,8 @@ class WordSegProbeLayerTests: XCTestCase { print("Decode") let decoded = model.decode( [ - CharacterSequence(alphabet: chrVocab, characters: [0, 0, 0]), // "aaa" - CharacterSequence(alphabet: chrVocab, characters: [0, 1]), // "ab" + CharacterSequence(alphabet: alphabet, characters: [0, 0, 0]), // "aaa" + CharacterSequence(alphabet: alphabet, characters: [0, 1]), // "ab" ], encoderStates[0], device: device @@ -197,7 +197,7 @@ class WordSegProbeLayerTests: XCTestCase { print("OK!\n") print("Build Lattice") - let abab = CharacterSequence(alphabet: chrVocab, characters: [0, 1, 0, 1]) + let abab = CharacterSequence(alphabet: alphabet, characters: [0, 1, 0, 1]) let lattice = model.buildLattice(abab, maxLen: 5, device: device) XCTAssert(lattice.isAlmostEqual(to: Example1.lattice, tolerance: 1e-5))