Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.3.x will end 7 October 2024 (12 months).
  • Bug fixes for security issues in 4.3.x will end 21 April 2025 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.2.x is supported too.

Differences Between: [Versions 310 and 403] [Versions 311 and 403] [Versions 39 and 403]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification\Ensemble;
   6  
   7  use Phpml\Classification\Classifier;
   8  use Phpml\Classification\DecisionTree;
   9  use Phpml\Exception\InvalidArgumentException;
  10  
  11  class RandomForest extends Bagging
  12  {
  13      /**
  14       * @var float|string
  15       */
  16      protected $featureSubsetRatio = 'log';
  17  
  18      /**
  19       * @var array|null
  20       */
  21      protected $columnNames;
  22  
  23      /**
  24       * Initializes RandomForest with the given number of trees. More trees
  25       * may increase the prediction performance while it will also substantially
  26       * increase the processing time and the required memory
  27       */
  28      public function __construct(int $numClassifier = 50)
  29      {
  30          parent::__construct($numClassifier);
  31  
  32          $this->setSubsetRatio(1.0);
  33      }
  34  
  35      /**
  36       * This method is used to determine how many of the original columns (features)
  37       * will be used to construct subsets to train base classifiers.<br>
  38       *
  39       * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br>
  40       *
  41       * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1
  42       * features to be taken into consideration while selecting subspace of features
  43       *
  44       * @param mixed $ratio
  45       */
  46      public function setFeatureSubsetRatio($ratio): self
  47      {
  48          if (!is_string($ratio) && !is_float($ratio)) {
  49              throw new InvalidArgumentException('Feature subset ratio must be a string or a float');
  50          }
  51  
  52          if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) {
  53              throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0');
  54          }
  55  
  56          if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') {
  57              throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'");
  58          }
  59  
  60          $this->featureSubsetRatio = $ratio;
  61  
  62          return $this;
  63      }
  64  
  65      /**
  66       * RandomForest algorithm is usable *only* with DecisionTree
  67       *
  68       * @return $this
  69       */
  70      public function setClassifer(string $classifier, array $classifierOptions = [])
  71      {
  72          if ($classifier !== DecisionTree::class) {
  73              throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier');
  74          }
  75  
  76          parent::setClassifer($classifier, $classifierOptions);
  77  
  78          return $this;
  79      }
  80  
  81      /**
  82       * This will return an array including an importance value for
  83       * each column in the given dataset. Importance values for a column
  84       * is the average importance of that column in all trees in the forest
  85       */
  86      public function getFeatureImportances(): array
  87      {
  88          // Traverse each tree and sum importance of the columns
  89          $sum = [];
  90          foreach ($this->classifiers as $tree) {
  91              /** @var DecisionTree $tree */
  92              $importances = $tree->getFeatureImportances();
  93  
  94              foreach ($importances as $column => $importance) {
  95                  if (array_key_exists($column, $sum)) {
  96                      $sum[$column] += $importance;
  97                  } else {
  98                      $sum[$column] = $importance;
  99                  }
 100              }
 101          }
 102  
 103          // Normalize & sort the importance values
 104          $total = array_sum($sum);
 105          array_walk($sum, function (&$importance) use ($total): void {
 106              $importance /= $total;
 107          });
 108          arsort($sum);
 109  
 110          return $sum;
 111      }
 112  
 113      /**
 114       * A string array to represent the columns is given. They are useful
 115       * when trying to print some information about the trees such as feature importances
 116       *
 117       * @return $this
 118       */
 119      public function setColumnNames(array $names)
 120      {
 121          $this->columnNames = $names;
 122  
 123          return $this;
 124      }
 125  
 126      /**
 127       * @return DecisionTree
 128       */
 129      protected function initSingleClassifier(Classifier $classifier): Classifier
 130      {
 131          if (!$classifier instanceof DecisionTree) {
 132              throw new InvalidArgumentException(
 133                  sprintf('Classifier %s expected, got %s', DecisionTree::class, get_class($classifier))
 134              );
 135          }
 136  
 137          if (is_float($this->featureSubsetRatio)) {
 138              $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount);
 139          } elseif ($this->featureSubsetRatio === 'sqrt') {
 140              $featureCount = (int) ($this->featureCount ** .5) + 1;
 141          } else {
 142              $featureCount = (int) log($this->featureCount, 2) + 1;
 143          }
 144  
 145          if ($featureCount >= $this->featureCount) {
 146              $featureCount = $this->featureCount;
 147          }
 148  
 149          if ($this->columnNames === null) {
 150              $this->columnNames = range(0, $this->featureCount - 1);
 151          }
 152  
 153          return $classifier
 154              ->setColumnNames($this->columnNames)
 155              ->setNumFeatures($featureCount);
 156      }
 157  }