Differences Between: [Versions 311 and 400] [Versions 311 and 401] [Versions 311 and 402] [Versions 311 and 403]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Metric; 6 7 use Phpml\Exception\InvalidArgumentException; 8 9 class ClassificationReport 10 { 11 public const MICRO_AVERAGE = 1; 12 13 public const MACRO_AVERAGE = 2; 14 15 public const WEIGHTED_AVERAGE = 3; 16 17 /** 18 * @var array 19 */ 20 private $truePositive = []; 21 22 /** 23 * @var array 24 */ 25 private $falsePositive = []; 26 27 /** 28 * @var array 29 */ 30 private $falseNegative = []; 31 32 /** 33 * @var array 34 */ 35 private $support = []; 36 37 /** 38 * @var array 39 */ 40 private $precision = []; 41 42 /** 43 * @var array 44 */ 45 private $recall = []; 46 47 /** 48 * @var array 49 */ 50 private $f1score = []; 51 52 /** 53 * @var array 54 */ 55 private $average = []; 56 57 public function __construct(array $actualLabels, array $predictedLabels, int $average = self::MACRO_AVERAGE) 58 { 59 $averagingMethods = range(self::MICRO_AVERAGE, self::WEIGHTED_AVERAGE); 60 if (!in_array($average, $averagingMethods, true)) { 61 throw new InvalidArgumentException('Averaging method must be MICRO_AVERAGE, MACRO_AVERAGE or WEIGHTED_AVERAGE'); 62 } 63 64 $this->aggregateClassificationResults($actualLabels, $predictedLabels); 65 $this->computeMetrics(); 66 $this->computeAverage($average); 67 } 68 69 public function getPrecision(): array 70 { 71 return $this->precision; 72 } 73 74 public function getRecall(): array 75 { 76 return $this->recall; 77 } 78 79 public function getF1score(): array 80 { 81 return $this->f1score; 82 } 83 84 public function getSupport(): array 85 { 86 return $this->support; 87 } 88 89 public function getAverage(): array 90 { 91 return $this->average; 92 } 93 94 private function aggregateClassificationResults(array $actualLabels, array $predictedLabels): void 95 { 96 $truePositive = $falsePositive = $falseNegative = $support = self::getLabelIndexedArray($actualLabels, $predictedLabels); 97 98 foreach ($actualLabels as $index => $actual) { 99 $predicted = $predictedLabels[$index]; 100 ++$support[$actual]; 101 102 if ($actual === $predicted) { 103 ++$truePositive[$actual]; 104 } else { 105 ++$falsePositive[$predicted]; 106 ++$falseNegative[$actual]; 107 } 108 } 109 110 $this->truePositive = $truePositive; 111 $this->falsePositive = $falsePositive; 112 $this->falseNegative = $falseNegative; 113 $this->support = $support; 114 } 115 116 private function computeMetrics(): void 117 { 118 foreach ($this->truePositive as $label => $tp) { 119 $this->precision[$label] = $this->computePrecision($tp, $this->falsePositive[$label]); 120 $this->recall[$label] = $this->computeRecall($tp, $this->falseNegative[$label]); 121 $this->f1score[$label] = $this->computeF1Score((float) $this->precision[$label], (float) $this->recall[$label]); 122 } 123 } 124 125 private function computeAverage(int $average): void 126 { 127 switch ($average) { 128 case self::MICRO_AVERAGE: 129 $this->computeMicroAverage(); 130 131 return; 132 case self::MACRO_AVERAGE: 133 $this->computeMacroAverage(); 134 135 return; 136 case self::WEIGHTED_AVERAGE: 137 $this->computeWeightedAverage(); 138 139 return; 140 } 141 } 142 143 private function computeMicroAverage(): void 144 { 145 $truePositive = (int) array_sum($this->truePositive); 146 $falsePositive = (int) array_sum($this->falsePositive); 147 $falseNegative = (int) array_sum($this->falseNegative); 148 149 $precision = $this->computePrecision($truePositive, $falsePositive); 150 $recall = $this->computeRecall($truePositive, $falseNegative); 151 $f1score = $this->computeF1Score((float) $precision, (float) $recall); 152 153 $this->average = compact('precision', 'recall', 'f1score'); 154 } 155 156 private function computeMacroAverage(): void 157 { 158 foreach (['precision', 'recall', 'f1score'] as $metric) { 159 $values = $this->{$metric}; 160 if (count($values) == 0) { 161 $this->average[$metric] = 0.0; 162 163 continue; 164 } 165 166 $this->average[$metric] = array_sum($values) / count($values); 167 } 168 } 169 170 private function computeWeightedAverage(): void 171 { 172 foreach (['precision', 'recall', 'f1score'] as $metric) { 173 $values = $this->{$metric}; 174 if (count($values) == 0) { 175 $this->average[$metric] = 0.0; 176 177 continue; 178 } 179 180 $sum = 0; 181 foreach ($values as $i => $value) { 182 $sum += $value * $this->support[$i]; 183 } 184 185 $this->average[$metric] = $sum / array_sum($this->support); 186 } 187 } 188 189 /** 190 * @return float|string 191 */ 192 private function computePrecision(int $truePositive, int $falsePositive) 193 { 194 $divider = $truePositive + $falsePositive; 195 if ($divider == 0) { 196 return 0.0; 197 } 198 199 return $truePositive / $divider; 200 } 201 202 /** 203 * @return float|string 204 */ 205 private function computeRecall(int $truePositive, int $falseNegative) 206 { 207 $divider = $truePositive + $falseNegative; 208 if ($divider == 0) { 209 return 0.0; 210 } 211 212 return $truePositive / $divider; 213 } 214 215 private function computeF1Score(float $precision, float $recall): float 216 { 217 $divider = $precision + $recall; 218 if ($divider == 0) { 219 return 0.0; 220 } 221 222 return 2.0 * (($precision * $recall) / $divider); 223 } 224 225 private static function getLabelIndexedArray(array $actualLabels, array $predictedLabels): array 226 { 227 $labels = array_values(array_unique(array_merge($actualLabels, $predictedLabels))); 228 sort($labels); 229 230 return (array) array_combine($labels, array_fill(0, count($labels), 0)); 231 } 232 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body