Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.2.x will end 22 April 2024 (12 months).
  • Bug fixes for security issues in 4.2.x will end 7 October 2024 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.1.x is supported too.

Differences Between: [Versions 310 and 402] [Versions 311 and 402] [Versions 39 and 402]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Helper\Optimizer;
   6  
   7  use Closure;
   8  use Phpml\Exception\InvalidOperationException;
   9  
  10  /**
  11   * Batch version of Gradient Descent to optimize the weights
  12   * of a classifier given samples, targets and the objective function to minimize
  13   */
  14  class GD extends StochasticGD
  15  {
  16      /**
  17       * Number of samples given
  18       *
  19       * @var int|null
  20       */
  21      protected $sampleCount;
  22  
  23      public function runOptimization(array $samples, array $targets, Closure $gradientCb): array
  24      {
  25          $this->samples = $samples;
  26          $this->targets = $targets;
  27          $this->gradientCb = $gradientCb;
  28          $this->sampleCount = count($this->samples);
  29  
  30          // Batch learning is executed:
  31          $currIter = 0;
  32          $this->costValues = [];
  33          while ($this->maxIterations > $currIter++) {
  34              $theta = $this->theta;
  35  
  36              // Calculate update terms for each sample
  37              [$errors, $updates, $totalPenalty] = $this->gradient($theta);
  38  
  39              $this->updateWeightsWithUpdates($updates, $totalPenalty);
  40  
  41              $this->costValues[] = array_sum($errors) / (int) $this->sampleCount;
  42  
  43              if ($this->earlyStop($theta)) {
  44                  break;
  45              }
  46          }
  47  
  48          $this->clear();
  49  
  50          return $this->theta;
  51      }
  52  
  53      /**
  54       * Calculates gradient, cost function and penalty term for each sample
  55       * then returns them as an array of values
  56       */
  57      protected function gradient(array $theta): array
  58      {
  59          $costs = [];
  60          $gradient = [];
  61          $totalPenalty = 0;
  62  
  63          if ($this->gradientCb === null) {
  64              throw new InvalidOperationException('Gradient callback is not defined');
  65          }
  66  
  67          foreach ($this->samples as $index => $sample) {
  68              $target = $this->targets[$index];
  69  
  70              $result = ($this->gradientCb)($theta, $sample, $target);
  71              [$cost, $grad, $penalty] = array_pad($result, 3, 0);
  72  
  73              $costs[] = $cost;
  74              $gradient[] = $grad;
  75              $totalPenalty += $penalty;
  76          }
  77  
  78          $totalPenalty /= $this->sampleCount;
  79  
  80          return [$costs, $gradient, $totalPenalty];
  81      }
  82  
  83      protected function updateWeightsWithUpdates(array $updates, float $penalty): void
  84      {
  85          // Updates all weights at once
  86          for ($i = 0; $i <= $this->dimensions; ++$i) {
  87              if ($i === 0) {
  88                  $this->theta[0] -= $this->learningRate * array_sum($updates);
  89              } else {
  90                  $col = array_column($this->samples, $i - 1);
  91  
  92                  $error = 0;
  93                  foreach ($col as $index => $val) {
  94                      $error += $val * $updates[$index];
  95                  }
  96  
  97                  $this->theta[$i] -= $this->learningRate *
  98                      ($error + $penalty * $this->theta[$i]);
  99              }
 100          }
 101      }
 102  
 103      /**
 104       * Clears the optimizer internal vars after the optimization process.
 105       */
 106      protected function clear(): void
 107      {
 108          $this->sampleCount = null;
 109          parent::clear();
 110      }
 111  }