Search moodle.org's
Developer Documentation

See Release Notes
Long Term Support Release

  • Bug fixes for general core bugs in 4.1.x will end 13 November 2023 (12 months).
  • Bug fixes for security issues in 4.1.x will end 10 November 2025 (36 months).
  • PHP version: minimum PHP 7.4.0 Note: minimum PHP version has increased since Moodle 4.0. PHP 8.0.x is supported too.

Differences Between: [Versions 401 and 402] [Versions 401 and 403]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification\Linear;
   6  
   7  use Phpml\Classification\DecisionTree;
   8  use Phpml\Classification\WeightedClassifier;
   9  use Phpml\Exception\InvalidArgumentException;
  10  use Phpml\Helper\OneVsRest;
  11  use Phpml\Helper\Predictable;
  12  use Phpml\Math\Comparison;
  13  
  14  class DecisionStump extends WeightedClassifier
  15  {
  16      use Predictable;
  17      use OneVsRest;
  18  
  19      public const AUTO_SELECT = -1;
  20  
  21      /**
  22       * @var int
  23       */
  24      protected $givenColumnIndex;
  25  
  26      /**
  27       * @var array
  28       */
  29      protected $binaryLabels = [];
  30  
  31      /**
  32       * Lowest error rate obtained while training/optimizing the model
  33       *
  34       * @var float
  35       */
  36      protected $trainingErrorRate;
  37  
  38      /**
  39       * @var int
  40       */
  41      protected $column;
  42  
  43      /**
  44       * @var mixed
  45       */
  46      protected $value;
  47  
  48      /**
  49       * @var string
  50       */
  51      protected $operator;
  52  
  53      /**
  54       * @var array
  55       */
  56      protected $columnTypes = [];
  57  
  58      /**
  59       * @var int
  60       */
  61      protected $featureCount;
  62  
  63      /**
  64       * @var float
  65       */
  66      protected $numSplitCount = 100.0;
  67  
  68      /**
  69       * Distribution of samples in the leaves
  70       *
  71       * @var array
  72       */
  73      protected $prob = [];
  74  
  75      /**
  76       * A DecisionStump classifier is a one-level deep DecisionTree. It is generally
  77       * used with ensemble algorithms as in the weak classifier role. <br>
  78       *
  79       * If columnIndex is given, then the stump tries to produce a decision node
  80       * on this column, otherwise in cases given the value of -1, the stump itself
  81       * decides which column to take for the decision (Default DecisionTree behaviour)
  82       */
  83      public function __construct(int $columnIndex = self::AUTO_SELECT)
  84      {
  85          $this->givenColumnIndex = $columnIndex;
  86      }
  87  
  88      public function __toString(): string
  89      {
  90          return "IF $this}->column $this}->operator $this}->value ".
  91              'THEN '.$this->binaryLabels[0].' '.
  92              'ELSE '.$this->binaryLabels[1];
  93      }
  94  
  95      /**
  96       * While finding best split point for a numerical valued column,
  97       * DecisionStump looks for equally distanced values between minimum and maximum
  98       * values in the column. Given <i>$count</i> value determines how many split
  99       * points to be probed. The more split counts, the better performance but
 100       * worse processing time (Default value is 10.0)
 101       */
 102      public function setNumericalSplitCount(float $count): void
 103      {
 104          $this->numSplitCount = $count;
 105      }
 106  
 107      /**
 108       * @throws InvalidArgumentException
 109       */
 110      protected function trainBinary(array $samples, array $targets, array $labels): void
 111      {
 112          $this->binaryLabels = $labels;
 113          $this->featureCount = count($samples[0]);
 114  
 115          // If a column index is given, it should be among the existing columns
 116          if ($this->givenColumnIndex > count($samples[0]) - 1) {
 117              $this->givenColumnIndex = self::AUTO_SELECT;
 118          }
 119  
 120          // Check the size of the weights given.
 121          // If none given, then assign 1 as a weight to each sample
 122          if (count($this->weights) === 0) {
 123              $this->weights = array_fill(0, count($samples), 1);
 124          } else {
 125              $numWeights = count($this->weights);
 126              if ($numWeights !== count($samples)) {
 127                  throw new InvalidArgumentException('Number of sample weights does not match with number of samples');
 128              }
 129          }
 130  
 131          // Determine type of each column as either "continuous" or "nominal"
 132          $this->columnTypes = DecisionTree::getColumnTypes($samples);
 133  
 134          // Try to find the best split in the columns of the dataset
 135          // by calculating error rate for each split point in each column
 136          $columns = range(0, count($samples[0]) - 1);
 137          if ($this->givenColumnIndex !== self::AUTO_SELECT) {
 138              $columns = [$this->givenColumnIndex];
 139          }
 140  
 141          $bestSplit = [
 142              'value' => 0,
 143              'operator' => '',
 144              'prob' => [],
 145              'column' => 0,
 146              'trainingErrorRate' => 1.0,
 147          ];
 148          foreach ($columns as $col) {
 149              if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) {
 150                  $split = $this->getBestNumericalSplit($samples, $targets, $col);
 151              } else {
 152                  $split = $this->getBestNominalSplit($samples, $targets, $col);
 153              }
 154  
 155              if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) {
 156                  $bestSplit = $split;
 157              }
 158          }
 159  
 160          // Assign determined best values to the stump
 161          foreach ($bestSplit as $name => $value) {
 162              $this->{$name} = $value;
 163          }
 164      }
 165  
 166      /**
 167       * Determines best split point for the given column
 168       */
 169      protected function getBestNumericalSplit(array $samples, array $targets, int $col): array
 170      {
 171          $values = array_column($samples, $col);
 172          // Trying all possible points may be accomplished in two general ways:
 173          // 1- Try all values in the $samples array ($values)
 174          // 2- Artificially split the range of values into several parts and try them
 175          // We choose the second one because it is faster in larger datasets
 176          $minValue = min($values);
 177          $maxValue = max($values);
 178          $stepSize = ($maxValue - $minValue) / $this->numSplitCount;
 179  
 180          $split = [];
 181  
 182          foreach (['<=', '>'] as $operator) {
 183              // Before trying all possible split points, let's first try
 184              // the average value for the cut point
 185              $threshold = array_sum($values) / (float) count($values);
 186              [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
 187              if (!isset($split['trainingErrorRate']) || $errorRate < $split['trainingErrorRate']) {
 188                  $split = [
 189                      'value' => $threshold,
 190                      'operator' => $operator,
 191                      'prob' => $prob,
 192                      'column' => $col,
 193                      'trainingErrorRate' => $errorRate,
 194                  ];
 195              }
 196  
 197              // Try other possible points one by one
 198              for ($step = $minValue; $step <= $maxValue; $step += $stepSize) {
 199                  $threshold = (float) $step;
 200                  [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values);
 201                  if ($errorRate < $split['trainingErrorRate']) {
 202                      $split = [
 203                          'value' => $threshold,
 204                          'operator' => $operator,
 205                          'prob' => $prob,
 206                          'column' => $col,
 207                          'trainingErrorRate' => $errorRate,
 208                      ];
 209                  }
 210              }// for
 211          }
 212  
 213          return $split;
 214      }
 215  
 216      protected function getBestNominalSplit(array $samples, array $targets, int $col): array
 217      {
 218          $values = array_column($samples, $col);
 219          $valueCounts = array_count_values($values);
 220          $distinctVals = array_keys($valueCounts);
 221  
 222          $split = [];
 223  
 224          foreach (['=', '!='] as $operator) {
 225              foreach ($distinctVals as $val) {
 226                  [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values);
 227                  if (!isset($split['trainingErrorRate']) || $split['trainingErrorRate'] < $errorRate) {
 228                      $split = [
 229                          'value' => $val,
 230                          'operator' => $operator,
 231                          'prob' => $prob,
 232                          'column' => $col,
 233                          'trainingErrorRate' => $errorRate,
 234                      ];
 235                  }
 236              }
 237          }
 238  
 239          return $split;
 240      }
 241  
 242      /**
 243       * Calculates the ratio of wrong predictions based on the new threshold
 244       * value given as the parameter
 245       */
 246      protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array
 247      {
 248          $wrong = 0.0;
 249          $prob = [];
 250          $leftLabel = $this->binaryLabels[0];
 251          $rightLabel = $this->binaryLabels[1];
 252  
 253          foreach ($values as $index => $value) {
 254              if (Comparison::compare($value, $threshold, $operator)) {
 255                  $predicted = $leftLabel;
 256              } else {
 257                  $predicted = $rightLabel;
 258              }
 259  
 260              $target = $targets[$index];
 261              if ((string) $predicted != (string) $targets[$index]) {
 262                  $wrong += $this->weights[$index];
 263              }
 264  
 265              if (!isset($prob[$predicted][$target])) {
 266                  $prob[$predicted][$target] = 0;
 267              }
 268  
 269              ++$prob[$predicted][$target];
 270          }
 271  
 272          // Calculate probabilities: Proportion of labels in each leaf
 273          $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0));
 274          foreach ($prob as $leaf => $counts) {
 275              $leafTotal = (float) array_sum($prob[$leaf]);
 276              foreach ($counts as $label => $count) {
 277                  if ((string) $leaf == (string) $label) {
 278                      $dist[$leaf] = $count / $leafTotal;
 279                  }
 280              }
 281          }
 282  
 283          return [$wrong / (float) array_sum($this->weights), $dist];
 284      }
 285  
 286      /**
 287       * Returns the probability of the sample of belonging to the given label
 288       *
 289       * Probability of a sample is calculated as the proportion of the label
 290       * within the labels of the training samples in the decision node
 291       *
 292       * @param mixed $label
 293       */
 294      protected function predictProbability(array $sample, $label): float
 295      {
 296          $predicted = $this->predictSampleBinary($sample);
 297          if ((string) $predicted == (string) $label) {
 298              return $this->prob[$label];
 299          }
 300  
 301          return 0.0;
 302      }
 303  
 304      /**
 305       * @return mixed
 306       */
 307      protected function predictSampleBinary(array $sample)
 308      {
 309          if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) {
 310              return $this->binaryLabels[0];
 311          }
 312  
 313          return $this->binaryLabels[1];
 314      }
 315  
 316      protected function resetBinary(): void
 317      {
 318      }
 319  }