We already learned the power and flexibility of decision trees for adding a decision-making component to our game. Furthermore, we can also build them dynamically through supervised learning. That's why we're revisiting them in this chapter.
There are several algorithms for building decision trees that are suited for different uses such as prediction and classification. In our case, we'll explore decision-tree learning by implementing the ID3 algorithm.
Despite having built decision trees in a previous chapter, and the fact that they're based on the same principles as the ones that we will implement now, we will use different data types for our implementation needs in spite of the learning algorithm.
We will need two data types: one for the decision nodes and one for storing the examples to be learned.
The code for the DecisionNode
data type is as follows:
using System.Collections.Generic; public class DecisionNode { public string testValue; public Dictionary<float, DecisionNode> children; public DecisionNode(string testValue = "") { this.testValue = testValue; children = new Dictionary<float, DecisionNode>(); } }
The code for the Example
data type is as follows:
using UnityEngine; using System.Collections.Generic; public enum ID3Action { STOP, WALK, RUN } public class ID3Example : MonoBehaviour { public ID3Action action; public Dictionary<string, float> values; public float GetValue(string attribute) { return values[attribute]; } }
We will create the
ID3
class with several functions for computing the resulting decision tree.
ID3
class:using UnityEngine; using System.Collections.Generic; public class ID3 : MonoBehaviour { // next steps }
public Dictionary<float, List<ID3Example>> SplitByAttribute( ID3Example[] examples, string attribute) { Dictionary<float, List<ID3Example>> sets; sets = new Dictionary<float, List<ID3Example>>(); // next step }
foreach (ID3Example e in examples) { float key = e.GetValue(attribute); if (!sets.ContainsKey(key)) sets.Add(key, new List<ID3Example>()); sets[key].Add(e); } return sets;
public float GetEntropy(ID3Example[] examples) { if (examples.Length == 0) return 0f; int numExamples = examples.Length; Dictionary<ID3Action, int> actionTallies; actionTallies = new Dictionary<ID3Action, int>(); // next steps }
foreach (ID3Example e in examples) { if (!actionTallies.ContainsKey(e.action)) actionTallies.Add(e.action, 0); actionTallies[e.action]++; }
int actionCount = actionTallies.Keys.Count; if (actionCount == 0) return 0f; float entropy = 0f; float proportion = 0f; foreach (int tally in actionTallies.Values) { proportion = tally / (float)numExamples; entropy -= proportion * Mathf.Log(proportion, 2); } return entropy;
public float GetEntropy( Dictionary<float, List<ID3Example>> sets, int numExamples) { float entropy = 0f; foreach (List<ID3Example> s in sets.Values) { float proportion; proportion = s.Count / (float)numExamples; entropy -= proportion * GetEntropy(s.ToArray()); } return entropy; }
public void MakeTree( ID3Example[] examples, List<string> attributes, DecisionNode node) { float initEntropy = GetEntropy(examples); if (initEntropy <= 0) return; // next steps }
int numExamples = examples.Length; float bestInfoGain = 0f; string bestSplitAttribute = ""; float infoGain = 0f; float overallEntropy = 0f; Dictionary<float, List<ID3Example>> bestSets; bestSets = new Dictionary<float, List<ID3Example>>(); Dictionary<float, List<ID3Example>> sets;
foreach (string a in attributes) { sets = SplitByAttribute(examples, a); overallEntropy = GetEntropy(sets, numExamples); infoGain = initEntropy - overallEntropy; if (infoGain > bestInfoGain) { bestInfoGain = infoGain; bestSplitAttribute = a; bestSets = sets; } }
node.testValue = bestSplitAttribute; List<string> newAttributes = new List<string>(attributes); newAttributes.Remove(bestSplitAttribute);
foreach (List<ID3Example> set in bestSets.Values) { float val = set[0].GetValue(bestSplitAttribute); DecisionNode child = new DecisionNode(); node.children.Add(val, child); MakeTree(set.ToArray(), newAttributes, child); }
The class is modular in terms of functionality. It doesn't store any information but is able to compute and retrieve everything needed for the function that builds the decision tree. SplitByAttribute
takes the examples and divides them into sets that are needed for computing their entropy. ComputeEntropy
is an overloaded function that computes a list of examples and all the sets of examples using the formulae defined in the ID3 algorithm. Finally, MakeTree
works recursively in order to build the decision tree, getting hold of the most significant attribute.