diff --git a/beets/util/rate_limiter.py b/beets/util/rate_limiter.py new file mode 100644 index 0000000000..c381b5e262 --- /dev/null +++ b/beets/util/rate_limiter.py @@ -0,0 +1,84 @@ +import threading +import time + + +class RateLimiter: + """Limits the rate at which one or multiple sections of code may be called. + + The limiting is thread-safe: threads that are rate-limited will sleep until they + are not anymore. + + Important: The rate limiter only limits the start of execution of the rate-limited + code. This means for example that for rate-limited web queries of one per second, + it is assured that at most one request is started per second, but there may be + multiple queries running concurrently if the first one takes time to execute. + """ + + def __init__(self, reqs_per_interval: int, interval_sec: float): + """Create the rate limiter with the specified rate + + :param reqs_per_interval: Number of requests that can be done per interval. + Must be strictly positive + :param interval_sec: The interval in seconds. Must be strictly positive + """ + + if reqs_per_interval <= 0.0: + raise ValueError("reqs_per_interval can't be less than 0") + if interval_sec <= 0: + raise ValueError("interval_sec can't be less than 0") + + # Configuration variables + self.reqs_per_interval = reqs_per_interval + self.interval_sec = interval_sec + + # Current state + self.lock = threading.Lock() + self.last_call = 0.0 + self.remaining_requests = None + + def _update_remaining(self): + """Update the number of remaining requests that can be done and the time of + last call + """ + if self.remaining_requests is None: + # On first invocation, we have the number of requests available + self.remaining_requests = float(self.reqs_per_interval) + + else: + # On following invocations, increase the number of requests available + # based on the time since last invocation + since_last_call = time.time() - self.last_call + self.remaining_requests += since_last_call * ( + self.reqs_per_interval / self.interval_sec + ) + # Number of requests cannot exceed the max number per interval + self.remaining_requests = min( + self.remaining_requests, float(self.reqs_per_interval) + ) + + self.last_call = time.time() + + def __enter__(self): + with self.lock: + self._update_remaining() + + # Assert to avoid typing errors + assert self.remaining_requests is not None + + # Delay if necessary + while self.remaining_requests < 0.999: + time.sleep( + (1.0 - self.remaining_requests) + * (self.interval_sec / self.reqs_per_interval) + ) + self._update_remaining() + + # "Pay" for the execution of the rate limited code section + self.remaining_requests -= 1.0 + + return self + + def __exit__(self, exc_type, exc_value, traceback): + # Nothing to do: limiting is only done on the start of execution of the + # rate-limited code + pass diff --git a/beetsplug/_mb_interface.py b/beetsplug/_mb_interface.py new file mode 100644 index 0000000000..9a3b5af2d4 --- /dev/null +++ b/beetsplug/_mb_interface.py @@ -0,0 +1,541 @@ +import json +import re +from typing import TYPE_CHECKING, Literal +from urllib.parse import quote + +import requests.exceptions # For mbzero exception handling +from mbzero import mbzauth, mbzerror +from mbzero import mbzrequest as mbzr + +import beets +from beets.ui import UserError +from beets.util.rate_limiter import RateLimiter + +if TYPE_CHECKING: + from ._typing import JSONDict + + +class MbInterfaceError(Exception): + """Base class for exceptions raised by MbInterface""" + + pass + + +class MbInterfaceBadRequestError(MbInterfaceError): + """Exception raised when the request is ill-formed""" + + pass + + +class MbInterfaceUnauthorizedError(MbInterfaceError): + """Exception raised when the request does not have valid and sufficient + authentication for accessing the resource""" + + pass + + +class MbInterfaceNotFoundError(MbInterfaceError): + """Exception raised when an entity is not found""" + + pass + + +def _convert_mbzero_exception_to_local( + mbz_ex: mbzerror.MbzWebServiceError, +) -> MbInterfaceError: + """Convert the mbzero exception to a MbInterfaceError + + :param mbz_ex: The mbzero exception to convert + :return: The converted exception. Either the error could be converted into a + specific error, or an instance of the base class is returned. + """ + + # Upstream issue to have fine-grained exception handling: + # - https://gitlab.com/mbzero/python-mbzero/-/issues/2 + + # mbz_ex.message is not a str but the exception cause, while mbz_ex.cause is None... + if isinstance(mbz_ex.message, requests.exceptions.HTTPError): + http_error: requests.exceptions.HTTPError = mbz_ex.message + status_code = http_error.response.status_code + + if status_code == 400: + return MbInterfaceBadRequestError(mbz_ex) + elif status_code == 401: + return MbInterfaceUnauthorizedError(mbz_ex) + elif status_code == 404: + return MbInterfaceNotFoundError(mbz_ex) + + # Base exception if no specific one could be found + return MbInterfaceError(mbz_ex) + + +class MbInterface: + """An interface for sending requests using MusicBrainz API""" + + def __init__( + self, + hostname: str, + https: bool, + rate_limiter: RateLimiter, + auth: mbzauth.MbzCredentials | None = None, + ): + self.hostname = hostname + self.https = https + self.rate_limiter = rate_limiter + self.useragent = "beets/{} (https://beets.io/)".format( + beets.__version__ + ) + self.auth = auth + + def _lookup( + self, + entity_type: str, + mbid: str, + includes: list[str], + ) -> bytes: + """Send a lookup request to the configured MusicBrainz API to get information + on a single entity + + :param entity_type: The type of entity to look up + :param mbid: The MusicBrainz ID of the entity to look up + :param includes: List of parameters to request more information to be included + about the entity + :return: The response as bytes + :raises MbInterfaceError: if the request did not succeed + """ + with self.rate_limiter: + return self._send( + mbzr.MbzRequestLookup( + self.useragent, entity_type, mbid, includes + ), + ) + + def _browse( + self, + lookup_entity_type: str, + mbid: str, + linked_entities_type: str, + includes: list[str] = [], + limit: int | None = None, + offset: int | None = None, + ) -> bytes: + """Send a browse request to the configured MusicBrainz API to get entities + linked to looked up one + + :param lookup_entity_type: The type of entity to look up + :param mbid: The MusicBrainz ID of the entity to look up + :param linked_entities_type: The type of linked entities to find + :param includes: List of parameters to request more information to be included + about the entity + :param limit: The number of entities that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The response as bytes + :raises MbInterfaceError: if the request did not succeed + """ + with self.rate_limiter: + return self._send( + mbzr.MbzRequestBrowse( + self.useragent, + linked_entities_type, + lookup_entity_type, + mbid, + includes, + ), + limit=limit, + offset=offset, + ) + + def _search( + self, + entity_type: str, + query: str, + limit: int | None = None, + offset: int | None = None, + **fields, + ) -> bytes: + """Send a search request to the configured MusicBrainz API to search entities + based on a query + + :param entity_type: The type of entity to look up + :param query: The query in the Lucene Search syntax + :param limit: The number of entities that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The response as bytes + :raises MbInterfaceError: if the request did not succeed + """ + + # mbzero does not properly handle queries with special characters. + # Quote the query beforehand to avoid this problem. + # See: https://gitlab.com/mbzero/python-mbzero/-/issues/1 + quoted_query = quote(query) + + with self.rate_limiter: + return self._send( + mbzr.MbzRequestSearch( + self.useragent, entity_type, quoted_query + ), + limit=limit, + offset=offset, + ) + + def _send( + self, + mbr: mbzr.MbzRequestLookup + | mbzr.MbzRequestSearch + | mbzr.MbzRequestBrowse, + limit: int | None = None, + offset: int | None = None, + ) -> bytes: + """Send the request + + :param mbr: The request object + :param limit: The number of entities that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The response as bytes + :raises MbInterfaceError: if the request did not succeed + """ + if self.hostname: + scheme = "https" if self.https else "http" + mbr.set_url(f"{scheme}://{self.hostname}/ws/2") + opts = {} + if limit: + opts["limit"] = limit + if offset: + opts["offset"] = offset + try: + return mbr.send(opts=opts, credentials=self.auth) + except mbzerror.MbzWebServiceError as ex: + raise _convert_mbzero_exception_to_local(ex) + + def _make_query(self, fields: dict[str, str] = {}) -> str: + """Make a Lucene Query string from a dict of fields + + :param fields: Dict of field keys and values used to build the query. + Values will be properly escaped. + :return: The built Lucene Query string + """ + # Encode the query terms as a Lucene query string. + lucene_special = r'([+\-&|!(){}\[\]\^"~*?:\\\/])' + query_parts = [] + + for key, value in fields.items(): + # Escape Lucene's special characters. + value = re.sub(lucene_special, r"\\\1", value) + if value: + value = value.lower() # avoid AND / OR + query_parts.append(f"{key}:({value})") + full_query = " ".join(query_parts).strip() + + if not full_query: + raise ValueError("at least one query term is required") + + return full_query + + @staticmethod + def _remove_none_values(data): + """Iterate recursively over a Python object to remove all None values in + dicts + """ + if isinstance(data, dict): + return { + key: MbInterface._remove_none_values(value) + for key, value in data.items() + if value is not None + } + elif isinstance(data, list): + return [MbInterface._remove_none_values(item) for item in data] + else: + return data + + @staticmethod + def _parse_and_clean_json(data: bytes) -> "JSONDict": + """Parse the JSON data and remove all None values in dicts. + This is needed as the MusicBrainz JSON data contains None values instead of + simply not setting them in dictionaries. + This is also different from the their XML data which only contains filled + values. + + :param data: JSON data as bytes + """ + return MbInterface._remove_none_values(json.loads(data)) + + def browse_recordings( + self, + lookup_entity_type: Literal["artist", "collection", "release", "work"], + mbid: str, + includes: list[str] = [], + limit: int | None = None, + offset: int | None = None, + ) -> "JSONDict": + """Browse recordings linked to an entity + + :param lookup_entity_type: The type of entity whose recordings are to be browsed + :param mbid: The MusicBrainz ID of the entity + :param includes: List of parameters to request more information to be included + about the recordings + :param limit: The number of recordings that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._browse( + lookup_entity_type, + mbid, + "recording", + includes, + limit=limit, + offset=offset, + ) + ) + + def browse_release( + self, + lookup_entity_type: Literal["artist", "collection", "release"], + mbid: str, + includes: list[str] = [], + limit: int | None = None, + offset: int | None = None, + ) -> "JSONDict": + """Browse releases linked to an entity + + :param lookup_entity_type: The type of entity whose releases are to be + browsed + :param mbid: The MusicBrainz ID of the entity + :param includes: List of parameters to request more information to be included + about the releases + :param limit: The number of releases that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The JSON-decoded response as an object + :raises mbzerror.MbzRequestError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._browse( + lookup_entity_type, + mbid, + "release", + includes, + limit=limit, + offset=offset, + ) + ) + + def browse_release_groups( + self, + lookup_entity_type: Literal["artist", "collection", "release"], + mbid: str, + includes: list[str] = [], + limit: int | None = None, + offset: int | None = None, + ) -> "JSONDict": + """Browse release-groups linked to an entity + + :param lookup_entity_type: The type of entity whose release-groups are to be + browsed + :param mbid: The MusicBrainz ID of the entity + :param includes: List of parameters to request more information to be included + about the release-groups + :param limit: The number of release-groups that should be returned + :param offset: Offset used for paging through more than one page of results + :return: The JSON-decoded response as an object + :raises mbzerror.MbzRequestError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._browse( + lookup_entity_type, + mbid, + "release-group", + includes, + limit=limit, + offset=offset, + ) + ) + + def get_release_by_id( + self, + mbid: str, + includes: list[str] = [], + ) -> "JSONDict": + """Get a release from its ID + + :param mbid: The MusicBrainz ID of the release + :param includes: List of parameters to request more information to be included + about the release + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._lookup( + "release", + mbid, + includes, + ) + ) + + def get_recording_by_id( + self, + mbid: str, + includes: list[str] = [], + ) -> "JSONDict": + """Get a recording from its ID + + :param mbid: The MusicBrainz ID of the entity + :param includes: List of parameters to request more information to be included + about the recording + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._lookup( + "recording", + mbid, + includes, + ) + ) + + def get_work_by_id( + self, + mbid: str, + includes: list[str] = [], + ) -> "JSONDict": + """Get a work from its ID + + :param mbid: The MusicBrainz ID of the entity + :param includes: List of parameters to request more information to be included + about the work + :return: The JSON-decoded response as an object + :raises mbzerror.MbzRequestError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._lookup( + "work", + mbid, + includes, + ) + ) + + def get_user_collections(self) -> "JSONDict": + """Get the collections of the authenticated user + + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + + assert self.auth is not None, ( + "get_user_collections requires an authenticated user" + ) + + return MbInterface._parse_and_clean_json( + self._lookup("collection", "", includes=[]) + ) + + def search_releases( + self, + limit: int | None = None, + offset: int | None = None, + **fields: str, + ) -> "JSONDict": + """Search for releases using a query + + :param limit: The number of releases that should be returned + :param offset: Offset used for paging through more than one page of results + :param fields: Dict of fields composing the search query + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._search( + "release", + query=self._make_query(fields), + limit=limit, + offset=offset, + ) + ) + + def search_recordings( + self, + limit: int | None = None, + offset: int | None = None, + **fields: str, + ) -> "JSONDict": + """Search for recordings using a query + + :param limit: The number of recordings that should be returned + :param offset: Offset used for paging through more than one page of results + :param fields: Dict of fields composing the search query + :return: The JSON-decoded response as an object + :raises MbInterfaceError: if the request did not succeed + """ + return MbInterface._parse_and_clean_json( + self._search( + "recording", + query=self._make_query(fields), + limit=limit, + offset=offset, + ) + ) + + +class SharedMbInterface: + """Singleton holding a shared MbInterface. + This can be used to use the same configuration, rate limiting, etc. between + multiple plugins. + """ + + def __new__(cls): + """Create the singleton""" + if not hasattr(cls, "instance"): + cls.instance = super(SharedMbInterface, cls).__new__(cls) + return cls.instance + + def __init__(self): + mb_config = beets.config["musicbrainz"] + mb_config.add( + { + "host": "musicbrainz.org", + "https": False, + "ratelimit": 1, + "ratelimit_interval": 1, + } + ) + mb_config["pass"].redact = True + + hostname = mb_config["host"].as_str() + https = mb_config["https"].get(bool) + # Force https usage for default MusicBrainz server + if hostname == "musicbrainz.org": + https = True + + # Set the auth if the config is set + if mb_config["user"].exists() and mb_config["pass"].exists(): + auth = mbzauth.MbzCredentials() + auth.auth_set( + mb_config["user"].as_str(), mb_config["pass"].as_str() + ) + else: + auth = None + + self.mb_interface = MbInterface( + hostname, + https, + RateLimiter( + reqs_per_interval=mb_config["ratelimit"].get(int), + interval_sec=mb_config["ratelimit_interval"].as_number(), + ), + auth, + ) + + def require_auth_for_plugin(self, plugin_name: str) -> "SharedMbInterface": + """Raise an error if the authentication has not been configured + + :param plugin_name: Name of the plugin that requires authentication to work + """ + if self.mb_interface.auth is None: + raise UserError( + f"MusicBrainz authentication is required for plugin {plugin_name}: " + "musicbrainz.user and musicbrainz.pass need to be set in configuration" + ) + return self + + def get(self) -> MbInterface: + return self.mb_interface diff --git a/beetsplug/listenbrainz.py b/beetsplug/listenbrainz.py index c579645db9..8670857cf9 100644 --- a/beetsplug/listenbrainz.py +++ b/beetsplug/listenbrainz.py @@ -2,13 +2,14 @@ import datetime -import musicbrainzngs import requests from beets import config, ui from beets.plugins import BeetsPlugin from beetsplug.lastimport import process_tracks +from ._mb_interface import SharedMbInterface + class ListenBrainzPlugin(BeetsPlugin): """A Beets plugin for interacting with ListenBrainz.""" @@ -22,6 +23,7 @@ def __init__(self): self.token = self.config["token"].get() self.username = self.config["username"].get() self.AUTH_HEADER = {"Authorization": f"Token {self.token}"} + self.mb_interface = SharedMbInterface().get() config["listenbrainz"]["token"].redact = True def commands(self): @@ -132,10 +134,9 @@ def get_tracks_from_listens(self, listens): def get_mb_recording_id(self, track): """Returns the MusicBrainz recording ID for a track.""" - resp = musicbrainzngs.search_recordings( - query=track["track_metadata"].get("track_name"), + resp = self.mb_interface.search_recordings( + recording=track["track_metadata"].get("track_name"), release=track["track_metadata"].get("release_name"), - strict=True, ) if resp.get("recording-count") == "1": return resp.get("recording-list")[0].get("id") @@ -210,7 +211,7 @@ def get_track_info(self, tracks): track_info = [] for track in tracks: identifier = track.get("identifier") - resp = musicbrainzngs.get_recording_by_id( + resp = self.mb_interface.get_recording_by_id( identifier, includes=["releases", "artist-credits"] ) recording = resp.get("recording") diff --git a/beetsplug/mbcollection.py b/beetsplug/mbcollection.py index 1c010bf504..ca981ae3e4 100644 --- a/beetsplug/mbcollection.py +++ b/beetsplug/mbcollection.py @@ -17,10 +17,16 @@ import musicbrainzngs -from beets import config, ui +from beets import ui from beets.plugins import BeetsPlugin from beets.ui import Subcommand +from ._mb_interface import ( + MbInterfaceError, + MbInterfaceUnauthorizedError, + SharedMbInterface, +) + SUBMISSION_CHUNK_SIZE = 200 FETCH_CHUNK_SIZE = 100 UUID_REGEX = r"^[a-f0-9]{8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}$" @@ -30,12 +36,10 @@ def mb_call(func, *args, **kwargs): """Call a MusicBrainz API function and catch exceptions.""" try: return func(*args, **kwargs) - except musicbrainzngs.AuthenticationError: + except MbInterfaceUnauthorizedError: raise ui.UserError("authentication with MusicBrainz failed") - except (musicbrainzngs.ResponseError, musicbrainzngs.NetworkError) as exc: + except MbInterfaceError as exc: raise ui.UserError(f"MusicBrainz API error: {exc}") - except musicbrainzngs.UsageError: - raise ui.UserError("MusicBrainz credentials missing") def submit_albums(collection_id, release_ids): @@ -44,17 +48,14 @@ def submit_albums(collection_id, release_ids): """ for i in range(0, len(release_ids), SUBMISSION_CHUNK_SIZE): chunk = release_ids[i : i + SUBMISSION_CHUNK_SIZE] + # TODO: mbzero does not support PUT requests... + # - https://gitlab.com/mbzero/python-mbzero/-/issues/3 mb_call(musicbrainzngs.add_releases_to_collection, collection_id, chunk) class MusicBrainzCollectionPlugin(BeetsPlugin): def __init__(self): super().__init__() - config["musicbrainz"]["pass"].redact = True - musicbrainzngs.auth( - config["musicbrainz"]["user"].as_str(), - config["musicbrainz"]["pass"].as_str(), - ) self.config.add( { "auto": False, @@ -62,11 +63,14 @@ def __init__(self): "remove": False, } ) + self.mb_interface = ( + SharedMbInterface().require_auth_for_plugin(self.name).get() + ) if self.config["auto"]: self.import_stages = [self.imported] def _get_collection(self): - collections = mb_call(musicbrainzngs.get_collections) + collections = mb_call(self.mb_interface.get_user_collections) if not collections["collection-list"]: raise ui.UserError("no collections exist for user") @@ -90,7 +94,8 @@ def _get_collection(self): def _get_albums_in_collection(self, id): def _fetch(offset): res = mb_call( - musicbrainzngs.get_releases_in_collection, + self.mb_interface.browse_release, + "collection", id, limit=FETCH_CHUNK_SIZE, offset=offset, @@ -124,6 +129,8 @@ def remove_missing(self, collection_id, lib_albums): remove_me = list(set(albums_in_collection) - lib_ids) for i in range(0, len(remove_me), FETCH_CHUNK_SIZE): chunk = remove_me[i : i + FETCH_CHUNK_SIZE] + # TODO: mbzero does not support DELETE requests... + # - https://gitlab.com/mbzero/python-mbzero/-/issues/3 mb_call( musicbrainzngs.remove_releases_from_collection, collection_id, diff --git a/beetsplug/missing.py b/beetsplug/missing.py index c4bbb83fda..8e5eb54a1b 100644 --- a/beetsplug/missing.py +++ b/beetsplug/missing.py @@ -18,8 +18,7 @@ from collections import defaultdict from collections.abc import Iterator -import musicbrainzngs -from musicbrainzngs.musicbrainz import MusicBrainzError +from mbzero import mbzerror from beets import config, plugins from beets.dbcore import types @@ -27,6 +26,8 @@ from beets.plugins import BeetsPlugin from beets.ui import Subcommand, decargs, print_ +from ._mb_interface import SharedMbInterface + MB_ARTIST_QUERY = r"mb_albumartistid::^\w{8}-\w{4}-\w{4}-\w{4}-\w{12}$" @@ -103,6 +104,8 @@ def __init__(self): } ) + self.mb_interface = SharedMbInterface().get() + self.album_template_fields["missing"] = _missing_count self._command = Subcommand("missing", help=__doc__, aliases=["miss"]) @@ -189,8 +192,8 @@ def _missing_albums(self, lib: Library, query: list[str]) -> None: calculating_total = self.config["total"].get() for (artist, artist_id), album_ids in album_ids_by_artist.items(): try: - resp = musicbrainzngs.browse_release_groups(artist=artist_id) - except MusicBrainzError as err: + resp = self.mb_interface.browse_release_groups(artist_id) + except mbzerror.MbzWebServiceError as err: self._log.info( "Couldn't fetch info for artist '{}' ({}) - '{}'", artist, diff --git a/beetsplug/musicbrainz.py b/beetsplug/musicbrainz.py index dc49d7f261..17aef1dee0 100644 --- a/beetsplug/musicbrainz.py +++ b/beetsplug/musicbrainz.py @@ -16,26 +16,26 @@ from __future__ import annotations -import json -import threading -import time import traceback -import re from collections import Counter from functools import cached_property from itertools import product from typing import TYPE_CHECKING, Any from urllib.parse import urljoin -from mbzero import mbzerror -from mbzero import mbzrequest as mbzr - import beets import beets.autotag.hooks from beets import config, plugins, util from beets.plugins import BeetsPlugin from beets.util.id_extractors import extract_release_id +from ._mb_interface import ( + MbInterface, + MbInterfaceError, + MbInterfaceNotFoundError, + SharedMbInterface, +) + if TYPE_CHECKING: from collections.abc import Iterator, Sequence from typing import Literal @@ -46,7 +46,7 @@ VARIOUS_ARTISTS_ID = "89ad4ac3-39f7-470e-963a-56509c546377" -BASE_URL = "musicbrainz.org/ws/2" +BASE_URL = "https://musicbrainz.org" SKIPPED_TRACKS = ["[data track]"] @@ -60,278 +60,6 @@ } -# Rate limiting - - -class _RateLimitsSingleton: - limit_interval = 1.0 - limit_requests = 1 - do_rate_limit = True - - def __new__(cls): - """Makes that class a singleton""" - if not hasattr(cls, "instance"): - cls.instance = super(_RateLimitsSingleton, cls).__new__(cls) - return cls.instance - - def get(self): - return (self.limit_interval, self.limit_requests, self.do_rate_limit) - - def set_rate_limit(self, limit_or_interval=1.0, new_requests=1): - """Sets the rate limiting behavior of the module. Must be invoked - before the first Web service call. - If the `limit_or_interval` parameter is set to False then - rate limiting will be disabled. If it is a number then only - a set number of requests (`new_requests`) will be made per - given interval (`limit_or_interval`). - """ - if isinstance(limit_or_interval, bool): - self.do_rate_limit = limit_or_interval - else: - if limit_or_interval <= 0.0: - raise ValueError("limit_or_interval can't be less than 0") - if new_requests <= 0: - raise ValueError("new_requests can't be less than 0") - self.do_rate_limit = True - self.limit_interval = limit_or_interval - self.limit_requests = new_requests - - -class _RateLim(object): - """A decorator that limits the rate at which the function may be - called. The rate is controlled by the `limit_interval` and - `limit_requests` global variables. The limiting is thread-safe; - only one thread may be in the function at a time (acts like a - monitor in this sense). The globals must be set before the first - call to the limited function. - """ - - def __init__(self, fun): - self.fun = fun - self.last_call = 0.0 - self.lock = threading.Lock() - self.remaining_requests = None # Set on first invocation. - - def _update_remaining(self): - """Update remaining requests based on the elapsed time since - they were last calculated. - """ - (limit_interval, limit_requests, do_rate_limit) = ( - _RateLimitsSingleton().get() - ) - - # On first invocation, we have the maximum number of requests - # available. - if self.remaining_requests is None: - self.remaining_requests = float(limit_requests) - - else: - since_last_call = time.time() - self.last_call - self.remaining_requests += since_last_call * ( - limit_requests / limit_interval - ) - self.remaining_requests = min( - self.remaining_requests, float(limit_requests) - ) - - self.last_call = time.time() - - def __call__(self, *args, **kwargs): - (limit_interval, limit_requests, do_rate_limit) = ( - _RateLimitsSingleton().get() - ) - - with self.lock: - if do_rate_limit: - self._update_remaining() - - # Delay if necessary. - while self.remaining_requests < 0.999: - time.sleep( - (1.0 - self.remaining_requests) - * (limit_requests / limit_interval) - ) - self._update_remaining() - - # Call the original function, "paying" for this call. - self.remaining_requests -= 1.0 - return self.fun(*args, **kwargs) - - -# Musicbrainz library interface - - -class MbWebServiceError(mbzerror.MbzWebServiceError): - pass - - -class MusicBrainzError(mbzerror.MbzError): - pass - - -class MbResponseError(mbzerror.MbzWebServiceError): - pass - - -class MbInterface: - BEETS_USERAGENT = "beets/{} (https://beets.io/)".format(beets.__version__) - - def __init__(self, useragent=BEETS_USERAGENT): - self.hostname = BASE_URL - self.https = True - self.useragent = useragent - pass - - def set_hostname(self, hostname, https): - self.hostname = hostname - self.https = https - - @_RateLim - def _lookup( - self, entity, mbid, includes, limit=None, offset=None, params={} - ): - return self._send( - mbzr.MbzRequestLookup(self.useragent, entity, mbid, includes), - limit=limit, - offset=offset, - ) - - @_RateLim - def _browse( - self, entity, bw_entity, mbid, includes=[], limit=None, offset=None - ): - return self._send( - mbzr.MbzRequestBrowse( - self.useragent, entity, bw_entity, mbid, includes - ), - limit=limit, - offset=offset, - ) - - @_RateLim - def _search(self, entity, query, limit=None, offset=None, **fields): - return self._send( - mbzr.MbzRequestSearch(self.useragent, entity, query), - limit=limit, - offset=offset, - ) - - def _send(self, mbr, limit=None, offset=None): - if self.hostname: - mbr.set_url( - "{}://{}".format( - "https" if self.https else "http", self.hostname - ) - ) - opts = {} - if limit: - opts["limit"] = limit - if offset: - opts["offset"] = offset - return mbr.send(opts=opts) - - def _make_params(self, release_status=[], release_type=[]): - params = {} - if len(release_status): - params["status"] = "|".join(release_status) - if len(release_type): - params["type"] = "|".join(release_type) - return params - - def _make_query(self, query="", fields={}): - """`query` is a lucene query string when no fields are set, - but is escaped when any fields are given. `fields` is a dictionary - of key/value query parameters. - """ - # Encode the query terms as a Lucene query string. - lucene_special = r'([+\-&|!(){}\[\]\^"~*?:\\\/])' - query_parts = [] - - if query: - clean_query = util._unicode(query) - if fields: - clean_query = re.sub(lucene_special, r"\\\1", clean_query) - query_parts.append(clean_query.lower()) - else: - query_parts.append(clean_query) - for key, value in fields.items(): - # Escape Lucene's special characters. - value = util._unicode(value) - value = re.sub(lucene_special, r"\\\1", value) - if value: - value = value.lower() # avoid AND / OR - query_parts.append("%s:(%s)" % (key, value)) - full_query = " ".join(query_parts).strip() - - if not full_query: - raise ValueError("at least one query term is required") - - return full_query - - def browse_recordings( - self, bw_entity, mbid, includes=[], limit=None, offset=None - ): - """Get all recordings linked to an artist or a release. - You need to give one MusicBrainz ID.""" - return self._browse( - self, - "recording", - bw_entity, - mbid, - includes, - limit=limit, - offset=offset, - ) - - def get_release_by_id( - self, mbid, includes=[], release_status=[], release_type=[] - ): - return json.loads( - self._lookup( - self, - "release", - mbid, - includes, - params=self._make_params(release_status, release_type), - ) - ) - - def get_recording_by_id( - self, mbid, includes=[], release_status=[], release_type=[] - ): - return json.loads( - self._lookup( - self, - "recording", - mbid, - includes, - params=self._make_params(release_status, release_type), - ) - ) - - def search_releases(self, query="", limit=None, offset=None, **fields): - return json.loads( - self._search( - self, - "release", - query=self._make_query(query, fields), - limit=limit, - offset=offset, - ) - ) - - def search_recordings(self, query="", limit=None, offset=None, **fields): - return json.loads( - self._search( - self, - "recording", - query=self._make_query(query, fields), - limit=limit, - offset=offset, - ) - ) - - class MusicBrainzAPIError(util.HumanReadableError): """An error while talking to MusicBrainz. The `query` field is the parameter to the action and may have any type. @@ -339,8 +67,6 @@ class MusicBrainzAPIError(util.HumanReadableError): def __init__(self, reason, verb, query, tb=None): self.query = query - if isinstance(reason, MbWebServiceError): - reason = "MusicBrainz not reachable" super().__init__(reason, verb, tb) def get_message(self): @@ -450,7 +176,7 @@ def _multi_artist_credit( # An artist. if alias: - cur_artist_name = alias["alias"] + cur_artist_name = alias["name"] else: cur_artist_name = el["artist"]["name"] artist_parts.append(cur_artist_name) @@ -577,6 +303,7 @@ def _is_translation(r): def _find_actual_release_from_pseudo_release( + mb_interface: MbInterface, pseudo_rel: JSONDict, ) -> JSONDict | None: try: @@ -592,7 +319,7 @@ def _find_actual_release_from_pseudo_release( actual_id = translations[0]["target"] - return MbInterface().get_release_by_id(actual_id, includes=RELEASE_INCLUDES) + return mb_interface.get_release_by_id(actual_id, includes=RELEASE_INCLUDES) def _merge_pseudo_and_actual_album( @@ -645,10 +372,7 @@ def __init__(self): super().__init__() self.config.add( { - "host": "musicbrainz.org", - "https": False, - "ratelimit": 1, - "ratelimit_interval": 1, + # The rest of the config is defined in SharedMbInterface "searchlimit": 5, "genres": False, "external_ids": { @@ -661,16 +385,7 @@ def __init__(self): "extra_tags": [], }, ) - hostname = self.config["host"].as_str() - https = self.config["https"].get(bool) - # Only call set_hostname when a custom server is configured. Since - # musicbrainz-ngs connects to musicbrainz.org with HTTPS by default - if hostname != "musicbrainz.org": - MbInterface().set_hostname(hostname, https) - _RateLimitsSingleton().set_rate_limit( - self.config["ratelimit_interval"].as_number(), - self.config["ratelimit"].get(int), - ) + self.mb_interface = SharedMbInterface().get() def track_info( self, @@ -803,8 +518,9 @@ def album_info(self, release: JSONDict) -> beets.autotag.hooks.AlbumInfo: for i in range(0, ntracks, BROWSE_CHUNKSIZE): self._log.debug("Retrieving tracks starting at {}", i) recording_list.extend( - MbInterface().browse_recordings( - release=release["id"], + self.mb_interface.browse_recordings( + "release", + release["id"], limit=BROWSE_CHUNKSIZE, includes=BROWSE_INCLUDES, offset=i, @@ -1074,10 +790,14 @@ def _search_api( self._log.debug( "Searching for MusicBrainz {}s with: {!r}", query_type, filters ) + method = ( + self.mb_interface.search_recordings + if query_type == "recording" + else self.mb_interface.search_releases + ) try: - method = getattr(MbInterface(), f"search_{query_type}s") res = method(limit=self.config["searchlimit"].get(int), **filters) - except MusicBrainzError as exc: + except MbInterfaceError as exc: raise MusicBrainzAPIError( exc, f"{query_type} search", filters, traceback.format_exc() ) @@ -1118,18 +838,20 @@ def album_for_id( return None try: - res = MbInterface().get_release_by_id(albumid, includes=RELEASE_INCLUDES) + res = self.mb_interface.get_release_by_id(albumid, includes=RELEASE_INCLUDES) # resolve linked release relations actual_res = None if res.get("status") == "Pseudo-Release": - actual_res = _find_actual_release_from_pseudo_release(res) + actual_res = _find_actual_release_from_pseudo_release( + self.mb_interface, res + ) - except MbResponseError: + except MbInterfaceNotFoundError: self._log.debug("Album ID match failed.") return None - except MusicBrainzError as exc: + except MbInterfaceError as exc: raise MusicBrainzAPIError( exc, "get release by ID", albumid, traceback.format_exc() ) @@ -1155,11 +877,11 @@ def track_for_id( return None try: - res = MbInterface().get_recording_by_id(trackid, TRACK_INCLUDES) - except MbResponseError: + res = self.mb_interface.get_recording_by_id(trackid, TRACK_INCLUDES) + except MbInterfaceNotFoundError: self._log.debug("Track ID match failed.") return None - except MusicBrainzError as exc: + except MbInterfaceError as exc: raise MusicBrainzAPIError( exc, "get recording by ID", trackid, traceback.format_exc() ) diff --git a/beetsplug/parentwork.py b/beetsplug/parentwork.py index 463a455f57..2a95b2dbc5 100644 --- a/beetsplug/parentwork.py +++ b/beetsplug/parentwork.py @@ -16,17 +16,19 @@ and work composition date """ -import musicbrainzngs +from mbzero import mbzerror from beets import ui from beets.plugins import BeetsPlugin +from ._mb_interface import MbInterface, SharedMbInterface -def direct_parent_id(mb_workid, work_date=None): + +def direct_parent_id(mb_interface: MbInterface, mb_workid: str, work_date=None): """Given a Musicbrainz work id, find the id one of the works the work is part of and the first composition date it encounters. """ - work_info = musicbrainzngs.get_work_by_id( + work_info = mb_interface.get_work_by_id( mb_workid, includes=["work-rels", "artist-rels"] ) if "artist-relation-list" in work_info["work"] and work_date is None: @@ -46,25 +48,25 @@ def direct_parent_id(mb_workid, work_date=None): return None, work_date -def work_parent_id(mb_workid): +def work_parent_id(mb_interface: MbInterface, mb_workid: str): """Find the parent work id and composition date of a work given its id.""" work_date = None while True: - new_mb_workid, work_date = direct_parent_id(mb_workid, work_date) + new_mb_workid, work_date = direct_parent_id( + mb_interface, mb_workid, work_date + ) if not new_mb_workid: return mb_workid, work_date mb_workid = new_mb_workid return mb_workid, work_date -def find_parentwork_info(mb_workid): +def find_parentwork_info(mb_interface: MbInterface, mb_workid): """Get the MusicBrainz information dict about a parent work, including the artist relations, and the composition date for a work's parent work. """ - parent_id, work_date = work_parent_id(mb_workid) - work_info = musicbrainzngs.get_work_by_id( - parent_id, includes=["artist-rels"] - ) + parent_id, work_date = work_parent_id(mb_interface, mb_workid) + work_info = mb_interface.get_work_by_id(parent_id, includes=["artist-rels"]) return work_info, work_date @@ -79,6 +81,8 @@ def __init__(self): } ) + self.mb_interface = SharedMbInterface().get() + if self.config["auto"]: self.import_stages = [self.imported] @@ -192,8 +196,10 @@ def find_work(self, item, force, verbose): work_changed = item.parentwork_workid_current != item.mb_workid if force or not hasparent or work_changed: try: - work_info, work_date = find_parentwork_info(item.mb_workid) - except musicbrainzngs.musicbrainz.WebServiceError as e: + work_info, work_date = find_parentwork_info( + self.mb_interface, item.mb_workid + ) + except mbzerror.MbzWebServiceError as e: self._log.debug("error fetching work: {}", e) return parent_info = self.get_info(item, work_info) diff --git a/poetry.lock b/poetry.lock index b47ed5060e..b1584d8591 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3301,4 +3301,4 @@ web = ["flask", "flask-cors"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "bb4f7352f9d55ffa22f09d3482d045cf50a5995714628bddff1419902cdd661f" +content-hash = "0ee33e6710cc82e6fec6c58aa0d451b5058f1979a8dce70c7f3eeb6aa682de88" diff --git a/test/plugins/test_musicbrainz.py b/test/plugins/test_musicbrainz.py index c04f1e2ec0..88494caea2 100644 --- a/test/plugins/test_musicbrainz.py +++ b/test/plugins/test_musicbrainz.py @@ -678,12 +678,12 @@ def _credit_dict(self, suffix=""): def _add_alias(self, credit_dict, suffix="", locale="", primary=False): alias = { - "alias": "ALIAS" + suffix, + "name": "ALIAS" + suffix, "locale": locale, "sort-name": "ALIASSORT" + suffix, } if primary: - alias["primary"] = "primary" + alias["primary"] = True if "aliases" not in credit_dict["artist"]: credit_dict["artist"]["aliases"] = [] credit_dict["artist"]["aliases"].append(alias) @@ -835,7 +835,10 @@ def test_follow_pseudo_releases(self): "country": "COUNTRY", }, ] - with mock.patch("musicbrainz.MbInterface.get_release_by_id") as gp: + + with mock.patch( + "beetsplug._mb_interface.MbInterface.get_release_by_id" + ) as gp: gp.side_effect = side_effect album = self.mb.album_for_id("d2a6f856-b553-40a0-ac54-a321e8e2da02") assert album.country == "COUNTRY" @@ -878,7 +881,9 @@ def test_pseudo_releases_with_empty_links(self): }, ] - with mock.patch("musicbrainz.MbInterface.get_release_by_id") as gp: + with mock.patch( + "beetsplug._mb_interface.MbInterface.get_release_by_id" + ) as gp: gp.side_effect = side_effect album = self.mb.album_for_id("d2a6f856-b553-40a0-ac54-a321e8e2da02") assert album.country is None @@ -920,7 +925,9 @@ def test_pseudo_releases_without_links(self): }, ] - with mock.patch("musicbrainz.MbInterface.get_release_by_id") as gp: + with mock.patch( + "beetsplug._mb_interface.MbInterface.get_release_by_id" + ) as gp: gp.side_effect = side_effect album = self.mb.album_for_id("d2a6f856-b553-40a0-ac54-a321e8e2da02") assert album.country is None @@ -969,7 +976,9 @@ def test_pseudo_releases_with_unsupported_links(self): }, ] - with mock.patch("musicbrainz.MbInterface.get_release_by_id") as gp: + with mock.patch( + "beetsplug._mb_interface.MbInterface.get_release_by_id" + ) as gp: gp.side_effect = side_effect album = self.mb.album_for_id("d2a6f856-b553-40a0-ac54-a321e8e2da02") assert album.country is None @@ -1020,7 +1029,7 @@ def test_get_album_criteria( def test_item_candidates(self, monkeypatch, mb): monkeypatch.setattr( - "musicbrainz.MbInterface.search_recordings", + "beetsplug._mb_interface.MbInterface.search_recordings", lambda *_, **__: {"recordings": [self.RECORDING]}, ) @@ -1031,11 +1040,11 @@ def test_item_candidates(self, monkeypatch, mb): def test_candidates(self, monkeypatch, mb): monkeypatch.setattr( - "musicbrainz.MbInterface.search_releases", + "beetsplug._mb_interface.MbInterface.search_releases", lambda *_, **__: {"releases": [{"id": self.mbid}]}, ) monkeypatch.setattr( - "musicbrainz.MbInterface.get_release_by_id", + "beetsplug._mb_interface.MbInterface.get_release_by_id", lambda *_, **__: { "title": "hi", "id": self.mbid, diff --git a/test/plugins/test_parentwork.py b/test/plugins/test_parentwork.py index 99267f6ffa..35705db091 100644 --- a/test/plugins/test_parentwork.py +++ b/test/plugins/test_parentwork.py @@ -160,7 +160,8 @@ def setUp(self): """Set up configuration""" super().setUp() self.patcher = patch( - "musicbrainzngs.get_work_by_id", side_effect=mock_workid_response + "beetsplug._mb_interface.MbInterface.get_work_by_id", + side_effect=mock_workid_response, ) self.patcher.start() diff --git a/test/util/test_rate_limiter.py b/test/util/test_rate_limiter.py new file mode 100644 index 0000000000..b56cb0c6fd --- /dev/null +++ b/test/util/test_rate_limiter.py @@ -0,0 +1,93 @@ +import time + +from beets.util.rate_limiter import RateLimiter + +# 10 reqs per 0.1 second +REQS_PER_INTERVAL = 10 +INTERVAL_SEC = 0.1 + +# Expected time to wait to be able to do one more request after being rate limited +WAIT_FOR_ONE_REQ = INTERVAL_SEC / REQS_PER_INTERVAL + + +def run_and_collect_delta_start_times(num_reqs: int) -> list[float]: + """Launch requests through the rate limiter and collect the durations between the + time before the first request and the starting time of each request. + + :param num_reqs: Number of requests to run + :return: A list of delta start times in seconds: non rate-limited ones should be + close to 0 + """ + rate_limiter = RateLimiter(REQS_PER_INTERVAL, INTERVAL_SEC) + + delta_start_times = [] + + start = time.time() + for _ in range(num_reqs): + with rate_limiter: + delta_start_times.append(time.time() - start) + + return delta_start_times + + +def test_all_reqs_in_one_interval(): + delta_start_times = run_and_collect_delta_start_times(REQS_PER_INTERVAL) + + for i in range(10): + assert delta_start_times[i] < WAIT_FOR_ONE_REQ, ( + f"request {i} should not have been rate-limited" + ) + + +def test_more_reqs_in_one_interval(): + delta_start_times = run_and_collect_delta_start_times(2 * REQS_PER_INTERVAL) + + # 20 reqs with rate-limitation of 10 reqs per 0.1s + # -> 10 reqs immediately, then 10*(1 req per 0.1s) + + for i in range(10): + assert delta_start_times[i] < WAIT_FOR_ONE_REQ, ( + f"request {i} should not have been rate-limited" + ) + + for i in range(10, len(delta_start_times)): + # Non rate-limited reqs are at interval 0 + # 1st rate-limited req is at interval 1 + # 2nd rate-limited req is at interval 2 + # etc. + expected_interval = i - 9 + expected_start_time = expected_interval * WAIT_FOR_ONE_REQ + + assert delta_start_times[i] >= expected_start_time, ( + f"request {i} has executed sooner than it should have" + ) + assert delta_start_times[i] < ( + expected_start_time + WAIT_FOR_ONE_REQ + ), f"request {i} has executed much later than it should have" + + +def test_reuse_after_no_requests(): + rate_limiter = RateLimiter(REQS_PER_INTERVAL, INTERVAL_SEC) + + # Use up all requests + start = time.time() + for _ in range(REQS_PER_INTERVAL): + with rate_limiter: + pass + end = time.time() + assert (end - start) < WAIT_FOR_ONE_REQ, ( + "requests should not have been rate-limited" + ) + + # Do no request for half an interval + time.sleep(INTERVAL_SEC / 2) + + # Now, we should be able to do half the REQS_PER_INTERVAL with no rate limitation + start = time.time() + for _ in range(REQS_PER_INTERVAL // 2): + with rate_limiter: + pass + end = time.time() + assert (end - start) < WAIT_FOR_ONE_REQ, ( + "requests should not have been rate-limited" + )