二叉查找树

二叉树的定义

我们可以将二叉树定义为一个空链接,或者是一个有左右两个链接的节点,每个链接都指向一棵(独立的)子二叉树。

定义:
一棵二叉查找树(BST)是一棵二叉树,其中的每一个节点都含有一个 Comparable 的键(以及相关联的值)且每个节点的键都大于其左子树中的任意节点的键而小于右子树的任意节点的键。

基本实现

树节点的实现

树由 Node 对象组成,每个对象都含有:

  • 一对键值
  • 两条链接
  • 一个节点计数器N
1
2
3
4
5
6
7
8
9
10
11
private class Node
{
private Key key;
private Value val;
private Node left, right;
public Node(Key key, Value val)
{
this.key = key;
this.val = val;
}
}

BST implementation (skeleton)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
public class BST<Key extends Comparable<Key>, Value>
{
private Node root; // root of BST

private class Node
{ /* 树节点 */ }

public void put(Key key, Value val)
{/* 添加操作 */}

public Value get(Key key)
{/* 查询操作 */}

public void delete(Key key)
{/* 删除操作 */}

public Iterable<Key> iterator()
{/* 迭代器操作 */}

}

查找和排序方法的实现

有序性相关的方法与删除操作

最大键和最小键

如果根结点的左链接为空,那么一棵二叉查找树中最小的键就是根节点;如果左链接为非空,那么树中的最小键就是左子树中的最小键。

找出最大键的方法也是类似的,只是变为查找右子树而已。

向上取整和向下取整

如果给定的键 key 小于二叉查找树的根结点的键,那么小于等于 key 的最大键 floor(key) 一定在根结点的左子树中;如果给定的键 key 大于二叉查找树的根结点,那么只有当根结点右子树中存在小于等于 key 的节点时,小于等于 key 的最大键才会出现在右子树中,否则根结点就是小于等于 key 的最大键。

选择操作

排名

删除最大键和删除最小键

删除操作

范围查找

完整实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.StdOut;

public class BST<Key extends Comparable<Key>, Value> {
private Node root; // 二叉查找树的根结点

private class Node {
private Key key; // 键
private Value val; // 值
private Node left, right; // 指向子树的链接
private int N; // 以该节点为根的子树中的节点总数

public Node(Key key, Value val, int N) {
this.key = key;
this.val = val;
this.N = N;
}
}

public int size() {
return size(root);
}

private int size(Node x) {
if (x == null) return 0;
else return x.N;
}

public Value get(Key key) {
return get(root, key);
}

private Value get(Node x, Key key) {
// 以 x 为根结点的子树中查找并返回 key 所对应的值
// 如果找不到则返回 null
if (x == null) return null;
int cmp = key.compareTo(x.key);
if (cmp < 0) return get(x.left, key);
else if (cmp > 0) return get(x.right, key);
else return x.val;
}

public void put(Key key, Value val) {
// 查找 key,找到则更新它的值,否则为它创建一个新的节点
root = put(root, key, val);
}

private Node put(Node x, Key key, Value val) {
// 如果 key 存在于以 x 为根结点的子树中则更新它的值;
// 否则将以 key 和 val 为键值对的新节点插入到该子树中;
if (x == null) return new Node(key, val, 1);
int cmp = key.compareTo(x.key);
if (cmp < 0) x.left = put(x.left, key, val);
else if (cmp > 0) x.right = put(x.right, key, val);
else x.val = val;
x.N = size(x.left) + size(x.right) + 1;
return x;
}

public Key min() {
return min(root).key;
}

private Node min(Node x) {
if (x.left == null) return x;
return min(x.left);
}

public Key max() {
return max(root).key;
}

private Node max(Node x) {
if (x.right== null) return x;
return max(x.right);
}

public Key floor(Key key) {
Node x = floor(root, key);
if (x == null) return null;
return x.key;
}

private Node floor(Node x, Key key) {
if (x == null) return null;
int cmp = key.compareTo(x.key);
if (cmp == 0) return x;
if (cmp < 0) return floor(x.left, key);
Node t = floor(x.right, key);
if (t != null) return t;
else return x;
}

public Key select(int k) {
return select(root, k).key;
}
private Node select(Node x, int k)
{
// 返回排名为 k 的节点
if (x == null) return null;
int t = size(x.left);
if (t > k) return select(x.left, k);
else if (t < k) return select(x.right, k - t - 1);
else return x;
}

public int rank(Key key)
{
return rank(key, root);
}
private int rank(Key key, Node x)
{
// 返回以 x 为根节点的子树中小于 x.key 的键的数量
if (x == null) return 0;
int cmp = key.compareTo(x.key);
if (cmp < 0) return rank(key, x.left);
else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right);
else return size(x.left);
}

public void deleteMin()
{
root = deleteMin(root);
}

private Node deleteMin(Node x)
{
if (x.left == null) return x.right;
x.left = deleteMin(x.left);
x.N = size(x.left) + size(x.right) + 1;
return x;
}

public void delete(Key key)
{
root = delete(root, key);
}

private Node delete(Node x, Key key)
{
if (x==null) return null;
int cmp = key.compareTo(x.key);
if (cmp < 0) x.left = delete(x.left, key);
else if (cmp > 0) x.right = delete(x.right, key);
else {
if (x.right == null) return x.left;
if (x.left == null) return x.right;
Node t = x;
x = min(t.right);
x.right = deleteMin(t.right);
x.left = t.left;
}
x.N = size(x.left) + size(x.right) + 1;
return x;
}

private void print(Node x)
{
if (x == null) return;
print(x.left);
StdOut.println(x.key);
print(x.right);
}

public Iterable<Key> keys()
{
return keys(min(), max());
}
public Iterable<Key> keys(Key lo, Key hi)
{
Queue<Key> queue = new Queue<Key>();
keys(root, queue, lo, hi);
return queue;
}
private void keys(Node x, Queue<Key> queue, Key lo, Key hi)
{
if (x == null) return;
int cmplo = lo.compareTo(x.key);
int cmphi = hi.compareTo(x.key);
if (cmplo < 0) keys(x.left, queue, lo, hi);
if (cmplo <= 0 && cmphi >= 0) queue.enqueue(x.key);
if (cmphi > 0) keys(x.right, queue, lo, hi);
}
}