diff --git a/.gitmodules b/.gitmodules index 6cb63fc0b..e3de79a86 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,3 +17,6 @@ [submodule "src/3rd_party/simple-websocket-server"] path = src/3rd_party/simple-websocket-server url = https://github.com/marian-nmt/Simple-WebSocket-Server +[submodule "src/3rd_party/cub"] + path = src/3rd_party/cub + url = https://github.com/NVIDIA/cub diff --git a/CHANGELOG.md b/CHANGELOG.md index 4182d72b6..75cc5ddf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## [Unreleased] ### Added +- Adds a fast path to perform a max reduction along the last axis to reduce the H2D communication. +- Refactors the beam search to batch processing of secondary factors for factored vocabulary models. - Add --train-embedder-rank for fine-tuning any encoder(-decoder) model for multi-lingual similarity via softmax-margin loss - Add --logical-epoch that allows to redefine the displayed epoch counter as a multiple of n data epochs, updates or labels. Also allows to define width of fractional part with second argument. - Add --metrics chrf for computing ChrF according to https://www.aclweb.org/anthology/W15-3049/ and SacreBLEU reference implementation diff --git a/src/3rd_party/cub b/src/3rd_party/cub new file mode 160000 index 000000000..52d58a889 --- /dev/null +++ b/src/3rd_party/cub @@ -0,0 +1 @@ +Subproject commit 52d58a88904da39c374e44a6a8ae0e4dcca5b71a diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6dcf7fd89..a79ab0ad2 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,3 +1,4 @@ +add_definitions(-DCUB_IGNORE_DEPRECATED_CPP_DIALECT=1) add_subdirectory(3rd_party) include_directories(.) diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index c565e0357..2082a06c6 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -1,3 +1,8 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "tensors/backend.h" @@ -495,6 +500,9 @@ struct ReduceNodeOp : public UnaryNodeOp { case ReduceNodeOpCode::min: return {NodeOp(Reduce(_1, min(_1,_2), std::numeric_limits::max(), val_, child(0)->val()))}; case ReduceNodeOpCode::max: + if(axis_ == child(0)->shape().size() - 1 && graph()->getBackend()->getDeviceId().type == DeviceType::gpu ) { + return {NodeOp(ReduceMaxLastAxis(val_, child(0)->val()))}; + } return {NodeOp(Reduce(_1, max(_1,_2), std::numeric_limits::lowest(), val_, child(0)->val()))}; case ReduceNodeOpCode::prod: return {NodeOp(Reduce(_1, _1 * _2, 1.0f, val_, child(0)->val()))}; diff --git a/src/layers/generic.cpp b/src/layers/generic.cpp index 6c760aace..16c1107d8 100755 --- a/src/layers/generic.cpp +++ b/src/layers/generic.cpp @@ -1,3 +1,9 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + +#include "graph/node_initializers.h" #include "marian.h" #include "layers/generic.h" @@ -66,6 +72,37 @@ namespace marian { // return logits_.front(); //} + std::vector Logits::getSecondaryFactorLogits(std::vector factorGroups, + const std::vector& hypIndices, + size_t batchSize, size_t beamSize, + const std::vector& expandedPathScores, + float scorerWeight) const { + const size_t totalElts = batchSize * beamSize; + std::vector updatedPathScores(factorGroups.size()); + auto indices = graph()->indices(hypIndices); + + for(int fgIndex = 0; fgIndex < (int)factorGroups.size(); ++fgIndex) { + size_t factorGroup = factorGroups[fgIndex]; + ABORT_IF(factorGroup == 0, "Lemmas not supported"); + + // Find and subtract max from factor scores + auto sel = logits_[factorGroup]->loss(); // [localBeamSize, 1, dimBatch, dimFactorVocab] + sel = sel - max(sel, -1); + + // Obtain slice for indices + int start = (int)totalElts * fgIndex; + int end = (int)totalElts * (fgIndex + 1); + Slice fgSlice(start, end, 1); + Expr fgIndices = slice(indices, 0, fgSlice); + + // Select relevant scores + Expr logProbs = rnn::State::select(sel, fgIndices, (int)beamSize, /*isBatchMajor=*/false); + updatedPathScores[fgIndex] = expandedPathScores[fgIndex] + scorerWeight * logProbs; + } + + return updatedPathScores; + } + // get logits for one factor group // For groupIndex == 0, the function also requires the shortlist if there is one. Expr Logits::getFactoredLogits(size_t groupIndex, Ptr shortlist /*= nullptr*/, const std::vector& hypIndices /*= {}*/, size_t beamSize /*= 0*/) const { @@ -88,8 +125,10 @@ namespace marian { } // if selIdx are given, then we must reshuffle accordingly - if (!hypIndices.empty()) // use the same function that shuffles decoder state - sel = rnn::State::select(sel, hypIndices, (int)beamSize, /*isBatchMajor=*/false); + if (!hypIndices.empty()) { // use the same function that shuffles decoder state + auto indices = graph()->indices(hypIndices); + sel = rnn::State::select(sel, indices, (int)beamSize, /*isBatchMajor=*/false); + } return sel; } diff --git a/src/layers/generic.h b/src/layers/generic.h index e83663357..c226844ee 100755 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -1,3 +1,8 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "marian.h" @@ -115,6 +120,8 @@ class Logits { Logits(std::vector>&& logits, Ptr embeddingFactorMapping) // factored-output constructor : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors + std::vector getSecondaryFactorLogits(std::vector factorGroups, const std::vector& hypIndices, size_t batchSize, size_t beamSize, + const std::vector& expandedPathScores, float scorerWeight) const; // get logits all secondary factor groups in factorGroups vector Expr getFactoredLogits(size_t groupIndex, Ptr shortlist = nullptr, const std::vector& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle //Ptr getRationalLoss() const; // assume it holds a loss: get that Expr applyLossFunction(const Words& labels, const std::function& lossFn) const; diff --git a/src/rnn/types.h b/src/rnn/types.h index 47424b759..739c7c496 100644 --- a/src/rnn/types.h +++ b/src/rnn/types.h @@ -1,5 +1,12 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once +#include "common/definitions.h" +#include "common/shape.h" #include "marian.h" #include @@ -12,7 +19,7 @@ struct State { Expr output; Expr cell; - State select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] + State select(Expr selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) const { return{ select(output, selIdx, beamSize, isBatchMajor), select(cell, selIdx, beamSize, isBatchMajor) }; @@ -20,15 +27,14 @@ struct State { // this function is also called by Logits static Expr select(Expr sel, // [beamSize, dimTime, dimBatch, dimDepth] or [beamSize, dimBatch, dimTime, dimDepth] (dimTime = 1 for RNN) - const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] + Expr selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) { if (!sel) return sel; // keep nullptr untouched sel = atleast_4d(sel); - - int dimBatch = (int)selIdx.size() / beamSize; + int dimBatch =(int) selIdx->shape().elements()/beamSize; int dimDepth = sel->shape()[-1]; int dimTime = isBatchMajor ? sel->shape()[-2] : sel->shape()[-3]; @@ -83,8 +89,30 @@ class States { States select(const std::vector& selIdx, // [beamIndex * activeBatchSize + batchIndex] int beamSize, bool isBatchMajor) const { States selected; + Expr indices; + + // We need to check if either a states's cell or output fields are non-null. In this case, we need + // to select rows from at least one of the tensors. If only some exprs are non-null, the call to + // select will handle this for us by returning a null expr naturally. + for (auto& state : states_) { + if (state.cell) { + indices = state.cell->graph()->indices(selIdx); + break; + } + + if (state.output) { + indices = state.output->graph()->indices(selIdx); + break; + } + } + + // If indices is null here, then all of the state.cell and state.output entries are null. Therefore, + // select will ignore the null indices expr and simply return a null pointer which is the expected + // behavior + + // GPU OPT: Implement kernel to batch these on GPU for(auto& state : states_) - selected.push_back(state.select(selIdx, beamSize, isBatchMajor)); + selected.push_back(state.select(indices, beamSize, isBatchMajor)); return selected; } diff --git a/src/tensors/cpu/tensor_operators.cpp b/src/tensors/cpu/tensor_operators.cpp index 211283d58..63f3757a1 100755 --- a/src/tensors/cpu/tensor_operators.cpp +++ b/src/tensors/cpu/tensor_operators.cpp @@ -3,6 +3,12 @@ * SPDX-License-Identifier: MIT */ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + + #include "tensors/tensor_operators.h" #include "tensors/cpu/backend.h" #include "tensors/allocator.h" @@ -24,6 +30,11 @@ namespace cpu { ABORT("Not implemented"); } +void ReduceMaxLastAxis(Tensor /*out*/, + const marian::Tensor& /*input*/) { + ABORT("Not implemented"); +} + template void CopyCastTo(To* out, const From* in, int length) { for(int i = 0; i < length; ++i) diff --git a/src/tensors/gpu/tensor_operators.cu b/src/tensors/gpu/tensor_operators.cu index 2552b7c7e..b983dcf07 100644 --- a/src/tensors/gpu/tensor_operators.cu +++ b/src/tensors/gpu/tensor_operators.cu @@ -1,3 +1,8 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #include "common/types.h" #include "tensors/tensor_operators.h" @@ -8,6 +13,22 @@ #include "tensors/gpu/cuda_helpers.h" #include "tensors/gpu/add_all.h" +#include "3rd_party/reduce_all.h" + +#if COMPILE_FP16 +#include +__device__ __forceinline__ half max(const half a, const half b) { + return a > b ? a : b; +} +#endif + + +#if CUDA_VERSION >= 11000 +#include +#else +#include "cub/cub/cub.cuh" +#endif + namespace marian { @@ -2930,5 +2951,70 @@ void PoolingWithMaskingBackward(Tensor adj, width, lastWidth); } + +template +__global__ void gReduceMaxLastAxis(T* outTensor, const T* inputTensor, int innerDimSize) { + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + size_t inputBlockStartOffset = blockIdx.x * innerDimSize; + const T* blockInputPtr = inputTensor + inputBlockStartOffset; + T blockMax = cub::FpLimits::Lowest(); + + for(int tid = threadIdx.x; tid < innerDimSize; tid += BLOCK_THREADS) { + blockMax = max(blockMax, blockInputPtr[tid]); + } + + int aggregate = BlockReduce(temp_storage).Reduce(blockMax, cub::Max()); + if(threadIdx.x == 0) outTensor[blockIdx.x] = aggregate; +} + +#define CASE_THREADS(BLOCKS, THREADS) \ + case THREADS: \ + gReduceMaxLastAxis<<>>( \ + out, \ + input, \ + sizeOfLastDim); \ + break; \ + +template +void ReduceMaxLastAxisTyped(T* out, const T* input, int sizeOfLastDim, int blocks, int threads) { + threads = std::max(nextPow2((unsigned int)threads), 32U); + switch(threads) { + CASE_THREADS(blocks, 32); + CASE_THREADS(blocks, 64); + CASE_THREADS(blocks, 128); + CASE_THREADS(blocks, 256); + CASE_THREADS(blocks, 512); + CASE_THREADS(blocks, 1024); + default: + ABORT("Invalid number of threads in config for ReduceMaxLastAxis"); + } +} + +void ReduceMaxLastAxis(Tensor out, + const marian::Tensor& input) { + + cudaSetDevice(out->getDeviceId().no); + int outputElts = out->shape().elements(); + int inputElts = input->shape().elements(); + + int sizeOfLastDim = input->shape()[-1]; + int blocks = inputElts / sizeOfLastDim; + + ABORT_IF(blocks != outputElts, "Expected {} elts in output tensor but tensor has size {}", blocks, outputElts); + int threads = std::min(sizeOfLastDim, MAX_THREADS); + + if(out->type() == Type::float32) { + ReduceMaxLastAxisTyped(out->data(), input->data(), sizeOfLastDim, blocks, threads); + #if COMPILE_FP16 + } else if(out->type() == Type::float16) { + ReduceMaxLastAxisTyped(out->data(), input->data(), sizeOfLastDim, blocks, threads); + #endif + } else { + ABORT("ReduceMaxLastAxis not implemented for type {}", out->type()); + } +} } // namespace gpu } // namespace marian diff --git a/src/tensors/tensor_operators.h b/src/tensors/tensor_operators.h index e075244f5..2a7a6f185 100644 --- a/src/tensors/tensor_operators.h +++ b/src/tensors/tensor_operators.h @@ -1,3 +1,8 @@ +/* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #pragma once #include "common/definitions.h" @@ -103,6 +108,7 @@ DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bo DISPATCH8(ProdBatched, marian::Tensor, Ptr, const marian::Tensor, const marian::Tensor, bool, bool, float, float) DISPATCH9(CSRProd, marian::Tensor, Ptr, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float) +DISPATCH2(ReduceMaxLastAxis, marian::Tensor, const marian::Tensor&) DISPATCH2(Softmax, marian::Tensor, marian::Tensor) DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor) diff --git a/src/translator/beam_search.cpp b/src/translator/beam_search.cpp old mode 100755 new mode 100644 index 9335c55bd..f6c509357 --- a/src/translator/beam_search.cpp +++ b/src/translator/beam_search.cpp @@ -1,3 +1,8 @@ + /* Part of this file was contributed by NVIDIA under license: + * Copyright (C) 2020 NVIDIA Corporation + * SPDX-License-Identifier: MIT + */ + #include "translator/beam_search.h" #include "data/factored_vocab.h" @@ -334,8 +339,9 @@ Histories BeamSearch::search(Ptr graph, Ptr if (maxBeamSize == 0) break; - for (size_t factorGroup = 0; factorGroup < numFactorGroups; factorGroup++) { - // for factored vocabs, we do one factor at a time, but without updating the scorer for secondary factors + // We first process the lemmas then all of the remaining factor groups in parallel. + for(int processingLemmas = 1; processingLemmas >= 0; --processingLemmas) { + // for factored vocabs, we do lemmas then all the factor groups, the scorer is not updated for the secondary factor //********************************************************************** // create constant containing previous path scores for current beam @@ -344,69 +350,107 @@ Histories BeamSearch::search(Ptr graph, Ptr std::vector hypIndices; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from std::vector prevWords; // [maxBeamSize, 1, currentDimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history Expr prevPathScores; // [maxBeamSize, 1, currentDimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores) + std::vector factorGroupsToExpand; // A list of all of the factor groups to be expanded bool anyCanExpand = false; // stays false if all hyps are invalid factor expansions - if(t == 0 && factorGroup == 0) { // no scores yet + if(t == 0 && processingLemmas) { // no scores yet prevPathScores = graph->constant({1, 1, 1, 1}, inits::fromValue(0)); anyCanExpand = true; + // We need to expand factorGroup 0 only + factorGroupsToExpand.push_back(0); + // at the beginning all batch entries are used batchIndices.resize(origDimBatch); std::iota(batchIndices.begin(), batchIndices.end(), 0); } else { - if(factorGroup == 0) // only factorGroup==0 can subselect neural state + if(processingLemmas) // only factorGroup==0 can subselect neural state for(int currentBatchIdx = 0; currentBatchIdx < beams.size(); ++currentBatchIdx) // loop over batch entries (active sentences) if(!beams[currentBatchIdx].empty() || !PURGE_BATCH) // for each beam check batchIndices.push_back(prevBatchIdxMap[currentBatchIdx]); // which batch entries were active in previous step std::vector prevScores; - for(size_t beamHypIdx = 0; beamHypIdx < maxBeamSize; ++beamHypIdx) { // loop over globally maximal beam-size (maxBeamSize) - for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) { // loop over all batch entries (active and inactive) - auto& beam = beams[origBatchIdx]; - if(beamHypIdx < beam.size()) { - auto hyp = beam[beamHypIdx]; - auto word = hyp->getWord(); - auto canExpand = (!factoredVocab || factoredVocab->canExpandFactoredWord(hyp->getWord(), factorGroup)); - //LOG(info, "[{}, {}] Can expand {} with {} -> {}", batchIdx, beamHypIdx, (*batch->back()->vocab())[hyp->getWord()], factorGroup, canExpand); - anyCanExpand |= canExpand; - - auto currentBatchIdx = origBatchIdx; - if(PURGE_BATCH) { - if(factorGroup == 0) - currentBatchIdx = prevBatchIdxMap[origBatchIdx]; // subselection may happen for factorGroup == 0 - else - currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0, - // but we treat it like a next step, since a step - // happened for factorGroup == 0 - } - - auto hypIndex = (IndexType)(hyp->getPrevStateIndex() * currentDimBatch + currentBatchIdx); // (beamHypIdx, batchIdx), flattened, for index_select() operation - - hypIndices.push_back(hypIndex); // (beamHypIdx, batchIdx), flattened as said above. - prevWords .push_back(word); - prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE); - } else { // pad to maxBeamSize (dummy hypothesis) - if(!PURGE_BATCH || !beam.empty()) { // but only if we are not pruning and the beam is not deactivated yet - hypIndices.push_back(0); - prevWords.push_back(trgEosId); // (unused, but must be valid) - prevScores.push_back((float)INVALID_PATH_SCORE); + // If we are processing the lemmas, we want to only process factor group 0. Therefore the bound is [0, 1) + // For all other factors, we batch the [1, numFactorGroups) factors together and process them in parallel. + size_t factorGroupBound = processingLemmas? 1 : numFactorGroups; + size_t factorGroupStart = processingLemmas? 0 : 1; + for(size_t factorGroup = factorGroupStart; factorGroup < factorGroupBound; ++factorGroup) { + bool factorCanExpand = false; + for(size_t beamHypIdx = 0; beamHypIdx < maxBeamSize; ++beamHypIdx) { // loop over globally maximal beam-size (maxBeamSize) + for(int origBatchIdx = 0; origBatchIdx < origDimBatch; ++origBatchIdx) { // loop over all batch entries (active and inactive) + auto& beam = beams[origBatchIdx]; + if(beamHypIdx < beam.size()) { + auto hyp = beam[beamHypIdx]; + auto word = hyp->getWord(); + auto canExpand = (!factoredVocab || factoredVocab->canExpandFactoredWord(hyp->getWord(), factorGroup)); + + //LOG(info, "[{}, {}] Can expand {} with {} -> {}", batchIdx, beamHypIdx, (*batch->back()->vocab())[hyp->getWord()], factorGroup, canExpand); + factorCanExpand |= canExpand; + anyCanExpand |= canExpand; + + auto currentBatchIdx = origBatchIdx; + if(PURGE_BATCH) { + if(factorGroup == 0) + currentBatchIdx = prevBatchIdxMap[origBatchIdx]; // subselection may happen for factorGroup == 0 + else + currentBatchIdx = batchIdxMap[origBatchIdx]; // no subselection happens for factorGroup > 0, + // but we treat it like a next step, since a step + // happened for factorGroup == 0 + } + + auto hypIndex = (IndexType)(hyp->getPrevStateIndex() * currentDimBatch + currentBatchIdx); // (beamHypIdx, batchIdx), flattened, for index_select() operation + + hypIndices.push_back(hypIndex); // (beamHypIdx, batchIdx), flattened as said above. + prevWords .push_back(word); + prevScores.push_back(canExpand ? hyp->getPathScore() : INVALID_PATH_SCORE); + } else { // pad to maxBeamSize (dummy hypothesis) + if(!PURGE_BATCH || !beam.empty()) { // but only if we are not pruning and the beam is not deactivated yet + hypIndices.push_back(0); + prevWords.push_back(trgEosId); // (unused, but must be valid) + prevScores.push_back((float)INVALID_PATH_SCORE); + } } } } + + // If none of the factor groups can be expanded + if(!factorCanExpand && !processingLemmas) { + int newElts = currentDimBatch * (int)maxBeamSize; + hypIndices.resize(hypIndices.size() - newElts); + prevWords.resize(prevWords.size() - newElts); + prevScores.resize(prevScores.size() - newElts); + } else { + factorGroupsToExpand.push_back(factorGroup); + } } - if(factorGroup == 0) - currentDimBatch = (IndexType) batchIndices.size(); // keep batch size constant for all factor groups in a time step - prevPathScores = graph->constant({(int)maxBeamSize, 1, (int)currentDimBatch, 1}, inits::fromVector(prevScores)); + + // keep batch size constant for all factor groups in a time step + if(processingLemmas) currentDimBatch = (IndexType) batchIndices.size(); + + // Avoid unnecessary memcpy on GPU if no words can expand this factor. + if(anyCanExpand) prevPathScores = graph->constant({(int)factorGroupsToExpand.size() * (int)maxBeamSize, 1, (int)currentDimBatch, 1}, inits::fromVector(prevScores)); } + if (!anyCanExpand) // all words cannot expand this factor: skip continue; + std::vector expandedPathScoresForFactorGroups(factorGroupsToExpand.size()); + if(processingLemmas) { + expandedPathScoresForFactorGroups[0] = prevPathScores; + } else { + for(int fgIndex = 0; fgIndex < factorGroupsToExpand.size(); ++fgIndex) { + Slice window(fgIndex * (int)maxBeamSize, (fgIndex + 1) * (int)maxBeamSize, 1); + auto scoreSlice = slice(prevPathScores, 0, window); + scoreSlice = reshape(scoreSlice, {(int)maxBeamSize, 1, (int)currentDimBatch, 1}); + expandedPathScoresForFactorGroups[fgIndex] = scoreSlice; + } + } + //********************************************************************** // compute expanded path scores with word prediction probs from all scorers - auto expandedPathScores = prevPathScores; // will become [maxBeamSize, 1, currDimBatch, dimVocab] - Expr logProbs; for(size_t i = 0; i < scorers_.size(); ++i) { - if (factorGroup == 0) { + if (processingLemmas) { + Expr logProbs; // compute output probabilities for current output time step // - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][] // - adds prevWords [index in beam, 1, batch index, 1] to the scorer's target history @@ -424,8 +468,10 @@ Histories BeamSearch::search(Ptr graph, Ptr else { auto shortlist = scorers_[i]->getShortlist(); - logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, shortlist); // [maxBeamSize, 1, currentDimBatch, dimVocab] + logProbs = states[i]->getLogProbs().getFactoredLogits(0 /*factorGroup*/, shortlist); // [maxBeamSize, 1, currentDimBatch, dimVocab] } + // expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab] + expandedPathScoresForFactorGroups[0] = expandedPathScoresForFactorGroups[0] + scorers_[i]->getWeight() * logProbs; } else { // add secondary factors @@ -437,53 +483,61 @@ Histories BeamSearch::search(Ptr graph, Ptr // and push out other hypotheses. Hence, we exclude those here by setting the path score to // INVALID_PATH_SCORE. Instead, toHyps() explicitly propagates those hyps by simply copying the // previous hypothesis. - logProbs = states[i]->getLogProbs().getFactoredLogits(factorGroup, /*shortlist=*/ nullptr, hypIndices, maxBeamSize); // [maxBeamSize, 1, currentDimBatch, dimVocab] + + expandedPathScoresForFactorGroups = states[i]->getLogProbs().getSecondaryFactorLogits(factorGroupsToExpand, hypIndices, currentDimBatch, maxBeamSize, + expandedPathScoresForFactorGroups, scorers_[i]->getWeight()); } - // expand all hypotheses, [maxBeamSize, 1, currentDimBatch, 1] -> [maxBeamSize, 1, currentDimBatch, dimVocab] - expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs; } // make beams continuous - expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab] - + for(auto& expandedPathScores : expandedPathScoresForFactorGroups) { + expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [currentDimBatch, 1, maxBeamSize, dimVocab] + } + // perform NN computation - if(t == 0 && factorGroup == 0) + if(t == 0 && processingLemmas) graph->forward(); else graph->forwardNext(); //********************************************************************** // suppress specific symbols if not at right positions - if(unkColId != -1 && factorGroup == 0) - suppressWord(expandedPathScores, unkColId); - for(auto state : states) - state->blacklist(expandedPathScores, batch); + for(auto& expandedPathScores : expandedPathScoresForFactorGroups) { + if(unkColId != -1 && processingLemmas) + suppressWord(expandedPathScores, unkColId); + for(auto state : states) + state->blacklist(expandedPathScores, batch); + } //********************************************************************** // perform beam search - - // find N best amongst the (maxBeamSize * dimVocab) hypotheses - std::vector nBestKeys; // [currentDimBatch, maxBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened - std::vector nBestPathScores; // [currentDimBatch, maxBeamSize] flattened - getNBestList(/*in*/ expandedPathScores->val(), // [currentDimBatch, 1, maxBeamSize, dimVocab or dimShortlist] - /*N=*/ maxBeamSize, // desired beam size - /*out*/ nBestPathScores, - /*out*/ nBestKeys, - /*first=*/t == 0 && factorGroup == 0); // @TODO: this is only used for checking presently, and should be removed altogether - // Now, nBestPathScores contain N-best expandedPathScores for each batch and beam, - // and nBestKeys for each their original location (batchIdx, beamHypIdx, word). - - // combine N-best sets with existing search space (beams) to updated search space - beams = toHyps(nBestKeys, nBestPathScores, - /*nBestBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys - /*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys - beams, - states, // used for keeping track of per-ensemble-member path score - batch, // only used for propagating alignment info - factoredVocab, factorGroup, - emptyBatchEntries, // [origDimBatch] - empty source batch entries are marked with true - batchIdxMap); // used to create a reverse batch index map to recover original batch indices for this step - } // END FOR factorGroup = 0 .. numFactorGroups-1 + for(int fgIndex = 0; fgIndex < (int) factorGroupsToExpand.size(); ++fgIndex) { + Expr expandedPathScores = expandedPathScoresForFactorGroups[fgIndex]; + int factorGroup = (int)factorGroupsToExpand[fgIndex]; + + // find N best amongst the (maxBeamSize * dimVocab) hypotheses + std::vector nBestKeys; // [currentDimBatch, maxBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened + std::vector nBestPathScores; // [currentDimBatch, maxBeamSize] flattened + getNBestList(/*in*/ expandedPathScores->val(), // [currentDimBatch, 1, maxBeamSize, dimVocab or dimShortlist] + /*N=*/ maxBeamSize, // desired beam size + /*out*/ nBestPathScores, + /*out*/ nBestKeys, + /*first=*/t == 0 && processingLemmas); // @TODO: this is only used for checking presently, and should be removed altogether + // Now, nBestPathScores contain N-best expandedPathScores for each batch and beam, + // and nBestKeys for each their original location (batchIdx, beamHypIdx, word). + + // combine N-best sets with existing search space (beams) to updated search space + beams = toHyps(nBestKeys, nBestPathScores, + /*nBestBeamSize*/expandedPathScores->shape()[-2], // used for interpretation of keys + /*vocabSize=*/expandedPathScores->shape()[-1], // used for interpretation of keys + beams, + states, // used for keeping track of per-ensemble-member path score + batch, // only used for propagating alignment info + factoredVocab, factorGroup, + emptyBatchEntries, // [origDimBatch] - empty source batch entries are marked with true + batchIdxMap); // used to create a reverse batch index map to recover original batch indices for this step + } + } // END FOR processingLemmas = 1 .. 0 prevBatchIdxMap = batchIdxMap; // save current batchIdx map to be used in next step; we are then going to look one step back