diff --git a/Algorithms.Tests/MachineLearning/DecisionTreeTests.cs b/Algorithms.Tests/MachineLearning/DecisionTreeTests.cs new file mode 100644 index 00000000..8c615971 --- /dev/null +++ b/Algorithms.Tests/MachineLearning/DecisionTreeTests.cs @@ -0,0 +1,208 @@ +using NUnit.Framework; +using Algorithms.MachineLearning; +using System; + +namespace Algorithms.Tests.MachineLearning; + +[TestFixture] +public class DecisionTreeTests +{ + [Test] + public void Fit_ThrowsOnEmptyInput() + { + var tree = new DecisionTree(); + Assert.Throws(() => tree.Fit(Array.Empty(), Array.Empty())); + } + + [Test] + public void Fit_ThrowsOnMismatchedLabels() + { + var tree = new DecisionTree(); + int[][] X = { new[] { 1, 2 } }; + int[] y = { 1, 0 }; + Assert.Throws(() => tree.Fit(X, y)); + } + + [Test] + public void Predict_ThrowsIfNotTrained() + { + var tree = new DecisionTree(); + Assert.Throws(() => tree.Predict(new[] { 1, 2 })); + } + + [Test] + public void Predict_ThrowsOnFeatureMismatch() + { + var tree = new DecisionTree(); + int[][] X = { new[] { 1, 2 } }; + int[] y = { 1 }; + tree.Fit(X, y); + Assert.Throws(() => tree.Predict(new[] { 1 })); + } + + [Test] + public void FitAndPredict_WorksOnSimpleData() + { + // Simple OR logic + int[][] X = + { + new[] { 0, 0 }, + new[] { 0, 1 }, + new[] { 1, 0 }, + new[] { 1, 1 } + }; + int[] y = { 0, 1, 1, 1 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + Assert.That(tree.Predict(new[] { 0, 0 }), Is.EqualTo(0)); + Assert.That(tree.Predict(new[] { 0, 1 }), Is.EqualTo(1)); + Assert.That(tree.Predict(new[] { 1, 0 }), Is.EqualTo(1)); + Assert.That(tree.Predict(new[] { 1, 1 }), Is.EqualTo(1)); + } + + [Test] + public void FeatureCount_ReturnsCorrectValue() + { + var tree = new DecisionTree(); + int[][] X = { new[] { 1, 2, 3 } }; + int[] y = { 1 }; + tree.Fit(X, y); + Assert.That(tree.FeatureCount, Is.EqualTo(3)); + } + + [Test] + public void Predict_FallbacksToZeroForUnseenValue() + { + int[][] X = { new[] { 0, 0 }, new[] { 1, 1 } }; + int[] y = { 0, 1 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + // Value 2 is unseen in feature 0 + Assert.That(tree.Predict(new[] { 2, 0 }), Is.EqualTo(0)); + } + + [Test] + public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft() + { + int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 } }; + int[] y = { 1, 0, 1 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + // All features used, fallback to most common label (0) + Assert.That(tree.Predict(new[] { 3 }), Is.EqualTo(0)); + } + + [Test] + public void BuildTree_ReturnsNodeWithMostCommonLabel_WhenNoFeaturesLeft_MultipleLabels() + { + int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 }, new[] { 3 } }; + int[] y = { 1, 0, 1, 0 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + // Most common label is 0 (2 times) + Assert.That(tree.Predict(new[] { 4 }), Is.EqualTo(0)); + } + + [Test] + public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsZero() + { + int[][] X = { new[] { 0 }, new[] { 1 } }; + int[] y = { 0, 0 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + Assert.That(tree.Predict(new[] { 0 }), Is.EqualTo(0)); + Assert.That(tree.Predict(new[] { 1 }), Is.EqualTo(0)); + } + + [Test] + public void Entropy_ReturnsZero_WhenAllZeroOrAllOne() + { + var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + Assert.That(method!.Invoke(null, new[] { new int[] { 0, 0, 0 } }), Is.EqualTo(0d)); + Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1 } }), Is.EqualTo(0d)); + } + + [Test] + public void MostCommon_ReturnsCorrectLabel() + { + var method = typeof(DecisionTree).GetMethod("MostCommon", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + Assert.That(method!.Invoke(null, new[] { new int[] { 1, 0, 1, 1, 0, 0, 0 } }), Is.EqualTo(0)); + Assert.That(method!.Invoke(null, new[] { new int[] { 1, 1, 1, 0 } }), Is.EqualTo(1)); + } + + [Test] + public void Traverse_FallbacksToZero_WhenChildrenIsNull() + { + // Create a node with Children = null and Label = null + var nodeType = typeof(DecisionTree).GetNestedType("Node", System.Reflection.BindingFlags.NonPublic); + var node = Activator.CreateInstance(nodeType!); + nodeType!.GetProperty("Feature")!.SetValue(node, 0); + nodeType!.GetProperty("Label")!.SetValue(node, null); + nodeType!.GetProperty("Children")!.SetValue(node, null); + var method = typeof(DecisionTree).GetMethod("Traverse", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + Assert.That(method!.Invoke(null, new[] { node!, new int[] { 99 } }), Is.EqualTo(0)); + } + + [Test] + public void BuildTree_ReturnsNodeWithSingleLabel_WhenAllLabelsSame() + { + int[][] X = { new[] { 0 }, new[] { 1 }, new[] { 2 } }; + int[] y = { 1, 1, 1 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + Assert.That(tree.Predict(new[] { 0 }), Is.EqualTo(1)); + Assert.That(tree.Predict(new[] { 1 }), Is.EqualTo(1)); + Assert.That(tree.Predict(new[] { 2 }), Is.EqualTo(1)); + } + + [Test] + public void Entropy_ReturnsZero_WhenEmptyLabels() + { + // Use reflection to call private static Entropy + var method = typeof(DecisionTree).GetMethod("Entropy", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + Assert.That(method!.Invoke(null, new object[] { Array.Empty() }), Is.EqualTo(0d)); + } + + [Test] + public void BestFeature_SkipsEmptyIdxBranch() + { + // Feature 1 has value 2 which is never present, triggers idx.Length == 0 branch + int[][] X = { new[] { 0, 1 }, new[] { 1, 1 } }; + int[] y = { 0, 1 }; + var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + var features = new System.Collections.Generic.List { 0, 1 }; + var resultObj = method!.Invoke(null, new object[] { X, y, features }); + Assert.That(resultObj, Is.Not.Null); + Assert.That((int)resultObj!, Is.EqualTo(0)); + } + + [Test] + public void BuildTree_MostCommonLabelBranch_IsCovered() + { + int[][] X = { new[] { 0 }, new[] { 1 } }; + int[] y = { 0, 1 }; + var tree = new DecisionTree(); + tree.Fit(X, y); + Assert.That(tree.Predict(new[] { 2 }), Is.EqualTo(0)); + } + + [Test] + public void BuildTree_ContinueBranch_IsCovered() + { + int[][] X = { new[] { 0 }, new[] { 1 } }; + int[] y = { 0, 1 }; + var method = typeof(DecisionTree).GetMethod("BuildTree", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + var features = new System.Collections.Generic.List { 0 }; + Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features })); + } + + [Test] + public void BestFeature_ContinueBranch_IsCovered() + { + int[][] X = { new[] { 0, 1 }, new[] { 1, 1 } }; + int[] y = { 0, 1 }; + var method = typeof(DecisionTree).GetMethod("BestFeature", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + var features = new System.Collections.Generic.List { 0, 1 }; + Assert.DoesNotThrow(() => method!.Invoke(null, new object[] { X, y, features })); + } +} diff --git a/Algorithms/MachineLearning/DecisionTree.cs b/Algorithms/MachineLearning/DecisionTree.cs new file mode 100644 index 00000000..429558be --- /dev/null +++ b/Algorithms/MachineLearning/DecisionTree.cs @@ -0,0 +1,176 @@ +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Algorithms.MachineLearning; + +/// +/// Simple Decision Tree for binary classification using the ID3 algorithm. +/// Supports categorical features (int values). +/// +public class DecisionTree +{ + private Node? root; + + /// + /// Trains the decision tree using the ID3 algorithm. + /// + /// 2D array of features (samples x features), categorical (int). + /// Array of labels (0 or 1). + public void Fit(int[][] x, int[] y) + { + if (x.Length == 0 || x[0].Length == 0) + { + throw new ArgumentException("Input features cannot be empty."); + } + + if (x.Length != y.Length) + { + throw new ArgumentException("Number of samples and labels must match."); + } + + root = BuildTree(x, y, Enumerable.Range(0, x[0].Length).ToList()); + } + + /// + /// Predicts the class label (0 or 1) for a single sample. + /// + public int Predict(int[] x) + { + if (root is null) + { + throw new InvalidOperationException("Model not trained."); + } + + if (x.Length != FeatureCount) + { + throw new ArgumentException("Feature count mismatch."); + } + + return Traverse(root, x); + } + + /// + /// Gets the number of features used in training. + /// + public int FeatureCount => root?.FeatureCount ?? 0; + + private static Node BuildTree(int[][] x, int[] y, List features) + { + if (y.All(l => l == y[0])) + { + return new Node { Label = y[0], FeatureCount = x[0].Length }; + } + + if (features.Count == 0) + { + return new Node { Label = MostCommon(y), FeatureCount = x[0].Length }; + } + + int bestFeature = BestFeature(x, y, features); + var node = new Node { Feature = bestFeature, FeatureCount = x[0].Length }; + var values = x.Select(row => row[bestFeature]).Distinct(); + node.Children = new(); + foreach (var v in values) + { + var idx = x.Select((row, i) => (row, i)).Where(t => t.row[bestFeature] == v).Select(t => t.i).ToArray(); + if (idx.Length == 0) + { + continue; + } + + var subX = idx.Select(i => x[i]).ToArray(); + var subY = idx.Select(i => y[i]).ToArray(); + var subFeatures = features.Where(f => f != bestFeature).ToList(); + node.Children[v] = BuildTree(subX, subY, subFeatures); + } + + return node; + } + + private static int Traverse(Node node, int[] x) + { + if (node.Label is not null) + { + return node.Label.Value; + } + + int v = x[node.Feature!.Value]; + if (node.Children != null && node.Children.TryGetValue(v, out var child)) + { + return Traverse(child, x); + } + + // fallback to 0 if unseen value or Children is null + return 0; + } + + private static int MostCommon(int[] y) => y.GroupBy(l => l).OrderByDescending(g => g.Count()).First().Key; + + private static int BestFeature(int[][] x, int[] y, List features) + { + double baseEntropy = Entropy(y); + double bestGain = double.MinValue; + int bestFeature = features[0]; + foreach (var f in features) + { + var values = x.Select(row => row[f]).Distinct(); + double splitEntropy = 0; + foreach (var v in values) + { + var idx = x.Select((row, i) => (row, i)).Where(t => t.row[f] == v).Select(t => t.i).ToArray(); + if (idx.Length == 0) + { + continue; + } + + var subY = idx.Select(i => y[i]).ToArray(); + splitEntropy += (double)subY.Length / y.Length * Entropy(subY); + } + + double gain = baseEntropy - splitEntropy; + if (gain > bestGain) + { + bestGain = gain; + bestFeature = f; + } + } + + return bestFeature; + } + + private static double Entropy(int[] y) + { + int n = y.Length; + if (n == 0) + { + return 0; + } + + double p0 = y.Count(l => l == 0) / (double)n; + double p1 = y.Count(l => l == 1) / (double)n; + double e = 0; + if (p0 > 0) + { + e -= p0 * Math.Log2(p0); + } + + if (p1 > 0) + { + e -= p1 * Math.Log2(p1); + } + + return e; + } + + private class Node + { + public int? Feature { get; set; } + + public int? Label { get; set; } + + public int FeatureCount { get; set; } + + public Dictionary? Children { get; set; } + } +} diff --git a/README.md b/README.md index 37165565..927938d9 100644 --- a/README.md +++ b/README.md @@ -108,6 +108,7 @@ find more than one implementation for the same objective but using different alg * [CollaborativeFiltering](./Algorithms/RecommenderSystem/CollaborativeFiltering) * [Machine Learning](./Algorithms/MachineLearning) * [Linear Regression](./Algorithms/MachineLearning/LinearRegression.cs) + * [Decision Tree](./Algorithms/MachineLearning/DecisionTree.cs) * [Searches](./Algorithms/Search) * [A-Star](./Algorithms/Search/AStar/) * [Binary Search](./Algorithms/Search/BinarySearcher.cs)