Union-Find (DSU)
Track connected components with path compression and union by rank. find() returns the root representative of a component; union() merges two components. Nearly O(1) per operation amortized.
Union-Find, also called Disjoint Set Union (DSU), solves a deceptively common problem: you have a collection of elements, and over time you're told that certain pairs are connected. You need to answer queries like "are A and B in the same group?" and "how many separate groups are there?" efficiently as the connections accumulate.
The straightforward solution, maintaining an adjacency list and running BFS/DFS for each query, costs per query. If you have N elements and perform Q union operations followed by Q connectivity queries, you're looking at total, which is far too slow for large inputs. Union-Find answers both union and find operations in nearly amortized time by maintaining a smarter internal representation.
Each disjoint set is represented as a tree. Every element starts as its own root (parent of itself). When two elements are unioned, one tree's root is attached under the other's root. To check if two elements are in the same set, you find both of their roots: if the roots are the same node, they're connected.
Without any optimizations this degenerates to a linked list ( per find). Two optimizations together bring it to near-:
Path Compression: When following parent pointers up to the root during a find, make every node along the path point directly to the root. Future finds from those nodes skip straight to the top.
Union by Rank: Track an approximate depth (rank) for each tree. When unioning two trees, attach the shallower one under the deeper one. This keeps trees balanced and short.
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x]) # path compression
return self.parent[x]
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False # already connected, this edge is redundant
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return TrueThe find method uses recursive path compression: after the recursive call resolves the root, it writes the root directly into self.parent[x] before returning. On the next call to find(x), it hits the base case immediately.
Start with five elements: 0, 1, 2, 3, 4. Initially parent = [0, 1, 2, 3, 4].
union(0, 1): roots are 0 and 1, different ranks (both 0). Attach 1 under 0. Both ranks were equal so rank[0] becomes 1. parent = [0, 0, 2, 3, 4].union(2, 3): similarly attach 3 under 2. parent = [0, 0, 2, 2, 4].union(0, 2): root of 0 is 0 (rank 1), root of 2 is 2 (rank 1). Tie: attach 2 under 0, rank[0] becomes 2. parent = [0, 0, 0, 2, 4].union(1, 3): find(1): parent[1] = 0, root is 0. find(3): parent[3] = 2, parent[2] = 0, root is 0. Path compression sets parent[3] = 0 directly. Same root, so it returns False.After step 4, elements 0, 1, 2, 3 are all in one component. Element 4 is isolated. Components: 2.
The combined effect of path compression and union by rank gives an amortized time per operation of , where α is the inverse Ackermann function. For any N you will ever encounter in practice (or in the universe), α(N) ≤ 4. You can treat this as constant time.
Notice that union returns False when the two elements already share a root. This is exactly a cycle: if you're building a graph edge by edge and a proposed edge connects two nodes already in the same component, adding it would create a cycle. This makes Union-Find the canonical solution for problems like LeetCode 684. Redundant Connection, where you want to find the first edge that forms a cycle.
def find_redundant_connection(edges):
uf = UnionFind(len(edges) + 1)
for u, v in edges:
if not uf.union(u, v):
return [u, v]The number of distinct components is the number of elements that are their own parent (i.e., roots). You can track this with a counter: initialize it to N, and decrement it each time union returns True (two previously separate components merge into one).
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.components = n
def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
self.components -= 1
return TrueThe standard LeetCode 200. Number of Islands is naturally solved with DFS, but Union-Find shines in the incremental variant LeetCode 305. Number of Islands II, where the grid arrives one land cell at a time and you must report the running island count after each addition. For each new land cell, union it with any adjacent land neighbors and track components.
In LeetCode 721. Accounts Merge, multiple accounts may share email addresses and need to be merged. Assign each email a numeric ID. For each account, union all its emails together. Finally, group emails by their root representative. This converts an brute-force merge into a near-linear pass.
def accounts_merge(accounts):
email_to_id = {}
email_to_name = {}
uid = 0
for account in accounts:
name = account[0]
for email in account[1:]:
if email not in email_to_id:
email_to_id[email] = uid
email_to_name[email] = name
uid += 1
uf = UnionFind(uid)
for account in accounts:
first_id = email_to_id[account[1]]
for email in account[2:]:
uf.union(first_id, email_to_id[email])
from collections import defaultdict
groups = defaultdict(list)
for email, eid in email_to_id.items():
groups[uf.find(eid)].append(email)
return [[email_to_name[emails[0]]] + sorted(emails)
for emails in groups.values()]In LeetCode 990. Satisfiability of Equality Equations, you're given equations like ["a==b", "b!=c", "c==a"] and must decide if they can all be satisfied simultaneously. The insight: process in two passes. First, union all variables connected by ==. Then, check every != constraint: if both sides share a root, the equations are unsatisfiable.
def equations_possible(equations):
uf = UnionFind(26)
for eq in equations:
if eq[1] == '=':
uf.union(ord(eq[0]) - ord('a'), ord(eq[3]) - ord('a'))
for eq in equations:
if eq[1] == '!':
if uf.find(ord(eq[0]) - ord('a')) == uf.find(ord(eq[3]) - ord('a')):
return False
return TrueThe two-pass pattern appears whenever you have both positive and negative constraints. Positive constraints build groups; negative constraints validate them.
LeetCode 1202. Smallest String With Swaps gives you a string and a list of index pairs (i, j) where you can swap s[i] and s[j] any number of times. Swaps are transitive: if you can swap indices 0↔1 and 1↔2, you can rearrange positions 0, 1, 2 freely. Union all swappable index pairs, group characters by component, sort each group, and place the sorted characters back.
def smallest_string_with_swaps(s, pairs):
n = len(s)
uf = UnionFind(n)
for i, j in pairs:
uf.union(i, j)
from collections import defaultdict
groups = defaultdict(list)
for i in range(n):
groups[uf.find(i)].append(i)
res = list(s)
for indices in groups.values():
chars = sorted(res[i] for i in indices)
for i, c in zip(sorted(indices), chars):
res[i] = c
return ''.join(res)This "group by component, operate within each group" pattern is broadly useful whenever transitive relationships unlock local rearrangement.
In LeetCode 947. Most Stones Removed with Same Row or Column, stones sit on a 2D plane. You can remove a stone if it shares a row or column with at least one other remaining stone. The key insight: stones that share a row or column are transitively connected, so each connected component of size can be reduced to 1 stone. The answer is total stones minus the number of components.
def remove_stones(stones):
uf = UnionFind(len(stones))
row_map, col_map = {}, {}
for i, (r, c) in enumerate(stones):
if r in row_map:
uf.union(i, row_map[r])
row_map[r] = i
if c in col_map:
uf.union(i, col_map[c])
col_map[c] = i
return len(stones) - uf.componentsUsing hash maps to track one representative per row and per column avoids pairwise comparisons.
LeetCode 1319. Number of Operations to Make Network Connected asks for the minimum number of operations to connect all computers in a network. Each operation moves one cable. If the network has fewer than cables total, it's impossible. Otherwise, every extra cable (one that creates a cycle) can be repurposed. The answer is simply the number of components minus one.
def make_connected(n, connections):
if len(connections) < n - 1:
return -1
uf = UnionFind(n)
for a, b in connections:
uf.union(a, b)
return uf.components - 1This is the purest "count components" application: no extra logic beyond checking the edge budget.
LeetCode 685. Redundant Connection II extends the undirected Redundant Connection to a directed graph that should form a rooted tree. Two problems can arise: a node with two parents, or a cycle, and sometimes both. You need to handle three cases:
def find_redundant_directed_connection(edges):
n = len(edges)
parent = [0] * (n + 1)
cand1 = cand2 = None
# Step 1: detect if any node has two parents
for u, v in edges:
if parent[v] != 0:
cand1 = [parent[v], v] # first edge to v
cand2 = [u, v] # second edge to v
break
parent[v] = u
# Step 2: union-find, skipping cand2 if it exists
uf = UnionFind(n + 1)
for u, v in edges:
if cand2 and u == cand2[0] and v == cand2[1]:
continue
if not uf.union(u, v):
return cand1 if cand1 else [u, v]
return cand2The trick is to tentatively skip the second candidate edge. If a cycle still forms, the first candidate caused it. If no cycle forms, the second candidate was the problem.
find or union operation, effectively parent and rank arraysYou're probably looking at Union-Find when:
Common templates:
union returns False when both endpoints already share a root. Example: Redundant Connection.== constraints first, then verify no != pair shares a root. Example: Satisfiability of Equality Equations.Practice Problems (38)