diff --git a/pydust/src/pytypes.zig b/pydust/src/pytypes.zig index 767b1ef9..8a7bd411 100644 --- a/pydust/src/pytypes.zig +++ b/pydust/src/pytypes.zig @@ -599,6 +599,40 @@ fn BinaryOperator( }; } +fn EqualsOperator( + comptime definition: type, + comptime op: []const u8, +) type { + return struct { + const equals = std.mem.eql(u8, op, "__eq__"); + fn call(pyself: *ffi.PyObject, pyother: *ffi.PyObject) callconv(.C) ?*ffi.PyObject { + const func = @field(definition, op); + const typeInfo = @typeInfo(@TypeOf(func)); + const sig = funcs.parseSignature(op, typeInfo.Fn, &.{}); + + if (sig.selfParam == null) @compileError(op ++ " must take a self parameter"); + if (sig.nargs != 1) @compileError(op ++ " must take exactly one parameter after self parameter"); + + // If other arg is of the same type as self we can short circuit equality of both objects aren't of same type + if (sig.argsParam == *const definition) { + const selfType = py.type_(py.object(pyself)) catch return null; + const otherType = py.type_(py.object(pyother)) catch return null; + if (otherType.obj.py != selfType.obj.py) { + return if (equals) py.False().obj.py else py.True().obj.py; + } + } + + const self: *PyTypeStruct(definition) = @ptrCast(pyself); + const other = tramp.Trampoline( + sig.argsParam orelse unreachable, + ).unwrap(.{ .py = pyother }) catch return null; + + const result = tramp.coerceError(func(&self.state, other)) catch return null; + return (py.createOwned(result) catch return null).py; + } + }; +} + fn RichCompare(comptime definition: type) type { const BinaryFunc = *const fn (*ffi.PyObject, *ffi.PyObject) callconv(.C) ?*ffi.PyObject; const errorMsg = @@ -669,9 +703,13 @@ fn RichCompare(comptime definition: type) type { const compareFuncs = blk: { var funcs_: [6]?BinaryFunc = .{ null, null, null, null, null, null }; - for (funcs.compareFuncs, 0..) |func, i| { - if (@hasDecl(definition, func)) { - funcs_[i] = &BinaryOperator(definition, func).call; + for (&funcs_, funcs.compareFuncs) |*func, funcName| { + if (@hasDecl(definition, funcName)) { + if (std.mem.eql(u8, funcName, "__eq__") or std.mem.eql(u8, funcName, "__ne__")) { + func.* = &EqualsOperator(definition, funcName).call; + } else { + func.* = &BinaryOperator(definition, funcName).call; + } } } break :blk funcs_; diff --git a/test/test_operators.py b/test/test_operators.py index 480d93fa..941bf8c2 100644 --- a/test/test_operators.py +++ b/test/test_operators.py @@ -133,6 +133,16 @@ def test_justequals(): assert cmp1 <= cmp2 +# Test short circuit logic in pydust that handles +# cases where __eq__ expects same types but values clearly are not +def test_justequals_different_type(): + cmp1 = operators.Equals(1) + cmp2 = 2 + + assert not cmp1 == cmp2 + assert cmp1 != cmp2 + + def test_justLessThan(): cmp1 = operators.LessThan("abc") cmp2 = operators.LessThan("abd")