See Release Notes
Long Term Support Release
<?php declare(strict_types=1); namespace Phpml\Classification; use Phpml\Classification\DecisionTree\DecisionTreeLeaf; use Phpml\Exception\InvalidArgumentException; use Phpml\Helper\Predictable; use Phpml\Helper\Trainable; use Phpml\Math\Statistic\Mean; class DecisionTree implements Classifier { use Trainable; use Predictable; public const CONTINUOUS = 1; public const NOMINAL = 2; /** * @var int */ public $actualDepth = 0; /** * @var array */ protected $columnTypes = []; /** * @var DecisionTreeLeaf */ protected $tree; /** * @var int */ protected $maxDepth; /** * @var array */ private $labels = []; /** * @var int */ private $featureCount = 0; /** * @var int */ private $numUsableFeatures = 0; /** * @var array */ private $selectedFeatures = []; /** * @var array|null */ private $featureImportances; /** * @var array */ private $columnNames = []; public function __construct(int $maxDepth = 10) { $this->maxDepth = $maxDepth; } public function train(array $samples, array $targets): void { $this->samples = array_merge($this->samples, $samples); $this->targets = array_merge($this->targets, $targets); $this->featureCount = count($this->samples[0]); $this->columnTypes = self::getColumnTypes($this->samples); $this->labels = array_keys(array_count_values($this->targets)); $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); // Each time the tree is trained, feature importances are reset so that // we will have to compute it again depending on the new data $this->featureImportances = null; // If column names are given or computed before, then there is no // need to init it and accidentally remove the previous given names if ($this->columnNames === []) { $this->columnNames = range(0, $this->featureCount - 1); } elseif (count($this->columnNames) > $this->featureCount) { $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount); } elseif (count($this->columnNames) < $this->featureCount) { $this->columnNames = array_merge( $this->columnNames, range(count($this->columnNames), $this->featureCount - 1) ); } } public static function getColumnTypes(array $samples): array { $types = []; $featureCount = count($samples[0]); for ($i = 0; $i < $featureCount; ++$i) { $values = array_column($samples, $i); $isCategorical = self::isCategoricalColumn($values); $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS; } return $types; } /** * @param mixed $baseValue */ public function getGiniIndex($baseValue, array $colValues, array $targets): float { $countMatrix = []; foreach ($this->labels as $label) { $countMatrix[$label] = [0, 0]; } foreach ($colValues as $index => $value) { $label = $targets[$index]; $rowIndex = $value === $baseValue ? 0 : 1; ++$countMatrix[$label][$rowIndex]; } $giniParts = [0, 0]; for ($i = 0; $i <= 1; ++$i) { $part = 0; $sum = array_sum(array_column($countMatrix, $i)); if ($sum > 0) { foreach ($this->labels as $label) { $part += ($countMatrix[$label][$i] / (float) $sum) ** 2; } } $giniParts[$i] = (1 - $part) * $sum; } return array_sum($giniParts) / count($colValues); } /** * This method is used to set number of columns to be used * when deciding a split at an internal node of the tree. <br> * If the value is given 0, then all features are used (default behaviour), * otherwise the given value will be used as a maximum for number of columns * randomly selected for each split operation. * * @return $this * * @throws InvalidArgumentException */ public function setNumFeatures(int $numFeatures) { if ($numFeatures < 0) { throw new InvalidArgumentException('Selected column count should be greater or equal to zero'); } $this->numUsableFeatures = $numFeatures; return $this; } /** * A string array to represent columns. Useful when HTML output or * column importances are desired to be inspected. * * @return $this * * @throws InvalidArgumentException */ public function setColumnNames(array $names) { if ($this->featureCount !== 0 && count($names) !== $this->featureCount) { throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount)); } $this->columnNames = $names; return $this; } public function getHtml(): string { return $this->tree->getHTML($this->columnNames); } /** * This will return an array including an importance value for * each column in the given dataset. The importance values are * normalized and their total makes 1.<br/> */ public function getFeatureImportances(): array { if ($this->featureImportances !== null) { return $this->featureImportances; } $sampleCount = count($this->samples); $this->featureImportances = []; foreach ($this->columnNames as $column => $columnName) { $nodes = $this->getSplitNodesByColumn($column, $this->tree); $importance = 0; foreach ($nodes as $node) { $importance += $node->getNodeImpurityDecrease($sampleCount); } $this->featureImportances[$columnName] = $importance; } // Normalize & sort the importances $total = array_sum($this->featureImportances); if ($total > 0) { array_walk($this->featureImportances, function (&$importance) use ($total): void { $importance /= $total; }); arsort($this->featureImportances); } return $this->featureImportances; } protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf { $split = $this->getBestSplit($records); $split->level = $depth; if ($this->actualDepth < $depth) { $this->actualDepth = $depth; } // Traverse all records to see if all records belong to the same class, // otherwise group the records so that we can classify the leaf // in case maximum depth is reached $leftRecords = []; $rightRecords = []; $remainingTargets = []; $prevRecord = null; $allSame = true; foreach ($records as $recordNo) { // Check if the previous record is the same with the current one $record = $this->samples[$recordNo]; if ($prevRecord !== null && $prevRecord != $record) { $allSame = false; } $prevRecord = $record; // According to the split criteron, this record will // belong to either left or the right side in the next split if ($split->evaluate($record)) { $leftRecords[] = $recordNo; } else { $rightRecords[] = $recordNo; } // Group remaining targets $target = $this->targets[$recordNo]; if (!array_key_exists($target, $remainingTargets)) { $remainingTargets[$target] = 1; } else { ++$remainingTargets[$target]; } } if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) { $split->isTerminal = true; arsort($remainingTargets); $split->classValue = (string) key($remainingTargets); } else { if (isset($leftRecords[0])) { $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1); } if (isset($rightRecords[0])) { $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1); } } return $split; } protected function getBestSplit(array $records): DecisionTreeLeaf { $targets = array_intersect_key($this->targets, array_flip($records)); $samples = (array) array_combine( $records, $this->preprocess(array_intersect_key($this->samples, array_flip($records))) ); $bestGiniVal = 1; $bestSplit = null; $features = $this->getSelectedFeatures(); foreach ($features as $i) { $colValues = []; foreach ($samples as $index => $row) { $colValues[$index] = $row[$i]; } $counts = array_count_values($colValues); arsort($counts); $baseValue = key($counts); if ($baseValue === null) { continue; } $gini = $this->getGiniIndex($baseValue, $colValues, $targets); if ($bestSplit === null || $bestGiniVal > $gini) { $split = new DecisionTreeLeaf(); $split->value = $baseValue; $split->giniIndex = $gini; $split->columnIndex = $i; $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS; $split->records = $records; // If a numeric column is to be selected, then // the original numeric value and the selected operator // will also be saved into the leaf for future access if ($this->columnTypes[$i] === self::CONTINUOUS) { $matches = []; preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches); $split->operator = $matches[1]; $split->numericValue = (float) $matches[2]; } $bestSplit = $split; $bestGiniVal = $gini; } } return $bestSplit; } /** * Returns available features/columns to the tree for the decision making * process. <br> * * If a number is given with setNumFeatures() method, then a random selection * of features up to this number is returned. <br> * * If some features are manually selected by use of setSelectedFeatures(), * then only these features are returned <br> * * If any of above methods were not called beforehand, then all features * are returned by default. */ protected function getSelectedFeatures(): array { $allFeatures = range(0, $this->featureCount - 1); if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) { return $allFeatures; } if (count($this->selectedFeatures) > 0) { return $this->selectedFeatures; } $numFeatures = $this->numUsableFeatures; if ($numFeatures > $this->featureCount) { $numFeatures = $this->featureCount; } shuffle($allFeatures); $selectedFeatures = array_slice($allFeatures, 0, $numFeatures); sort($selectedFeatures); return $selectedFeatures; } protected function preprocess(array $samples): array { // Detect and convert continuous data column values into // discrete values by using the median as a threshold value $columns = []; for ($i = 0; $i < $this->featureCount; ++$i) { $values = array_column($samples, $i); if ($this->columnTypes[$i] == self::CONTINUOUS) { $median = Mean::median($values); foreach ($values as &$value) { if ($value <= $median) {< $value = "<= ${median}";> $value = "<= {$median}";} else {< $value = "> ${median}";> $value = "> {$median}";} } } $columns[] = $values; } // Below method is a strange yet very simple & efficient method // to get the transpose of a 2D array return array_map(null, ...$columns); } protected static function isCategoricalColumn(array $columnValues): bool { $count = count($columnValues); // There are two main indicators that *may* show whether a // column is composed of discrete set of values: // 1- Column may contain string values and non-float values // 2- Number of unique values in the column is only a small fraction of // all values in that column (Lower than or equal to %20 of all values) $numericValues = array_filter($columnValues, 'is_numeric'); $floatValues = array_filter($columnValues, 'is_float'); if (count($floatValues) > 0) { return false; } if (count($numericValues) !== $count) { return true; } $distinctValues = array_count_values($columnValues); return count($distinctValues) <= $count / 5; } /** * Used to set predefined features to consider while deciding which column to use for a split */ protected function setSelectedFeatures(array $selectedFeatures): void { $this->selectedFeatures = $selectedFeatures; } /** * Collects and returns an array of internal nodes that use the given * column as a split criterion */ protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array { if ($node->isTerminal) { return []; } $nodes = []; if ($node->columnIndex === $column) { $nodes[] = $node; } $lNodes = []; $rNodes = []; if ($node->leftLeaf !== null) { $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf); } if ($node->rightLeaf !== null) { $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf); } return array_merge($nodes, $lNodes, $rNodes); } /** * @return mixed */ protected function predictSample(array $sample) { $node = $this->tree; do { if ($node->isTerminal) { return $node->classValue; } if ($node->evaluate($sample)) { $node = $node->leftLeaf; } else { $node = $node->rightLeaf; } } while ($node); return $this->labels[0]; } }