Skip to content

Commit

Permalink
Default handling for __eq__ with typed other (#184)
Browse files Browse the repository at this point in the history
If user defines `__eq__` or `__ne__`  with other not being same type as self we perform is
instance check before delegating to user function
  • Loading branch information
robert3005 authored Oct 5, 2023
1 parent 173158c commit 01cdd47
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 3 deletions.
44 changes: 41 additions & 3 deletions pydust/src/pytypes.zig
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,40 @@ fn BinaryOperator(
};
}

fn EqualsOperator(
comptime definition: type,
comptime op: []const u8,
) type {
return struct {
const equals = std.mem.eql(u8, op, "__eq__");
fn call(pyself: *ffi.PyObject, pyother: *ffi.PyObject) callconv(.C) ?*ffi.PyObject {
const func = @field(definition, op);
const typeInfo = @typeInfo(@TypeOf(func));
const sig = funcs.parseSignature(op, typeInfo.Fn, &.{});

if (sig.selfParam == null) @compileError(op ++ " must take a self parameter");
if (sig.nargs != 1) @compileError(op ++ " must take exactly one parameter after self parameter");

// If other arg is of the same type as self we can short circuit equality of both objects aren't of same type
if (sig.argsParam == *const definition) {
const selfType = py.type_(py.object(pyself)) catch return null;
const otherType = py.type_(py.object(pyother)) catch return null;
if (otherType.obj.py != selfType.obj.py) {
return if (equals) py.False().obj.py else py.True().obj.py;
}
}

const self: *PyTypeStruct(definition) = @ptrCast(pyself);
const other = tramp.Trampoline(
sig.argsParam orelse unreachable,
).unwrap(.{ .py = pyother }) catch return null;

const result = tramp.coerceError(func(&self.state, other)) catch return null;
return (py.createOwned(result) catch return null).py;
}
};
}

fn RichCompare(comptime definition: type) type {
const BinaryFunc = *const fn (*ffi.PyObject, *ffi.PyObject) callconv(.C) ?*ffi.PyObject;
const errorMsg =
Expand Down Expand Up @@ -669,9 +703,13 @@ fn RichCompare(comptime definition: type) type {

const compareFuncs = blk: {
var funcs_: [6]?BinaryFunc = .{ null, null, null, null, null, null };
for (funcs.compareFuncs, 0..) |func, i| {
if (@hasDecl(definition, func)) {
funcs_[i] = &BinaryOperator(definition, func).call;
for (&funcs_, funcs.compareFuncs) |*func, funcName| {
if (@hasDecl(definition, funcName)) {
if (std.mem.eql(u8, funcName, "__eq__") or std.mem.eql(u8, funcName, "__ne__")) {
func.* = &EqualsOperator(definition, funcName).call;
} else {
func.* = &BinaryOperator(definition, funcName).call;
}
}
}
break :blk funcs_;
Expand Down
10 changes: 10 additions & 0 deletions test/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ def test_justequals():
assert cmp1 <= cmp2


# Test short circuit logic in pydust that handles
# cases where __eq__ expects same types but values clearly are not
def test_justequals_different_type():
cmp1 = operators.Equals(1)
cmp2 = 2

assert not cmp1 == cmp2
assert cmp1 != cmp2


def test_justLessThan():
cmp1 = operators.LessThan("abc")
cmp2 = operators.LessThan("abd")
Expand Down

0 comments on commit 01cdd47

Please sign in to comment.