Fork me on GitHub

数据结构-线段树

https://blog.csdn.net/zearot/article/details/52280189

https://github.com/raywenderlich/swift-algorithm-club/tree/master/Segment%20Tree

https://blog.csdn.net/johnny901114/article/details/80643017

实现一个线段树

下面实现的线段树,有三个功能:

  1. 把数组构建成一颗线段树
  2. 线段树的修改
  3. 线段树的查询
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    public class SegmentTree<T> {

    private T tree[];
    private T data[];

    private Merger<T> merger;

    public interface Merger<T> {
    T merge(T a, T b);
    }

    public SegmentTree(T[] arr, Merger<T> merger) {
    this.merger = merger;
    data = (T[]) new Object[arr.length];
    for (int i = 0; i < data.length; i++) {
    data[i] = arr[i];
    }

    this.tree = (T[]) new Object[data.length * 4];
    buildSegmentTree(0, 0, data.length - 1);

    }


    /**
    * 构建线段树
    *
    * @param treeIndex 当前需要添加节点的索引
    * @param treeLeft treeIndex左边界
    * @param treeRight treeIndex右边界
    */
    private void buildSegmentTree(int treeIndex, int treeLeft, int treeRight) {
    if (treeLeft == treeRight) {
    tree[treeIndex] = data[treeLeft];
    return;
    }
    //当前节点左子树索引
    int leftTreeIndex = getLeft(treeIndex);
    //当前节点右子树索引
    int rightTreeIndex = getRight(treeIndex);
    //int mid = (left+right)/2; 如果left和right很大,可能会导致整型溢出
    int mid = treeLeft + (treeRight - treeLeft) / 2;
    //构建左子树
    buildSegmentTree(leftTreeIndex, treeLeft, mid);
    //构建右子树
    buildSegmentTree(rightTreeIndex, mid + 1, treeRight);
    //当前节点存放的值
    tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);

    }

    public T query(int start, int end) {
    return query(0, 0, data.length - 1, start, end);
    }

    /**
    * @param treeIndex 当前查找的节点
    * @param treeLeft treeIndex的左边界
    * @param treeRight treeIndex的右边界
    * @param queryL 用户需要查找的左边界
    * @param queryR 用户需要查找的右边界
    * @return
    */
    private T query(int treeIndex, int treeLeft, int treeRight, int queryL, int queryR) {

    //1, 需要查找的范围完刚好在这个treeIndex节点的区间
    if (treeLeft == queryL && treeRight == queryR) {
    return tree[treeIndex];
    }

    //当前节点的区间的中间点
    int mid = treeLeft + (treeRight - treeLeft) / 2;
    //左子树索引
    int leftTreeIndex = getLeft(treeIndex);
    //右子树索引
    int rightTreeIndex = getRight(treeIndex);


    //2, 需要查找的范围完全在左子树的区间里
    if (queryR <= mid) {
    return query(leftTreeIndex, treeLeft, mid, queryL, queryR);
    }
    //3, 需要查找的范围完全在右子树区间里
    if (queryL >= mid + 1) {
    return query(rightTreeIndex, mid + 1, treeRight, queryL, queryR);
    }

    //需要查找的范围一部分在左子树里,一部分在右子树中
    T left = query(leftTreeIndex, treeLeft, mid, queryL, mid);
    T right = query(rightTreeIndex, mid + 1, treeRight, mid + 1, queryR);
    return merger.merge(left, right);
    }


    public void update(int index, T e) {
    data[index] = e;
    update(0, 0, data.length - 1, index, e);
    }


    private void update(int treeIndex, int treeLeft, int treeRight, int index, T e) {
    if (treeLeft == treeRight) {
    tree[treeIndex] = e;
    return;
    }

    int mid = treeLeft + (treeRight - treeLeft) / 2;
    int leftChildIndex = getLeft(treeIndex);
    int rightChildIndex = getRight(treeIndex);

    if (index <= mid) {
    update(leftChildIndex, treeLeft, mid, index, e);
    } else if (index >= mid + 1) {
    update(rightChildIndex, mid + 1, treeRight, index, e);
    }

    //更改完叶子节点后,还需要对他的所有祖辈节点更新
    tree[treeIndex] = merger.merge(tree[leftChildIndex], tree[rightChildIndex]);
    }

    public T get(int index) {
    return data[0];
    }

    public int size() {
    return data.length;
    }

    public int getLeft(int index) {
    return index * 2 + 1;
    }

    public int getRight(int index) {
    return index * 2 + 2;
    }

    @Override
    public String toString() {
    StringBuilder builder = new StringBuilder();
    builder.append("[");
    for (int i = 0; i < tree.length; i++) {
    if (tree[i] == null) {
    continue;
    }
    builder.append(tree[i]).append(',');
    }
    builder.deleteCharAt(builder.length() - 1);
    builder.append(']');
    return builder.toString();
    }
    }

303号问题

给定数组, 求区间的和, 数组不可变

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class NumArray {

private interface Merger<E> {
E merge(E a, E b);
}

private class SegmentTree<E> {

private E[] tree;
private E[] data;
private Merger<E> merger;

public SegmentTree(E[] arr, Merger<E> merger){

this.merger = merger;

data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];

tree = (E[])new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}

// 在treeIndex的位置创建表示区间[l...r]的线段树
private void buildSegmentTree(int treeIndex, int l, int r){

if(l == r){
tree[treeIndex] = data[l];
return;
}

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);

// int mid = (l + r) / 2;
int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);

tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

public int getSize(){
return data.length;
}

public E get(int index){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal.");
return data[index];
}

// 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
private int leftChild(int index){
return 2*index + 1;
}

// 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
private int rightChild(int index){
return 2*index + 2;
}

// 返回区间[queryL, queryR]的值
public E query(int queryL, int queryR){

if(queryL < 0 || queryL >= data.length ||
queryR < 0 || queryR >= data.length || queryL > queryR)
throw new IllegalArgumentException("Index is illegal.");

return query(0, 0, data.length - 1, queryL, queryR);
}

// 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值
private E query(int treeIndex, int l, int r, int queryL, int queryR){

if(l == queryL && r == queryR)
return tree[treeIndex];

int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(queryL >= mid + 1)
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
else if(queryR <= mid)
return query(leftTreeIndex, l, mid, queryL, queryR);

E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return merger.merge(leftResult, rightResult);
}

@Override
public String toString(){
StringBuilder res = new StringBuilder();
res.append('[');
for(int i = 0 ; i < tree.length ; i ++){
if(tree[i] != null)
res.append(tree[i]);
else
res.append("null");

if(i != tree.length - 1)
res.append(", ");
}
res.append(']');
return res.toString();
}
}

private SegmentTree<Integer> segmentTree;

public NumArray(int[] nums) {

if(nums.length > 0){
Integer[] data = new Integer[nums.length];
for (int i = 0; i < nums.length; i++)
data[i] = nums[i];
segmentTree = new SegmentTree<>(data, (a, b) -> a + b);
}

}

public int sumRange(int i, int j) {

if(segmentTree == null)
throw new IllegalArgumentException("Segment Tree is null");

return segmentTree.query(i, j);
}
}

307号问题

给定数组, 求区间的和, 数组可变
需要给线段树添加一个更新方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
class NumArray {

private interface Merger<E> {
E merge(E a, E b);
}

private class SegmentTree<E> {

private E[] tree;
private E[] data;
private Merger<E> merger;

public SegmentTree(E[] arr, Merger<E> merger){

this.merger = merger;

data = (E[])new Object[arr.length];
for(int i = 0 ; i < arr.length ; i ++)
data[i] = arr[i];

tree = (E[])new Object[4 * arr.length];
buildSegmentTree(0, 0, arr.length - 1);
}

// 在treeIndex的位置创建表示区间[l...r]的线段树
private void buildSegmentTree(int treeIndex, int l, int r){

if(l == r){
tree[treeIndex] = data[l];
return;
}

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);

// int mid = (l + r) / 2;
int mid = l + (r - l) / 2;
buildSegmentTree(leftTreeIndex, l, mid);
buildSegmentTree(rightTreeIndex, mid + 1, r);

tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

public int getSize(){
return data.length;
}

public E get(int index){
if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal.");
return data[index];
}

// 返回完全二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
private int leftChild(int index){
return 2*index + 1;
}

// 返回完全二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
private int rightChild(int index){
return 2*index + 2;
}

// 返回区间[queryL, queryR]的值
public E query(int queryL, int queryR){

if(queryL < 0 || queryL >= data.length ||
queryR < 0 || queryR >= data.length || queryL > queryR)
throw new IllegalArgumentException("Index is illegal.");

return query(0, 0, data.length - 1, queryL, queryR);
}

// 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值
private E query(int treeIndex, int l, int r, int queryL, int queryR){

if(l == queryL && r == queryR)
return tree[treeIndex];

int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(queryL >= mid + 1)
return query(rightTreeIndex, mid + 1, r, queryL, queryR);
else if(queryR <= mid)
return query(leftTreeIndex, l, mid, queryL, queryR);

E leftResult = query(leftTreeIndex, l, mid, queryL, mid);
E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR);
return merger.merge(leftResult, rightResult);
}

// 将index位置的值,更新为e
public void set(int index, E e){

if(index < 0 || index >= data.length)
throw new IllegalArgumentException("Index is illegal");

data[index] = e;
set(0, 0, data.length - 1, index, e);
}

// 在以treeIndex为根的线段树中更新index的值为e
private void set(int treeIndex, int l, int r, int index, E e){

if(l == r){
tree[treeIndex] = e;
return;
}

int mid = l + (r - l) / 2;
// treeIndex的节点分为[l...mid]和[mid+1...r]两部分

int leftTreeIndex = leftChild(treeIndex);
int rightTreeIndex = rightChild(treeIndex);
if(index >= mid + 1)
set(rightTreeIndex, mid + 1, r, index, e);
else // index <= mid
set(leftTreeIndex, l, mid, index, e);

tree[treeIndex] = merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]);
}

@Override
public String toString(){
StringBuilder res = new StringBuilder();
res.append('[');
for(int i = 0 ; i < tree.length ; i ++){
if(tree[i] != null)
res.append(tree[i]);
else
res.append("null");

if(i != tree.length - 1)
res.append(", ");
}
res.append(']');
return res.toString();
}
}

private SegmentTree<Integer> segTree;

public NumArray(int[] nums) {

if(nums.length != 0){
Integer[] data = new Integer[nums.length];
for(int i = 0 ; i < nums.length ; i ++)
data[i] = nums[i];
segTree = new SegmentTree<>(data, (a, b) -> a + b);
}
}

public void update(int i, int val) {
if(segTree == null)
throw new IllegalArgumentException("Error");
segTree.set(i, val);
}

public int sumRange(int i, int j) {
if(segTree == null)
throw new IllegalArgumentException("Error");
return segTree.query(i, j);
}
}

完整代码

-------------本文结束感谢您的阅读-------------
坚持原创技术分享,您的支持将鼓励我继续创作!
0%