diff --git a/contrib/pyln-testing/pyln/testing/fixtures.py b/contrib/pyln-testing/pyln/testing/fixtures.py index 57932515ba81..970739716c4d 100644 --- a/contrib/pyln-testing/pyln/testing/fixtures.py +++ b/contrib/pyln-testing/pyln/testing/fixtures.py @@ -1,6 +1,6 @@ from concurrent import futures from pyln.testing.db import SqliteDbProvider, PostgresDbProvider -from pyln.testing.utils import NodeFactory, BitcoinD, ElementsD, env, DEVELOPER +from pyln.testing.utils import NodeFactory, BitcoinD, ElementsD, env, DEVELOPER, LightningNode import logging import os @@ -67,6 +67,11 @@ def test_name(request): } +@pytest.fixture +def node_cls(): + return LightningNode + + @pytest.fixture def bitcoind(directory, teardown_checks): chaind = network_daemons[env('TEST_NETWORK', 'regtest')] @@ -145,13 +150,14 @@ def teardown_checks(request): @pytest.fixture -def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks): +def node_factory(request, directory, test_name, bitcoind, executor, db_provider, teardown_checks, node_cls): nf = NodeFactory( test_name, bitcoind, executor, directory=directory, db_provider=db_provider, + node_cls=node_cls ) yield nf diff --git a/contrib/pyln-testing/pyln/testing/utils.py b/contrib/pyln-testing/pyln/testing/utils.py index 24c8660e9c10..a87d9c48ad8d 100644 --- a/contrib/pyln-testing/pyln/testing/utils.py +++ b/contrib/pyln-testing/pyln/testing/utils.py @@ -465,7 +465,7 @@ def getnewaddress(self): class LightningD(TailableProc): def __init__(self, lightning_dir, bitcoindproxy, port=9735, random_hsm=False, node_id=0): TailableProc.__init__(self, lightning_dir) - self.executable = 'lightningd/lightningd' + self.executable = 'lightningd' self.lightning_dir = lightning_dir self.port = port self.cmd_prefix = [] @@ -903,7 +903,7 @@ def passes_filters(hmsg, filters): class NodeFactory(object): """A factory to setup and start `lightningd` daemons. """ - def __init__(self, testname, bitcoind, executor, directory, db_provider): + def __init__(self, testname, bitcoind, executor, directory, db_provider, node_cls): self.testname = testname self.next_id = 1 self.nodes = [] @@ -912,6 +912,7 @@ def __init__(self, testname, bitcoind, executor, directory, db_provider): self.directory = directory self.lock = threading.Lock() self.db_provider = db_provider + self.node_cls = node_cls def split_options(self, opts): """Split node options from cli options @@ -985,7 +986,7 @@ def get_node(self, node_id=None, options=None, dbfile=None, # Get the DB backend DSN we should be using for this test and this # node. db = self.db_provider.get_db(lightning_dir, self.testname, node_id) - node = LightningNode( + node = self.node_cls( node_id, lightning_dir, self.bitcoind, self.executor, db=db, port=port, options=options, **kwargs ) diff --git a/tests/fixtures.py b/tests/fixtures.py index ff49eec60ced..478d50970e2c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,2 +1,19 @@ from utils import DEVELOPER, TEST_NETWORK # noqa: F401,F403 from pyln.testing.fixtures import directory, test_base_dir, test_name, chainparams, node_factory, bitcoind, teardown_checks, db_provider, executor # noqa: F401,F403 +from pyln.testing import utils + +import pytest + + +@pytest.fixture +def node_cls(): + return LightningNode + + +class LightningNode(utils.LightningNode): + def __init__(self, *args, **kwargs): + utils.LightningNode.__init__(self, *args, **kwargs) + + # Yes, we really want to test the local development version, not + # something in out path. + self.daemon.executable = 'lightningd/lightningd'