Differences Between: [Versions 400 and 402] [Versions 400 and 403]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Classification\Linear; 6 7 use Phpml\Classification\DecisionTree; 8 use Phpml\Classification\WeightedClassifier; 9 use Phpml\Exception\InvalidArgumentException; 10 use Phpml\Helper\OneVsRest; 11 use Phpml\Helper\Predictable; 12 use Phpml\Math\Comparison; 13 14 class DecisionStump extends WeightedClassifier 15 { 16 use Predictable; 17 use OneVsRest; 18 19 public const AUTO_SELECT = -1; 20 21 /** 22 * @var int 23 */ 24 protected $givenColumnIndex; 25 26 /** 27 * @var array 28 */ 29 protected $binaryLabels = []; 30 31 /** 32 * Lowest error rate obtained while training/optimizing the model 33 * 34 * @var float 35 */ 36 protected $trainingErrorRate; 37 38 /** 39 * @var int 40 */ 41 protected $column; 42 43 /** 44 * @var mixed 45 */ 46 protected $value; 47 48 /** 49 * @var string 50 */ 51 protected $operator; 52 53 /** 54 * @var array 55 */ 56 protected $columnTypes = []; 57 58 /** 59 * @var int 60 */ 61 protected $featureCount; 62 63 /** 64 * @var float 65 */ 66 protected $numSplitCount = 100.0; 67 68 /** 69 * Distribution of samples in the leaves 70 * 71 * @var array 72 */ 73 protected $prob = []; 74 75 /** 76 * A DecisionStump classifier is a one-level deep DecisionTree. It is generally 77 * used with ensemble algorithms as in the weak classifier role. <br> 78 * 79 * If columnIndex is given, then the stump tries to produce a decision node 80 * on this column, otherwise in cases given the value of -1, the stump itself 81 * decides which column to take for the decision (Default DecisionTree behaviour) 82 */ 83 public function __construct(int $columnIndex = self::AUTO_SELECT) 84 { 85 $this->givenColumnIndex = $columnIndex; 86 } 87 88 public function __toString(): string 89 { 90 return "IF $this}->column $this}->operator $this}->value ". 91 'THEN '.$this->binaryLabels[0].' '. 92 'ELSE '.$this->binaryLabels[1]; 93 } 94 95 /** 96 * While finding best split point for a numerical valued column, 97 * DecisionStump looks for equally distanced values between minimum and maximum 98 * values in the column. Given <i>$count</i> value determines how many split 99 * points to be probed. The more split counts, the better performance but 100 * worse processing time (Default value is 10.0) 101 */ 102 public function setNumericalSplitCount(float $count): void 103 { 104 $this->numSplitCount = $count; 105 } 106 107 /** 108 * @throws InvalidArgumentException 109 */ 110 protected function trainBinary(array $samples, array $targets, array $labels): void 111 { 112 $this->binaryLabels = $labels; 113 $this->featureCount = count($samples[0]); 114 115 // If a column index is given, it should be among the existing columns 116 if ($this->givenColumnIndex > count($samples[0]) - 1) { 117 $this->givenColumnIndex = self::AUTO_SELECT; 118 } 119 120 // Check the size of the weights given. 121 // If none given, then assign 1 as a weight to each sample 122 if (count($this->weights) === 0) { 123 $this->weights = array_fill(0, count($samples), 1); 124 } else { 125 $numWeights = count($this->weights); 126 if ($numWeights !== count($samples)) { 127 throw new InvalidArgumentException('Number of sample weights does not match with number of samples'); 128 } 129 } 130 131 // Determine type of each column as either "continuous" or "nominal" 132 $this->columnTypes = DecisionTree::getColumnTypes($samples); 133 134 // Try to find the best split in the columns of the dataset 135 // by calculating error rate for each split point in each column 136 $columns = range(0, count($samples[0]) - 1); 137 if ($this->givenColumnIndex !== self::AUTO_SELECT) { 138 $columns = [$this->givenColumnIndex]; 139 } 140 141 $bestSplit = [ 142 'value' => 0, 143 'operator' => '', 144 'prob' => [], 145 'column' => 0, 146 'trainingErrorRate' => 1.0, 147 ]; 148 foreach ($columns as $col) { 149 if ($this->columnTypes[$col] == DecisionTree::CONTINUOUS) { 150 $split = $this->getBestNumericalSplit($samples, $targets, $col); 151 } else { 152 $split = $this->getBestNominalSplit($samples, $targets, $col); 153 } 154 155 if ($split['trainingErrorRate'] < $bestSplit['trainingErrorRate']) { 156 $bestSplit = $split; 157 } 158 } 159 160 // Assign determined best values to the stump 161 foreach ($bestSplit as $name => $value) { 162 $this->{$name} = $value; 163 } 164 } 165 166 /** 167 * Determines best split point for the given column 168 */ 169 protected function getBestNumericalSplit(array $samples, array $targets, int $col): array 170 { 171 $values = array_column($samples, $col); 172 // Trying all possible points may be accomplished in two general ways: 173 // 1- Try all values in the $samples array ($values) 174 // 2- Artificially split the range of values into several parts and try them 175 // We choose the second one because it is faster in larger datasets 176 $minValue = min($values); 177 $maxValue = max($values); 178 $stepSize = ($maxValue - $minValue) / $this->numSplitCount; 179 180 $split = []; 181 182 foreach (['<=', '>'] as $operator) { 183 // Before trying all possible split points, let's first try 184 // the average value for the cut point 185 $threshold = array_sum($values) / (float) count($values); 186 [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values); 187 if (!isset($split['trainingErrorRate']) || $errorRate < $split['trainingErrorRate']) { 188 $split = [ 189 'value' => $threshold, 190 'operator' => $operator, 191 'prob' => $prob, 192 'column' => $col, 193 'trainingErrorRate' => $errorRate, 194 ]; 195 } 196 197 // Try other possible points one by one 198 for ($step = $minValue; $step <= $maxValue; $step += $stepSize) { 199 $threshold = (float) $step; 200 [$errorRate, $prob] = $this->calculateErrorRate($targets, $threshold, $operator, $values); 201 if ($errorRate < $split['trainingErrorRate']) { 202 $split = [ 203 'value' => $threshold, 204 'operator' => $operator, 205 'prob' => $prob, 206 'column' => $col, 207 'trainingErrorRate' => $errorRate, 208 ]; 209 } 210 }// for 211 } 212 213 return $split; 214 } 215 216 protected function getBestNominalSplit(array $samples, array $targets, int $col): array 217 { 218 $values = array_column($samples, $col); 219 $valueCounts = array_count_values($values); 220 $distinctVals = array_keys($valueCounts); 221 222 $split = []; 223 224 foreach (['=', '!='] as $operator) { 225 foreach ($distinctVals as $val) { 226 [$errorRate, $prob] = $this->calculateErrorRate($targets, $val, $operator, $values); 227 if (!isset($split['trainingErrorRate']) || $split['trainingErrorRate'] < $errorRate) { 228 $split = [ 229 'value' => $val, 230 'operator' => $operator, 231 'prob' => $prob, 232 'column' => $col, 233 'trainingErrorRate' => $errorRate, 234 ]; 235 } 236 } 237 } 238 239 return $split; 240 } 241 242 /** 243 * Calculates the ratio of wrong predictions based on the new threshold 244 * value given as the parameter 245 */ 246 protected function calculateErrorRate(array $targets, float $threshold, string $operator, array $values): array 247 { 248 $wrong = 0.0; 249 $prob = []; 250 $leftLabel = $this->binaryLabels[0]; 251 $rightLabel = $this->binaryLabels[1]; 252 253 foreach ($values as $index => $value) { 254 if (Comparison::compare($value, $threshold, $operator)) { 255 $predicted = $leftLabel; 256 } else { 257 $predicted = $rightLabel; 258 } 259 260 $target = $targets[$index]; 261 if ((string) $predicted != (string) $targets[$index]) { 262 $wrong += $this->weights[$index]; 263 } 264 265 if (!isset($prob[$predicted][$target])) { 266 $prob[$predicted][$target] = 0; 267 } 268 269 ++$prob[$predicted][$target]; 270 } 271 272 // Calculate probabilities: Proportion of labels in each leaf 273 $dist = array_combine($this->binaryLabels, array_fill(0, 2, 0.0)); 274 foreach ($prob as $leaf => $counts) { 275 $leafTotal = (float) array_sum($prob[$leaf]); 276 foreach ($counts as $label => $count) { 277 if ((string) $leaf == (string) $label) { 278 $dist[$leaf] = $count / $leafTotal; 279 } 280 } 281 } 282 283 return [$wrong / (float) array_sum($this->weights), $dist]; 284 } 285 286 /** 287 * Returns the probability of the sample of belonging to the given label 288 * 289 * Probability of a sample is calculated as the proportion of the label 290 * within the labels of the training samples in the decision node 291 * 292 * @param mixed $label 293 */ 294 protected function predictProbability(array $sample, $label): float 295 { 296 $predicted = $this->predictSampleBinary($sample); 297 if ((string) $predicted == (string) $label) { 298 return $this->prob[$label]; 299 } 300 301 return 0.0; 302 } 303 304 /** 305 * @return mixed 306 */ 307 protected function predictSampleBinary(array $sample) 308 { 309 if (Comparison::compare($sample[$this->column], $this->value, $this->operator)) { 310 return $this->binaryLabels[0]; 311 } 312 313 return $this->binaryLabels[1]; 314 } 315 316 protected function resetBinary(): void 317 { 318 } 319 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body