Differences Between: [Versions 310 and 402] [Versions 311 and 402] [Versions 39 and 402] [Versions 400 and 402] [Versions 401 and 402]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Classification; 6 7 use Phpml\Classification\DecisionTree\DecisionTreeLeaf; 8 use Phpml\Exception\InvalidArgumentException; 9 use Phpml\Helper\Predictable; 10 use Phpml\Helper\Trainable; 11 use Phpml\Math\Statistic\Mean; 12 13 class DecisionTree implements Classifier 14 { 15 use Trainable; 16 use Predictable; 17 18 public const CONTINUOUS = 1; 19 20 public const NOMINAL = 2; 21 22 /** 23 * @var int 24 */ 25 public $actualDepth = 0; 26 27 /** 28 * @var array 29 */ 30 protected $columnTypes = []; 31 32 /** 33 * @var DecisionTreeLeaf 34 */ 35 protected $tree; 36 37 /** 38 * @var int 39 */ 40 protected $maxDepth; 41 42 /** 43 * @var array 44 */ 45 private $labels = []; 46 47 /** 48 * @var int 49 */ 50 private $featureCount = 0; 51 52 /** 53 * @var int 54 */ 55 private $numUsableFeatures = 0; 56 57 /** 58 * @var array 59 */ 60 private $selectedFeatures = []; 61 62 /** 63 * @var array|null 64 */ 65 private $featureImportances; 66 67 /** 68 * @var array 69 */ 70 private $columnNames = []; 71 72 public function __construct(int $maxDepth = 10) 73 { 74 $this->maxDepth = $maxDepth; 75 } 76 77 public function train(array $samples, array $targets): void 78 { 79 $this->samples = array_merge($this->samples, $samples); 80 $this->targets = array_merge($this->targets, $targets); 81 82 $this->featureCount = count($this->samples[0]); 83 $this->columnTypes = self::getColumnTypes($this->samples); 84 $this->labels = array_keys(array_count_values($this->targets)); 85 $this->tree = $this->getSplitLeaf(range(0, count($this->samples) - 1)); 86 87 // Each time the tree is trained, feature importances are reset so that 88 // we will have to compute it again depending on the new data 89 $this->featureImportances = null; 90 91 // If column names are given or computed before, then there is no 92 // need to init it and accidentally remove the previous given names 93 if ($this->columnNames === []) { 94 $this->columnNames = range(0, $this->featureCount - 1); 95 } elseif (count($this->columnNames) > $this->featureCount) { 96 $this->columnNames = array_slice($this->columnNames, 0, $this->featureCount); 97 } elseif (count($this->columnNames) < $this->featureCount) { 98 $this->columnNames = array_merge( 99 $this->columnNames, 100 range(count($this->columnNames), $this->featureCount - 1) 101 ); 102 } 103 } 104 105 public static function getColumnTypes(array $samples): array 106 { 107 $types = []; 108 $featureCount = count($samples[0]); 109 for ($i = 0; $i < $featureCount; ++$i) { 110 $values = array_column($samples, $i); 111 $isCategorical = self::isCategoricalColumn($values); 112 $types[] = $isCategorical ? self::NOMINAL : self::CONTINUOUS; 113 } 114 115 return $types; 116 } 117 118 /** 119 * @param mixed $baseValue 120 */ 121 public function getGiniIndex($baseValue, array $colValues, array $targets): float 122 { 123 $countMatrix = []; 124 foreach ($this->labels as $label) { 125 $countMatrix[$label] = [0, 0]; 126 } 127 128 foreach ($colValues as $index => $value) { 129 $label = $targets[$index]; 130 $rowIndex = $value === $baseValue ? 0 : 1; 131 ++$countMatrix[$label][$rowIndex]; 132 } 133 134 $giniParts = [0, 0]; 135 for ($i = 0; $i <= 1; ++$i) { 136 $part = 0; 137 $sum = array_sum(array_column($countMatrix, $i)); 138 if ($sum > 0) { 139 foreach ($this->labels as $label) { 140 $part += ($countMatrix[$label][$i] / (float) $sum) ** 2; 141 } 142 } 143 144 $giniParts[$i] = (1 - $part) * $sum; 145 } 146 147 return array_sum($giniParts) / count($colValues); 148 } 149 150 /** 151 * This method is used to set number of columns to be used 152 * when deciding a split at an internal node of the tree. <br> 153 * If the value is given 0, then all features are used (default behaviour), 154 * otherwise the given value will be used as a maximum for number of columns 155 * randomly selected for each split operation. 156 * 157 * @return $this 158 * 159 * @throws InvalidArgumentException 160 */ 161 public function setNumFeatures(int $numFeatures) 162 { 163 if ($numFeatures < 0) { 164 throw new InvalidArgumentException('Selected column count should be greater or equal to zero'); 165 } 166 167 $this->numUsableFeatures = $numFeatures; 168 169 return $this; 170 } 171 172 /** 173 * A string array to represent columns. Useful when HTML output or 174 * column importances are desired to be inspected. 175 * 176 * @return $this 177 * 178 * @throws InvalidArgumentException 179 */ 180 public function setColumnNames(array $names) 181 { 182 if ($this->featureCount !== 0 && count($names) !== $this->featureCount) { 183 throw new InvalidArgumentException(sprintf('Length of the given array should be equal to feature count %s', $this->featureCount)); 184 } 185 186 $this->columnNames = $names; 187 188 return $this; 189 } 190 191 public function getHtml(): string 192 { 193 return $this->tree->getHTML($this->columnNames); 194 } 195 196 /** 197 * This will return an array including an importance value for 198 * each column in the given dataset. The importance values are 199 * normalized and their total makes 1.<br/> 200 */ 201 public function getFeatureImportances(): array 202 { 203 if ($this->featureImportances !== null) { 204 return $this->featureImportances; 205 } 206 207 $sampleCount = count($this->samples); 208 $this->featureImportances = []; 209 foreach ($this->columnNames as $column => $columnName) { 210 $nodes = $this->getSplitNodesByColumn($column, $this->tree); 211 212 $importance = 0; 213 foreach ($nodes as $node) { 214 $importance += $node->getNodeImpurityDecrease($sampleCount); 215 } 216 217 $this->featureImportances[$columnName] = $importance; 218 } 219 220 // Normalize & sort the importances 221 $total = array_sum($this->featureImportances); 222 if ($total > 0) { 223 array_walk($this->featureImportances, function (&$importance) use ($total): void { 224 $importance /= $total; 225 }); 226 arsort($this->featureImportances); 227 } 228 229 return $this->featureImportances; 230 } 231 232 protected function getSplitLeaf(array $records, int $depth = 0): DecisionTreeLeaf 233 { 234 $split = $this->getBestSplit($records); 235 $split->level = $depth; 236 if ($this->actualDepth < $depth) { 237 $this->actualDepth = $depth; 238 } 239 240 // Traverse all records to see if all records belong to the same class, 241 // otherwise group the records so that we can classify the leaf 242 // in case maximum depth is reached 243 $leftRecords = []; 244 $rightRecords = []; 245 $remainingTargets = []; 246 $prevRecord = null; 247 $allSame = true; 248 249 foreach ($records as $recordNo) { 250 // Check if the previous record is the same with the current one 251 $record = $this->samples[$recordNo]; 252 if ($prevRecord !== null && $prevRecord != $record) { 253 $allSame = false; 254 } 255 256 $prevRecord = $record; 257 258 // According to the split criteron, this record will 259 // belong to either left or the right side in the next split 260 if ($split->evaluate($record)) { 261 $leftRecords[] = $recordNo; 262 } else { 263 $rightRecords[] = $recordNo; 264 } 265 266 // Group remaining targets 267 $target = $this->targets[$recordNo]; 268 if (!array_key_exists($target, $remainingTargets)) { 269 $remainingTargets[$target] = 1; 270 } else { 271 ++$remainingTargets[$target]; 272 } 273 } 274 275 if ($allSame || $depth >= $this->maxDepth || count($remainingTargets) === 1) { 276 $split->isTerminal = true; 277 arsort($remainingTargets); 278 $split->classValue = (string) key($remainingTargets); 279 } else { 280 if (isset($leftRecords[0])) { 281 $split->leftLeaf = $this->getSplitLeaf($leftRecords, $depth + 1); 282 } 283 284 if (isset($rightRecords[0])) { 285 $split->rightLeaf = $this->getSplitLeaf($rightRecords, $depth + 1); 286 } 287 } 288 289 return $split; 290 } 291 292 protected function getBestSplit(array $records): DecisionTreeLeaf 293 { 294 $targets = array_intersect_key($this->targets, array_flip($records)); 295 $samples = (array) array_combine( 296 $records, 297 $this->preprocess(array_intersect_key($this->samples, array_flip($records))) 298 ); 299 $bestGiniVal = 1; 300 $bestSplit = null; 301 $features = $this->getSelectedFeatures(); 302 foreach ($features as $i) { 303 $colValues = []; 304 foreach ($samples as $index => $row) { 305 $colValues[$index] = $row[$i]; 306 } 307 308 $counts = array_count_values($colValues); 309 arsort($counts); 310 $baseValue = key($counts); 311 if ($baseValue === null) { 312 continue; 313 } 314 315 $gini = $this->getGiniIndex($baseValue, $colValues, $targets); 316 if ($bestSplit === null || $bestGiniVal > $gini) { 317 $split = new DecisionTreeLeaf(); 318 $split->value = $baseValue; 319 $split->giniIndex = $gini; 320 $split->columnIndex = $i; 321 $split->isContinuous = $this->columnTypes[$i] === self::CONTINUOUS; 322 $split->records = $records; 323 324 // If a numeric column is to be selected, then 325 // the original numeric value and the selected operator 326 // will also be saved into the leaf for future access 327 if ($this->columnTypes[$i] === self::CONTINUOUS) { 328 $matches = []; 329 preg_match("/^([<>=]{1,2})\s*(.*)/", (string) $split->value, $matches); 330 $split->operator = $matches[1]; 331 $split->numericValue = (float) $matches[2]; 332 } 333 334 $bestSplit = $split; 335 $bestGiniVal = $gini; 336 } 337 } 338 339 return $bestSplit; 340 } 341 342 /** 343 * Returns available features/columns to the tree for the decision making 344 * process. <br> 345 * 346 * If a number is given with setNumFeatures() method, then a random selection 347 * of features up to this number is returned. <br> 348 * 349 * If some features are manually selected by use of setSelectedFeatures(), 350 * then only these features are returned <br> 351 * 352 * If any of above methods were not called beforehand, then all features 353 * are returned by default. 354 */ 355 protected function getSelectedFeatures(): array 356 { 357 $allFeatures = range(0, $this->featureCount - 1); 358 if ($this->numUsableFeatures === 0 && count($this->selectedFeatures) === 0) { 359 return $allFeatures; 360 } 361 362 if (count($this->selectedFeatures) > 0) { 363 return $this->selectedFeatures; 364 } 365 366 $numFeatures = $this->numUsableFeatures; 367 if ($numFeatures > $this->featureCount) { 368 $numFeatures = $this->featureCount; 369 } 370 371 shuffle($allFeatures); 372 $selectedFeatures = array_slice($allFeatures, 0, $numFeatures); 373 sort($selectedFeatures); 374 375 return $selectedFeatures; 376 } 377 378 protected function preprocess(array $samples): array 379 { 380 // Detect and convert continuous data column values into 381 // discrete values by using the median as a threshold value 382 $columns = []; 383 for ($i = 0; $i < $this->featureCount; ++$i) { 384 $values = array_column($samples, $i); 385 if ($this->columnTypes[$i] == self::CONTINUOUS) { 386 $median = Mean::median($values); 387 foreach ($values as &$value) { 388 if ($value <= $median) { 389 $value = "<= {$median}"; 390 } else { 391 $value = "> {$median}"; 392 } 393 } 394 } 395 396 $columns[] = $values; 397 } 398 399 // Below method is a strange yet very simple & efficient method 400 // to get the transpose of a 2D array 401 return array_map(null, ...$columns); 402 } 403 404 protected static function isCategoricalColumn(array $columnValues): bool 405 { 406 $count = count($columnValues); 407 408 // There are two main indicators that *may* show whether a 409 // column is composed of discrete set of values: 410 // 1- Column may contain string values and non-float values 411 // 2- Number of unique values in the column is only a small fraction of 412 // all values in that column (Lower than or equal to %20 of all values) 413 $numericValues = array_filter($columnValues, 'is_numeric'); 414 $floatValues = array_filter($columnValues, 'is_float'); 415 if (count($floatValues) > 0) { 416 return false; 417 } 418 419 if (count($numericValues) !== $count) { 420 return true; 421 } 422 423 $distinctValues = array_count_values($columnValues); 424 425 return count($distinctValues) <= $count / 5; 426 } 427 428 /** 429 * Used to set predefined features to consider while deciding which column to use for a split 430 */ 431 protected function setSelectedFeatures(array $selectedFeatures): void 432 { 433 $this->selectedFeatures = $selectedFeatures; 434 } 435 436 /** 437 * Collects and returns an array of internal nodes that use the given 438 * column as a split criterion 439 */ 440 protected function getSplitNodesByColumn(int $column, DecisionTreeLeaf $node): array 441 { 442 if ($node->isTerminal) { 443 return []; 444 } 445 446 $nodes = []; 447 if ($node->columnIndex === $column) { 448 $nodes[] = $node; 449 } 450 451 $lNodes = []; 452 $rNodes = []; 453 if ($node->leftLeaf !== null) { 454 $lNodes = $this->getSplitNodesByColumn($column, $node->leftLeaf); 455 } 456 457 if ($node->rightLeaf !== null) { 458 $rNodes = $this->getSplitNodesByColumn($column, $node->rightLeaf); 459 } 460 461 return array_merge($nodes, $lNodes, $rNodes); 462 } 463 464 /** 465 * @return mixed 466 */ 467 protected function predictSample(array $sample) 468 { 469 $node = $this->tree; 470 do { 471 if ($node->isTerminal) { 472 return $node->classValue; 473 } 474 475 if ($node->evaluate($sample)) { 476 $node = $node->leftLeaf; 477 } else { 478 $node = $node->rightLeaf; 479 } 480 } while ($node); 481 482 return $this->labels[0]; 483 } 484 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body