Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 36 additions & 41 deletions 1.cpp
Original file line number Diff line number Diff line change
@@ -1,60 +1,59 @@
#include <iostream>
#include <cassert>
#include <vector>
#include<memory>
using namespace std;

template <typename T>
class AVLTree {
class Node {
Node* left = nullptr;
Node* right = nullptr;
std::shared_ptr<Node> left = nullptr;
std::shared_ptr<Node> right = nullptr;
T _value;
int _height = 1;

public:
Node(T value): _value(value) {}

Node(const Node& other): _value(other._value), _height(other._height) {
if (other.left) left = new Node{*other.left};
if (other.right) right = new Node{*other.right};
if (other.left) left = make_shared<Node>(*other.left);
if (other.right) right = make_shared<Node>(*other.right);
}
friend AVLTree;

public:
int height() const { return _height; }
T value() const { return _value; }
};

Node* root = nullptr;
std::shared_ptr<Node> root = nullptr;

Node* _insert(Node* node, T value) {
if (!node) return new Node{value};
std::shared_ptr<Node> _insert(std::shared_ptr<Node> node, T value) {
if (!node) return make_shared<Node>(value);
if (value > node->_value) node->right = _insert(node->right, value);
else node->left = _insert(node->left, value);
return balance(node);
}

Node* _find(Node* node, T value) {
std::shared_ptr<Node> _find(std::shared_ptr<Node> node, T value) {
if (!node) return nullptr;
if (node->_value == value) return node;
if (value > node->_value) return _find(node->right, value);
return _find(node->left, value);
}

Node* _remove(Node* node, T value) {
std::shared_ptr<Node> _remove(std::shared_ptr<Node> node, T value) {
if (!node) return nullptr;
if (value < node->_value) node->left = _remove(node->left, value);
else if (value > node->_value) node->right = _remove(node->right, value);
else {
if (!node->left and !node->right) {
delete node;
node = nullptr;
return node;
}
else if (!node->left or !node->right) {
Node* res;
std::shared_ptr<Node> res;
if (node->left) res = node->left;
else res = node->right;
delete node;
node = nullptr;
return res;
}
Expand All @@ -67,17 +66,17 @@ class AVLTree {
return balance(node);
}

T _min(Node* node) {
T _min(std::shared_ptr<Node> node) {
if (!node->left) return node->_value;
return _min(node->left);
}

T _max(Node* node) {
T _max(std::shared_ptr<Node> node) {
if (!node->right) return node->_value;
return _max(node->right);
}

Node* balance(Node* node) {
std::shared_ptr<Node> balance(std::shared_ptr<Node> node) {
calc_height(node);
if (get_balance(node) == 2) {
if (get_balance(node->right) == -1)
Expand All @@ -93,58 +92,57 @@ class AVLTree {
return node;
}

Node* right_rotate(Node* node) {
Node* left = node->left;
std::shared_ptr<Node> right_rotate(std::shared_ptr<Node> node) {
std::shared_ptr<Node> left = node->left;
node->left = left->right;
left->right = node;
calc_height(node);
calc_height(left);
return left;
}

Node* left_rotate(Node* node) {
Node* right = node->right;
std::shared_ptr<Node> left_rotate(std::shared_ptr<Node> node) {
std::shared_ptr<Node> right = node->right;
node->right = right->left;
right->left = node;
calc_height(node);
calc_height(right);
return right;
}

int get_balance(Node* node) {
int get_balance(std::shared_ptr<Node> node) {
return height(node->right) - height(node->left);
}

void calc_height(Node* node) {
void calc_height(std::shared_ptr<Node> node) {
int left_height = height(node->left);
int right_height = height(node->right);
node->_height = std::max(left_height, right_height) + 1;
}

int height(Node* node) {
int height(std::shared_ptr<Node> node) {
if (!node) return 0;
return node->_height;
}

void _inorder(Node* node, vector<T>& result) {
void _inorder(std::shared_ptr<Node> node, vector<T>& result) {
if (!node) return;
_inorder(node->left, result);
result.push_back(node->_value);
_inorder(node->right, result);
}

void _delete(Node* node) {
if (!node) return;
_delete(node->left);
_delete(node->right);
delete node;
}

public:
AVLTree() = default;

AVLTree(std::initializer_list<T> il) {
for(auto&& x: il) {
insert(x);
}
}

AVLTree(const AVLTree &other) {
if (other.root) root = new Node(*other.root);
if (other.root) root = make_shared<Node>(*other.root);
}

~AVLTree() {
Expand All @@ -164,7 +162,7 @@ class AVLTree {
root = _insert(root, value);
}

Node* find(T value) {
std::shared_ptr<Node> find(T value) {
return _find(root, value);
}

Expand All @@ -185,18 +183,17 @@ class AVLTree {
}

void clear() {
_delete(root);
root = nullptr;
}

class iterator {
vector<Node*> path;
vector<std::shared_ptr<Node>> path;
public:
iterator(Node* root) {
iterator(std::shared_ptr<Node> root) {
go_left(root);
}

void go_left(Node* node) {
void go_left(std::shared_ptr<Node> node) {
while (node) {
path.push_back(node);
node = node->left;
Expand All @@ -210,7 +207,7 @@ class AVLTree {
}

iterator& operator++() {
Node* leaf = path.back();
std::shared_ptr<Node> leaf = path.back();
path.pop_back();
go_left(leaf->right);
return *this;
Expand All @@ -235,9 +232,7 @@ class AVLTree {
};

int main() {
AVLTree<int> int_tree{};
int_tree.insert(100); int_tree.insert(50); int_tree.insert(25); int_tree.insert(75); int_tree.insert(65);
int_tree.insert(85); int_tree.insert(150); int_tree.insert(125); int_tree.insert(175);
AVLTree<int> int_tree{100, 50, 25, 75, 65, 85, 150, 125, 175};
auto it = int_tree.begin();
auto end = int_tree.end();
for (; it != end; ++it) {
Expand Down