Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.3.x will end 7 October 2024 (12 months).
  • Bug fixes for security issues in 4.3.x will end 21 April 2025 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.2.x is supported too.

Differences Between: [Versions 310 and 403] [Versions 311 and 403] [Versions 39 and 403]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Metric;
   6  
   7  use Phpml\Exception\InvalidArgumentException;
   8  
   9  class ClassificationReport
  10  {
  11      public const MICRO_AVERAGE = 1;
  12  
  13      public const MACRO_AVERAGE = 2;
  14  
  15      public const WEIGHTED_AVERAGE = 3;
  16  
  17      /**
  18       * @var array
  19       */
  20      private $truePositive = [];
  21  
  22      /**
  23       * @var array
  24       */
  25      private $falsePositive = [];
  26  
  27      /**
  28       * @var array
  29       */
  30      private $falseNegative = [];
  31  
  32      /**
  33       * @var array
  34       */
  35      private $support = [];
  36  
  37      /**
  38       * @var array
  39       */
  40      private $precision = [];
  41  
  42      /**
  43       * @var array
  44       */
  45      private $recall = [];
  46  
  47      /**
  48       * @var array
  49       */
  50      private $f1score = [];
  51  
  52      /**
  53       * @var array
  54       */
  55      private $average = [];
  56  
  57      public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE)
  58      {
  59          $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE);
  60          if (!in_array($average, $averagingMethods, true)) {
  61              throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE');
  62          }
  63  
  64          $this->aggregateClassificationResults($actualLabels, $predictedLabels);
  65          $this->computeMetrics();
  66          $this->computeAverage($average);
  67      }
  68  
  69      public function getPrecision(): array
  70      {
  71          return $this->precision;
  72      }
  73  
  74      public function getRecall(): array
  75      {
  76          return $this->recall;
  77      }
  78  
  79      public function getF1score(): array
  80      {
  81          return $this->f1score;
  82      }
  83  
  84      public function getSupport(): array
  85      {
  86          return $this->support;
  87      }
  88  
  89      public function getAverage(): array
  90      {
  91          return $this->average;
  92      }
  93  
  94      private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void
  95      {
  96          $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels);
  97  
  98          foreach ($actualLabels as $index => $actual) {
  99              $predicted = $predictedLabels[$index];
 100              ++$support[$actual];
 101  
 102              if ($actual === $predicted) {
 103                  ++$truePositive[$actual];
 104              } else {
 105                  ++$falsePositive[$predicted];
 106                  ++$falseNegative[$actual];
 107              }
 108          }
 109  
 110          $this->truePositive = $truePositive;
 111          $this->falsePositive = $falsePositive;
 112          $this->falseNegative = $falseNegative;
 113          $this->support = $support;
 114      }
 115  
 116      private function computeMetrics(): void
 117      {
 118          foreach ($this->truePositive as $label => $tp) {
 119              $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]);
 120              $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]);
 121              $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]);
 122          }
 123      }
 124  
 125      private function computeAverage(int $average): void
 126      {
 127          switch ($average) {
 128              case self::MICRO_AVERAGE:
 129                  $this->computeMicroAverage();
 130  
 131                  return;
 132              case self::MACRO_AVERAGE:
 133                  $this->computeMacroAverage();
 134  
 135                  return;
 136              case self::WEIGHTED_AVERAGE:
 137                  $this->computeWeightedAverage();
 138  
 139                  return;
 140          }
 141      }
 142  
 143      private function computeMicroAverage(): void
 144      {
 145          $truePositive = (int) array_sum($this->truePositive);
 146          $falsePositive = (int) array_sum($this->falsePositive);
 147          $falseNegative = (int) array_sum($this->falseNegative);
 148  
 149          $precision = $this->computePrecision($truePositive, $falsePositive);
 150          $recall = $this->computeRecall($truePositive, $falseNegative);
 151          $f1score = $this->computeF1Score($precision, $recall);
 152  
 153          $this->average = compact('precision', 'recall', 'f1score');
 154      }
 155  
 156      private function computeMacroAverage(): void
 157      {
 158          foreach (['precision', 'recall', 'f1score'] as $metric) {
 159              $values = $this->{$metric};
 160              if (count($values) == 0) {
 161                  $this->average[$metric] = 0.0;
 162  
 163                  continue;
 164              }
 165  
 166              $this->average[$metric] = array_sum($values) / count($values);
 167          }
 168      }
 169  
 170      private function computeWeightedAverage(): void
 171      {
 172          foreach (['precision', 'recall', 'f1score'] as $metric) {
 173              $values = $this->{$metric};
 174              if (count($values) == 0) {
 175                  $this->average[$metric] = 0.0;
 176  
 177                  continue;
 178              }
 179  
 180              $sum = 0;
 181              foreach ($values as $i => $value) {
 182                  $sum += $value * $this->support[$i];
 183              }
 184  
 185              $this->average[$metric] = $sum / array_sum($this->support);
 186          }
 187      }
 188  
 189      private function computePrecision(int $truePositive, int $falsePositive): float
 190      {
 191          $divider = $truePositive + $falsePositive;
 192          if ($divider == 0) {
 193              return 0.0;
 194          }
 195  
 196          return $truePositive / $divider;
 197      }
 198  
 199      private function computeRecall(int $truePositive, int $falseNegative): float
 200      {
 201          $divider = $truePositive + $falseNegative;
 202          if ($divider == 0) {
 203              return 0.0;
 204          }
 205  
 206          return $truePositive / $divider;
 207      }
 208  
 209      private function computeF1Score(float $precision, float $recall): float
 210      {
 211          $divider = $precision + $recall;
 212          if ($divider == 0) {
 213              return 0.0;
 214          }
 215  
 216          return 2.0 * (($precision * $recall) / $divider);
 217      }
 218  
 219      private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array
 220      {
 221          $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels)));
 222          sort($labels);
 223  
 224          return (array) array_combine($labels, array_fill(0, count($labels), 0));
 225      }
 226  }