Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 3.11.x will end 14 Nov 2022 (12 months plus 6 months extension).
  • Bug fixes for security issues in 3.11.x will end 13 Nov 2023 (18 months plus 12 months extension).
  • PHP version: minimum PHP 7.3.0 Note: minimum PHP version has increased since Moodle 3.10. PHP 7.4.x is supported too.

Differences Between: [Versions 311 and 402] [Versions 311 and 403]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Helper;
   6  
   7  use Phpml\Classification\Classifier;
   8  
   9  trait OneVsRest
  10  {
  11      /**
  12       * @var array
  13       */
  14      protected $classifiers = [];
  15  
  16      /**
  17       * All provided training targets' labels.
  18       *
  19       * @var array
  20       */
  21      protected $allLabels = [];
  22  
  23      /**
  24       * @var array
  25       */
  26      protected $costValues = [];
  27  
  28      /**
  29       * Train a binary classifier in the OvR style
  30       */
  31      public function train(array $samples, array $targets): void
  32      {
  33          // Clears previous stuff.
  34          $this->reset();
  35  
  36          $this->trainByLabel($samples, $targets);
  37      }
  38  
  39      /**
  40       * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers.
  41       */
  42      public function reset(): void
  43      {
  44          $this->classifiers = [];
  45          $this->allLabels = [];
  46          $this->costValues = [];
  47  
  48          $this->resetBinary();
  49      }
  50  
  51      protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void
  52      {
  53          // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run.
  54          $this->allLabels = count($allLabels) === 0 ? array_keys(array_count_values($targets)) : $allLabels;
  55          sort($this->allLabels, SORT_STRING);
  56  
  57          // If there are only two targets, then there is no need to perform OvR
  58          if (count($this->allLabels) === 2) {
  59              // Init classifier if required.
  60              if (count($this->classifiers) === 0) {
  61                  $this->classifiers[0] = $this->getClassifierCopy();
  62              }
  63  
  64              $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels);
  65          } else {
  66              // Train a separate classifier for each label and memorize them
  67  
  68              foreach ($this->allLabels as $label) {
  69                  // Init classifier if required.
  70                  if (!isset($this->classifiers[$label])) {
  71                      $this->classifiers[$label] = $this->getClassifierCopy();
  72                  }
  73  
  74                  [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label);
  75                  $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels);
  76              }
  77          }
  78  
  79          // If the underlying classifier is capable of giving the cost values
  80          // during the training, then assign it to the relevant variable
  81          // Adding just the first classifier cost values to avoid complex average calculations.
  82          $classifierref = reset($this->classifiers);
  83          if (method_exists($classifierref, 'getCostValues')) {
  84              $this->costValues = $classifierref->getCostValues();
  85          }
  86      }
  87  
  88      /**
  89       * Returns an instance of the current class after cleaning up OneVsRest stuff.
  90       */
  91      protected function getClassifierCopy(): Classifier
  92      {
  93          // Clone the current classifier, so that
  94          // we don't mess up its variables while training
  95          // multiple instances of this classifier
  96          $classifier = clone $this;
  97          $classifier->reset();
  98  
  99          return $classifier;
 100      }
 101  
 102      /**
 103       * @return mixed
 104       */
 105      protected function predictSample(array $sample)
 106      {
 107          if (count($this->allLabels) === 2) {
 108              return $this->classifiers[0]->predictSampleBinary($sample);
 109          }
 110  
 111          $probs = [];
 112  
 113          foreach ($this->classifiers as $label => $predictor) {
 114              $probs[$label] = $predictor->predictProbability($sample, $label);
 115          }
 116  
 117          arsort($probs, SORT_NUMERIC);
 118  
 119          return key($probs);
 120      }
 121  
 122      /**
 123       * Each classifier should implement this method instead of train(samples, targets)
 124       */
 125      abstract protected function trainBinary(array $samples, array $targets, array $labels);
 126  
 127      /**
 128       * To be overwritten by OneVsRest classifiers.
 129       */
 130      abstract protected function resetBinary(): void;
 131  
 132      /**
 133       * Each classifier that make use of OvR approach should be able to
 134       * return a probability for a sample to belong to the given label.
 135       *
 136       * @return mixed
 137       */
 138      abstract protected function predictProbability(array $sample, string $label);
 139  
 140      /**
 141       * Each classifier should implement this method instead of predictSample()
 142       *
 143       * @return mixed
 144       */
 145      abstract protected function predictSampleBinary(array $sample);
 146  
 147      /**
 148       * Groups all targets into two groups: Targets equal to
 149       * the given label and the others
 150       *
 151       * $targets is not passed by reference nor contains objects so this method
 152       * changes will not affect the caller $targets array.
 153       *
 154       * @param mixed $label
 155       *
 156       * @return array Binarized targets and target's labels
 157       */
 158      private function binarizeTargets(array $targets, $label): array
 159      {
 160          $notLabel = "not_$label}";
 161          foreach ($targets as $key => $target) {
 162              $targets[$key] = $target == $label ? $label : $notLabel;
 163          }
 164  
 165          $labels = [$label, $notLabel];
 166  
 167          return [$targets, $labels];
 168      }
 169  }