Skip to content

Commit

Permalink
Add support for @calc_node(ignore_self)
Browse files Browse the repository at this point in the history
  • Loading branch information
edparcell committed Aug 30, 2024
1 parent ac4a625 commit 0417924
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
21 changes: 12 additions & 9 deletions loman/computeengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
import tempfile
import traceback
import typing
from collections import defaultdict
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait
from datetime import datetime
from enum import Enum

import inspect
from typing import List, Dict
from typing import List, Dict, Tuple, Any, Callable

import decorator
import dill
Expand Down Expand Up @@ -108,13 +107,13 @@ def add_to_comp(self, comp, name, ctx):

@dataclass
class InputNode(Node):
args: List = field(default_factory=list)
args: Tuple[Any, ...] = field(default_factory=tuple)
kwds: Dict = field(default_factory=dict)

def add_to_comp(self, comp: 'Computation', name: str, ctx: dict):
kwds = ctx.copy()
kwds.update(self.kwds)
comp.add_node(name, *self.args, **kwds)
comp.add_node(name, **kwds)


def input_node(*args, **kwds):
Expand All @@ -123,18 +122,22 @@ def input_node(*args, **kwds):

@dataclass
class CalcNode(Node):
f: typing.Callable
args: List = field(default_factory=list)
f: Callable
kwds: Dict = field(default_factory=dict)

def add_to_comp(self, comp, name, ctx):
kwds = ctx.copy()
kwds.update(self.kwds)
comp.add_node(name, self.f, *self.args, **kwds)
comp.add_node(name, self.f, **kwds)


def calc_node(f, *args, **kwds):
return CalcNode(f, args, kwds)
def calc_node(f=None, **kwds):
def wrap(func):
return CalcNode(func, kwds)

if f is None:
return wrap
return wrap(f)


def ComputationFactory(maybe_cls=None, *, ignore_self=True):
Expand Down
22 changes: 22 additions & 0 deletions test/test_class_style_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,25 @@ def d(self, b, c):
comp.add_node('self', value=1)
comp.compute_all()
assert comp.s.d == States.UPTODATE and comp.v.d == 10


def test_computation_factory_methods_calc_node_ignore_self():
@ComputationFactory(ignore_self=False)
class FooComp:
a = input_node(value=3)

@calc_node
def b(a):
return a + 1

@calc_node(ignore_self=True)
def c(self, a):
return 2 * a

@calc_node
def d(b, c):
return b + c

comp = FooComp()
comp.compute_all()
assert comp.s.d == States.UPTODATE and comp.v.d == 10

0 comments on commit 0417924

Please sign in to comment.