diff --git a/scrapscript.py b/scrapscript.py index 5abb21ae..101e2320 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -41,6 +41,11 @@ class IntLit(Token): value: int +@dataclass(eq=True) +class FloatLit(Token): + value: float + + @dataclass(eq=True) class StringLit(Token): value: str @@ -174,7 +179,7 @@ def read_one(self) -> Token: return self.read_bytes() raise ParseError(f"unexpected token {c!r}") if c.isdigit(): - return self.read_integer(c) + return self.read_number(c) if c in "()[]{}": custom = { "(": LeftParen, @@ -205,11 +210,23 @@ def read_comment(self) -> None: while self.has_input() and self.read_char() != "\n": pass - def read_integer(self, first_digit: str) -> Token: + def read_number(self, first_digit: str) -> Token: + # TODO: Support floating point numbers with no integer part buf = first_digit - while self.has_input() and (c := self.peek_char()).isdigit(): + has_decimal = False + while self.has_input(): + c = self.peek_char() + if c == ".": + if has_decimal: + raise ParseError(f"unexpected token {c!r}") + has_decimal = True + elif not c.isdigit(): + break self.read_char() buf += c + + if has_decimal: + return self.make_token(FloatLit, float(buf)) return self.make_token(IntLit, int(buf)) def _starts_operator(self, buf: str) -> bool: @@ -350,8 +367,9 @@ def parse(tokens: typing.List[Token], p: float = 0) -> "Object": token = tokens.pop(0) l: Object if isinstance(token, IntLit): - # TODO: Handle float literals l = Int(token.value) + elif isinstance(token, FloatLit): + l = Float(token.value) elif isinstance(token, Name): # TODO: Handle kebab case vars l = Var(token.value) @@ -554,6 +572,21 @@ def __str__(self) -> str: return str(self.value) +@dataclass(eq=True, frozen=True, unsafe_hash=True) +class Float(Object): + value: float + + def serialize(self) -> Dict[bytes, object]: + raise NotImplementedError("serialization for Float is not supported") + + @staticmethod + def deserialize(msg: Dict[str, object]) -> "Float": + raise NotImplementedError("serialization for Float is not supported") + + def __str__(self) -> str: + return str(self.value) + + @dataclass(eq=True, frozen=True, unsafe_hash=True) class String(Object): value: str @@ -971,15 +1004,15 @@ def __str__(self) -> str: return f"#{self.value}" -def unpack_int(obj: Object) -> int: - if not isinstance(obj, Int): - raise TypeError(f"expected Int, got {type(obj).__name__}") +def unpack_number(obj: Object) -> Union[int, float]: + if not isinstance(obj, (Int, Float)): + raise TypeError(f"expected Int or Float, got {type(obj).__name__}") return obj.value -def eval_int(env: Env, exp: Object) -> int: +def eval_number(env: Env, exp: Object) -> Union[int, float]: result = eval_exp(env, exp) - return unpack_int(result) + return unpack_number(result) def eval_str(env: Env, exp: Object) -> str: @@ -1007,20 +1040,30 @@ def make_bool(x: bool) -> Object: return Symbol("true" if x else "false") +def wrap_inferred_number_type(x: Union[int, float]) -> Object: + # TODO: Since this is intended to be a reference implementation + # we should avoid relying heavily on Python's implementation of + # arithmetic operations, type inference, and multiple dispatch. + # Update this to make the interpreter more language agnostic. + if isinstance(x, int): + return Int(x) + return Float(x) + + BINOP_HANDLERS: Dict[BinopKind, Callable[[Env, Object, Object], Object]] = { - BinopKind.ADD: lambda env, x, y: Int(eval_int(env, x) + eval_int(env, y)), - BinopKind.SUB: lambda env, x, y: Int(eval_int(env, x) - eval_int(env, y)), - BinopKind.MUL: lambda env, x, y: Int(eval_int(env, x) * eval_int(env, y)), - BinopKind.DIV: lambda env, x, y: Int(eval_int(env, x) // eval_int(env, y)), - BinopKind.FLOOR_DIV: lambda env, x, y: Int(eval_int(env, x) // eval_int(env, y)), - BinopKind.EXP: lambda env, x, y: Int(eval_int(env, x) ** eval_int(env, y)), - BinopKind.MOD: lambda env, x, y: Int(eval_int(env, x) % eval_int(env, y)), + BinopKind.ADD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) + eval_number(env, y)), + BinopKind.SUB: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) - eval_number(env, y)), + BinopKind.MUL: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) * eval_number(env, y)), + BinopKind.DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) / eval_number(env, y)), + BinopKind.FLOOR_DIV: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) // eval_number(env, y)), + BinopKind.EXP: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) ** eval_number(env, y)), + BinopKind.MOD: lambda env, x, y: wrap_inferred_number_type(eval_number(env, x) % eval_number(env, y)), BinopKind.EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) == eval_exp(env, y)), BinopKind.NOT_EQUAL: lambda env, x, y: make_bool(eval_exp(env, x) != eval_exp(env, y)), - BinopKind.LESS: lambda env, x, y: make_bool(eval_int(env, x) < eval_int(env, y)), - BinopKind.GREATER: lambda env, x, y: make_bool(eval_int(env, x) > eval_int(env, y)), - BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_int(env, x) <= eval_int(env, y)), - BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_int(env, x) >= eval_int(env, y)), + BinopKind.LESS: lambda env, x, y: make_bool(eval_number(env, x) < eval_number(env, y)), + BinopKind.GREATER: lambda env, x, y: make_bool(eval_number(env, x) > eval_number(env, y)), + BinopKind.LESS_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) <= eval_number(env, y)), + BinopKind.GREATER_EQUAL: lambda env, x, y: make_bool(eval_number(env, x) >= eval_number(env, y)), BinopKind.BOOL_AND: lambda env, x, y: make_bool(eval_bool(env, x) and eval_bool(env, y)), BinopKind.BOOL_OR: lambda env, x, y: make_bool(eval_bool(env, x) or eval_bool(env, y)), BinopKind.STRING_CONCAT: lambda env, x, y: String(eval_str(env, x) + eval_str(env, y)), @@ -1037,6 +1080,8 @@ class MatchError(Exception): def match(obj: Object, pattern: Object) -> Optional[Env]: if isinstance(pattern, Int): return {} if isinstance(obj, Int) and obj.value == pattern.value else None + if isinstance(pattern, Float): + raise MatchError("pattern matching is not supported for Floats") if isinstance(pattern, String): return {} if isinstance(obj, String) and obj.value == pattern.value else None if isinstance(pattern, Var): @@ -1092,7 +1137,7 @@ def match(obj: Object, pattern: Object) -> Optional[Env]: # pylint: disable=redefined-builtin def eval_exp(env: Env, exp: Object) -> Object: logger.debug(exp) - if isinstance(exp, (Int, String, Bytes, Hole, Closure, NativeFunction, Symbol)): + if isinstance(exp, (Int, Float, String, Bytes, Hole, Closure, NativeFunction, Symbol)): return exp if isinstance(exp, Var): value = env.get(exp.name) @@ -1285,6 +1330,23 @@ def test_tokenize_multiple_digits(self) -> None: def test_tokenize_negative_int(self) -> None: self.assertEqual(tokenize("-123"), [Operator("-"), IntLit(123)]) + def test_tokenize_float(self) -> None: + self.assertEqual(tokenize("3.14"), [FloatLit(3.14)]) + + def test_tokenize_negative_float(self) -> None: + self.assertEqual(tokenize("-3.14"), [Operator("-"), FloatLit(3.14)]) + + @unittest.skip("TODO: support floats with no integer part") + def test_tokenize_float_with_no_integer_part(self) -> None: + self.assertEqual(tokenize(".14"), [FloatLit(0.14)]) + + def test_tokenize_float_with_no_decimal_part(self) -> None: + self.assertEqual(tokenize("10."), [FloatLit(10.0)]) + + def test_tokenize_float_with_multiple_decimal_points_raises_parse_error(self) -> None: + with self.assertRaisesRegex(ParseError, re.escape("unexpected token '.'")): + tokenize("1.0.1") + def test_tokenize_binop(self) -> None: self.assertEqual(tokenize("1 + 2"), [IntLit(1), Operator("+"), IntLit(2)]) @@ -1708,6 +1770,12 @@ def test_parse_negative_int_binds_tighter_than_apply(self) -> None: Apply(Binop(BinopKind.SUB, Int(0), Var("l")), Var("r")), ) + def test_parse_decimal_returns_float(self) -> None: + self.assertEqual(parse([FloatLit(3.14)]), Float(3.14)) + + def test_parse_negative_float_returns_binary_sub_float(self) -> None: + self.assertEqual(parse([Operator("-"), FloatLit(3.14)]), Binop(BinopKind.SUB, Int(0), Float(3.14))) + def test_parse_var_returns_var(self) -> None: self.assertEqual(parse([Name("abc_123")]), Var("abc_123")) @@ -2122,6 +2190,18 @@ def test_match_with_inequal_ints_returns_none(self) -> None: def test_match_int_with_non_int_returns_none(self) -> None: self.assertEqual(match(String("abc"), pattern=Int(1)), None) + def test_match_with_equal_floats_raises_match_error(self) -> None: + with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): + match(Float(1), pattern=Float(1)) + + def test_match_with_inequal_floats_raises_match_error(self) -> None: + with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): + match(Float(2), pattern=Float(1)) + + def test_match_float_with_non_float_raises_match_error(self) -> None: + with self.assertRaisesRegex(MatchError, re.escape("pattern matching is not supported for Floats")): + match(String("abc"), pattern=Float(1)) + def test_match_with_equal_strings_returns_empty_dict(self) -> None: self.assertEqual(match(String("a"), pattern=String("a")), {}) @@ -2379,6 +2459,10 @@ def test_eval_int_returns_int(self) -> None: exp = Int(5) self.assertEqual(eval_exp({}, exp), Int(5)) + def test_eval_float_returns_float(self) -> None: + exp = Float(3.14) + self.assertEqual(eval_exp({}, exp), Float(3.14)) + def test_eval_str_returns_str(self) -> None: exp = String("xyz") self.assertEqual(eval_exp({}, exp), String("xyz")) @@ -2410,7 +2494,7 @@ def test_eval_with_binop_add_with_int_string_raises_type_error(self) -> None: exp = Binop(BinopKind.ADD, Int(1), String("hello")) with self.assertRaises(TypeError) as ctx: eval_exp({}, exp) - self.assertEqual(ctx.exception.args[0], "expected Int, got String") + self.assertEqual(ctx.exception.args[0], "expected Int or Float, got String") def test_eval_with_binop_sub(self) -> None: exp = Binop(BinopKind.SUB, Int(1), Int(2)) @@ -2421,8 +2505,8 @@ def test_eval_with_binop_mul(self) -> None: self.assertEqual(eval_exp({}, exp), Int(6)) def test_eval_with_binop_div(self) -> None: - exp = Binop(BinopKind.DIV, Int(2), Int(3)) - self.assertEqual(eval_exp({}, exp), Int(0)) + exp = Binop(BinopKind.DIV, Int(3), Int(10)) + self.assertEqual(eval_exp({}, exp), Float(0.3)) def test_eval_with_binop_floor_div(self) -> None: exp = Binop(BinopKind.FLOOR_DIV, Int(2), Int(3)) @@ -2707,7 +2791,7 @@ def test_eval_less_returns_bool(self) -> None: def test_eval_less_on_non_bool_raises_type_error(self) -> None: ast = Binop(BinopKind.LESS, String("xyz"), Int(4)) - with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")): + with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): eval_exp({}, ast) def test_eval_less_equal_returns_bool(self) -> None: @@ -2716,7 +2800,7 @@ def test_eval_less_equal_returns_bool(self) -> None: def test_eval_less_equal_on_non_bool_raises_type_error(self) -> None: ast = Binop(BinopKind.LESS_EQUAL, String("xyz"), Int(4)) - with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")): + with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): eval_exp({}, ast) def test_eval_greater_returns_bool(self) -> None: @@ -2725,7 +2809,7 @@ def test_eval_greater_returns_bool(self) -> None: def test_eval_greater_on_non_bool_raises_type_error(self) -> None: ast = Binop(BinopKind.GREATER, String("xyz"), Int(4)) - with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")): + with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): eval_exp({}, ast) def test_eval_greater_equal_returns_bool(self) -> None: @@ -2734,7 +2818,7 @@ def test_eval_greater_equal_returns_bool(self) -> None: def test_eval_greater_equal_on_non_bool_raises_type_error(self) -> None: ast = Binop(BinopKind.GREATER_EQUAL, String("xyz"), Int(4)) - with self.assertRaisesRegex(TypeError, re.escape("expected Int, got String")): + with self.assertRaisesRegex(TypeError, re.escape("expected Int or Float, got String")): eval_exp({}, ast) def test_boolean_and_evaluates_args(self) -> None: @@ -2791,6 +2875,21 @@ def test_eval_record_with_spread_fails(self) -> None: def test_eval_symbol_returns_symbol(self) -> None: self.assertEqual(eval_exp({}, Symbol("abc")), Symbol("abc")) + def test_eval_float_and_float_addition_returns_float(self) -> None: + self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Float(1.0), Float(2.0))), Float(3.0)) + + def test_eval_int_and_float_addition_returns_float(self) -> None: + self.assertEqual(eval_exp({}, Binop(BinopKind.ADD, Int(1), Float(2.0))), Float(3.0)) + + def test_eval_int_and_float_division_returns_float(self) -> None: + self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Float(2.0))), Float(0.5)) + + def test_eval_float_and_int_division_returns_float(self) -> None: + self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Float(1.0), Int(2))), Float(0.5)) + + def test_eval_int_and_int_division_returns_float(self) -> None: + self.assertEqual(eval_exp({}, Binop(BinopKind.DIV, Int(1), Int(2))), Float(0.5)) + class EndToEndTestsBase(unittest.TestCase): def _run(self, text: str, env: Optional[Env] = None) -> Object: @@ -2805,6 +2904,9 @@ class EndToEndTests(EndToEndTestsBase): def test_int_returns_int(self) -> None: self.assertEqual(self._run("1"), Int(1)) + def test_float_returns_float(self) -> None: + self.assertEqual(self._run("3.14"), Float(3.14)) + def test_bytes_returns_bytes(self) -> None: self.assertEqual(self._run("~~QUJD"), Bytes(b"ABC")) @@ -3604,6 +3706,11 @@ def test_serialize_negative_int(self) -> None: obj = Int(-123) self.assertEqual(obj.serialize(), {b"type": b"Int", b"value": -123}) + def test_serialize_float_raises_not_implemented_error(self) -> None: + obj = Float(3.14) + with self.assertRaisesRegex(NotImplementedError, re.escape("serialization for Float is not supported")): + obj.serialize() + def test_serialize_str(self) -> None: obj = String("abc") self.assertEqual(obj.serialize(), {b"type": b"String", b"value": b"abc"}) @@ -3816,6 +3923,10 @@ def test_pretty_print_int(self) -> None: obj = Int(1) self.assertEqual(str(obj), "1") + def test_pretty_print_float(self) -> None: + obj = Float(3.14) + self.assertEqual(str(obj), "3.14") + def test_pretty_print_string(self) -> None: obj = String("hello") self.assertEqual(str(obj), '"hello"')