Count of Range Sum
这道题, 看似简单,O(n*n)的方法很好想到,然而O(nlogn)并不好想到。
这道题试图用segment tree写,写到最后search的时候发现不对劲,传统search会遗漏掉一些空间,本质因为传统search传入的是lowidx和highidx,所以query的时候可以有效剪枝。而这里是用lower bound和higher bound进行的剪枝。没法有效进行。
参考了论坛的segment tree解法,非常的不直观,需要对sum数组建立segment tree,再查询。查询的方式也很奇怪,花了时间研究后放弃这个仅有17票的解法。
merge sort的思路是这样的
- 首先建立sum[n+1]数组。如[1,2,3],sum数组为[0,1,3,6],sum[j]-sum[i]可求出range sum。
- 递归调用merge(start,mid)和merge(mid,end),此时(start,mid)是“单调有序”的sum数组,(mid,end)是“单调有序”的sum数组,并且(mid,end)的sum index 一定比(start,mid)的sum index大。最重要的性质。
- 求解用当前左半部分(start,mid)和右半部分(mid,end)构造 range sum, 有多少合法的解决。 这里又利用到了two pointers的思路。 对左半部分每个i小于mid进行枚举,在右半部分找到对应的j和k,满足sum[j]-sum[i]<=upper, lower<=sum[k]-sum[i], 因为右半部分单调,j>=k.
- 最后记得将左右半部分merge sort,供下次使用。
public class Solution {
public int countRangeSum(int[] nums, int lower, int upper) {
long[] sums = new long[nums.length+1];
for(int i=0;i<nums.length;i++){
sums[i+1]=sums[i]+nums[i];
}
int count = mergeCount(sums, 0, nums.length+1, lower, upper);
return count;
}
public int mergeCount(long[] sums, int start, int end, int lower, int upper){
if(end-start<=1) return 0;
int mid = start+(end-start)/2;
int count = mergeCount(sums, start, mid, lower, upper)+mergeCount(sums, mid, end, lower, upper);
long[] cache = new long[end-start]; int c = 0;
int j = mid; int k=mid; int i = start; int l = mid;
while(i<mid){
while(k<end&&sums[k]-sums[i]<lower) k++;
while(j<end&&sums[j]-sums[i]<=upper) j++;
while(l<end&&sums[l]<=sums[i]) cache[c++]=sums[l++];
count+=j-k;
cache[c] = sums[i];
i++;c++;
}
System.arraycopy(cache, 0, sums, start, l - start);
return count;
}
}
这道题我最喜欢的最简单最直接的方式是利用和 count of smaller numbers after itself的思路,对sum数组建立BST,然后对于每一个sum[j],查询有多少sum[i]在[sum[j]-upper,sum[j]-lower]的range里。 注意查询的时候先查询再插入。
public class Solution {
class TreeNode{
int count;
int leftCount;
long val;
TreeNode left;
TreeNode right;
public TreeNode(long val){
this.val = val;
count = 1;
leftCount = 0;
}
}
public void insert(TreeNode node, long val){
while(node!=null){
if(node.val==val){
node.count++;
return;
}
else if(node.val<val){
if(node.right!=null) node = node.right;
else {node.right = new TreeNode(val); return;}
}
else if(node.val>val){
node.leftCount++;
if(node.left!=null) node = node.left;
else {node.left = new TreeNode(val);return;}
}
}
}
public int getBound(TreeNode node, long val, boolean include){
int count = 0;
while(node!=null){
if(node.val==val){
count+=node.leftCount;
count+=include?node.count:0;
return count;
}
else if(val>node.val){
count+=node.count+node.leftCount;
node = node.right;
}
else{
node = node.left;
}
}
return count;
}
public int countRangeSum(int[] nums, int lower, int upper) {
TreeNode root = new TreeNode(0); long sum = 0; int count=0;
for(int i=0;i<nums.length;i++){
sum+=nums[i];
count += getBound(root,sum-lower, true)-getBound(root, sum-upper, false);
// System.out.println("i" + i + "count: "+ count);
insert(root, sum);
//System.out.println("root");
}
return count;
}
}