Search moodle.org's
Developer Documentation

See Release Notes
Long Term Support Release

  • Bug fixes for general core bugs in 4.1.x will end 13 November 2023 (12 months).
  • Bug fixes for security issues in 4.1.x will end 10 November 2025 (36 months).
  • PHP version: minimum PHP 7.4.0 Note: minimum PHP version has increased since Moodle 4.0. PHP 8.0.x is supported too.
   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\DimensionReduction;
   6  
   7  use Phpml\Exception\InvalidArgumentException;
   8  use Phpml\Exception\InvalidOperationException;
   9  use Phpml\Math\Matrix;
  10  
  11  class LDA extends EigenTransformerBase
  12  {
  13      /**
  14       * @var bool
  15       */
  16      public $fit = false;
  17  
  18      /**
  19       * @var array
  20       */
  21      public $labels = [];
  22  
  23      /**
  24       * @var array
  25       */
  26      public $means = [];
  27  
  28      /**
  29       * @var array
  30       */
  31      public $counts = [];
  32  
  33      /**
  34       * @var float[]
  35       */
  36      public $overallMean = [];
  37  
  38      /**
  39       * Linear Discriminant Analysis (LDA) is used to reduce the dimensionality
  40       * of the data. Unlike Principal Component Analysis (PCA), it is a supervised
  41       * technique that requires the class labels in order to fit the data to a
  42       * lower dimensional space. <br><br>
  43       * The algorithm can be initialized by speciyfing
  44       * either with the totalVariance(a value between 0.1 and 0.99)
  45       * or numFeatures (number of features in the dataset) to be preserved.
  46       *
  47       * @param float|null $totalVariance Total explained variance to be preserved
  48       * @param int|null   $numFeatures   Number of features to be preserved
  49       *
  50       * @throws InvalidArgumentException
  51       */
  52      public function __construct(?float $totalVariance = null, ?int $numFeatures = null)
  53      {
  54          if ($totalVariance !== null && ($totalVariance < 0.1 || $totalVariance > 0.99)) {
  55              throw new InvalidArgumentException('Total variance can be a value between 0.1 and 0.99');
  56          }
  57  
  58          if ($numFeatures !== null && $numFeatures <= 0) {
  59              throw new InvalidArgumentException('Number of features to be preserved should be greater than 0');
  60          }
  61  
  62          if (($totalVariance !== null) === ($numFeatures !== null)) {
  63              throw new InvalidArgumentException('Either totalVariance or numFeatures should be specified in order to run the algorithm');
  64          }
  65  
  66          if ($numFeatures !== null) {
  67              $this->numFeatures = $numFeatures;
  68          }
  69  
  70          if ($totalVariance !== null) {
  71              $this->totalVariance = $totalVariance;
  72          }
  73      }
  74  
  75      /**
  76       * Trains the algorithm to transform the given data to a lower dimensional space.
  77       */
  78      public function fit(array $data, array $classes): array
  79      {
  80          $this->labels = $this->getLabels($classes);
  81          $this->means = $this->calculateMeans($data, $classes);
  82  
  83          $sW = $this->calculateClassVar($data, $classes);
  84          $sB = $this->calculateClassCov();
  85  
  86          $S = $sW->inverse()->multiply($sB);
  87          $this->eigenDecomposition($S->toArray());
  88  
  89          $this->fit = true;
  90  
  91          return $this->reduce($data);
  92      }
  93  
  94      /**
  95       * Transforms the given sample to a lower dimensional vector by using
  96       * the eigenVectors obtained in the last run of <code>fit</code>.
  97       *
  98       * @throws InvalidOperationException
  99       */
 100      public function transform(array $sample): array
 101      {
 102          if (!$this->fit) {
 103              throw new InvalidOperationException('LDA has not been fitted with respect to original dataset, please run LDA::fit() first');
 104          }
 105  
 106          if (!is_array($sample[0])) {
 107              $sample = [$sample];
 108          }
 109  
 110          return $this->reduce($sample);
 111      }
 112  
 113      /**
 114       * Returns unique labels in the dataset
 115       */
 116      protected function getLabels(array $classes): array
 117      {
 118          $counts = array_count_values($classes);
 119  
 120          return array_keys($counts);
 121      }
 122  
 123      /**
 124       * Calculates mean of each column for each class and returns
 125       * n by m matrix where n is number of labels and m is number of columns
 126       */
 127      protected function calculateMeans(array $data, array $classes): array
 128      {
 129          $means = [];
 130          $counts = [];
 131          $overallMean = array_fill(0, count($data[0]), 0.0);
 132  
 133          foreach ($data as $index => $row) {
 134              $label = array_search($classes[$index], $this->labels, true);
 135  
 136              foreach ($row as $col => $val) {
 137                  if (!isset($means[$label][$col])) {
 138                      $means[$label][$col] = 0.0;
 139                  }
 140  
 141                  $means[$label][$col] += $val;
 142                  $overallMean[$col] += $val;
 143              }
 144  
 145              if (!isset($counts[$label])) {
 146                  $counts[$label] = 0;
 147              }
 148  
 149              ++$counts[$label];
 150          }
 151  
 152          foreach ($means as $index => $row) {
 153              foreach ($row as $col => $sum) {
 154                  $means[$index][$col] = $sum / $counts[$index];
 155              }
 156          }
 157  
 158          // Calculate overall mean of the dataset for each column
 159          $numElements = array_sum($counts);
 160          $map = function ($el) use ($numElements) {
 161              return $el / $numElements;
 162          };
 163          $this->overallMean = array_map($map, $overallMean);
 164          $this->counts = $counts;
 165  
 166          return $means;
 167      }
 168  
 169      /**
 170       * Returns in-class scatter matrix for each class, which
 171       * is a n by m matrix where n is number of classes and
 172       * m is number of columns
 173       */
 174      protected function calculateClassVar(array $data, array $classes): Matrix
 175      {
 176          // s is an n (number of classes) by m (number of column) matrix
 177          $s = array_fill(0, count($data[0]), array_fill(0, count($data[0]), 0));
 178          $sW = new Matrix($s, false);
 179  
 180          foreach ($data as $index => $row) {
 181              $label = array_search($classes[$index], $this->labels, true);
 182              $means = $this->means[$label];
 183  
 184              $row = $this->calculateVar($row, $means);
 185  
 186              $sW = $sW->add($row);
 187          }
 188  
 189          return $sW;
 190      }
 191  
 192      /**
 193       * Returns between-class scatter matrix for each class, which
 194       * is an n by m matrix where n is number of classes and
 195       * m is number of columns
 196       */
 197      protected function calculateClassCov(): Matrix
 198      {
 199          // s is an n (number of classes) by m (number of column) matrix
 200          $s = array_fill(0, count($this->overallMean), array_fill(0, count($this->overallMean), 0));
 201          $sB = new Matrix($s, false);
 202  
 203          foreach ($this->means as $index => $classMeans) {
 204              $row = $this->calculateVar($classMeans, $this->overallMean);
 205              $N = $this->counts[$index];
 206              $sB = $sB->add($row->multiplyByScalar($N));
 207          }
 208  
 209          return $sB;
 210      }
 211  
 212      /**
 213       * Returns the result of the calculation (x - m)T.(x - m)
 214       */
 215      protected function calculateVar(array $row, array $means): Matrix
 216      {
 217          $x = new Matrix($row, false);
 218          $m = new Matrix($means, false);
 219          $diff = $x->subtract($m);
 220  
 221          return $diff->transpose()->multiply($diff);
 222      }
 223  }