diff --git a/locust/test/test_interruptable_task.py b/locust/test/test_interruptable_task.py new file mode 100644 index 0000000000..b40ce4361a --- /dev/null +++ b/locust/test/test_interruptable_task.py @@ -0,0 +1,48 @@ +from collections import defaultdict +from unittest import TestCase + +from locust import SequentialTaskSet, User, constant, task +from locust.env import Environment +from locust.exception import StopUser + + +class InterruptableTaskSet(SequentialTaskSet): + counter: defaultdict[str, int] = defaultdict(int) + + def on_start(self): + super().on_start() + self.counter["on_start"] += 1 + + @task + def t1(self): + self.counter["t1"] += 1 + self.interrupt(reschedule=False) + + @task + def t2(self): + self.counter["t2"] += 1 + + def on_stop(self): + super().on_stop() + self.counter["on_stop"] += 1 + if self.counter["on_stop"] >= 2: + raise StopUser() + + +class TestInterruptableTask(TestCase): + def setUp(self): + super().setUp() + + class InterruptableUser(User): + host = "127.0.0.1" + tasks = [InterruptableTaskSet] + wait_time = constant(0) + + self.locust = InterruptableUser(Environment(catch_exceptions=True)) + + def test_interruptable_task(self): + self.locust.run() + self.assertEqual(InterruptableTaskSet.counter.get("on_start"), 2) + self.assertEqual(InterruptableTaskSet.counter.get("t1"), 2) + self.assertEqual(InterruptableTaskSet.counter.get("t2", 0), 0) + self.assertEqual(InterruptableTaskSet.counter.get("on_stop"), 2) diff --git a/locust/user/task.py b/locust/user/task.py index a09beb7603..68d0b63baa 100644 --- a/locust/user/task.py +++ b/locust/user/task.py @@ -350,6 +350,8 @@ def run(self): except InterruptTaskSet as e: try: self.on_stop() + except (StopUser, GreenletExit): + raise except Exception: logging.error("Uncaught exception in on_stop: \n%s", traceback.format_exc()) if e.reschedule: