Search moodle.org's
Developer Documentation

See Release Notes
Long Term Support Release

  • Bug fixes for general core bugs in 3.9.x will end* 10 May 2021 (12 months).
  • Bug fixes for security issues in 3.9.x will end* 8 May 2023 (36 months).
  • PHP version: minimum PHP 7.2.0 Note: minimum PHP version has increased since Moodle 3.8. PHP 7.3.x and 7.4.x are supported too.
<?php

declare(strict_types=1);

namespace Phpml\Classification\Linear;

use Closure;
use Exception;
use Phpml\Exception\InvalidArgumentException;
use Phpml\Helper\Optimizer\ConjugateGradient;

class LogisticRegression extends Adaline
{
    /**
     * Batch training: Gradient descent algorithm (default)
     */
    public const BATCH_TRAINING = 1;

    /**
     * Online training: Stochastic gradient descent learning
     */
    public const ONLINE_TRAINING = 2;

    /**
     * Conjugate Batch: Conjugate Gradient algorithm
     */
    public const CONJUGATE_GRAD_TRAINING = 3;

    /**
     * Cost function to optimize: 'log' and 'sse' are supported <br>
     *  - 'log' : log likelihood <br>
     *  - 'sse' : sum of squared errors <br>
     *
     * @var string
     */
    protected $costFunction = 'log';

    /**
     * Regularization term: only 'L2' is supported
     *
     * @var string
     */
    protected $penalty = 'L2';

    /**
     * Lambda (λ) parameter of regularization term. If λ is set to 0, then
     * regularization term is cancelled.
     *
     * @var float
     */
    protected $lambda = 0.5;

    /**
     * Initalize a Logistic Regression classifier with maximum number of iterations
     * and learning rule to be applied <br>
     *
     * Maximum number of iterations can be an integer value greater than 0 <br>
     * If normalizeInputs is set to true, then every input given to the algorithm will be standardized
     * by use of standard deviation and mean calculation <br>
     *
     * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br>
     *
     * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term
     *
     * @throws InvalidArgumentException
     */
    public function __construct(
        int $maxIterations = 500,
        bool $normalizeInputs = true,
        int $trainingType = self::CONJUGATE_GRAD_TRAINING,
        string $cost = 'log',
        string $penalty = 'L2'
    ) {
        $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING);
        if (!in_array($trainingType, $trainingTypes, true)) {
            throw new InvalidArgumentException(
                'Logistic regression can only be trained with '.
                'batch (gradient descent), online (stochastic gradient descent) '.
                'or conjugate batch (conjugate gradients) algorithms'
            );
        }

        if (!in_array($cost, ['log', 'sse'], true)) {
            throw new InvalidArgumentException(
                "Logistic regression cost function can be one of the following: \n".
                "'log' for log-likelihood and 'sse' for sum of squared errors"
            );
        }

        if ($penalty !== '' && strtoupper($penalty) !== 'L2') {
            throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization');
        }

        $this->learningRate = 0.001;

        parent::__construct($this->learningRate, $maxIterations, $normalizeInputs);

        $this->trainingType = $trainingType;
        $this->costFunction = $cost;
        $this->penalty = $penalty;
    }

    /**
     * Sets the learning rate if gradient descent algorithm is
     * selected for training
     */
    public function setLearningRate(float $learningRate): void
    {
        $this->learningRate = $learningRate;
    }

    /**
     * Lambda (λ) parameter of regularization term. If 0 is given,
     * then the regularization term is cancelled
     */
    public function setLambda(float $lambda): void
    {
        $this->lambda = $lambda;
    }

    /**
     * Adapts the weights with respect to given samples and targets
     * by use of selected solver
     *
     * @throws \Exception
     */
    protected function runTraining(array $samples, array $targets): void
    {
        $callback = $this->getCostFunction();

        switch ($this->trainingType) {
            case self::BATCH_TRAINING:
                $this->runGradientDescent($samples, $targets, $callback, true);

                return;

            case self::ONLINE_TRAINING:
                $this->runGradientDescent($samples, $targets, $callback, false);

                return;

            case self::CONJUGATE_GRAD_TRAINING:
                $this->runConjugateGradient($samples, $targets, $callback);

                return;

            default:
                // Not reached
                throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType));
        }
    }

    /**
     * Executes Conjugate Gradient method to optimize the weights of the LogReg model
     */
    protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void
    {
        if ($this->optimizer === null) {
            $this->optimizer = (new ConjugateGradient($this->featureCount))
                ->setMaxIterations($this->maxIterations);
        }

        $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc);
        $this->costValues = $this->optimizer->getCostValues();
    }

    /**
     * Returns the appropriate callback function for the selected cost function
     *
     * @throws \Exception
     */
    protected function getCostFunction(): Closure
    {
        $penalty = 0;
        if ($this->penalty === 'L2') {
            $penalty = $this->lambda;
        }

        switch ($this->costFunction) {
            case 'log':
                /*
                 * Negative of Log-likelihood cost function to be minimized:
                 *		J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x)))
                 *
                 * If regularization term is given, then it will be added to the cost:
                 *		for L2 : J(x) = J(x) +  λ/m . w
                 *
                 * The gradient of the cost function to be used with gradient descent:
                 *		∇J(x) = -(y - h(x)) = (h(x) - y)
                 */
< return function ($weights, $sample, $y) use ($penalty) {
> return function ($weights, $sample, $y) use ($penalty): array {
$this->weights = $weights; $hX = $this->output($sample); // In cases where $hX = 1 or $hX = 0, the log-likelihood // value will give a NaN, so we fix these values if ($hX == 1) { $hX = 1 - 1e-10; } if ($hX == 0) { $hX = 1e-10; } $y = $y < 0 ? 0 : 1; $error = -$y * log($hX) - (1 - $y) * log(1 - $hX); $gradient = $hX - $y; return [$error, $gradient, $penalty]; }; case 'sse': /* * Sum of squared errors or least squared errors cost function: * J(x) = ∑ (y - h(x))^2 * * If regularization term is given, then it will be added to the cost: * for L2 : J(x) = J(x) + λ/m . w * * The gradient of the cost function: * ∇J(x) = -(h(x) - y) . h(x) . (1 - h(x)) */
< return function ($weights, $sample, $y) use ($penalty) {
> return function ($weights, $sample, $y) use ($penalty): array {
$this->weights = $weights; $hX = $this->output($sample); $y = $y < 0 ? 0 : 1;
< $error = ($y - $hX) ** 2;
> $error = (($y - $hX) ** 2);
$gradient = -($y - $hX) * $hX * (1 - $hX); return [$error, $gradient, $penalty]; }; default: // Not reached throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction)); } } /** * Returns the output of the network, a float value between 0.0 and 1.0 */ protected function output(array $sample): float { $sum = parent::output($sample); return 1.0 / (1.0 + exp(-$sum)); } /** * Returns the class value (either -1 or 1) for the given input */ protected function outputClass(array $sample): int { $output = $this->output($sample); if ($output > 0.5) { return 1; } return -1; } /** * Returns the probability of the sample of belonging to the given label. * * The probability is simply taken as the distance of the sample * to the decision plane. * * @param mixed $label */ protected function predictProbability(array $sample, $label): float { $sample = $this->checkNormalizedSample($sample); $probability = $this->output($sample); if (array_search($label, $this->labels, true) > 0) { return $probability; } return 1 - $probability; } }