Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 3.10.x will end 8 November 2021 (12 months).
  • Bug fixes for security issues in 3.10.x will end 9 May 2022 (18 months).
  • PHP version: minimum PHP 7.2.0 Note: minimum PHP version has increased since Moodle 3.8. PHP 7.3.x and 7.4.x are supported too.

Differences Between: [Versions 310 and 402] [Versions 310 and 403]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification;
   6  
   7  use Phpml\Classification\DecisionTree\DecisionTreeLeaf;
   8  use Phpml\Exception\InvalidArgumentException;
   9  use Phpml\Helper\Predictable;
  10  use Phpml\Helper\Trainable;
  11  use Phpml\Math\Statistic\Mean;
  12  
  13  class DecisionTree implements Classifier
  14  {
  15      use Trainable;
  16      use Predictable;
  17  
  18      public const CONTINUOUS = 1;
  19  
  20      public const NOMINAL = 2;
  21  
  22      /**
  23       * @var int
  24       */
  25      public $actualDepth = 0;
  26  
  27      /**
  28       * @var array
  29       */
  30      protected $columnTypes = [];
  31  
  32      /**
  33       * @var DecisionTreeLeaf
  34       */
  35      protected $tree;
  36  
  37      /**
  38       * @var int
  39       */
  40      protected $maxDepth;
  41  
  42      /**
  43       * @var array
  44       */
  45      private $labels = [];
  46  
  47      /**
  48       * @var int
  49       */
  50      private $featureCount = 0;
  51  
  52      /**
  53       * @var int
  54       */
  55      private $numUsableFeatures = 0;
  56  
  57      /**
  58       * @var array
  59       */
  60      private $selectedFeatures = [];
  61  
  62      /**
  63       * @var array|null
  64       */
  65      private $featureImportances;
  66  
  67      /**
  68       * @var array
  69       */
  70      private $columnNames = [];
  71  
  72      public function __construct(int $maxDepth = 10)
  73      {
  74          $this->maxDepth = $maxDepth;
  75      }
  76  
  77      public function train(array $samples, array $targets): void
  78      {
  79          $this->samples = array_merge($this->samples, $samples);
  80          $this->targets = array_merge($this->targets, $targets);
  81  
  82          $this->featureCount = count($this->samples[0]);
  83          $this->columnTypes = self::getColumnTypes($this->samples);
  84          $this->labels = array_keys(array_count_values($this->targets));
  85          $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1));
  86  
  87          // Each time the tree is trained, feature importances are reset so that
  88          // we will have to compute it again depending on the new data
  89          $this->featureImportances = null;
  90  
  91          // If column names are given or computed before, then there is no
  92          // need to init it and accidentally remove the previous given names
  93          if ($this->columnNames === []) {
  94              $this->columnNames = range(0, $this->featureCount - 1);
  95          } elseif (count($this->columnNames) > $this->featureCount) {
  96              $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount);
  97          } elseif (count($this->columnNames) < $this->featureCount) {
  98              $this->columnNames = array_merge(
  99                  $this->columnNames,
 100                  range(count($this->columnNames), $this->featureCount - 1)
 101              );
 102          }
 103      }
 104  
 105      public static function getColumnTypes(array $samples): array
 106      {
 107          $types = [];
 108          $featureCount = count($samples[0]);
 109          for ($i = 0; $i < $featureCount; ++$i) {
 110              $values = array_column($samples, $i);
 111              $isCategorical = self::isCategoricalColumn($values);
 112              $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS;
 113          }
 114  
 115          return $types;
 116      }
 117  
 118      /**
 119       * @param mixed $baseValue
 120       */
 121      public function getGiniIndex($baseValue, array $colValues, array $targets): float
 122      {
 123          $countMatrix = [];
 124          foreach ($this->labels as $label) {
 125              $countMatrix[$label] = [0, 0];
 126          }
 127  
 128          foreach ($colValues as $index => $value) {
 129              $label = $targets[$index];
 130              $rowIndex = $value === $baseValue ? 0 : 1;
 131              ++$countMatrix[$label][$rowIndex];
 132          }
 133  
 134          $giniParts = [0, 0];
 135          for ($i = 0; $i <= 1; ++$i) {
 136              $part = 0;
 137              $sum = array_sum(array_column($countMatrix, $i));
 138              if ($sum > 0) {
 139                  foreach ($this->labels as $label) {
 140                      $part += ($countMatrix[$label][$i] / (float) $sum) ** 2;
 141                  }
 142              }
 143  
 144              $giniParts[$i] = (1 - $part) * $sum;
 145          }
 146  
 147          return array_sum($giniParts) / count($colValues);
 148      }
 149  
 150      /**
 151       * This method is used to set number of columns to be used
 152       * when deciding a split at an internal node of the tree.  <br>
 153       * If the value is given 0, then all features are used (default behaviour),
 154       * otherwise the given value will be used as a maximum for number of columns
 155       * randomly selected for each split operation.
 156       *
 157       * @return $this
 158       *
 159       * @throws InvalidArgumentException
 160       */
 161      public function setNumFeatures(int $numFeatures)
 162      {
 163          if ($numFeatures < 0) {
 164              throw new InvalidArgumentException('Selected column count should be greater or equal to zero');
 165          }
 166  
 167          $this->numUsableFeatures = $numFeatures;
 168  
 169          return $this;
 170      }
 171  
 172      /**
 173       * A string array to represent columns. Useful when HTML output or
 174       * column importances are desired to be inspected.
 175       *
 176       * @return $this
 177       *
 178       * @throws InvalidArgumentException
 179       */
 180      public function setColumnNames(array $names)
 181      {
 182          if ($this->featureCount !== 0 && count($names) !== $this->featureCount) {
 183              throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount));
 184          }
 185  
 186          $this->columnNames = $names;
 187  
 188          return $this;
 189      }
 190  
 191      public function getHtml(): string
 192      {
 193          return $this->tree->getHTML($this->columnNames);
 194      }
 195  
 196      /**
 197       * This will return an array including an importance value for
 198       * each column in the given dataset. The importance values are
 199       * normalized and their total makes 1.<br/>
 200       */
 201      public function getFeatureImportances(): array
 202      {
 203          if ($this->featureImportances !== null) {
 204              return $this->featureImportances;
 205          }
 206  
 207          $sampleCount = count($this->samples);
 208          $this->featureImportances = [];
 209          foreach ($this->columnNames as $column => $columnName) {
 210              $nodes = $this->getSplitNodesByColumn($column, $this->tree);
 211  
 212              $importance = 0;
 213              foreach ($nodes as $node) {
 214                  $importance += $node->getNodeImpurityDecrease($sampleCount);
 215              }
 216  
 217              $this->featureImportances[$columnName] = $importance;
 218          }
 219  
 220          // Normalize & sort the importances
 221          $total = array_sum($this->featureImportances);
 222          if ($total > 0) {
 223              array_walk($this->featureImportances, function (&$importance) use ($total): void {
 224                  $importance /= $total;
 225              });
 226              arsort($this->featureImportances);
 227          }
 228  
 229          return $this->featureImportances;
 230      }
 231  
 232      protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf
 233      {
 234          $split = $this->getBestSplit($records);
 235          $split->level = $depth;
 236          if ($this->actualDepth < $depth) {
 237              $this->actualDepth = $depth;
 238          }
 239  
 240          // Traverse all records to see if all records belong to the same class,
 241          // otherwise group the records so that we can classify the leaf
 242          // in case maximum depth is reached
 243          $leftRecords = [];
 244          $rightRecords = [];
 245          $remainingTargets = [];
 246          $prevRecord = null;
 247          $allSame = true;
 248  
 249          foreach ($records as $recordNo) {
 250              // Check if the previous record is the same with the current one
 251              $record = $this->samples[$recordNo];
 252              if ($prevRecord !== null && $prevRecord != $record) {
 253                  $allSame = false;
 254              }
 255  
 256              $prevRecord = $record;
 257  
 258              // According to the split criteron, this record will
 259              // belong to either left or the right side in the next split
 260              if ($split->evaluate($record)) {
 261                  $leftRecords[] = $recordNo;
 262              } else {
 263                  $rightRecords[] = $recordNo;
 264              }
 265  
 266              // Group remaining targets
 267              $target = $this->targets[$recordNo];
 268              if (!array_key_exists($target, $remainingTargets)) {
 269                  $remainingTargets[$target] = 1;
 270              } else {
 271                  ++$remainingTargets[$target];
 272              }
 273          }
 274  
 275          if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) {
 276              $split->isTerminal = true;
 277              arsort($remainingTargets);
 278              $split->classValue = (string) key($remainingTargets);
 279          } else {
 280              if (isset($leftRecords[0])) {
 281                  $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1);
 282              }
 283  
 284              if (isset($rightRecords[0])) {
 285                  $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1);
 286              }
 287          }
 288  
 289          return $split;
 290      }
 291  
 292      protected function getBestSplit(array $records): DecisionTreeLeaf
 293      {
 294          $targets = array_intersect_key($this->targets, array_flip($records));
 295          $samples = (array) array_combine(
 296              $records,
 297              $this->preprocess(array_intersect_key($this->samples, array_flip($records)))
 298          );
 299          $bestGiniVal = 1;
 300          $bestSplit = null;
 301          $features = $this->getSelectedFeatures();
 302          foreach ($features as $i) {
 303              $colValues = [];
 304              foreach ($samples as $index => $row) {
 305                  $colValues[$index] = $row[$i];
 306              }
 307  
 308              $counts = array_count_values($colValues);
 309              arsort($counts);
 310              $baseValue = key($counts);
 311              if ($baseValue === null) {
 312                  continue;
 313              }
 314  
 315              $gini = $this->getGiniIndex($baseValue, $colValues, $targets);
 316              if ($bestSplit === null || $bestGiniVal > $gini) {
 317                  $split = new DecisionTreeLeaf();
 318                  $split->value = $baseValue;
 319                  $split->giniIndex = $gini;
 320                  $split->columnIndex = $i;
 321                  $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS;
 322                  $split->records = $records;
 323  
 324                  // If a numeric column is to be selected, then
 325                  // the original numeric value and the selected operator
 326                  // will also be saved into the leaf for future access
 327                  if ($this->columnTypes[$i] === self::CONTINUOUS) {
 328                      $matches = [];
 329                      preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches);
 330                      $split->operator = $matches[1];
 331                      $split->numericValue = (float) $matches[2];
 332                  }
 333  
 334                  $bestSplit = $split;
 335                  $bestGiniVal = $gini;
 336              }
 337          }
 338  
 339          return $bestSplit;
 340      }
 341  
 342      /**
 343       * Returns available features/columns to the tree for the decision making
 344       * process. <br>
 345       *
 346       * If a number is given with setNumFeatures() method, then a random selection
 347       * of features up to this number is returned. <br>
 348       *
 349       * If some features are manually selected by use of setSelectedFeatures(),
 350       * then only these features are returned <br>
 351       *
 352       * If any of above methods were not called beforehand, then all features
 353       * are returned by default.
 354       */
 355      protected function getSelectedFeatures(): array
 356      {
 357          $allFeatures = range(0, $this->featureCount - 1);
 358          if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) {
 359              return $allFeatures;
 360          }
 361  
 362          if (count($this->selectedFeatures) > 0) {
 363              return $this->selectedFeatures;
 364          }
 365  
 366          $numFeatures = $this->numUsableFeatures;
 367          if ($numFeatures > $this->featureCount) {
 368              $numFeatures = $this->featureCount;
 369          }
 370  
 371          shuffle($allFeatures);
 372          $selectedFeatures = array_slice($allFeatures, 0, $numFeatures);
 373          sort($selectedFeatures);
 374  
 375          return $selectedFeatures;
 376      }
 377  
 378      protected function preprocess(array $samples): array
 379      {
 380          // Detect and convert continuous data column values into
 381          // discrete values by using the median as a threshold value
 382          $columns = [];
 383          for ($i = 0; $i < $this->featureCount; ++$i) {
 384              $values = array_column($samples, $i);
 385              if ($this->columnTypes[$i] == self::CONTINUOUS) {
 386                  $median = Mean::median($values);
 387                  foreach ($values as &$value) {
 388                      if ($value <= $median) {
 389                          $value = "<= $median}";
 390                      } else {
 391                          $value = "> $median}";
 392                      }
 393                  }
 394              }
 395  
 396              $columns[] = $values;
 397          }
 398  
 399          // Below method is a strange yet very simple & efficient method
 400          // to get the transpose of a 2D array
 401          return array_map(null, ...$columns);
 402      }
 403  
 404      protected static function isCategoricalColumn(array $columnValues): bool
 405      {
 406          $count = count($columnValues);
 407  
 408          // There are two main indicators that *may* show whether a
 409          // column is composed of discrete set of values:
 410          // 1- Column may contain string values and non-float values
 411          // 2- Number of unique values in the column is only a small fraction of
 412          //	   all values in that column (Lower than or equal to %20 of all values)
 413          $numericValues = array_filter($columnValues, 'is_numeric');
 414          $floatValues = array_filter($columnValues, 'is_float');
 415          if (count($floatValues) > 0) {
 416              return false;
 417          }
 418  
 419          if (count($numericValues) !== $count) {
 420              return true;
 421          }
 422  
 423          $distinctValues = array_count_values($columnValues);
 424  
 425          return count($distinctValues) <= $count / 5;
 426      }
 427  
 428      /**
 429       * Used to set predefined features to consider while deciding which column to use for a split
 430       */
 431      protected function setSelectedFeatures(array $selectedFeatures): void
 432      {
 433          $this->selectedFeatures = $selectedFeatures;
 434      }
 435  
 436      /**
 437       * Collects and returns an array of internal nodes that use the given
 438       * column as a split criterion
 439       */
 440      protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array
 441      {
 442          if ($node->isTerminal) {
 443              return [];
 444          }
 445  
 446          $nodes = [];
 447          if ($node->columnIndex === $column) {
 448              $nodes[] = $node;
 449          }
 450  
 451          $lNodes = [];
 452          $rNodes = [];
 453          if ($node->leftLeaf !== null) {
 454              $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf);
 455          }
 456  
 457          if ($node->rightLeaf !== null) {
 458              $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf);
 459          }
 460  
 461          return array_merge($nodes, $lNodes, $rNodes);
 462      }
 463  
 464      /**
 465       * @return mixed
 466       */
 467      protected function predictSample(array $sample)
 468      {
 469          $node = $this->tree;
 470          do {
 471              if ($node->isTerminal) {
 472                  return $node->classValue;
 473              }
 474  
 475              if ($node->evaluate($sample)) {
 476                  $node = $node->leftLeaf;
 477              } else {
 478                  $node = $node->rightLeaf;
 479              }
 480          } while ($node);
 481  
 482          return $this->labels[0];
 483      }
 484  }