Skip to content

Commit

Permalink
added search similar, handled http exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
teticio committed Aug 16, 2021
1 parent 6e5fdbf commit bbfd747
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 44 deletions.
3 changes: 3 additions & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ uvicorn = "*"
sqlalchemy = "*"
aiohttp = "*"
aiofiles = "*"
librosa = "*"
yapf = "*"
pytest = "*"

[dev-packages]
yapf = "*"
Expand Down
324 changes: 320 additions & 4 deletions Pipfile.lock

Large diffs are not rendered by default.

118 changes: 96 additions & 22 deletions backend/deejai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import os
import re
import uuid
import pickle
import random
import shutil
import librosa
import requests
import numpy as np
from io import BytesIO
import tensorflow as tf
from keras.models import load_model
from starlette.concurrency import run_in_threadpool
from tensorflow.compat.v1.keras.losses import cosine_proximity


class DeejAI:
Expand All @@ -24,20 +34,26 @@ def __init__(self):
self.mp3tovecs = np.array([[mp3tovecs[_], tracktovecs[_]]
for _ in mp3tovecs])
del mp3tovecs, tracktovecs
self.model = load_model(
'speccy_model',
custom_objects={'cosine_proximity': cosine_proximity})

def get_tracks(self):
return self.tracks
return self.tracks

def search(self, string, max_items=100):
tracks = self.tracks
search_string = re.sub(r'([^\s\w]|_)+', '', string.lower()).split()
ids = sorted([
track for track in tracks
if all(word in re.sub(r'([^\s\w]|_)+', '', tracks[track].lower())
for word in search_string)
],
key=lambda x: tracks[x])[:max_items]
return ids
async def search(self, string, max_items=100):
def _search():
tracks = self.tracks
search_string = re.sub(r'([^\s\w]|_)+', '', string.lower()).split()
ids = sorted([
track for track in tracks
if all(word in re.sub(r'([^\s\w]|_)+', '', tracks[track].lower())
for word in search_string)
],
key=lambda x: tracks[x])[:max_items]
return ids

return await run_in_threadpool(_search)

async def playlist(self, track_ids, size, creativity, noise):
if len(track_ids) == 0:
Expand All @@ -54,12 +70,12 @@ async def playlist(self, track_ids, size, creativity, noise):
noise=noise)

async def most_similar(self,
mp3tovecs,
weights,
positive=[],
negative=[],
noise=0,
vecs=None):
mp3tovecs,
weights,
positive=[],
negative=[],
noise=0,
vecs=None):
mp3_vecs_i = np.array([
weights[j] *
np.sum([mp3tovecs[i, j]
Expand All @@ -83,11 +99,11 @@ async def most_similar(self,
return result

async def most_similar_by_vec(self,
mp3tovecs,
weights,
positives=[],
negatives=[],
noise=0):
mp3tovecs,
weights,
positives=[],
negatives=[],
noise=0):
mp3_vecs_i = np.array([
weights[j] * np.sum(positives[j] if positives else [] +
-negatives[j] if negatives else [],
Expand Down Expand Up @@ -159,3 +175,61 @@ async def make_playlist(self,
playlist_tracks.append(self.tracks[track_id])
playlist_indices.append(candidate)
return playlist

async def get_similar_vec(self, track_url, max_items=10):
def _get_similar_vec():
y, sr = librosa.load(f'{playlist_id}.{extension}', mono=True)
os.remove(f'{playlist_id}.{extension}')
S = librosa.feature.melspectrogram(y=y,
sr=sr,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
fmax=sr / 2)
# hack because Spotify samples are a shade under 30s
x = np.ndarray(shape=(S.shape[1] // slice_size + 1, n_mels,
slice_size, 1),
dtype=float)
for slice in range(S.shape[1] // slice_size):
log_S = librosa.power_to_db(
S[:, slice * slice_size:(slice + 1) * slice_size],
ref=np.max)
if np.max(log_S) - np.min(log_S) != 0:
log_S = (log_S - np.min(log_S)) / (np.max(log_S) -
np.min(log_S))
x[slice, :, :, 0] = log_S
# hack because Spotify samples are a shade under 30s
log_S = librosa.power_to_db(S[:, -slice_size:], ref=np.max)
if np.max(log_S) - np.min(log_S) != 0:
log_S = (log_S - np.min(log_S)) / (np.max(log_S) -
np.min(log_S))
x[-1, :, :, 0] = log_S
return self.model.predict(x)

playlist_id = str(uuid.uuid4())
n_fft = 2048
hop_length = 512
n_mels = self.model.layers[0].input_shape[0][1]
slice_size = self.model.layers[0].input_shape[0][2]

try:
r = requests.get(track_url, allow_redirects=True)
if r.status_code != 200:
return []
extension = 'wav' if 'wav' in r.headers['Content-Type'] else 'mp3'
with open(f'{playlist_id}.{extension}',
'wb') as file: # this is really annoying!
shutil.copyfileobj(BytesIO(r.content), file, length=131072)
vecs = await run_in_threadpool(_get_similar_vec)
candidates = await self.most_similar_by_vec(
self.mp3tovecs[:, np.newaxis, 0, :], [1], [vecs])
ids = [
self.track_ids[candidate]
for candidate in candidates[0:max_items]
]
return ids
except Exception as e:
print(e)
if os.path.exists(f'./{playlist_id}.mp3'):
os.remove(f'./{playlist_id}.mp3')
return []
40 changes: 28 additions & 12 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from .deejai import DeejAI
from sqlalchemy import desc
from sqlalchemy.orm import Session
from fastapi import Depends, FastAPI
from .database import SessionLocal, engine
from fastapi.staticfiles import StaticFiles
from starlette.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.concurrency import run_in_threadpool
from fastapi import Depends, FastAPI, HTTPException

# create tables if necessary
models.Base.metadata.create_all(bind=engine)
Expand Down Expand Up @@ -76,10 +75,16 @@ async def spotify_callback(code: str):
.encode('utf-8')).decode('utf-8')
}
async with aiohttp.ClientSession() as session:
async with session.post('https://accounts.spotify.com/api/token',
data=data,
headers=headers) as response:
json = await response.json()
try:
async with session.post('https://accounts.spotify.com/api/token',
data=data,
headers=headers) as response:
if response.status != 200:
raise HTTPException(status_code=response.status,
detail=response.reason)
json = await response.json()
except aiohttp.ClientError as error:
raise HTTPException(status_code=400, detail=str(error))
url = os.environ.get('APP_URL', '') + "/#" + urllib.parse.urlencode(
{
'access_token': json['access_token'],
Expand All @@ -98,17 +103,28 @@ async def spotify_refresh_token(refresh_token: str):
.encode('utf-8')).decode('utf-8')
}
async with aiohttp.ClientSession() as session:
async with session.post('https://accounts.spotify.com/api/token',
data=data,
headers=headers) as response:
json = await response.json()
try:
async with session.post('https://accounts.spotify.com/api/token',
data=data,
headers=headers) as response:
if response.status != 200:
raise HTTPException(status_code=response.status,
detail=response.reason)
json = await response.json()
except aiohttp.ClientError as error:
raise HTTPException(status_code=400, detail=str(error))
return json


@app.post("/search")
async def search_tracks(search: schemas.Search):
ids = await run_in_threadpool(deejai.search, search.string,
search.max_items)
ids = await deejai.search(search.string, search.max_items)
return [{'track_id': id, 'track': deejai.get_tracks()[id]} for id in ids]


@app.post("/search_similar")
async def search_similar_tracks(search: schemas.SearchSimilar):
ids = await deejai.get_similar_vec(search.url, search.max_items)
return [{'track_id': id, 'track': deejai.get_tracks()[id]} for id in ids]


Expand Down
5 changes: 5 additions & 0 deletions backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class Search(BaseModel):
max_items: Optional[int] = 100


class SearchSimilar(BaseModel):
url: str
max_items: Optional[int] = 10


class NewPlaylist(BaseModel):
track_ids: list
size: Optional[int] = 10
Expand Down
58 changes: 58 additions & 0 deletions backend/test_deejai.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,61 @@ def test_update_playlist():
playlist = get_playlist(id, db)
assert (playlist.name == "Test" and playlist.av_rating == 4.5
and playlist.num_ratings == 1)


def test_search_similar():
search = schemas.SearchSimilar(
url=
'https://p.scdn.co/mp3-preview/04b28b12174a4c4448486070962dae74494c0f70?cid=194086cb37be48ebb45b9ba4ce4c5936'
)
assert asyncio.run(search_similar_tracks(search)) == [{
"track_id":
"1a9SiOELQS7YsBQwdEPMuq",
"track":
"Luis Fonsi - Despacito"
}, {
"track_id":
"6rPO02ozF3bM7NnOV4h6s2",
"track":
"Luis Fonsi - Despacito - Remix"
}, {
"track_id":
"5AgTL2WmiCvoObA8fpncKs",
"track":
"Luis Fonsi - Despacito"
}, {
"track_id":
"7dx0Funwrd0LRvquDFQ8fv",
"track":
"Cali Y El Dandee - Lumbra"
}, {
"track_id":
"7CUYHcu0RnbOnMz4RuN07w",
"track":
"Luis Fonsi - Despacito (Featuring Daddy Yankee)"
}, {
"track_id":
"2YFOm3hznEzQsIMmEwGyUg",
"track":
"Leon - Legalna"
}, {
"track_id":
"1tJw60G9KHl7fYVdQ2JDgo",
"track":
"J Balvin - Ginza - Remix"
}, {
"track_id":
"1v3fyyGJRlblbobabiXxIs",
"track":
"Latifah - On My Way"
}, {
"track_id":
"3jWfGOOUffq51fWGQdPV68",
"track":
"Achille Lauro - Non sei come me"
}, {
"track_id":
"2HR9Ih2IjpGEQ3YZl7aRUQ",
"track":
"Jeano - Abow"
}]
5 changes: 3 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
fastapi
tensorflow
yapf
uvicorn
pytest
sqlalchemy
aiohttp
aiofiles
librosa
yapf
pytest
17 changes: 17 additions & 0 deletions requirements-lock.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
absl-py==0.13.0
aiofiles==0.7.0
aiohttp==3.7.4.post0
appdirs==1.4.4
asgiref==3.4.1
astunparse==1.6.3
async-timeout==3.0.1
attrs==21.2.0
audioread==2.1.9
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
chardet==4.0.0
charset-normalizer==2.0.4
clang==5.0
click==8.0.1
colorama==0.4.4
decorator==5.0.9
fastapi==0.68.0
flatbuffers==1.12
gast==0.4.0
Expand All @@ -23,21 +27,33 @@ grpcio==1.39.0
h11==0.12.0
h5py==3.1.0
idna==3.2
joblib==1.0.1
keras==2.6.0
Keras-Preprocessing==1.1.2
librosa==0.8.1
llvmlite==0.36.0
Markdown==3.3.4
multidict==5.1.0
numba==0.53.1
numpy==1.19.5
oauthlib==3.1.1
opt-einsum==3.3.0
packaging==21.0
pooch==1.4.0
protobuf==3.17.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
pydantic==1.8.2
pyparsing==2.4.7
requests==2.26.0
requests-oauthlib==1.3.0
resampy==0.2.2
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
six==1.15.0
SoundFile==0.10.3.post1
SQLAlchemy==1.4.22
starlette==0.14.2
tensorboard==2.6.0
Expand All @@ -46,6 +62,7 @@ tensorboard-plugin-wit==1.8.0
tensorflow==2.6.0
tensorflow-estimator==2.6.0
termcolor==1.1.0
threadpoolctl==2.2.0
typing-extensions==3.7.4.3
urllib3==1.26.6
uvicorn==0.15.0
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ uvicorn
sqlalchemy
aiohttp
aiofiles
librosa
Loading

0 comments on commit bbfd747

Please sign in to comment.