Range Sum Problems

Introduction

There are four similar problems about range sum in leetcode: 303. Range Sum Query - Immutable Immutable version doesn’t have update operation. 307. Range Sum Query - Mutable Given an integer array nums, find the sum of the elements between indices i and j (i ≤ j), inclusive. The update(i, val) function modifies nums by updating the element at index i to val.

1
2
3
4
5
Given nums = [1, 3, 5]

sumRange(0, 2) -> 9
update(1, 2)
sumRange(0, 2) -> 8
304. Range Sum Query 2D - Immutable Immutable version doesn’t have update operation 308. Range Sum Query 2D - Mutable Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).

Reference source

The above rectangle (with the red border) is defined by (row1, col1) = (2, 1) and (row2, col2) = (4, 3), which contains sum = 8.

Immutable Range Sum

Immutable Range Sum (2D) problem can be solved by accumulating sum from index 0 to current index i in an additional array sum[], then it takes only O(1) time to calculate sum from i to j by simply doing sum[j] - sum[i-1]. The pre-computation of sum[] array in constructor takes O(N) time and never changes over time as there is no update operation.

For 2D Sum Range Problem, the basic idea is the same and we create an additional matrix to store sum in area (0, 0, i, j). Dynamic Programming can be applied here to avoid repeatedly calculate sum in previous areas.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/* It takes O(1) to sum but O(N) to update by pre-computing sum from
* start to index i and storing results in a new array.
*/
public class RangeSumQueryImmutable {
private int[] sums;
public RangeSumQueryImmutable(int[] nums) {
sums = new int[nums.length];
int sum = 0;
for(int i=0; i<nums.length; i++) {
sum += nums[i];
sums[i] = sum;
}
}

public int sumRange(int i, int j) {
if(i <= 0) return sums[j];
return sums[j] - sums[i-1];
}
}

public static void main(String[] args) {
int[] nums = new int[]{0,9,5,7,3};
RangeSumQueryImmutable obj = new RangeSumQueryImmutable(nums);
//obj.update(1, 2);
int sum = obj.sumRange(0, 3);
System.out.println(sum);
}
}

Mutable Range Sum

The naive method to solve Immutable Range Sum doesn’t apply here as update operation will take O(N) time and make overall time complexity O(N) because update and sum operations are equally distributed.

Some data structure like Segment Tree and Binary Indexed Tree can have performance gain and make both update and sum operations to O(log(N)) with a little memory overhead.

Segment Tree

Segment Tree is basically a binary tree, in which the left node is left half subarray, the right node is right half subarray and the value of each node is the sum of left and right subarray. The root node contains the whole input array and we can recursively parse an array into two subarrays until the leaf node only contains one element.

Figure source

Obviously it takes O(log(N)) time to update the value by adding the difference between old value and new value into corresponding segments in each level of the segment tree. If index i is about to be updated and i is between start and end of current node, then updateNode() method will add difference to the value of current node and recursively update its left and right node with same range check and update.

Sum operation is a little harder to understand than update operation. It searches from starting root node with O(log(N)) time and has three conditions below: * Query range is completely outside of node range, then return 0. * Node range is within query range, then return the value of this node. * Node range and query range is overlapped, then recursively search and sum on both left and right nodes with same range parameters i and j.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
/********************** Segment Tree *****************************/
class SegmentTreeNode {
public int start, end, sum;
public SegmentTreeNode lnode;
public SegmentTreeNode rnode;

public SegmentTreeNode(int start, int end, int sum) {
this.start = start;
this.end = end;
this.sum = sum;
this.lnode = null;
this.rnode = null;
}
}

public class SegmentTreeSumQuery {
private SegmentTreeNode root;
private int[] nums;

private SegmentTreeNode buildTree(int[] nums, int start, int end) {
if(start > end) {
return null;
}
else if(start == end) {
return new SegmentTreeNode(start, end, nums[start]);
}
else {
int mid = (start + end) / 2;
SegmentTreeNode left = buildTree(nums, start, mid);
SegmentTreeNode right = buildTree(nums, mid+1, end);
SegmentTreeNode root = new SegmentTreeNode(start, end, left.sum + right.sum);
root.lnode = left;
root.rnode = right;
return root;
}

}

private void updateNode(SegmentTreeNode node, int i, int diff) {
if(node == null) return;
// Update current node with diff and recursively update its children.
if(i>=node.start && i<=node.end) {
node.sum += diff;
updateNode(node.lnode, i, diff);
updateNode(node.rnode, i, diff);
}
}

private int sumRangeNode(SegmentTreeNode node, int i, int j) {
if(node == null) { return 0; }
// Node range is completely outside of query range.
// Either on left or on right without any overlap.
if(node.end < i || node.start > j) {
return 0;
}
// Node within query range.
else if(node.start >= i && node.end <= j) {
return node.sum;
}
else {
return sumRangeNode(node.lnode, i, j) + sumRangeNode(node.rnode, i, j);
}
}

// Segment Tree method Constructor
public SegmentTreeSumQuery(int[] nums) {
int start = 0, end = nums.length-1;
this.root = buildTree(nums, start, end);
this.nums = nums;
}

// Public function available to user.
public void update(int i, int val) {
int diff = val - nums[i];
nums[i] = val;
updateNode(this.root, i, diff);
}

// Public function available to user.
public int sumRange(int i, int j) {
return sumRangeNode(this.root, i, j);
}
}

Binary Indexed Tree

Binary Indexed Tree also known as Fenwick Tree is a data structure used to efficiently update elements and calculate prefix sums in number tables with O(log(N)) on both update and sum operations. It only uses an additional array with same length of input array to store its data structure and code of its implementation is deadly simple and short!

The basic idea is that any integer can be written in binary form as the sum of powers of 2. For example, 13 = 2^3 + 2^2 + 2^0. Let’s define input array as S[] and sum array as T[]. For each index idx in array S[], r is the position in binary form idx of the last digit 1 from left to right. In our example 13 = 1101, r = 3 and T[idx] is responsible for sums from 0 to idx - 2^r - 1 in S[] array, so S[0] + ... + S[1101] = T[1101] + T[1100] + T[1000]. The strategy here is to find the last digit 1, replace this digit with 0 and get the new index until the new index is all zeros. idx -= (idx & -idx) is used to isolate last digit 1 and get the next index in an amazingly simple way with less than 10 lines of code.

Article Reference from TopCoder

Reference source

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class BinaryIndexedTreeSumQuery {
private int[] tree;
private int[] array;
private int maxVal;

// BIT method Constructor
public BinaryIndexedTreeSumQuery(int[] nums) {
this.maxVal = nums.length;
// Ignore first index and set to 0.
this.tree = new int[nums.length+1];
this.array = new int[nums.length+1];
for(int i=0; i<nums.length; i++) {
update(i, nums[i]);
}
}

// Get sum from 0 to idx.
public int read(int idx) {
idx++;
int sum = 0;
while(idx > 0) {
sum += tree[idx];
idx -= (idx & -idx);
}
return sum;
}

// Public function available to user, update both original array and tree.
public void update(int idx, int val) {
idx++;
int i = idx;
int old = array[i];
// Replace all segments containing index i.
while(idx <= maxVal) {
tree[idx] += val - old;
idx += (idx & -idx);
}
array[i] = val;
}

// Public function available to user.
public int sumRange(int i, int j) {
return read(j) - read(i-1);
}
}

Move from 1D to 2D

This follow up question requires us to extend 1d Binary Indexed Tree to 2D version in the same manner. When we update position (i, j) with new value val. In inner while loop, consider each row as 1d Binary Indexed Tree and several positions at this row are updated with val starting from index j and backward to beginning. In outer while loop, index i decides which rows are updated as 1D Binary Indexed Tree starting from index i and backward to beginning as well. A read() function is needed to sum area from (0, 0) to (i, j) and this function is then used to to sum a random area by cropping out side areas from main area and adding overlapped area.

Reference source

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
public class RangeSumQuery2DMutable {
private int rowMax;
private int colMax;
private int[][] matrix;
private int[][] tree;

public RangeSumQuery2DMutable(int[][] m) {
if(m.length == 0 || m[0].length == 0) return;
// Ignore index 0 and set value to 0.
this.rowMax = m.length;
this.colMax = m[0].length;
this.matrix = new int[rowMax+1][colMax+1];
this.tree = new int[rowMax+1][colMax+1];
for(int i=0; i<rowMax; i++) {
for(int j=0; j<colMax; j++){
update(i, j, m[i][j]);
}
}
}

public void update(int row, int col, int val) {
row++;
col++;
int r = row, c = col;
int old = this.matrix[r][c];
while(row <= rowMax) {
int tcol = col; // A copy of col variable for inner loop.
while(tcol <= colMax) {
tree[row][tcol] += val - old;
tcol += (tcol & - tcol);
}
row += (row & -row);
}
this.matrix[r][c] = val;
}

public int read(int row, int col) {
row++;
col++;
int sum = 0;
while(row > 0) {
int tcol = col;
while(tcol > 0) {
sum += tree[row][tcol];
tcol -= (tcol & -tcol);
}
row -= (row & -row);
}
return sum;
}

public int sumRegion(int row1, int col1, int row2, int col2) {
return read(row2, col2) - read(row1-1, col2) - read(row2, col1-1) + read(row1-1, col1-1);
}
}