See Release Notes
Long Term Support Release
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\DimensionReduction; 6 7 use Phpml\Exception\InvalidArgumentException; 8 use Phpml\Exception\InvalidOperationException; 9 use Phpml\Math\Matrix; 10 11 class LDA extends EigenTransformerBase 12 { 13 /** 14 * @var bool 15 */ 16 public $fit = false; 17 18 /** 19 * @var array 20 */ 21 public $labels = []; 22 23 /** 24 * @var array 25 */ 26 public $means = []; 27 28 /** 29 * @var array 30 */ 31 public $counts = []; 32 33 /** 34 * @var float[] 35 */ 36 public $overallMean = []; 37 38 /** 39 * Linear Discriminant Analysis (LDA) is used to reduce the dimensionality 40 * of the data. Unlike Principal Component Analysis (PCA), it is a supervised 41 * technique that requires the class labels in order to fit the data to a 42 * lower dimensional space. <br><br> 43 * The algorithm can be initialized by speciyfing 44 * either with the totalVariance(a value between 0.1 and 0.99) 45 * or numFeatures (number of features in the dataset) to be preserved. 46 * 47 * @param float|null $totalVariance Total explained variance to be preserved 48 * @param int|null $numFeatures Number of features to be preserved 49 * 50 * @throws InvalidArgumentException 51 */ 52 public function __construct(?float $totalVariance = null, ?int $numFeatures = null) 53 { 54 if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) { 55 throw new InvalidArgumentException('Total variance can be a value between 0.1 and 0.99'); 56 } 57 58 if ($numFeatures !== null && $numFeatures <= 0) { 59 throw new InvalidArgumentException('Number of features to be preserved should be greater than 0'); 60 } 61 62 if (($totalVariance !== null) === ($numFeatures !== null)) { 63 throw new InvalidArgumentException('Either totalVariance or numFeatures should be specified in order to run the algorithm'); 64 } 65 66 if ($numFeatures !== null) { 67 $this->numFeatures = $numFeatures; 68 } 69 70 if ($totalVariance !== null) { 71 $this->totalVariance = $totalVariance; 72 } 73 } 74 75 /** 76 * Trains the algorithm to transform the given data to a lower dimensional space. 77 */ 78 public function fit(array $data, array $classes): array 79 { 80 $this->labels = $this->getLabels($classes); 81 $this->means = $this->calculateMeans($data, $classes); 82 83 $sW = $this->calculateClassVar($data, $classes); 84 $sB = $this->calculateClassCov(); 85 86 $S = $sW->inverse()->multiply($sB); 87 $this->eigenDecomposition($S->toArray()); 88 89 $this->fit = true; 90 91 return $this->reduce($data); 92 } 93 94 /** 95 * Transforms the given sample to a lower dimensional vector by using 96 * the eigenVectors obtained in the last run of <code>fit</code>. 97 * 98 * @throws InvalidOperationException 99 */ 100 public function transform(array $sample): array 101 { 102 if (!$this->fit) { 103 throw new InvalidOperationException('LDA has not been fitted with respect to original dataset, please run LDA::fit() first'); 104 } 105 106 if (!is_array($sample[0])) { 107 $sample = [$sample]; 108 } 109 110 return $this->reduce($sample); 111 } 112 113 /** 114 * Returns unique labels in the dataset 115 */ 116 protected function getLabels(array $classes): array 117 { 118 $counts = array_count_values($classes); 119 120 return array_keys($counts); 121 } 122 123 /** 124 * Calculates mean of each column for each class and returns 125 * n by m matrix where n is number of labels and m is number of columns 126 */ 127 protected function calculateMeans(array $data, array $classes): array 128 { 129 $means = []; 130 $counts = []; 131 $overallMean = array_fill(0, count($data[0]), 0.0); 132 133 foreach ($data as $index => $row) { 134 $label = array_search($classes[$index], $this->labels, true); 135 136 foreach ($row as $col => $val) { 137 if (!isset($means[$label][$col])) { 138 $means[$label][$col] = 0.0; 139 } 140 141 $means[$label][$col] += $val; 142 $overallMean[$col] += $val; 143 } 144 145 if (!isset($counts[$label])) { 146 $counts[$label] = 0; 147 } 148 149 ++$counts[$label]; 150 } 151 152 foreach ($means as $index => $row) { 153 foreach ($row as $col => $sum) { 154 $means[$index][$col] = $sum / $counts[$index]; 155 } 156 } 157 158 // Calculate overall mean of the dataset for each column 159 $numElements = array_sum($counts); 160 $map = function ($el) use ($numElements) { 161 return $el / $numElements; 162 }; 163 $this->overallMean = array_map($map, $overallMean); 164 $this->counts = $counts; 165 166 return $means; 167 } 168 169 /** 170 * Returns in-class scatter matrix for each class, which 171 * is a n by m matrix where n is number of classes and 172 * m is number of columns 173 */ 174 protected function calculateClassVar(array $data, array $classes): Matrix 175 { 176 // s is an n (number of classes) by m (number of column) matrix 177 $s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0)); 178 $sW = new Matrix($s, false); 179 180 foreach ($data as $index => $row) { 181 $label = array_search($classes[$index], $this->labels, true); 182 $means = $this->means[$label]; 183 184 $row = $this->calculateVar($row, $means); 185 186 $sW = $sW->add($row); 187 } 188 189 return $sW; 190 } 191 192 /** 193 * Returns between-class scatter matrix for each class, which 194 * is an n by m matrix where n is number of classes and 195 * m is number of columns 196 */ 197 protected function calculateClassCov(): Matrix 198 { 199 // s is an n (number of classes) by m (number of column) matrix 200 $s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0)); 201 $sB = new Matrix($s, false); 202 203 foreach ($this->means as $index => $classMeans) { 204 $row = $this->calculateVar($classMeans, $this->overallMean); 205 $N = $this->counts[$index]; 206 $sB = $sB->add($row->multiplyByScalar($N)); 207 } 208 209 return $sB; 210 } 211 212 /** 213 * Returns the result of the calculation (x - m)T.(x - m) 214 */ 215 protected function calculateVar(array $row, array $means): Matrix 216 { 217 $x = new Matrix($row, false); 218 $m = new Matrix($means, false); 219 $diff = $x->subtract($m); 220 221 return $diff->transpose()->multiply($diff); 222 } 223 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body