Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.2.x will end 22 April 2024 (12 months).
  • Bug fixes for security issues in 4.2.x will end 7 October 2024 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.1.x is supported too.
<?php

declare(strict_types=1);

namespace Phpml\Classification;

use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Predictable;
use Phpml\Helper\Trainable;
use Phpml\Math\Statistic\Mean;

class DecisionTree implements Classifier
{
    use Trainable;
    use Predictable;

    public const CONTINUOUS = 1;

    public const NOMINAL = 2;

    /**
     * @var int
     */
    public $actualDepth = 0;

    /**
     * @var array
     */
    protected $columnTypes = [];

    /**
     * @var DecisionTreeLeaf
     */
    protected $tree;

    /**
     * @var int
     */
    protected $maxDepth;

    /**
     * @var array
     */
    private $labels = [];

    /**
     * @var int
     */
    private $featureCount = 0;

    /**
     * @var int
     */
    private $numUsableFeatures = 0;

    /**
     * @var array
     */
    private $selectedFeatures = [];

    /**
     * @var array|null
     */
    private $featureImportances;

    /**
     * @var array
     */
    private $columnNames = [];

    public function __construct(int $maxDepth = 10)
    {
        $this->maxDepth = $maxDepth;
    }

    public function train(array $samples, array $targets): void
    {
        $this->samples = array_merge($this->samples, $samples);
        $this->targets = array_merge($this->targets, $targets);

        $this->featureCount = count($this->samples[0]);
        $this->columnTypes = self::getColumnTypes($this->samples);
        $this->labels = array_keys(array_count_values($this->targets));
        $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));

        // Each time the tree is trained, feature importances are reset so that
        // we will have to compute it again depending on the new data
        $this->featureImportances = null;

        // If column names are given or computed before, then there is no
        // need to init it and accidentally remove the previous given names
        if ($this->columnNames === []) {
            $this->columnNames = range(0, $this->featureCount - 1);
        } elseif (count($this->columnNames) > $this->featureCount) {
            $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
        } elseif (count($this->columnNames) < $this->featureCount) {
            $this->columnNames = array_merge(
                $this->columnNames,
                range(count($this->columnNames), $this->featureCount - 1)
            );
        }
    }

    public static function getColumnTypes(array $samples): array
    {
        $types = [];
        $featureCount = count($samples[0]);
        for ($i = 0; $i < $featureCount; ++$i) {
            $values = array_column($samples, $i);
            $isCategorical = self::isCategoricalColumn($values);
            $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
        }

        return $types;
    }

    /**
     * @param mixed $baseValue
     */
    public function getGiniIndex($baseValue, array $colValues, array $targets): float
    {
        $countMatrix = [];
        foreach ($this->labels as $label) {
            $countMatrix[$label] = [0, 0];
        }

        foreach ($colValues as $index => $value) {
            $label = $targets[$index];
            $rowIndex = $value === $baseValue ? 0 : 1;
            ++$countMatrix[$label][$rowIndex];
        }

        $giniParts = [0, 0];
        for ($i = 0; $i <= 1; ++$i) {
            $part = 0;
            $sum = array_sum(array_column($countMatrix, $i));
            if ($sum > 0) {
                foreach ($this->labels as $label) {
                    $part += ($countMatrix[$label][$i] / (float) $sum) ** 2;
                }
            }

            $giniParts[$i] = (1 - $part) * $sum;
        }

        return array_sum($giniParts) / count($colValues);
    }

    /**
     * This method is used to set number of columns to be used
     * when deciding a split at an internal node of the tree.  <br>
     * If the value is given 0, then all features are used (default behaviour),
     * otherwise the given value will be used as a maximum for number of columns
     * randomly selected for each split operation.
     *
     * @return $this
     *
     * @throws InvalidArgumentException
     */
    public function setNumFeatures(int $numFeatures)
    {
        if ($numFeatures < 0) {
            throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
        }

        $this->numUsableFeatures = $numFeatures;

        return $this;
    }

    /**
     * A string array to represent columns. Useful when HTML output or
     * column importances are desired to be inspected.
     *
     * @return $this
     *
     * @throws InvalidArgumentException
     */
    public function setColumnNames(array $names)
    {
        if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
            throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
        }

        $this->columnNames = $names;

        return $this;
    }

    public function getHtml(): string
    {
        return $this->tree->getHTML($this->columnNames);
    }

    /**
     * This will return an array including an importance value for
     * each column in the given dataset. The importance values are
     * normalized and their total makes 1.<br/>
     */
    public function getFeatureImportances(): array
    {
        if ($this->featureImportances !== null) {
            return $this->featureImportances;
        }

        $sampleCount = count($this->samples);
        $this->featureImportances = [];
        foreach ($this->columnNames as $column => $columnName) {
            $nodes = $this->getSplitNodesByColumn($column, $this->tree);

            $importance = 0;
            foreach ($nodes as $node) {
                $importance += $node->getNodeImpurityDecrease($sampleCount);
            }

            $this->featureImportances[$columnName] = $importance;
        }

        // Normalize & sort the importances
        $total = array_sum($this->featureImportances);
        if ($total > 0) {
            array_walk($this->featureImportances, function (&$importance) use ($total): void {
                $importance /= $total;
            });
            arsort($this->featureImportances);
        }

        return $this->featureImportances;
    }

    protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
    {
        $split = $this->getBestSplit($records);
        $split->level = $depth;
        if ($this->actualDepth < $depth) {
            $this->actualDepth = $depth;
        }

        // Traverse all records to see if all records belong to the same class,
        // otherwise group the records so that we can classify the leaf
        // in case maximum depth is reached
        $leftRecords = [];
        $rightRecords = [];
        $remainingTargets = [];
        $prevRecord = null;
        $allSame = true;

        foreach ($records as $recordNo) {
            // Check if the previous record is the same with the current one
            $record = $this->samples[$recordNo];
            if ($prevRecord !== null && $prevRecord != $record) {
                $allSame = false;
            }

            $prevRecord = $record;

            // According to the split criteron, this record will
            // belong to either left or the right side in the next split
            if ($split->evaluate($record)) {
                $leftRecords[] = $recordNo;
            } else {
                $rightRecords[] = $recordNo;
            }

            // Group remaining targets
            $target = $this->targets[$recordNo];
            if (!array_key_exists($target, $remainingTargets)) {
                $remainingTargets[$target] = 1;
            } else {
                ++$remainingTargets[$target];
            }
        }

        if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
            $split->isTerminal = true;
            arsort($remainingTargets);
            $split->classValue = (string) key($remainingTargets);
        } else {
            if (isset($leftRecords[0])) {
                $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
            }

            if (isset($rightRecords[0])) {
                $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
            }
        }

        return $split;
    }

    protected function getBestSplit(array $records): DecisionTreeLeaf
    {
        $targets = array_intersect_key($this->targets, array_flip($records));
        $samples = (array) array_combine(
            $records,
            $this->preprocess(array_intersect_key($this->samples, array_flip($records)))
        );
        $bestGiniVal = 1;
        $bestSplit = null;
        $features = $this->getSelectedFeatures();
        foreach ($features as $i) {
            $colValues = [];
            foreach ($samples as $index => $row) {
                $colValues[$index] = $row[$i];
            }

            $counts = array_count_values($colValues);
            arsort($counts);
            $baseValue = key($counts);
            if ($baseValue === null) {
                continue;
            }

            $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
            if ($bestSplit === null || $bestGiniVal > $gini) {
                $split = new DecisionTreeLeaf();
                $split->value = $baseValue;
                $split->giniIndex = $gini;
                $split->columnIndex = $i;
                $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;
                $split->records = $records;

                // If a numeric column is to be selected, then
                // the original numeric value and the selected operator
                // will also be saved into the leaf for future access
                if ($this->columnTypes[$i] === self::CONTINUOUS) {
                    $matches = [];
                    preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
                    $split->operator = $matches[1];
                    $split->numericValue = (float) $matches[2];
                }

                $bestSplit = $split;
                $bestGiniVal = $gini;
            }
        }

        return $bestSplit;
    }

    /**
     * Returns available features/columns to the tree for the decision making
     * process. <br>
     *
     * If a number is given with setNumFeatures() method, then a random selection
     * of features up to this number is returned. <br>
     *
     * If some features are manually selected by use of setSelectedFeatures(),
     * then only these features are returned <br>
     *
     * If any of above methods were not called beforehand, then all features
     * are returned by default.
     */
    protected function getSelectedFeatures(): array
    {
        $allFeatures = range(0, $this->featureCount - 1);
        if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {
            return $allFeatures;
        }

        if (count($this->selectedFeatures) > 0) {
            return $this->selectedFeatures;
        }

        $numFeatures = $this->numUsableFeatures;
        if ($numFeatures > $this->featureCount) {
            $numFeatures = $this->featureCount;
        }

        shuffle($allFeatures);
        $selectedFeatures = array_slice($allFeatures, 0, $numFeatures);
        sort($selectedFeatures);

        return $selectedFeatures;
    }

    protected function preprocess(array $samples): array
    {
        // Detect and convert continuous data column values into
        // discrete values by using the median as a threshold value
        $columns = [];
        for ($i = 0; $i < $this->featureCount; ++$i) {
            $values = array_column($samples, $i);
            if ($this->columnTypes[$i] == self::CONTINUOUS) {
                $median = Mean::median($values);
                foreach ($values as &$value) {
                    if ($value <= $median) {
< $value = "<= ${median}";
> $value = "<= {$median}";
} else {
< $value = "> ${median}";
> $value = "> {$median}";
} } } $columns[] = $values; } // Below method is a strange yet very simple & efficient method // to get the transpose of a 2D array return array_map(null, ...$columns); } protected static function isCategoricalColumn(array $columnValues): bool { $count = count($columnValues); // There are two main indicators that *may* show whether a // column is composed of discrete set of values: // 1- Column may contain string values and non-float values // 2- Number of unique values in the column is only a small fraction of // all values in that column (Lower than or equal to %20 of all values) $numericValues = array_filter($columnValues, 'is_numeric'); $floatValues = array_filter($columnValues, 'is_float'); if (count($floatValues) > 0) { return false; } if (count($numericValues) !== $count) { return true; } $distinctValues = array_count_values($columnValues); return count($distinctValues) <= $count / 5; } /** * Used to set predefined features to consider while deciding which column to use for a split */ protected function setSelectedFeatures(array $selectedFeatures): void { $this->selectedFeatures = $selectedFeatures; } /** * Collects and returns an array of internal nodes that use the given * column as a split criterion */ protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array { if ($node->isTerminal) { return []; } $nodes = []; if ($node->columnIndex === $column) { $nodes[] = $node; } $lNodes = []; $rNodes = []; if ($node->leftLeaf !== null) { $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf); } if ($node->rightLeaf !== null) { $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf); } return array_merge($nodes, $lNodes, $rNodes); } /** * @return mixed */ protected function predictSample(array $sample) { $node = $this->tree; do { if ($node->isTerminal) { return $node->classValue; } if ($node->evaluate($sample)) { $node = $node->leftLeaf; } else { $node = $node->rightLeaf; } } while ($node); return $this->labels[0]; } }