← back

ML Pipeline DAG Scheduler with Critical Path Analysis

#270 · MLOps · Hard

⊣ Solve on deep-ml.com

Problem

Build a DAG scheduler for an ML pipeline. Given a set of tasks with dependencies and execution times, determine the topological execution order, the critical path (longest path through the DAG), and the minimum total execution time assuming unlimited parallelism.

Solution

Perform topological sorting, compute earliest start/finish times forward, then latest start/finish times backward to find the critical path.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from collections import deque

def dag_scheduler(
    tasks: dict[str, dict],
) -> dict:
    # tasks format: {task_id: {"duration": float, "deps": [task_id, ...]}}

    # Build adjacency and reverse adjacency
    in_degree: dict[str, int] = {t: 0 for t in tasks}
    adj: dict[str, list[str]] = {t: [] for t in tasks}
    for t, info in tasks.items():
        for dep in info.get("deps", []):
            adj[dep].append(t)
            in_degree[t] += 1

    # Topological sort (Kahn's algorithm)
    queue = deque([t for t in tasks if in_degree[t] == 0])
    topo_order: list[str] = []
    while queue:
        node = queue.popleft()
        topo_order.append(node)
        for neighbor in adj[node]:
            in_degree[neighbor] -= 1
            if in_degree[neighbor] == 0:
                queue.append(neighbor)

    if len(topo_order) != len(tasks):
        return {"error": "Cycle detected in DAG"}

    # Forward pass: earliest start and finish times
    earliest_start: dict[str, float] = {}
    earliest_finish: dict[str, float] = {}
    for t in topo_order:
        es = 0.0
        for dep in tasks[t].get("deps", []):
            es = max(es, earliest_finish[dep])
        earliest_start[t] = es
        earliest_finish[t] = es + tasks[t]["duration"]

    # Minimum total time = max of all earliest finish times
    makespan = max(earliest_finish.values())

    # Backward pass: latest start and finish times
    latest_finish: dict[str, float] = {}
    latest_start: dict[str, float] = {}
    for t in reversed(topo_order):
        if not adj[t]:  # no successors
            latest_finish[t] = makespan
        else:
            latest_finish[t] = min(latest_start[s] for s in adj[t])
        latest_start[t] = latest_finish[t] - tasks[t]["duration"]

    # Slack and critical path
    slack: dict[str, float] = {}
    critical_tasks: list[str] = []
    for t in topo_order:
        slack[t] = round(latest_start[t] - earliest_start[t], 6)
        if abs(slack[t]) < 1e-9:
            critical_tasks.append(t)

    # Reconstruct critical path as a chain
    critical_path: list[str] = []
    for t in topo_order:
        if t in critical_tasks:
            critical_path.append(t)

    # Compute parallel schedule (levels)
    levels: list[list[str]] = []
    assigned: set[str] = set()
    remaining = set(topo_order)
    while remaining:
        level = [t for t in topo_order if t in remaining and
                 all(d not in remaining for d in tasks[t].get("deps", []))]
        levels.append(level)
        for t in level:
            remaining.discard(t)
            assigned.add(t)

    return {
        "topo_order": topo_order,
        "critical_path": critical_path,
        "makespan": round(makespan, 6),
        "earliest_start": {t: round(v, 6) for t, v in earliest_start.items()},
        "slack": slack,
        "parallel_levels": levels,
    }

Explanation

  1. Build the DAG from task dependencies and perform topological sort using Kahn's algorithm.
  2. Forward pass: compute the earliest each task can start (max finish time of all its dependencies).
  3. The makespan (minimum total time) is the maximum earliest finish across all tasks.
  4. Backward pass: compute the latest each task can start without delaying the makespan.
  5. Slack = latest start - earliest start. Tasks with zero slack are on the critical path.
  6. The critical path determines the minimum execution time — it cannot be shortened without reducing individual task durations.

Complexity

  • Time: O(V + E) where V is the number of tasks and E is the number of dependency edges
  • Space: O(V + E) for adjacency lists and scheduling data