Skip to content

Commit

Permalink
Merge pull request #981 from python-babel/mypy-misc-fix
Browse files Browse the repository at this point in the history
Misc. mypy-discovered fixes
  • Loading branch information
akx authored Mar 1, 2023
2 parents 0aa54ca + 69aafef commit 0ce196f
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
10 changes: 6 additions & 4 deletions babel/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
self.variant = variant
#: the modifier
self.modifier = modifier
self.__data = None
self.__data: localedata.LocaleDataDict | None = None

identifier = str(self)
identifier_without_modifier = identifier.partition('@')[0]
Expand Down Expand Up @@ -260,6 +260,7 @@ def negotiate(
aliases=aliases)
if identifier:
return Locale.parse(identifier, sep=sep)
return None

@classmethod
def parse(
Expand Down Expand Up @@ -468,9 +469,9 @@ def get_display_name(self, locale: Locale | str | None = None) -> str | None:
details.append(locale.variants.get(self.variant))
if self.modifier:
details.append(self.modifier)
details = filter(None, details)
if details:
retval += f" ({', '.join(details)})"
detail_string = ', '.join(atom for atom in details if atom)
if detail_string:
retval += f" ({detail_string})"
return retval

display_name = property(get_display_name, doc="""\
Expand Down Expand Up @@ -1080,6 +1081,7 @@ def default_locale(category: str | None = None, aliases: Mapping[str, str] = LOC
return get_locale_identifier(parse_locale(locale))
except ValueError:
pass
return None


def negotiate_locale(preferred: Iterable[str], available: Iterable[str], sep: str = '_', aliases: Mapping[str, str] = LOCALE_ALIASES) -> str | None:
Expand Down
2 changes: 1 addition & 1 deletion babel/messages/jslexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def unquote_string(string: str) -> str:
assert string and string[0] == string[-1] and string[0] in '"\'`', \
'string provided is not properly delimited'
string = line_join_re.sub('\\1', string[1:-1])
result = []
result: list[str] = []
add = result.append
pos = 0

Expand Down
14 changes: 6 additions & 8 deletions babel/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def list_currencies(locale: Locale | str | None = None) -> set[str]:
"""
# Get locale-scoped currencies.
if locale:
currencies = Locale.parse(locale).currencies.keys()
else:
currencies = get_global('all_currencies')
return set(currencies)
return set(Locale.parse(locale).currencies)
return set(get_global('all_currencies'))


def validate_currency(currency: str, locale: Locale | str | None = None) -> None:
Expand Down Expand Up @@ -103,7 +101,7 @@ def normalize_currency(currency: str, locale: Locale | str | None = None) -> str
if isinstance(currency, str):
currency = currency.upper()
if not is_currency(currency, locale):
return
return None
return currency


Expand Down Expand Up @@ -706,7 +704,7 @@ def _format_currency_long_name(

# Step 5.
if not format:
format = locale.decimal_formats[format]
format = locale.decimal_formats[None]

pattern = parse_pattern(format)

Expand Down Expand Up @@ -810,7 +808,7 @@ def format_percent(
"""
locale = Locale.parse(locale)
if not format:
format = locale.percent_formats[format]
format = locale.percent_formats[None]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator)
Expand Down Expand Up @@ -849,7 +847,7 @@ def format_scientific(
"""
locale = Locale.parse(locale)
if not format:
format = locale.scientific_formats[format]
format = locale.scientific_formats[None]
pattern = parse_pattern(format)
return pattern.apply(
number, locale, decimal_quantization=decimal_quantization)
Expand Down
23 changes: 14 additions & 9 deletions babel/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ def _find_unit_pattern(unit_id: str, locale: Locale | str | None = LC_NUMERIC) -
for unit_pattern in sorted(unit_patterns, key=len):
if unit_pattern.endswith(unit_id):
return unit_pattern
return None


def format_unit(
value: float | decimal.Decimal,
value: str | float | decimal.Decimal,
measurement_unit: str,
length: Literal['short', 'long', 'narrow'] = 'long',
format: str | None = None,
Expand Down Expand Up @@ -184,28 +185,28 @@ def _find_compound_unit(
# units like "kilometer" or "hour" into actual units like "length-kilometer" and
# "duration-hour".

numerator_unit = _find_unit_pattern(numerator_unit, locale=locale)
denominator_unit = _find_unit_pattern(denominator_unit, locale=locale)
resolved_numerator_unit = _find_unit_pattern(numerator_unit, locale=locale)
resolved_denominator_unit = _find_unit_pattern(denominator_unit, locale=locale)

# If either was not found, we can't possibly build a suitable compound unit either.
if not (numerator_unit and denominator_unit):
if not (resolved_numerator_unit and resolved_denominator_unit):
return None

# Since compound units are named "speed-kilometer-per-hour", we'll have to slice off
# the quantities (i.e. "length", "duration") from both qualified units.

bare_numerator_unit = numerator_unit.split("-", 1)[-1]
bare_denominator_unit = denominator_unit.split("-", 1)[-1]
bare_numerator_unit = resolved_numerator_unit.split("-", 1)[-1]
bare_denominator_unit = resolved_denominator_unit.split("-", 1)[-1]

# Now we can try and rebuild a compound unit specifier, then qualify it:

return _find_unit_pattern(f"{bare_numerator_unit}-per-{bare_denominator_unit}", locale=locale)


def format_compound_unit(
numerator_value: float | decimal.Decimal,
numerator_value: str | float | decimal.Decimal,
numerator_unit: str | None = None,
denominator_value: float | decimal.Decimal = 1,
denominator_value: str | float | decimal.Decimal = 1,
denominator_unit: str | None = None,
length: Literal["short", "long", "narrow"] = "long",
format: str | None = None,
Expand Down Expand Up @@ -289,7 +290,11 @@ def format_compound_unit(
denominator_value = ""

formatted_denominator = format_unit(
denominator_value, denominator_unit, length=length, format=format, locale=locale
denominator_value,
measurement_unit=(denominator_unit or ""),
length=length,
format=format,
locale=locale,
).strip()
else: # Bare denominator
formatted_denominator = format_decimal(denominator_value, format=format, locale=locale)
Expand Down

0 comments on commit 0ce196f

Please sign in to comment.