Skip to content

Commit

Permalink
advancedrewrite: Support simple syntax and improve advanced syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxr1998 committed Dec 13, 2023
1 parent 4b1c7dd commit 304a052
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 29 deletions.
155 changes: 126 additions & 29 deletions beetsplug/advancedrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,40 @@

"""Plugin to rewrite fields based on a given query."""

import re
import shlex
from collections import defaultdict

import confuse

from beets import ui
from beets.dbcore import AndQuery, query_from_strings
from beets.dbcore.types import MULTI_VALUE_DSV
from beets.library import Album, Item
from beets.plugins import BeetsPlugin
from beets.ui import UserError


def rewriter(field, rules):
def simple_rewriter(field, rules):
"""Template field function factory.
Create a template field function that rewrites the given field
with the given rewriting rules.
``rules`` must be a list of (pattern, replacement) pairs.
"""

def fieldfunc(item):
value = item._values_fixed[field]
for pattern, replacement in rules:
if pattern.match(value.lower()):
# Rewrite activated.
return replacement
# Not activated; return original value.
return value

return fieldfunc


def advanced_rewriter(field, rules):
"""Template field function factory.
Create a template field function that rewrites the given field
Expand Down Expand Up @@ -53,40 +75,115 @@ def __init__(self):
super().__init__()

template = confuse.Sequence(
{
"match": str,
"field": str,
"replacement": str,
}
confuse.OneOf(
[
confuse.MappingValues(str),
{
"match": str,
"replacements": confuse.MappingValues(
confuse.OneOf([str, confuse.Sequence(str)]),
),
},
]
)
)

# Gather all the rewrite rules for each field.
rules = defaultdict(list)
simple_rules = defaultdict(list)
advanced_rules = defaultdict(list)
for rule in self.config.get(template):
query = query_from_strings(
AndQuery,
Item,
prefixes={},
query_parts=shlex.split(rule["match"]),
)
fieldname = rule["field"]
replacement = rule["replacement"]
if fieldname not in Item._fields:
raise ui.UserError(
"invalid field name (%s) in rewriter" % fieldname
if "match" not in rule:
# Simple syntax
if len(rule) != 1:
raise UserError(
"Simple rewrites must have only one rule, "
"but found multiple entries. "
"Did you forget to prepend a dash (-)?"
)
key, value = next(iter(rule.items()))
try:
fieldname, pattern = key.split(None, 1)
except ValueError:
raise UserError(
f"Invalid simple rewrite specification {key}"
)
if fieldname not in Item._fields:
raise UserError(
f"invalid field name {fieldname} in rewriter"
)
self._log.debug(
f"adding simple rewrite '{pattern}' → '{value}' "
f"for field {fieldname}"
)
self._log.debug(
"adding template field {0} → {1}", fieldname, replacement
)
rules[fieldname].append((query, replacement))
if fieldname == "artist":
# Special case for the artist field: apply the same
# rewrite for "albumartist" as well.
rules["albumartist"].append((query, replacement))
pattern = re.compile(pattern.lower())
simple_rules[fieldname].append((pattern, value))
if fieldname == "artist":
# Special case for the artist field: apply the same
# rewrite for "albumartist" as well.
simple_rules["albumartist"].append((pattern, value))
else:
# Advanced syntax
match = rule["match"]
replacements = rule["replacements"]
if len(replacements) == 0:
raise UserError(
"Advanced rewrites must have at least one replacement"
)
query = query_from_strings(
AndQuery,
Item,
prefixes={},
query_parts=shlex.split(match),
)
for fieldname, replacement in replacements.items():
if fieldname not in Item._fields:
raise UserError(
f"Invalid field name {fieldname} in rewriter"
)
self._log.debug(
f"adding advanced rewrite to '{replacement}' "
f"for field {fieldname}"
)
if isinstance(replacement, list):
if Item._fields[fieldname] is not MULTI_VALUE_DSV:
raise UserError(
f"Field {fieldname} is not a multi-valued field "
f"but a list was given: {', '.join(replacement)}"
)
elif isinstance(replacement, str):
if Item._fields[fieldname] is MULTI_VALUE_DSV:
replacement = list(replacement)
else:
raise UserError(
f"Invalid type of replacement {replacement} "
f"for field {fieldname}"
)

advanced_rules[fieldname].append((query, replacement))
# Special case for the artist(s) field:
# apply the same rewrite for "albumartist(s)" as well.
if fieldname == "artist":
advanced_rules["albumartist"].append(
(query, replacement)
)
elif fieldname == "artists":
advanced_rules["albumartists"].append(
(query, replacement)
)
elif fieldname == "artist_sort":
advanced_rules["albumartist_sort"].append(
(query, replacement)
)

# Replace each template field with the new rewriter function.
for fieldname, fieldrules in rules.items():
getter = rewriter(fieldname, fieldrules)
for fieldname, fieldrules in simple_rules.items():
getter = simple_rewriter(fieldname, fieldrules)
self.template_fields[fieldname] = getter
if fieldname in Album._fields:
self.album_template_fields[fieldname] = getter

for fieldname, fieldrules in advanced_rules.items():
getter = advanced_rewriter(fieldname, fieldrules)
self.template_fields[fieldname] = getter
if fieldname in Album._fields:
self.album_template_fields[fieldname] = getter
142 changes: 142 additions & 0 deletions test/plugins/test_advancedrewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# This file is part of beets.
# Copyright 2023, Max Rumpf.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.

"""Test the advancedrewrite plugin for various configurations.
"""

import unittest
from test.helper import TestHelper

from beets.ui import UserError

PLUGIN_NAME = "advancedrewrite"


class AdvancedRewritePluginTest(unittest.TestCase, TestHelper):
def setUp(self):
self.setup_beets()

def tearDown(self):
self.unload_plugins()
self.teardown_beets()

def test_simple_rewrite_example(self):
self.config[PLUGIN_NAME] = [
{"artist ODD EYE CIRCLE": "이달의 소녀 오드아이써클"},
]
self.load_plugins(PLUGIN_NAME)

item = self.add_item(
title="Uncover",
artist="ODD EYE CIRCLE",
albumartist="ODD EYE CIRCLE",
album="Mix & Match",
)

self.assertEqual(item.artist, "이달의 소녀 오드아이써클")

def test_advanced_rewrite_example(self):
self.config[PLUGIN_NAME] = [
{
"match": "mb_artistid:dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c year:..2022",
"replacements": {
"artist": "이달의 소녀 오드아이써클",
"artist_sort": "LOONA / ODD EYE CIRCLE",
},
},
]
self.load_plugins(PLUGIN_NAME)

item_a = self.add_item(
title="Uncover",
artist="ODD EYE CIRCLE",
albumartist="ODD EYE CIRCLE",
artist_sort="ODD EYE CIRCLE",
albumartist_sort="ODD EYE CIRCLE",
album="Mix & Match",
mb_artistid="dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c",
year=2017,
)
item_b = self.add_item(
title="Air Force One",
artist="ODD EYE CIRCLE",
albumartist="ODD EYE CIRCLE",
artist_sort="ODD EYE CIRCLE",
albumartist_sort="ODD EYE CIRCLE",
album="ODD EYE CIRCLE <Version Up>",
mb_artistid="dec0f331-cb08-4c8e-9c9f-aeb1f0f6d88c",
year=2023,
)

# Assert that all replacements were applied to item_a
self.assertEqual("이달의 소녀 오드아이써클", item_a.artist)
self.assertEqual("LOONA / ODD EYE CIRCLE", item_a.artist_sort)
self.assertEqual("LOONA / ODD EYE CIRCLE", item_a.albumartist_sort)

# Assert that no replacements were applied to item_b
self.assertEqual("ODD EYE CIRCLE", item_b.artist)

def test_advanced_rewrite_example_with_multi_valued_field(self):
self.config[PLUGIN_NAME] = [
{
"match": "artist:배유빈 feat. 김미현",
"replacements": {
"artists": ["유빈", "미미"],
},
},
]
self.load_plugins(PLUGIN_NAME)

item = self.add_item(
artist="배유빈 feat. 김미현",
artists=["배유빈", "김미현"],
)

self.assertEqual(item.artists, ["유빈", "미미"])

def test_fail_when_replacements_empty(self):
self.config[PLUGIN_NAME] = [
{
"match": "artist:A",
"replacements": {},
},
]
with self.assertRaises(
UserError,
msg="Advanced rewrites must have at least one replacement",
):
self.load_plugins(PLUGIN_NAME)

def test_fail_when_rewriting_single_valued_field_with_list(self):
self.config[PLUGIN_NAME] = [
{
"match": "artist:'A & B'",
"replacements": {
"artist": ["C", "D"],
},
},
]
with self.assertRaises(
UserError,
msg="Field artist is not a multi-valued field but a list was given: C, D",
):
self.load_plugins(PLUGIN_NAME)


def suite():
return unittest.TestLoader().loadTestsFromName(__name__)


if __name__ == "__main__":
unittest.main(defaultTest="suite")

0 comments on commit 304a052

Please sign in to comment.