Skip to content

Commit

Permalink
Deduplicate outputs in luigi.task.flatten_output (spotify#3106)
Browse files Browse the repository at this point in the history
  • Loading branch information
starhel committed Jan 31, 2023
1 parent 9d50be4 commit d9a950e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 9 deletions.
23 changes: 14 additions & 9 deletions luigi/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
It is a central concept of Luigi and represents the state of the workflow.
See :doc:`/tasks` for an overview.
"""

from collections import deque, OrderedDict
from contextlib import contextmanager
import logging
import traceback
Expand Down Expand Up @@ -955,7 +955,7 @@ def getpaths(struct):

def flatten(struct):
"""
Creates a flat list of all all items in structured output (dicts, lists, items):
Creates a flat list of all items in structured output (dicts, lists, items):
.. code-block:: python
Expand Down Expand Up @@ -992,14 +992,19 @@ def flatten(struct):
def flatten_output(task):
"""
Lists all output targets by recursively walking output-less (wrapper) tasks.
FIXME order consistently.
"""
r = flatten(task.output())
if not r:
for dep in flatten(task.requires()):
r += flatten_output(dep)
return r

output_tasks = OrderedDict() # OrderedDict used as ordered set
tasks_to_process = deque([task])
while tasks_to_process:
current_task = tasks_to_process.popleft()
if flatten(current_task.output()):
if current_task not in output_tasks:
output_tasks[current_task] = None
else:
tasks_to_process.extend(flatten(current_task.requires()))

return flatten(task.output() for task in output_tasks)


def _task_wraps(task_class):
Expand Down
49 changes: 49 additions & 0 deletions test/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,55 @@ class TaskB(luigi.Task):
)


class TaskFlattenOutputTest(unittest.TestCase):
def test_single_task(self):
expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")]

class TestTask(luigi.ExternalTask):
def output(self):
return expected

self.assertListEqual(luigi.task.flatten_output(TestTask()), expected)

def test_wrapper_task(self):
expected = [luigi.LocalTarget("f1.txt"), luigi.LocalTarget("f2.txt")]

class Test1Task(luigi.ExternalTask):
def output(self):
return expected[0]

class Test2Task(luigi.ExternalTask):
def output(self):
return expected[1]

@luigi.util.requires(Test1Task, Test2Task)
class TestWrapperTask(luigi.WrapperTask):
pass

self.assertListEqual(luigi.task.flatten_output(TestWrapperTask()), expected)

def test_wrapper_tasks_diamond(self):
expected = [luigi.LocalTarget("file.txt")]

class TestTask(luigi.ExternalTask):
def output(self):
return expected

@luigi.util.requires(TestTask)
class LeftWrapperTask(luigi.WrapperTask):
pass

@luigi.util.requires(TestTask)
class RightWrapperTask(luigi.WrapperTask):
pass

@luigi.util.requires(LeftWrapperTask, RightWrapperTask)
class MasterWrapperTask(luigi.WrapperTask):
pass

self.assertListEqual(luigi.task.flatten_output(MasterWrapperTask()), expected)


class ExternalizeTaskTest(LuigiTestCase):

def test_externalize_taskclass(self):
Expand Down

0 comments on commit d9a950e

Please sign in to comment.