Segment Tree 的应用

  • 最适合用 Segment tree 的情形最好同时满足以下三点:

    • 区间查找 min/max

    • 频繁 update

    • 频繁 query


Range Sum Query - Mutable

这个题注意算range sum的时候,if(i==cur.start && j==cur.end) return cur.val; 这个要放在后面。否则当前的cur没判断,就进入cur.left cur.right.

//if(i==cur.start && j==cur.end) return cur.val; 
        int left = Math.max(i, cur.left);
        int right = Math.min(j, cur.end);
        if(cur.start==left && cur.right==right) return cur.val;
public class NumArray {
    private class SegmentTreeNode{
        int start;
        int end;
        int sum;
        SegmentTreeNode left, right;

        public SegmentTreeNode(){}
        public SegmentTreeNode(int start, int end, int sum){
            this.start = start;
            this.end = end;
            this.sum = sum;
            this.left = null;
            this.right = null;
        }
    }

    SegmentTreeNode root;

    public NumArray(int[] nums) {
        root = buildTree(nums, 0, nums.length - 1);
    }

    private SegmentTreeNode buildTree(int[] nums, int start, int end){
        if(nums == null || nums.length == 0) return null;
        if(start == end) return new SegmentTreeNode(start, end, nums[start]);

        int mid = start + (end - start) / 2;

        SegmentTreeNode left = buildTree(nums, start, mid);
        SegmentTreeNode right = buildTree(nums, mid + 1, end);

        SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
        root.left = left;
        root.right = right;

        return root;
    }

    void update(int i, int val) {
        update(root, i, val);
    }

    private void update(SegmentTreeNode root, int i, int val){
        if(root == null) return;
        if(i < root.start) return;
        if(i > root.end) return;

        if(root.start == i && root.end == i) {
            root.sum = val;
            return;
        }

        update(root.left, i, val);
        update(root.right, i, val);

        root.sum = root.left.sum + root.right.sum;
    }

    public int sumRange(int i, int j) {
        return sumRange(root, i, j);
    }

    private int sumRange(SegmentTreeNode root, int i, int j){
        if(root == null) return 0;
        if(i > root.end) return 0;
        if(j < root.start) return 0;

        i = Math.max(i, root.start);
        j = Math.min(j, root.end);

        if(root.start == i && root.end == j) return root.sum;

        int left = sumRange(root.left, i, j);
        int right = sumRange(root.right, i, j);

        return left + right;
    }
}


// Your NumArray object will be instantiated and called as such:
// NumArray numArray = new NumArray(nums);
// numArray.sumRange(0, 1);
// numArray.update(1, 10);
// numArray.sumRange(1, 2);

Range Sum Query - Immutable

这题从类型上讲,看着和上一题非常像。然而其实这题因为需要的操作比较简单,其实就是一个 prefix sum 数组的 dp ...

教育了我们 segment tree 虽屌,也不要一言不合就随便用。。

  • 最适合用 Segment tree 的情形最好同时满足以下三点:

    • 区间查找

    • 频繁 update

    • 频繁 query

  • 在只有区间没有 update 的情况下,其实是一个一维/二维的 DP 问题,并不能体现出 segment tree 的优势。

  • 前缀和数组记得在最前面加上 sum = 0 的 padding.

public class NumArray {
    int[] prefixSum;

    public NumArray(int[] nums) {
        if(nums == null || nums.length == 0) return;

        prefixSum = new int[nums.length + 1];
        prefixSum[0] = 0;
        prefixSum[1] = nums[0];
        for(int i = 1; i < nums.length; i++){
            prefixSum[i + 1] = prefixSum[i] + nums[i];
        }
    }

    public int sumRange(int i, int j) {
        return prefixSum[j + 1] - prefixSum[i];
    }
}

Range Sum Query 2D - Mutable

这道题当然可以把矩阵降维之后用 segment tree 解,把一个 region 拆分成若干个 interval of rows 然后把结果加起来,但是很慢。

这题既体现了 segment tree的应用,又暴露了 segment tree的问题。

  • width = m, height = n, 现有 1D segment tree 的复杂度

    • build O(mn)

    • update O(log(mn))

    • query O(n * log (mn))

  • 因为这题更适合用binary index tree解,另一个教程贴在这里,还有这里,加上这个陈老师推荐的中文帖子

Fenwick tree can also be used to update and query subarrays in multidimensional arrays with complexity, where d is number of dimensions and n is the number of elements along each dimension.

public class NumMatrix {
    class TreeNode{
        int start;
        int end;
        int val;
        TreeNode left;
        TreeNode right;
        public TreeNode(int start, int end, int val){
            this.start = start;
            this.end = end;
            this.val = val;
        }
    }


    public int getIdx(int a, int b){
        return a*width+b;
    }

    public TreeNode buildTree(int[][] matrix, int start, int end){
        if(start==end){
            return new TreeNode(start, end, matrix[start/width][end%width]);
        }
        int mid = start + (end-start)/2;
        TreeNode left = buildTree(matrix, start, mid);
        TreeNode right = buildTree(matrix, mid+1, end);
        TreeNode cur = new TreeNode(start, end, left.val+right.val);
        cur.left = left;
        cur.right = right;
        return cur;
    }

    int width;
    int height;
    TreeNode root;
    public NumMatrix(int[][] matrix) {
        if(matrix == null || matrix.length == 0) return;
        height = matrix.length;
        width = matrix[0].length;

        root=buildTree(matrix, 0, width*height-1);
    }

    public void update(TreeNode root, int index, int val){
        if(root==null || index<root.start || index>root.end) return;
        if(root.start == index && root.end==index) {root.val = val;return;}
        update(root.left, index, val);
        update(root.right, index, val);
        root.val = root.left.val+root.right.val;
    }

    public void update(int row, int col, int val) {
        update(root, getIdx(row, col), val);
    }

    public int sumRegion(TreeNode root, int start, int end){
        if(root==null || start>root.end || end<root.start) return 0;
        start = Math.max(start, root.start);
        end = Math.min(end, root.end);
        if(start==root.start && end==root.end) {return root.val;}
        return sumRegion(root.left, start, end) + sumRegion(root.right, start, end);
    }

    public int sumRegion(int row1, int col1, int row2, int col2) {
        int sum = 0;
        for(int i=row1;i<=row2;i++){
            sum+=sumRegion(root, getIdx(i,col1), getIdx(i, col2));
        }
        return sum;
    }
}

Count of Range Sum

这道题试图用segment tree写,写到最后search的时候发现不对劲,传统search会遗漏掉一些空间,本质因为传统search传入的是lowidx和highidx,所以query的时候可以有效剪枝。而这里是用lower bound和higher bound进行的剪枝。没法有效进行。

参考了论坛的segment tree解法,非常的不直观,需要对sum数组建立segment tree,再查询。查询的方式也很奇怪,花了时间研究后放弃这个仅有17票的解法。

更简洁高效的解法是merge sort。见two pointers, merge sort

results matching ""

    No results matching ""