Skip to content

Commit

Permalink
Merge pull request #1381 from palonso/add-maest-models
Browse files Browse the repository at this point in the history
Add MAEST models
  • Loading branch information
dbogdanov authored Oct 20, 2023
2 parents efc65d8 + 92f0831 commit 77a6a95
Show file tree
Hide file tree
Showing 18 changed files with 857 additions and 146 deletions.
106 changes: 101 additions & 5 deletions doc/sphinxdoc/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ If you use any of the models in your research, please cite the following paper::
booktitle={International Conference on Acoustics, Speech and Signal Processing ({ICASSP})},
year={2020}
}

.. highlight:: default


Expand Down Expand Up @@ -137,6 +137,105 @@ Models:
*Note: We provide models operating with a fixed batch size of 64 samples since it was not possible to port the version with dynamic batch size from ONNX to TensorFlow. Additionally, an ONNX version of the model with* `dynamic batch <https://essentia.upf.edu/models/feature-extractors/discogs-effnet/discogs-effnet-bsdynamic-1.onnx>`_ *size is provided.*


MAEST
^^^^^

Music Audio Efficient Spectrogram Transformer (`MAEST <https://github.com/palonso/MAEST/>`_) trained to predict music style labels using an in-house dataset annotated with Discogs metadata.
We offer versions of MAEST trained with sequence lengths ranging from 5 to 30 seconds (``5s``, ``10s``, ``20s``, and ``30s``), and trained starting from different intial weights: from random initialization (``fs``), from `DeiT <https://doi.org/10.48550/arXiv.2012.12877>`_ pre-trained weights (``dw``), and from `PaSST <https://doi.org/10.48550/arXiv.2106.07139>`_ pre-trained weights (``pw``). Additionally, we offer a version of MAEST trained following a teacher student setup (``ts``).
According to our study ``discogs-maest-30s-pw``, achieved the most competitive performance in most downstream tasks (refer to the `paper <http://hdl.handle.net/10230/58023>`_ for details).


Models:

.. collapse:: ⬇️ <a class="reference external">discogs-maest-30s-pw</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-30s-pw-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-30s-pw-ts</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-ts-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-30s-pw-ts-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-30s-pw-ts-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-20s-pw</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-20s-pw-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-20s-pw-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-20s-pw-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-10s-pw</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-pw-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-pw-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-10s-pw-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-10s-fs</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-fs-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-fs-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-10s-fs-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-10s-dw</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-dw-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-10s-dw-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-10s-dw-1_embeddings.py

.. collapse:: ⬇️ <a class="reference external">discogs-maest-5s-pw</a>

|
[`weights <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-5s-pw-1.pb>`_, `metadata <https://essentia.upf.edu/models/feature-extractors/maest/discogs-maest-5s-pw-1.json>`_]

Model trained with a multi-label classification objective targeting 400 Discogs styles.

Python code for embedding extraction:

.. literalinclude:: ../../src/examples/python/models/scripts/feature-extractors/maest/discogs-maest-5s-pw-1_embeddings.py


*Note: It is possible to retrieve the output of each attention layer by setting* ``output=StatefulParitionedCall:n`` *, where* ``n`` *is the index of the layer (starting from 1).*
*The output from the attention layers should be interpreted as* ``[batch_index, 1, token_number, embeddings_size]``
*, where the first and second tokens (i.e.,* ``[0, 0, :2, :]`` *) correspond to the* ``CLS`` *and* ``DIST`` *tokens respectively, and the following ones to input signal.*

OpenL3
^^^^^^

Expand Down Expand Up @@ -240,7 +339,7 @@ The name of these models is a combination of the classification/regression task
*Note: TensorflowPredict2D has to be configured with the correct output layer name for each classifier. Check the attached JSON file to find the name of the output layer on each case.*


Music genre and style
Music genre and style
^^^^^^^^^^^^^^^^^^^^^


Expand Down Expand Up @@ -2071,6 +2170,3 @@ Models:
Python code for predictions:

.. literalinclude :: ../../src/examples/python/models/scripts/tempo/tempocnn/deeptemp-k16-3_predictions.py
17 changes: 6 additions & 11 deletions src/algorithms/machinelearning/tensorflowpredict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ const Tensor<Real> TensorflowPredict::TFToTensor(
TF_Output TensorflowPredict::graphOperationByName(const string nodeName) {
int index = 0;
const char* name = nodeName.c_str();
string newNodeName;

// TensorFlow operations (or nodes from the graph perspective) return tensors named <nodeName:n>, where n goes
// from 0 to the number of outputs. The first output tensor of a node can be extracted implicitly (nodeName)
Expand All @@ -374,22 +375,16 @@ TF_Output TensorflowPredict::graphOperationByName(const string nodeName) {
string::size_type n = nodeName.find(':');
if (n != string::npos) {
try {
string::size_type next_char;
index = stoi(nodeName.substr(n + 1), &next_char);

if (n + next_char + 1 != nodeName.size()) {
throw EssentiaException("TensorflowPredict: `" + nodeName + "` is not a valid node name, the index cannot "
"be followed by other characters. Make sure that all your inputs and outputs follow "
"the pattern `nodeName:n`, where `n` in an integer that goes from 0 to the number "
"of outputs of the node - 1.");
}
newNodeName = nodeName.substr(0, n);
name = newNodeName.c_str();
index = stoi(nodeName.substr(n + 1, nodeName.size()));

} catch (const invalid_argument& ) {
throw EssentiaException("TensorflowPredict: `" + nodeName + "` is not a valid node name. Make sure that all "
"your inputs and outputs follow the pattern `nodeName:n`, where `n` in an integer that "
"goes from 0 to the number of outputs of the node - 1.");
}
name = nodeName.substr(0, n).c_str();
}

}

TF_Operation* oper = TF_GraphOperationByName(_graph, name);
Expand Down
Loading

0 comments on commit 77a6a95

Please sign in to comment.