Differences Between: [Versions 310 and 400] [Versions 311 and 400] [Versions 39 and 400]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Classification\Ensemble; 6 7 use Phpml\Classification\Classifier; 8 use Phpml\Classification\DecisionTree; 9 use Phpml\Exception\InvalidArgumentException; 10 11 class RandomForest extends Bagging 12 { 13 /** 14 * @var float|string 15 */ 16 protected $featureSubsetRatio = 'log'; 17 18 /** 19 * @var array|null 20 */ 21 protected $columnNames; 22 23 /** 24 * Initializes RandomForest with the given number of trees. More trees 25 * may increase the prediction performance while it will also substantially 26 * increase the processing time and the required memory 27 */ 28 public function __construct(int $numClassifier = 50) 29 { 30 parent::__construct($numClassifier); 31 32 $this->setSubsetRatio(1.0); 33 } 34 35 /** 36 * This method is used to determine how many of the original columns (features) 37 * will be used to construct subsets to train base classifiers.<br> 38 * 39 * Allowed values: 'sqrt', 'log' or any float number between 0.1 and 1.0 <br> 40 * 41 * Default value for the ratio is 'log' which results in log(numFeatures, 2) + 1 42 * features to be taken into consideration while selecting subspace of features 43 * 44 * @param mixed $ratio 45 */ 46 public function setFeatureSubsetRatio($ratio): self 47 { 48 if (!is_string($ratio) && !is_float($ratio)) { 49 throw new InvalidArgumentException('Feature subset ratio must be a string or a float'); 50 } 51 52 if (is_float($ratio) && ($ratio < 0.1 || $ratio > 1.0)) { 53 throw new InvalidArgumentException('When a float is given, feature subset ratio should be between 0.1 and 1.0'); 54 } 55 56 if (is_string($ratio) && $ratio !== 'sqrt' && $ratio !== 'log') { 57 throw new InvalidArgumentException("When a string is given, feature subset ratio can only be 'sqrt' or 'log'"); 58 } 59 60 $this->featureSubsetRatio = $ratio; 61 62 return $this; 63 } 64 65 /** 66 * RandomForest algorithm is usable *only* with DecisionTree 67 * 68 * @return $this 69 */ 70 public function setClassifer(string $classifier, array $classifierOptions = []) 71 { 72 if ($classifier !== DecisionTree::class) { 73 throw new InvalidArgumentException('RandomForest can only use DecisionTree as base classifier'); 74 } 75 76 parent::setClassifer($classifier, $classifierOptions); 77 78 return $this; 79 } 80 81 /** 82 * This will return an array including an importance value for 83 * each column in the given dataset. Importance values for a column 84 * is the average importance of that column in all trees in the forest 85 */ 86 public function getFeatureImportances(): array 87 { 88 // Traverse each tree and sum importance of the columns 89 $sum = []; 90 foreach ($this->classifiers as $tree) { 91 /** @var DecisionTree $tree */ 92 $importances = $tree->getFeatureImportances(); 93 94 foreach ($importances as $column => $importance) { 95 if (array_key_exists($column, $sum)) { 96 $sum[$column] += $importance; 97 } else { 98 $sum[$column] = $importance; 99 } 100 } 101 } 102 103 // Normalize & sort the importance values 104 $total = array_sum($sum); 105 array_walk($sum, function (&$importance) use ($total): void { 106 $importance /= $total; 107 }); 108 arsort($sum); 109 110 return $sum; 111 } 112 113 /** 114 * A string array to represent the columns is given. They are useful 115 * when trying to print some information about the trees such as feature importances 116 * 117 * @return $this 118 */ 119 public function setColumnNames(array $names) 120 { 121 $this->columnNames = $names; 122 123 return $this; 124 } 125 126 /** 127 * @return DecisionTree 128 */ 129 protected function initSingleClassifier(Classifier $classifier): Classifier 130 { 131 if (!$classifier instanceof DecisionTree) { 132 throw new InvalidArgumentException( 133 sprintf('Classifier %s expected, got %s', DecisionTree::class, get_class($classifier)) 134 ); 135 } 136 137 if (is_float($this->featureSubsetRatio)) { 138 $featureCount = (int) ($this->featureSubsetRatio * $this->featureCount); 139 } elseif ($this->featureSubsetRatio === 'sqrt') { 140 $featureCount = (int) ($this->featureCount ** .5) + 1; 141 } else { 142 $featureCount = (int) log($this->featureCount, 2) + 1; 143 } 144 145 if ($featureCount >= $this->featureCount) { 146 $featureCount = $this->featureCount; 147 } 148 149 if ($this->columnNames === null) { 150 $this->columnNames = range(0, $this->featureCount - 1); 151 } 152 153 return $classifier 154 ->setColumnNames($this->columnNames) 155 ->setNumFeatures($featureCount); 156 } 157 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body