Search moodle.org's
Developer Documentation

See Release Notes
Long Term Support Release

  • Bug fixes for general core bugs in 4.1.x will end 13 November 2023 (12 months).
  • Bug fixes for security issues in 4.1.x will end 10 November 2025 (36 months).
  • PHP version: minimum PHP 7.4.0 Note: minimum PHP version has increased since Moodle 4.0. PHP 8.0.x is supported too.
   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification\Ensemble;
   6  
   7  use Phpml\Classification\Classifier;
   8  use Phpml\Classification\Linear\DecisionStump;
   9  use Phpml\Classification\WeightedClassifier;
  10  use Phpml\Exception\InvalidArgumentException;
  11  use Phpml\Helper\Predictable;
  12  use Phpml\Helper\Trainable;
  13  use Phpml\Math\Statistic\Mean;
  14  use Phpml\Math\Statistic\StandardDeviation;
  15  use ReflectionClass;
  16  
  17  class AdaBoost implements Classifier
  18  {
  19      use Predictable;
  20      use Trainable;
  21  
  22      /**
  23       * Actual labels given in the targets array
  24       *
  25       * @var array
  26       */
  27      protected $labels = [];
  28  
  29      /**
  30       * @var int
  31       */
  32      protected $sampleCount;
  33  
  34      /**
  35       * @var int
  36       */
  37      protected $featureCount;
  38  
  39      /**
  40       * Number of maximum iterations to be done
  41       *
  42       * @var int
  43       */
  44      protected $maxIterations;
  45  
  46      /**
  47       * Sample weights
  48       *
  49       * @var array
  50       */
  51      protected $weights = [];
  52  
  53      /**
  54       * List of selected 'weak' classifiers
  55       *
  56       * @var array
  57       */
  58      protected $classifiers = [];
  59  
  60      /**
  61       * Base classifier weights
  62       *
  63       * @var array
  64       */
  65      protected $alpha = [];
  66  
  67      /**
  68       * @var string
  69       */
  70      protected $baseClassifier = DecisionStump::class;
  71  
  72      /**
  73       * @var array
  74       */
  75      protected $classifierOptions = [];
  76  
  77      /**
  78       * ADAptive BOOSTing (AdaBoost) is an ensemble algorithm to
  79       * improve classification performance of 'weak' classifiers such as
  80       * DecisionStump (default base classifier of AdaBoost).
  81       */
  82      public function __construct(int $maxIterations = 50)
  83      {
  84          $this->maxIterations = $maxIterations;
  85      }
  86  
  87      /**
  88       * Sets the base classifier that will be used for boosting (default = DecisionStump)
  89       */
  90      public function setBaseClassifier(string $baseClassifier = DecisionStump::class, array $classifierOptions = []): void
  91      {
  92          $this->baseClassifier = $baseClassifier;
  93          $this->classifierOptions = $classifierOptions;
  94      }
  95  
  96      /**
  97       * @throws InvalidArgumentException
  98       */
  99      public function train(array $samples, array $targets): void
 100      {
 101          // Initialize usual variables
 102          $this->labels = array_keys(array_count_values($targets));
 103          if (count($this->labels) !== 2) {
 104              throw new InvalidArgumentException('AdaBoost is a binary classifier and can classify between two classes only');
 105          }
 106  
 107          // Set all target values to either -1 or 1
 108          $this->labels = [
 109              1 => $this->labels[0],
 110              -1 => $this->labels[1],
 111          ];
 112          foreach ($targets as $target) {
 113              $this->targets[] = $target == $this->labels[1] ? 1 : -1;
 114          }
 115  
 116          $this->samples = array_merge($this->samples, $samples);
 117          $this->featureCount = count($samples[0]);
 118          $this->sampleCount = count($this->samples);
 119  
 120          // Initialize AdaBoost parameters
 121          $this->weights = array_fill(0, $this->sampleCount, 1.0 / $this->sampleCount);
 122          $this->classifiers = [];
 123          $this->alpha = [];
 124  
 125          // Execute the algorithm for a maximum number of iterations
 126          $currIter = 0;
 127          while ($this->maxIterations > $currIter++) {
 128              // Determine the best 'weak' classifier based on current weights
 129              $classifier = $this->getBestClassifier();
 130              $errorRate = $this->evaluateClassifier($classifier);
 131  
 132              // Update alpha & weight values at each iteration
 133              $alpha = $this->calculateAlpha($errorRate);
 134              $this->updateWeights($classifier, $alpha);
 135  
 136              $this->classifiers[] = $classifier;
 137              $this->alpha[] = $alpha;
 138          }
 139      }
 140  
 141      /**
 142       * @return mixed
 143       */
 144      public function predictSample(array $sample)
 145      {
 146          $sum = 0;
 147          foreach ($this->alpha as $index => $alpha) {
 148              $h = $this->classifiers[$index]->predict($sample);
 149              $sum += $h * $alpha;
 150          }
 151  
 152          return $this->labels[$sum > 0 ? 1 : -1];
 153      }
 154  
 155      /**
 156       * Returns the classifier with the lowest error rate with the
 157       * consideration of current sample weights
 158       */
 159      protected function getBestClassifier(): Classifier
 160      {
 161          $ref = new ReflectionClass($this->baseClassifier);
 162          /** @var Classifier $classifier */
 163          $classifier = count($this->classifierOptions) === 0 ? $ref->newInstance() : $ref->newInstanceArgs($this->classifierOptions);
 164  
 165          if ($classifier instanceof WeightedClassifier) {
 166              $classifier->setSampleWeights($this->weights);
 167              $classifier->train($this->samples, $this->targets);
 168          } else {
 169              [$samples, $targets] = $this->resample();
 170              $classifier->train($samples, $targets);
 171          }
 172  
 173          return $classifier;
 174      }
 175  
 176      /**
 177       * Resamples the dataset in accordance with the weights and
 178       * returns the new dataset
 179       */
 180      protected function resample(): array
 181      {
 182          $weights = $this->weights;
 183          $std = StandardDeviation::population($weights);
 184          $mean = Mean::arithmetic($weights);
 185          $min = min($weights);
 186          $minZ = (int) round(($min - $mean) / $std);
 187  
 188          $samples = [];
 189          $targets = [];
 190          foreach ($weights as $index => $weight) {
 191              $z = (int) round(($weight - $mean) / $std) - $minZ + 1;
 192              for ($i = 0; $i < $z; ++$i) {
 193                  if (random_int(0, 1) == 0) {
 194                      continue;
 195                  }
 196  
 197                  $samples[] = $this->samples[$index];
 198                  $targets[] = $this->targets[$index];
 199              }
 200          }
 201  
 202          return [$samples, $targets];
 203      }
 204  
 205      /**
 206       * Evaluates the classifier and returns the classification error rate
 207       */
 208      protected function evaluateClassifier(Classifier $classifier): float
 209      {
 210          $total = (float) array_sum($this->weights);
 211          $wrong = 0;
 212          foreach ($this->samples as $index => $sample) {
 213              $predicted = $classifier->predict($sample);
 214              if ($predicted != $this->targets[$index]) {
 215                  $wrong += $this->weights[$index];
 216              }
 217          }
 218  
 219          return $wrong / $total;
 220      }
 221  
 222      /**
 223       * Calculates alpha of a classifier
 224       */
 225      protected function calculateAlpha(float $errorRate): float
 226      {
 227          if ($errorRate == 0) {
 228              $errorRate = 1e-10;
 229          }
 230  
 231          return 0.5 * log((1 - $errorRate) / $errorRate);
 232      }
 233  
 234      /**
 235       * Updates the sample weights
 236       */
 237      protected function updateWeights(Classifier $classifier, float $alpha): void
 238      {
 239          $sumOfWeights = array_sum($this->weights);
 240          $weightsT1 = [];
 241          foreach ($this->weights as $index => $weight) {
 242              $desired = $this->targets[$index];
 243              $output = $classifier->predict($this->samples[$index]);
 244  
 245              $weight *= exp(-$alpha * $desired * $output) / $sumOfWeights;
 246  
 247              $weightsT1[] = $weight;
 248          }
 249  
 250          $this->weights = $weightsT1;
 251      }
 252  }