Differences Between: [Versions 311 and 402] [Versions 311 and 403]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Helper; 6 7 use Phpml\Classification\Classifier; 8 9 trait OneVsRest 10 { 11 /** 12 * @var array 13 */ 14 protected $classifiers = []; 15 16 /** 17 * All provided training targets' labels. 18 * 19 * @var array 20 */ 21 protected $allLabels = []; 22 23 /** 24 * @var array 25 */ 26 protected $costValues = []; 27 28 /** 29 * Train a binary classifier in the OvR style 30 */ 31 public function train(array $samples, array $targets): void 32 { 33 // Clears previous stuff. 34 $this->reset(); 35 36 $this->trainByLabel($samples, $targets); 37 } 38 39 /** 40 * Resets the classifier and the vars internally used by OneVsRest to create multiple classifiers. 41 */ 42 public function reset(): void 43 { 44 $this->classifiers = []; 45 $this->allLabels = []; 46 $this->costValues = []; 47 48 $this->resetBinary(); 49 } 50 51 protected function trainByLabel(array $samples, array $targets, array $allLabels = []): void 52 { 53 // Overwrites the current value if it exist. $allLabels must be provided for each partialTrain run. 54 $this->allLabels = count($allLabels) === 0 ? array_keys(array_count_values($targets)) : $allLabels; 55 sort($this->allLabels, SORT_STRING); 56 57 // If there are only two targets, then there is no need to perform OvR 58 if (count($this->allLabels) === 2) { 59 // Init classifier if required. 60 if (count($this->classifiers) === 0) { 61 $this->classifiers[0] = $this->getClassifierCopy(); 62 } 63 64 $this->classifiers[0]->trainBinary($samples, $targets, $this->allLabels); 65 } else { 66 // Train a separate classifier for each label and memorize them 67 68 foreach ($this->allLabels as $label) { 69 // Init classifier if required. 70 if (!isset($this->classifiers[$label])) { 71 $this->classifiers[$label] = $this->getClassifierCopy(); 72 } 73 74 [$binarizedTargets, $classifierLabels] = $this->binarizeTargets($targets, $label); 75 $this->classifiers[$label]->trainBinary($samples, $binarizedTargets, $classifierLabels); 76 } 77 } 78 79 // If the underlying classifier is capable of giving the cost values 80 // during the training, then assign it to the relevant variable 81 // Adding just the first classifier cost values to avoid complex average calculations. 82 $classifierref = reset($this->classifiers); 83 if (method_exists($classifierref, 'getCostValues')) { 84 $this->costValues = $classifierref->getCostValues(); 85 } 86 } 87 88 /** 89 * Returns an instance of the current class after cleaning up OneVsRest stuff. 90 */ 91 protected function getClassifierCopy(): Classifier 92 { 93 // Clone the current classifier, so that 94 // we don't mess up its variables while training 95 // multiple instances of this classifier 96 $classifier = clone $this; 97 $classifier->reset(); 98 99 return $classifier; 100 } 101 102 /** 103 * @return mixed 104 */ 105 protected function predictSample(array $sample) 106 { 107 if (count($this->allLabels) === 2) { 108 return $this->classifiers[0]->predictSampleBinary($sample); 109 } 110 111 $probs = []; 112 113 foreach ($this->classifiers as $label => $predictor) { 114 $probs[$label] = $predictor->predictProbability($sample, $label); 115 } 116 117 arsort($probs, SORT_NUMERIC); 118 119 return key($probs); 120 } 121 122 /** 123 * Each classifier should implement this method instead of train(samples, targets) 124 */ 125 abstract protected function trainBinary(array $samples, array $targets, array $labels); 126 127 /** 128 * To be overwritten by OneVsRest classifiers. 129 */ 130 abstract protected function resetBinary(): void; 131 132 /** 133 * Each classifier that make use of OvR approach should be able to 134 * return a probability for a sample to belong to the given label. 135 * 136 * @return mixed 137 */ 138 abstract protected function predictProbability(array $sample, string $label); 139 140 /** 141 * Each classifier should implement this method instead of predictSample() 142 * 143 * @return mixed 144 */ 145 abstract protected function predictSampleBinary(array $sample); 146 147 /** 148 * Groups all targets into two groups: Targets equal to 149 * the given label and the others 150 * 151 * $targets is not passed by reference nor contains objects so this method 152 * changes will not affect the caller $targets array. 153 * 154 * @param mixed $label 155 * 156 * @return array Binarized targets and target's labels 157 */ 158 private function binarizeTargets(array $targets, $label): array 159 { 160 $notLabel = "not_$label}"; 161 foreach ($targets as $key => $target) { 162 $targets[$key] = $target == $label ? $label : $notLabel; 163 } 164 165 $labels = [$label, $notLabel]; 166 167 return [$targets, $labels]; 168 } 169 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body