3 @brief Data Structure: KD-Tree
4 @details Implementation based on paper @cite book::the_design_and_analysis.
6 @authors Andrei Novikov (pyclustering@yandex.ru)
8 @copyright BSD-3-Clause
14 import matplotlib.pyplot
as plt
21 @brief KD-tree visualizer that provides service to display graphical representation of the tree using
23 @details The visualizer is able to visualize 2D KD-trees only.
25 There is an example how to visualize balanced KD-tree for `TwoDiamonds` sample using `kdtree_visualizer`:
27 from pyclustering.container.kdtree import kdtree_balanced, kdtree_visualizer
28 from pyclustering.utils import read_sample
29 from pyclustering.samples.definitions import FCPS_SAMPLES
31 sample = read_sample(FCPS_SAMPLES.SAMPLE_TWO_DIAMONDS)
32 tree_instance = kdtree_balanced(sample)
34 kdtree_visualizer(tree_instance).visualize()
37 Output result of the example above (balanced tree) - figure 1:
38 @image html kd_tree_unbalanced_two_diamonds.png "Fig. 1. Balanced KD-tree for sample 'TwoDiamonds'."
40 There is one more example to demonstrate unbalanced KD-tree. `kdtree` class is child class of `kdtree_balanced`
41 that allows to add points step by step and thus an unbalanced KD-tree can be built.
43 from pyclustering.container.kdtree import kdtree, kdtree_visualizer
44 from pyclustering.utils import read_sample
45 from pyclustering.samples.definitions import FCPS_SAMPLES
47 sample = read_sample(FCPS_SAMPLES.SAMPLE_TWO_DIAMONDS)
48 tree_instance = kdtree() # Do not use sample in constructor to avoid building of balanced tree.
50 # Fill KD-tree step by step to obtain unbalanced tree.
52 tree_instance.insert(point)
54 kdtree_visualizer(tree_instance).visualize()
57 Output result of the example above (unbalanced tree) - figure 2:
58 @image html kd_tree_unbalanced_two_diamonds.png "Fig. 2. Unbalanced KD-tree for sample 'TwoDiamonds'."
64 @brief Initialize KD-tree visualizer.
66 @param[in] kdtree_instance (kdtree): Instance of a KD-tree that should be visualized.
69 self.
__tree = kdtree_instance
70 self.
__colors = [
'blue',
'red',
'green']
75 @brief Visualize KD-tree using plot in 2-dimensional data-space.
78 node = self.
__tree.get_root()
81 figure = plt.figure(111)
82 ax = figure.add_subplot(111)
86 ax.set_xlim(min[0], max[0])
87 ax.set_ylim(min[1], max[1])
90 def __draw_node(self, ax, node, min, max):
93 if node.left
is not None:
95 rborder[node.disc] = node.data[node.disc]
98 if node.right
is not None:
100 lborder[node.disc] = node.data[node.disc]
103 def __draw_split_line(self, ax, node, min, max):
108 for d
in range(dimension):
110 max_coord[d] = node.data[d]
111 min_coord[d] = node.data[d]
114 ax.plot(node.data[0], node.data[1], color=
'black', marker=
'.', markersize=6)
115 ax.plot([min_coord[0], max_coord[0]], [min_coord[1], max_coord[1]], color=self.
__colors[node.disc],
116 linestyle=
'-', linewidth=1)
118 def __get_limits(self):
119 dimension = len(self.
__tree.get_root().data)
122 max, min = [float(
'-inf')] * dimension, [float(
'+inf')] * dimension
125 for d
in range(dimension):
126 if max[d] < node.data[d]:
127 max[d] = node.data[d]
129 if min[d] > node.data[d]:
130 min[d] = node.data[d]
134 def __get_all_nodes(self):
137 next_level = [self.
__tree.get_root()]
139 while len(next_level) != 0:
140 cur_level = next_level
144 for cur_node
in cur_level:
145 children = cur_node.get_children()
146 if children
is not None:
147 next_level += children
152 root = self.
__tree.get_root()
154 raise ValueError(
"KD-Tree is empty - nothing to visualize.")
156 dimension = len(root.data)
158 raise NotImplementedError(
"KD-Tree data has '%d' dimension - only KD-tree with 2D data can be visualized."
165 @brief Represents a node in a KD-Tree.
166 @details The KD-Tree node contains point's coordinates, discriminator, payload and pointers to parent and children.
173 def __init__(self, data=None, payload=None, left=None, right=None, disc=None, parent=None):
175 @brief Creates KD-tree node.
177 @param[in] data (list): Data point that is presented as list of coordinates.
178 @param[in] payload (any): Payload of node (pointer to essence that is attached to this node).
179 @param[in] left (node): Node of KD-Tree that represents left successor.
180 @param[in] right (node): Node of KD-Tree that represents right successor.
181 @param[in] disc (uint): Index of dimension of that node.
182 @param[in] parent (node): Node of KD-Tree that represents parent.
206 @return (string) Default representation of the node.
212 if self.
left is not None:
213 left = self.
left.data
215 if self.
right is not None:
216 right = self.
right.data
218 return "(%s: [L:'%s', R:'%s'])" % (self.
data, left, right)
222 @return (string) String representation of the node.
229 @brief Returns list of not `None` children of the node.
231 @return (list) list of not `None` children of the node; if the node does not have children
232 then `None` is returned.
236 if self.
left is not None:
238 if self.
right is not None:
245 @brief Represents balanced static KD-tree that does not provide services to add and remove nodes after
247 @details In the term KD tree, k denotes the dimensionality of the space being represented. Each data point is
248 represented as a node in the k-d tree in the form of a record of type node.
250 There is an example how to create KD-tree:
252 from pyclustering.container.kdtree import kdtree_balanced, kdtree_visualizer
253 from pyclustering.utils import read_sample
254 from pyclustering.samples.definitions import FCPS_SAMPLES
256 sample = read_sample(FCPS_SAMPLES.SAMPLE_LSUN)
257 tree_instance = kdtree_balanced(sample)
259 kdtree_visualizer(tree_instance).visualize()
262 Output result of the example above - figure 1.
263 @image html kd_tree_balanced_lsun.png "Fig. 1. Balanced KD-tree for sample 'Lsun'."
271 @brief Initializes balanced static KD-tree.
273 @param[in] points (array_like): Points that should be used to build KD-tree.
274 @param[in] payloads (array_like): Payload of each point in `points`.
290 for i
in range(len(points)):
292 if payloads
is not None:
293 payload = payloads[i]
295 nodes.append(
node(points[i], payload,
None,
None, -1,
None))
301 @brief Returns amount of nodes in the KD-tree.
303 @return (uint) Amount of nodes in the KD-tree.
310 @brief Returns root of the tree.
312 @return (node) The root of the tree.
317 def __create_tree(self, nodes, parent, depth):
319 @brief Creates balanced sub-tree using elements from list `nodes`.
321 @param[in] nodes (list): List of KD-tree nodes.
322 @param[in] parent (node): Parent node that is used as a root to build the sub-tree.
323 @param[in] depth (uint): Depth of the tree that where children of the `parent` should be placed.
325 @return (node) Returns a node that is a root of the built sub-tree.
333 nodes.sort(key=
lambda n: n.data[discriminator])
334 median = len(nodes) // 2
339 median = find_left_element(nodes, median,
lambda n1, n2: n1.data[discriminator] < n2.data[discriminator])
344 new_node = nodes[median]
345 new_node.disc = discriminator
346 new_node.parent = parent
347 new_node.left = self.
__create_tree(nodes[:median], new_node, depth + 1)
348 new_node.right = self.
__create_tree(nodes[median + 1:], new_node, depth + 1)
353 def _create_point_comparator(self, type_point):
355 @brief Create point comparator.
356 @details In case of numpy.array specific comparator is required.
358 @param[in] type_point (data_type): Type of point that is stored in KD-node.
360 @return (callable) Callable point comparator to compare to points.
363 if type_point == numpy.ndarray:
364 return lambda obj1, obj2: numpy.array_equal(obj1, obj2)
366 return lambda obj1, obj2: obj1 == obj2
368 def _find_node_by_rule(self, point, search_rule, cur_node):
370 @brief Search node that satisfy to parameters in search rule.
371 @details If node with specified parameters does not exist then None will be returned,
372 otherwise required node will be returned.
374 @param[in] point (list): Coordinates of the point whose node should be found.
375 @param[in] search_rule (lambda): Rule that is called to check whether node satisfies to search parameter.
376 @param[in] cur_node (node): Node from which search should be started.
378 @return (node) Node if it satisfies to input parameters, otherwise it return None.
383 cur_node = self.
_root
386 if cur_node.data[cur_node.disc] <= point[cur_node.disc]:
388 if search_rule(cur_node):
391 cur_node = cur_node.right
393 cur_node = cur_node.left
399 @brief Find node with specified coordinates and payload.
400 @details If node with specified parameters does not exist then None will be returned,
401 otherwise required node will be returned.
403 @param[in] point (list): Coordinates of the point whose node should be found.
404 @param[in] point_payload (any): Payload of the node that is searched in the tree.
405 @param[in] cur_node (node): Node from which search should be started.
407 @return (node) Node if it satisfies to input parameters, otherwise it return None.
411 rule_search =
lambda node, point=point, payload=point_payload: self.
_point_comparator(node.data, point)
and \
412 node.payload == payload
417 @brief Find node with coordinates that are defined by specified point.
418 @details If node with specified parameters does not exist then None will be returned,
419 otherwise required node will be returned.
421 @param[in] point (list): Coordinates of the point whose node should be found.
422 @param[in] cur_node (node): Node from which search should be started.
424 @return (node) Node if it satisfies to input parameters, otherwise it return None.
433 @brief Find nearest neighbor in area with radius = distance.
435 @param[in] point (list): Maximum distance where neighbors are searched.
436 @param[in] distance (double): Maximum distance where neighbors are searched.
437 @param[in] retdistance (bool): If True - returns neighbors with distances to them, otherwise only neighbors
440 @return (node|list) Nearest neighbor if 'retdistance' is False and list with two elements [node, distance]
441 if 'retdistance' is True, where the first element is pointer to node and the second element is
448 if len(best_nodes) == 0:
451 nearest = min(best_nodes, key=
lambda item: item[0])
453 if retdistance
is True:
460 @brief Find neighbors that are located in area that is covered by specified distance.
462 @param[in] point (list): Coordinates that is considered as centroid for searching.
463 @param[in] distance (double): Distance from the center where searching is performed.
465 @return (list) Neighbors in area that is specified by point (center) and distance (radius).
470 if self.
_root is not None:
475 def __recursive_nearest_nodes(self, point, distance, sqrt_distance, node_head, best_nodes):
477 @brief Returns list of neighbors such as tuple (distance, node) that is located in area that is covered by distance.
479 @param[in] point (list): Coordinates that is considered as centroid for searching
480 @param[in] distance (double): Distance from the center where searching is performed.
481 @param[in] sqrt_distance (double): Square distance from the center where searching is performed.
482 @param[in] node_head (node): Node from that searching is performed.
483 @param[in|out] best_nodes (list): List of founded nodes.
487 if node_head.right
is not None:
488 minimum = node_head.data[node_head.disc] - distance
489 if point[node_head.disc] >= minimum:
492 if node_head.left
is not None:
493 maximum = node_head.data[node_head.disc] + distance
494 if point[node_head.disc] < maximum:
497 candidate_distance = euclidean_distance_square(point, node_head.data)
498 if candidate_distance <= sqrt_distance:
499 best_nodes.append((candidate_distance, node_head))
505 @brief Represents KD Tree that is a space-partitioning data structure for organizing points
506 in a k-dimensional space.
507 @details In the term k-d tree, k denotes the dimensionality of the space being represented. Each data point is
508 represented as a node in the k-d tree in the form of a record of type node. The tree supports
509 dynamic construction when nodes can be dynamically added and removed. As a result KD-tree might not be
510 balanced if methods `insert` and `remove` are used to built the tree. If the tree is built using
511 constructor where all points are passed to the tree then balanced tree is built. Single point search and
512 range-search procedures have complexity is `O(n)` in worse case in case of unbalanced tree.
513 If there is no need to build dynamic KD-tree, then it is much better to use static KD-tree
516 There is an example how to use KD-tree to search nodes (points from input data) that are nearest to some point:
518 # Import required modules
519 from pyclustering.samples.definitions import SIMPLE_SAMPLES;
520 from pyclustering.container.kdtree import kdtree;
521 from pyclustering.utils import read_sample;
523 # Read data from text file
524 sample = read_sample(SIMPLE_SAMPLES.SAMPLE_SIMPLE3);
526 # Create instance of KD-tree and initialize (fill) it by read data.
527 # In this case the tree is balanced.
528 tree_instance = kdtree(sample);
530 # Search for nearest point
531 search_distance = 0.3;
532 nearest_node = tree_instance.find_nearest_dist_node([1.12, 4.31], search_distance);
534 # Search for nearest point in radius 0.3
535 nearest_nodes = tree_instance.find_nearest_dist_nodes([1.12, 4.31], search_distance);
536 print("Nearest nodes:", nearest_nodes);
539 In case of building KD-tree using `insert` and `remove` method, the output KD-tree might be unbalanced - here
540 is an example that demonstrates this:
542 from pyclustering.container.kdtree import kdtree, kdtree_visualizer
543 from pyclustering.utils import read_sample
544 from pyclustering.samples.definitions import FCPS_SAMPLES
546 sample = read_sample(FCPS_SAMPLES.SAMPLE_TWO_DIAMONDS)
548 # Build tree using constructor - balanced will be built because tree will know about all points.
549 tree_instance = kdtree(sample)
550 kdtree_visualizer(tree_instance).visualize()
552 # Build tree using `insert` only - unbalanced tree will be built.
553 tree_instance = kdtree()
555 tree_instance.insert(point)
557 kdtree_visualizer(tree_instance).visualize()
560 There are two figures where difference between balanced and unbalanced KD-trees is demonstrated.
562 @image html kd_tree_unbalanced_two_diamonds.png "Fig. 1. Balanced KD-tree for sample 'TwoDiamonds'."
563 @image html kd_tree_unbalanced_two_diamonds.png "Fig. 2. Unbalanced KD-tree for sample 'TwoDiamonds'."
569 def __init__(self, data_list=None, payload_list=None):
571 @brief Create kd-tree from list of points and from according list of payloads.
572 @details If lists were not specified then empty kd-tree will be created.
574 @param[in] data_list (list): Insert points from the list to created KD tree.
575 @param[in] payload_list (list): Insert payload from the list to created KD tree, length should be equal to
576 length of data_list if it is specified.
580 super().
__init__(data_list, payload_list)
584 @brief Insert new point with payload to kd-tree.
586 @param[in] point (list): Coordinates of the point of inserted node.
587 @param[in] payload (any-type): Payload of inserted node. It can be ID of the node or
588 some useful payload that belongs to the point.
590 @return (node) Inserted node to the kd-tree.
594 if self.
_root is None:
596 self.
_root =
node(point, payload,
None,
None, 0)
602 cur_node = self.
_root
605 discriminator = (cur_node.disc + 1) % self.
_dimension
607 if cur_node.data[cur_node.disc] <= point[cur_node.disc]:
608 if cur_node.right
is None:
609 cur_node.right =
node(point, payload,
None,
None, discriminator, cur_node)
612 return cur_node.right
614 cur_node = cur_node.right
617 if cur_node.left
is None:
618 cur_node.left =
node(point, payload,
None,
None, discriminator, cur_node)
623 cur_node = cur_node.left
627 @brief Remove specified point from kd-tree.
628 @details It removes the first found node that satisfy to the input parameters. Make sure that
629 pair (point, payload) is unique for each node, otherwise the first found is removed.
631 @param[in] point (list): Coordinates of the point of removed node.
632 @param[in] **kwargs: Arbitrary keyword arguments (available arguments: 'payload').
634 <b>Keyword Args:</b><br>
635 - payload (any): Payload of the node that should be removed.
637 @return (node) Root if node has been successfully removed, otherwise None.
641 if 'payload' in kwargs:
644 node_for_remove = self.
find_node(point,
None)
646 if node_for_remove
is None:
651 parent = node_for_remove.parent
654 self.
_root = minimal_node
657 if minimal_node
is not None:
658 minimal_node.parent =
None
660 if parent.left
is node_for_remove:
661 parent.left = minimal_node
662 elif parent.right
is node_for_remove:
663 parent.right = minimal_node
667 def __recursive_remove(self, node_removed):
669 @brief Delete node and return root of subtree.
671 @param[in] node_removed (node): Node that should be removed.
673 @return (node) Minimal node in line with coordinate that is defined by discriminator.
678 if (node_removed.right
is None)
and (node_removed.left
is None):
681 discriminator = node_removed.disc
684 if node_removed.right
is None:
685 node_removed.right = node_removed.left
686 node_removed.left =
None
690 parent = minimal_node.parent
692 if parent.left
is minimal_node:
694 elif parent.right
is minimal_node:
697 minimal_node.parent = node_removed.parent
698 minimal_node.disc = node_removed.disc
699 minimal_node.right = node_removed.right
700 minimal_node.left = node_removed.left
703 if minimal_node.right
is not None:
704 minimal_node.right.parent = minimal_node
706 if minimal_node.left
is not None:
707 minimal_node.left.parent = minimal_node
711 def __find_minimal_node(self, node_head, discriminator):
713 @brief Find minimal node in line with coordinate that is defined by discriminator.
715 @param[in] node_head (node): Node of KD tree from that search should be started.
716 @param[in] discriminator (uint): Coordinate number that is used for comparison.
718 @return (node) Minimal node in line with discriminator from the specified node.
722 min_key =
lambda cur_node: cur_node.data[discriminator]
724 stack, candidates = [], []
727 while is_finished
is False:
728 if node_head
is not None:
729 stack.append(node_head)
730 node_head = node_head.left
733 node_head = stack.pop()
734 candidates.append(node_head)
735 node_head = node_head.right
739 return min(candidates, key=min_key)