link: https://www.lintcode.com/problem/range-sum-query-2d-mutable/description
Description
Given a 2D matrix matrix, find the sum of the elements inside the rectangle defined by its upper left corner (row1, col1) and lower right corner (row2, col2).
1.The matrix is only modifiable by the update function.
2.You may assume the number of calls to update and sumRegion function is distributed evenly.
3.You may assume that row1 ≤ row2 and col1 ≤ col2.
Example
Given matrix = [
[3, 0, 1, 4, 2],
[5, 6, 3, 2, 1],
[1, 2, 0, 1, 5],
[4, 1, 0, 1, 7],
[1, 0, 3, 0, 5]
]
sumRegion(2, 1, 4, 3) -> 8
update(3, 2, 2)
sumRegion(2, 1, 4, 3) -> 10
思路
为matrix 中的每一行建立segment tree. c++在 lint code 中跑不过最后一个 test case,会超时. 决定暂时放弃...
以下 segment tree 的 build
, update
, query
写法都比较简单了. query
写了两种.都 work.
作为模板吧.
class SegmentTreeNodeII {
public:
int start;
int end;
int sum;
SegmentTreeNodeII* left, *right;
SegmentTreeNodeII(int s, int e) {
this->start = s;
this->end = e;
this->sum = 0;
this->left = this->right = nullptr;
}
};
class NumMatrix {
private:
vector<SegmentTreeNodeII*> nodes; // each node represents head of a segment
// tree for *a row*
vector<vector<int>> m;
SegmentTreeNodeII* build_segment_tree(vector<int>& row, int start, int end) {
if (start > end) {
return nullptr;
}
SegmentTreeNodeII* node = new SegmentTreeNodeII(start, end);
if (start == end) {
node->sum = row[start];
return node;
}
int mid = start + (end - start)/2;
node->left = build_segment_tree(row, start, mid);
node->right = build_segment_tree(row, mid+1, end);
if (node->left) {
node->sum += node->left->sum;
}
if (node->right) {
node->sum += node->right->sum;
}
return node;
}
/**
* here we assume idx is within range of node->left ~ node->right
*/
// void update_segment_tree(SegmentTreeNodeII* node, int idx, int diff) {
// if (diff == 0) {
// return;
// }
// if (!node) {
// return;
// }
// if (idx >= node->start && idx <= node->end) {
// node->sum += diff;
// update_segment_tree(node->left, idx, diff);
// update_segment_tree(node->right, idx, diff);
// }
// }
void update_segment_tree(SegmentTreeNodeII* node, int idx, int diff) {
if (!node) {
return;
}
if (node->start > idx || node->end < idx) {
return;
}
node->sum += diff;
int mid = node->start + (node->end - node->start)/2;
if (idx <= mid) {
update_segment_tree(node->left, idx, diff);
}
else {
update_segment_tree(node->right, idx, diff);
}
}
int query(SegmentTreeNodeII* node, int left, int right) {
if (!node || node->start > right || node->end < left) {
return 0;
}
if (node->start >= left && node->end <= right) {
return node->sum;
}
return query(node->left, left, right) + query(node->right, left, right);
}
public:
NumMatrix(vector<vector<int>> matrix) {
m = matrix; // make a local copy of the 2d matrix.
nodes.assign(m.size(), nullptr);
for (int i = 0; i < m.size(); i++) {
nodes[i] = build_segment_tree(m[i], 0, m[i].size()-1);
}
}
void update(int row, int col, int val) {
int diff = val - m[row][col];
if (diff == 0) {
return;
}
m[row][col] += diff;
update_segment_tree(nodes[row], col, diff);
}
int sumRegion(int row1, int col1, int row2, int col2) {
int res = 0;
for (int i = row1; i <= row2; i++) {
res += query(nodes[i], col1, col2);
}
return res;
}
};
/**
* Your NumMatrix object will be instantiated and called as such:
* NumMatrix obj = new NumMatrix(matrix);
* obj.update(row,col,val);
* int param_2 = obj.sumRegion(row1,col1,row2,col2);
*/
网友评论