Union-Find,并查集基础


并查集是一个用于解决 disjoint sets 类问题的常用数据结构,用于处理集合中

  • 元素的归属 find

  • 集合的融合 union

  • Online algorithm, stream of input

  • 计算 number of connected components

  • 不支持 delete

CSDN 有一个非常好的总结帖子

我做并查集问题的时候,最喜欢的方式是直接无脑撸一个 weighted union find class 出来,然后在具体问题上直接调用接口。。


以下是 Princeton 的 Algorithm 算法课上的样例代码,这老头可真喜欢用 array 啊。。。他的并查集和KMP算法构建方式都稍微有点麻烦。

public class WeightedQuickUnionPathCompressionUF {
    private int[] parent;  // parent[i] = parent of i
    private int[] size;    // size[i] = number of sites in tree rooted at i
                           // Note: not necessarily correct if i is not a root node
    private int count;     // number of components

    public WeightedQuickUnionPathCompressionUF(int N) {
        count = N;
        parent = new int[N];
        size = new int[N];
        for (int i = 0; i < N; i++) {
            parent[i] = i;
            size[i] = 1;
        }
    }

    /**
     * @return the number of components
     */
    public int count() {
        return count;
    }


    /**
     * @return the component identifier for the component containing site, 有点问题 
     */
    public int find(int p) {
        int root = p;
        while (root != parent[root])
            root = parent[root];
        while (p != root) {
            int newp = parent[p];
            parent[p] = root;
            p = newp;
        }
        return root;
    }

    //alternate

    public int find(int i) {
        for (;i != parent[i]; i = parent[i])
            parent[i] = parent[parent[i]]; //path compression
        return i;
    }

   /**
     * @return true if the two sites p and qare in the same component;
     *         false otherwise
     */
    public boolean connected(int p, int q) {
        return find(p) == find(q);
    }

    /**
     * Merges the component containing site p with the 
     * the component containing site q.
     *
     * @param  p the integer representing one site
     * @param  q the integer representing the other site
     */
    public void union(int p, int q) {
        int rootP = find(p);
        int rootQ = find(q);
        if (rootP == rootQ) return;

        // make smaller root point to larger one
        if (size[rootP] < size[rootQ]) {
            parent[rootP] = rootQ;
            size[rootQ] += size[rootP];
        }
        else {
            parent[rootQ] = rootP;
            size[rootP] += size[rootQ];
        }
        count--;
    }

Longest Consecutive Sequence

其实我们建的是对于每一个元素的 1-1 mapping,或者说是一个 元素之间的 graph,表示 join 关系。

中间有一个 [2147483646,-2147483647,0,2,2147483644,-2147483645,2147483645] 的 test case 始终出 bug,把 find 函数的返回 type 从 int 改到 Integer 就好了。

看来以后不能总是假设 int 和 Integer 是完全相等的,尤其是这种在 hashmap 里以 Integer 为 Key 的情况,要尽可能的保持类型正确。

union的时候if(rootp == rootq) return; size.put(rootp, sizep+sizeq)会让size翻倍

public class Solution {

    class WeightedUnionFind{
            HashMap<Integer, Integer> parent = new HashMap<Integer, Integer>();
            HashMap<Integer, Integer> size = new HashMap<Integer, Integer>();
            int maxSize = 1;
            public WeightedUnionFind(int[] nums){
                int n = nums.length;
                for(int i=0;i<n;i++){
                    parent.put(nums[i], nums[i]);
                    size.put(nums[i],1);
                }
            }

            public Integer root(Integer i){
                if(!parent.containsKey(i)) return null;
                while(i!=parent.get(i))
                    i=parent.get(i);
                return i;
            }

            public void union(int p, int q){
                Integer rootp = root(p);
                Integer rootq = root(q);
                if(rootp==null || rootq==null) return;
                if(rootp == rootq) return;

                int sizep = size.get(rootp);
                int sizeq = size.get(rootq);
                if(sizep<sizeq){
                    parent.put(rootp, rootq);
                    size.put(rootq, sizep+sizeq);
                }
                else{
                    parent.put(rootq, rootp);
                    size.put(rootp, sizep+sizeq);
                }

                maxSize = Math.max(maxSize, sizep + sizeq);
                System.out.println(maxSize+ " " + sizep+" "+sizeq);
            }

        }

    public int longestConsecutive(int[] nums) {
        if(nums == null || nums.length == 0) return 0;
        WeightedUnionFind uf = new WeightedUnionFind(nums);
        for(int num : nums){
            uf.union(num, num-1);
            uf.union(num, num+1);
        }
        return uf.maxSize;
    }
}

这道题更简单的是用HashMap maintain, 朝两边扫

public class Solution {
    public int longestConsecutive(int[] nums) {
        int ans = 0;
        HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
        for(int num:nums){
            if(!map.containsKey(num)){
                int left = map.getOrDefault(num-1, 0);
                int right = map.getOrDefault(num+1, 0);
                int sum = left+right+1;
                map.put(num, sum);
                ans = Math.max(ans, sum);
                map.put(num-left, sum);
                map.put(num+right, sum);
            }
            else continue;
        }
        return ans;
    }
}

Number of Islands II

  • 常犯错误:二维转一维 index 的时候总乘错,搞混。正确的是 x * cols + y,以后自己想的时候还是用 rows / cols 吧

  • 在这题里降维成一维 index 是可以的,不过要注意边界处理,否则某一行的最后一个元素会连通到下一行的第一个元素上去。

  • 注意用if而不是else if, 可以同时和四个方向union

public class Solution {
    private class WeightedUnionFind{
        HashMap<Integer, Integer> parent;
        HashMap<Integer, Integer> size;
        int count;

        public WeightedUnionFind(){
            parent = new HashMap<Integer, Integer>();
            size = new HashMap<Integer, Integer>();
            count = 0;
        }

        public Integer find(Integer index){
            if(!parent.containsKey(index)) return null;

            Integer root = index;
            while(root != parent.get(root)){
                root = parent.get(root);
            }
            while(index != root){
                Integer next = parent.get(index);
                parent.put(index, root);
                index = next;
            }
            return root;
        }

        public void union(Integer a, Integer b){
            Integer aRoot = find(a);
            Integer bRoot = find(b);
            if(aRoot == null || bRoot == null) return;
            if(aRoot.equals(bRoot)) return;

            int aSize = size.get(aRoot);
            int bSize = size.get(bRoot);

            if(aSize > bSize){
                parent.put(bRoot, aRoot);
                size.put(aRoot, aSize + bSize);
            } else {
                parent.put(aRoot, bRoot);
                size.put(bRoot, aSize + bSize);
            }
            count --;
        }

        public void add(Integer index){
            if(!parent.containsKey(index)){
                parent.put(index, index);
                size.put(index, 1);
                count ++;
            }
        }

        public int getCount(){
            return this.count;
        }

    }



    public List<Integer> numIslands2(int m, int n, int[][] positions) {
        List<Integer> list = new ArrayList<Integer>();
        if(positions == null || positions.length == 0) return list;
        WeightedUnionFind uf = new WeightedUnionFind();

        for(int i = 0; i < positions.length; i++){
            int x = positions[i][0];
            int y = positions[i][1];
            int index = x * n + y;

            uf.add(index);

            if(x + 1 <= m - 1)   uf.union(index, (x + 1) * n + y);
            if(x > 0)            uf.union(index, (x - 1) * n + y);
            if(y + 1 <= n - 1)   uf.union(index, x * n + y + 1);
            if(y > 0)            uf.union(index, x * n + y - 1);

            list.add(uf.getCount());
        }

        return list;
    }

}

Number of Connected Components in an Undirected Graph

这题如果是把所有的 nodes 给你,其实很好做,每个点做 dfs 就好了,用 hashset 避免重复访问,毕竟是 undirected graph.

然而这题比较 gay 的地方在于。。。数据是以一个个 edge 的方式给你的,强行让你以一个 union-find 的方式一个一个节点添加,那么显然读取所有 edges 建图再去 dfs 就是很不现实的做法,而且也失去了 online algorithm 的优势。

  • 对于 Integer type,要用 a.equals(b),不要用 ==

代码轻松愉快~

public class Solution {

    class UnionFind{
        HashMap<Integer, Integer> parent;
        HashMap<Integer, Integer> size;
        int count;

        public UnionFind(int n){
            parent = new HashMap<Integer, Integer>();
            size = new HashMap<Integer, Integer>();
            for(int i=0;i<n;i++){
                parent.put(i,i);
                size.put(i,1);
            }
            count = n;
        }

        public Integer root(int i){
            if(!parent.containsKey(i)) return null;
            while(i!=parent.get(i))
                i = parent.get(i);
            return i;
        }

        public void union(int p, int q){
            Integer rootp = root(p);
            Integer rootq = root(q);
            if(rootp==null||rootq==null) return;
            if(rootp.equals(rootq)) return;
            int sizep = size.get(p);
            int sizeq = size.get(q);
            if(sizep<sizeq){
                parent.put(rootp, rootq);
                size.put(rootq, sizep+sizeq);
            }
            else{
                parent.put(rootq, rootp);
                size.put(rootq, sizep+sizeq);
            }
            count--;
        }

    }

    public int countComponents(int n, int[][] edges) {
        if(edges == null || edges.length == 0) return n;
        UnionFind uf = new UnionFind(n);
        for(int[] edge:edges){
            uf.union(edge[0], edge[1]);
        }
        return uf.count;
    }
}

方法二是根据边建立图, 对单源点DFS floodfill

public class Solution {

    public int countComponents(int n, int[][] edges) {
        boolean[] vis = new boolean[n];
        HashMap<Integer, ArrayList<Integer>> map = new  HashMap<Integer, ArrayList<Integer>>();
        for(int[] edge: edges){
            if(!map.containsKey(edge[0])){
                map.put(edge[0], new ArrayList<Integer>());
            }
            if(!map.containsKey(edge[1])){
                map.put(edge[1], new ArrayList<Integer>());
            }
            map.get(edge[0]).add(edge[1]);
            map.get(edge[1]).add(edge[0]);
        }
        int count = 0;
        for(int i=0;i<n;i++){
            if(!vis[i]){
                dfs(map, vis, i);
                count++;
            }
        }
        return count;
    }

    public void dfs(HashMap<Integer, ArrayList<Integer>> map, boolean[] vis, int cur){
        if(vis[cur] || !map.containsKey(cur)) return;
        vis[cur] =true;
        for(int neighbor: map.get(cur))
            dfs(map, vis, neighbor);
    }
}

results matching ""

    No results matching ""