Search moodle.org's
Developer Documentation

See Release Notes
Long Term Support Release

  • Bug fixes for general core bugs in 4.1.x will end 13 November 2023 (12 months).
  • Bug fixes for security issues in 4.1.x will end 10 November 2025 (36 months).
  • PHP version: minimum PHP 7.4.0 Note: minimum PHP version has increased since Moodle 4.0. PHP 8.0.x is supported too.

Differences Between: [Versions 310 and 401] [Versions 311 and 401] [Versions 39 and 401]

   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification\Linear;
   6  
   7  use Closure;
   8  use Exception;
   9  use Phpml\Exception\InvalidArgumentException;
  10  use Phpml\Helper\Optimizer\ConjugateGradient;
  11  
  12  class LogisticRegression extends Adaline
  13  {
  14      /**
  15       * Batch training: Gradient descent algorithm (default)
  16       */
  17      public const BATCH_TRAINING = 1;
  18  
  19      /**
  20       * Online training: Stochastic gradient descent learning
  21       */
  22      public const ONLINE_TRAINING = 2;
  23  
  24      /**
  25       * Conjugate Batch: Conjugate Gradient algorithm
  26       */
  27      public const CONJUGATE_GRAD_TRAINING = 3;
  28  
  29      /**
  30       * Cost function to optimize: 'log' and 'sse' are supported <br>
  31       *  - 'log' : log likelihood <br>
  32       *  - 'sse' : sum of squared errors <br>
  33       *
  34       * @var string
  35       */
  36      protected $costFunction = 'log';
  37  
  38      /**
  39       * Regularization term: only 'L2' is supported
  40       *
  41       * @var string
  42       */
  43      protected $penalty = 'L2';
  44  
  45      /**
  46       * Lambda (λ) parameter of regularization term. If λ is set to 0, then
  47       * regularization term is cancelled.
  48       *
  49       * @var float
  50       */
  51      protected $lambda = 0.5;
  52  
  53      /**
  54       * Initalize a Logistic Regression classifier with maximum number of iterations
  55       * and learning rule to be applied <br>
  56       *
  57       * Maximum number of iterations can be an integer value greater than 0 <br>
  58       * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
  59       * by use of standard deviation and mean calculation <br>
  60       *
  61       * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
  62       *
  63       * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
  64       *
  65       * @throws InvalidArgumentException
  66       */
  67      public function __construct(
  68          int $maxIterations = 500,
  69          bool $normalizeInputs = true,
  70          int $trainingType = self::CONJUGATE_GRAD_TRAINING,
  71          string $cost = 'log',
  72          string $penalty = 'L2'
  73      ) {
  74          $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
  75          if (!in_array($trainingType, $trainingTypes, true)) {
  76              throw new InvalidArgumentException(
  77                  'Logistic regression can only be trained with '.
  78                  'batch (gradient descent), online (stochastic gradient descent) '.
  79                  'or conjugate batch (conjugate gradients) algorithms'
  80              );
  81          }
  82  
  83          if (!in_array($cost, ['log', 'sse'], true)) {
  84              throw new InvalidArgumentException(
  85                  "Logistic regression cost function can be one of the following: \n".
  86                  "'log' for log-likelihood and 'sse' for sum of squared errors"
  87              );
  88          }
  89  
  90          if ($penalty !== '' && strtoupper($penalty) !== 'L2') {
  91              throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization');
  92          }
  93  
  94          $this->learningRate = 0.001;
  95  
  96          parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);
  97  
  98          $this->trainingType = $trainingType;
  99          $this->costFunction = $cost;
 100          $this->penalty = $penalty;
 101      }
 102  
 103      /**
 104       * Sets the learning rate if gradient descent algorithm is
 105       * selected for training
 106       */
 107      public function setLearningRate(float $learningRate): void
 108      {
 109          $this->learningRate = $learningRate;
 110      }
 111  
 112      /**
 113       * Lambda (λ) parameter of regularization term. If 0 is given,
 114       * then the regularization term is cancelled
 115       */
 116      public function setLambda(float $lambda): void
 117      {
 118          $this->lambda = $lambda;
 119      }
 120  
 121      /**
 122       * Adapts the weights with respect to given samples and targets
 123       * by use of selected solver
 124       *
 125       * @throws \Exception
 126       */
 127      protected function runTraining(array $samples, array $targets): void
 128      {
 129          $callback = $this->getCostFunction();
 130  
 131          switch ($this->trainingType) {
 132              case self::BATCH_TRAINING:
 133                  $this->runGradientDescent($samples, $targets, $callback, true);
 134  
 135                  return;
 136  
 137              case self::ONLINE_TRAINING:
 138                  $this->runGradientDescent($samples, $targets, $callback, false);
 139  
 140                  return;
 141  
 142              case self::CONJUGATE_GRAD_TRAINING:
 143                  $this->runConjugateGradient($samples, $targets, $callback);
 144  
 145                  return;
 146  
 147              default:
 148                  // Not reached
 149                  throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType));
 150          }
 151      }
 152  
 153      /**
 154       * Executes Conjugate Gradient method to optimize the weights of the LogReg model
 155       */
 156      protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
 157      {
 158          if ($this->optimizer === null) {
 159              $this->optimizer = (new ConjugateGradient($this->featureCount))
 160                  ->setMaxIterations($this->maxIterations);
 161          }
 162  
 163          $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
 164          $this->costValues = $this->optimizer->getCostValues();
 165      }
 166  
 167      /**
 168       * Returns the appropriate callback function for the selected cost function
 169       *
 170       * @throws \Exception
 171       */
 172      protected function getCostFunction(): Closure
 173      {
 174          $penalty = 0;
 175          if ($this->penalty === 'L2') {
 176              $penalty = $this->lambda;
 177          }
 178  
 179          switch ($this->costFunction) {
 180              case 'log':
 181                  /*
 182                   * Negative of Log-likelihood cost function to be minimized:
 183                   *	 	 J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
 184                   *
 185                   * If regularization term is given, then it will be added to the cost:
 186                   *	 	 for L2 : J(x) = J(x) +  λ/m . w
 187                   *
 188                   * The gradient of the cost function to be used with gradient descent:
 189                   *	 	 ∇J(x) = -(y - h(x)) = (h(x) - y)
 190                   */
 191                  return function ($weights, $sample, $y) use ($penalty): array {
 192                      $this->weights = $weights;
 193                      $hX = $this->output($sample);
 194  
 195                      // In cases where $hX = 1 or $hX = 0, the log-likelihood
 196                      // value will give a NaN, so we fix these values
 197                      if ($hX == 1) {
 198                          $hX = 1 - 1e-10;
 199                      }
 200  
 201                      if ($hX == 0) {
 202                          $hX = 1e-10;
 203                      }
 204  
 205                      $y = $y < 0 ? 0 : 1;
 206  
 207                      $error = -$y * log($hX) - (1 - $y) * log(1 - $hX);
 208                      $gradient = $hX - $y;
 209  
 210                      return [$error, $gradient, $penalty];
 211                  };
 212              case 'sse':
 213                  /*
 214                   * Sum of squared errors or least squared errors cost function:
 215                   *	 	 J(x) = ∑ (y - h(x))^2
 216                   *
 217                   * If regularization term is given, then it will be added to the cost:
 218                   *	 	 for L2 : J(x) = J(x) +  λ/m . w
 219                   *
 220                   * The gradient of the cost function:
 221                   *	 	 ∇J(x) = -(h(x) - y) . h(x) . (1 - h(x))
 222                   */
 223                  return function ($weights, $sample, $y) use ($penalty): array {
 224                      $this->weights = $weights;
 225                      $hX = $this->output($sample);
 226  
 227                      $y = $y < 0 ? 0 : 1;
 228  
 229                      $error = (($y - $hX) ** 2);
 230                      $gradient = -($y - $hX) * $hX * (1 - $hX);
 231  
 232                      return [$error, $gradient, $penalty];
 233                  };
 234              default:
 235                  // Not reached
 236                  throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction));
 237          }
 238      }
 239  
 240      /**
 241       * Returns the output of the network, a float value between 0.0 and 1.0
 242       */
 243      protected function output(array $sample): float
 244      {
 245          $sum = parent::output($sample);
 246  
 247          return 1.0 / (1.0 + exp(-$sum));
 248      }
 249  
 250      /**
 251       * Returns the class value (either -1 or 1) for the given input
 252       */
 253      protected function outputClass(array $sample): int
 254      {
 255          $output = $this->output($sample);
 256  
 257          if ($output > 0.5) {
 258              return 1;
 259          }
 260  
 261          return -1;
 262      }
 263  
 264      /**
 265       * Returns the probability of the sample of belonging to the given label.
 266       *
 267       * The probability is simply taken as the distance of the sample
 268       * to the decision plane.
 269       *
 270       * @param mixed $label
 271       */
 272      protected function predictProbability(array $sample, $label): float
 273      {
 274          $sample = $this->checkNormalizedSample($sample);
 275          $probability = $this->output($sample);
 276  
 277          if (array_search($label, $this->labels, true) > 0) {
 278              return $probability;
 279          }
 280  
 281          return 1 - $probability;
 282      }
 283  }