See Release Notes
Long Term Support Release
Differences Between: [Versions 39 and 400] [Versions 39 and 401] [Versions 39 and 402] [Versions 39 and 403]
1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Classification\Linear; 6 7 use Closure; 8 use Exception; 9 use Phpml\Exception\InvalidArgumentException; 10 use Phpml\Helper\Optimizer\ConjugateGradient; 11 12 class LogisticRegression extends Adaline 13 { 14 /** 15 * Batch training: Gradient descent algorithm (default) 16 */ 17 public const BATCH_TRAINING = 1; 18 19 /** 20 * Online training: Stochastic gradient descent learning 21 */ 22 public const ONLINE_TRAINING = 2; 23 24 /** 25 * Conjugate Batch: Conjugate Gradient algorithm 26 */ 27 public const CONJUGATE_GRAD_TRAINING = 3; 28 29 /** 30 * Cost function to optimize: 'log' and 'sse' are supported <br> 31 * - 'log' : log likelihood <br> 32 * - 'sse' : sum of squared errors <br> 33 * 34 * @var string 35 */ 36 protected $costFunction = 'log'; 37 38 /** 39 * Regularization term: only 'L2' is supported 40 * 41 * @var string 42 */ 43 protected $penalty = 'L2'; 44 45 /** 46 * Lambda (λ) parameter of regularization term. If λ is set to 0, then 47 * regularization term is cancelled. 48 * 49 * @var float 50 */ 51 protected $lambda = 0.5; 52 53 /** 54 * Initalize a Logistic Regression classifier with maximum number of iterations 55 * and learning rule to be applied <br> 56 * 57 * Maximum number of iterations can be an integer value greater than 0 <br> 58 * If normalizeInputs is set to true, then every input given to the algorithm will be standardized 59 * by use of standard deviation and mean calculation <br> 60 * 61 * Cost function can be 'log' for log-likelihood and 'sse' for sum of squared errors <br> 62 * 63 * Penalty (Regularization term) can be 'L2' or empty string to cancel penalty term 64 * 65 * @throws InvalidArgumentException 66 */ 67 public function __construct( 68 int $maxIterations = 500, 69 bool $normalizeInputs = true, 70 int $trainingType = self::CONJUGATE_GRAD_TRAINING, 71 string $cost = 'log', 72 string $penalty = 'L2' 73 ) { 74 $trainingTypes = range(self::BATCH_TRAINING, self::CONJUGATE_GRAD_TRAINING); 75 if (!in_array($trainingType, $trainingTypes, true)) { 76 throw new InvalidArgumentException( 77 'Logistic regression can only be trained with '. 78 'batch (gradient descent), online (stochastic gradient descent) '. 79 'or conjugate batch (conjugate gradients) algorithms' 80 ); 81 } 82 83 if (!in_array($cost, ['log', 'sse'], true)) { 84 throw new InvalidArgumentException( 85 "Logistic regression cost function can be one of the following: \n". 86 "'log' for log-likelihood and 'sse' for sum of squared errors" 87 ); 88 } 89 90 if ($penalty !== '' && strtoupper($penalty) !== 'L2') { 91 throw new InvalidArgumentException('Logistic regression supports only \'L2\' regularization'); 92 } 93 94 $this->learningRate = 0.001; 95 96 parent::__construct($this->learningRate, $maxIterations, $normalizeInputs); 97 98 $this->trainingType = $trainingType; 99 $this->costFunction = $cost; 100 $this->penalty = $penalty; 101 } 102 103 /** 104 * Sets the learning rate if gradient descent algorithm is 105 * selected for training 106 */ 107 public function setLearningRate(float $learningRate): void 108 { 109 $this->learningRate = $learningRate; 110 } 111 112 /** 113 * Lambda (λ) parameter of regularization term. If 0 is given, 114 * then the regularization term is cancelled 115 */ 116 public function setLambda(float $lambda): void 117 { 118 $this->lambda = $lambda; 119 } 120 121 /** 122 * Adapts the weights with respect to given samples and targets 123 * by use of selected solver 124 * 125 * @throws \Exception 126 */ 127 protected function runTraining(array $samples, array $targets): void 128 { 129 $callback = $this->getCostFunction(); 130 131 switch ($this->trainingType) { 132 case self::BATCH_TRAINING: 133 $this->runGradientDescent($samples, $targets, $callback, true); 134 135 return; 136 137 case self::ONLINE_TRAINING: 138 $this->runGradientDescent($samples, $targets, $callback, false); 139 140 return; 141 142 case self::CONJUGATE_GRAD_TRAINING: 143 $this->runConjugateGradient($samples, $targets, $callback); 144 145 return; 146 147 default: 148 // Not reached 149 throw new Exception(sprintf('Logistic regression has invalid training type: %d.', $this->trainingType)); 150 } 151 } 152 153 /** 154 * Executes Conjugate Gradient method to optimize the weights of the LogReg model 155 */ 156 protected function runConjugateGradient(array $samples, array $targets, Closure $gradientFunc): void 157 { 158 if ($this->optimizer === null) { 159 $this->optimizer = (new ConjugateGradient($this->featureCount)) 160 ->setMaxIterations($this->maxIterations); 161 } 162 163 $this->weights = $this->optimizer->runOptimization($samples, $targets, $gradientFunc); 164 $this->costValues = $this->optimizer->getCostValues(); 165 } 166 167 /** 168 * Returns the appropriate callback function for the selected cost function 169 * 170 * @throws \Exception 171 */ 172 protected function getCostFunction(): Closure 173 { 174 $penalty = 0; 175 if ($this->penalty === 'L2') { 176 $penalty = $this->lambda; 177 } 178 179 switch ($this->costFunction) { 180 case 'log': 181 /* 182 * Negative of Log-likelihood cost function to be minimized: 183 * J(x) = ∑( - y . log(h(x)) - (1 - y) . log(1 - h(x))) 184 * 185 * If regularization term is given, then it will be added to the cost: 186 * for L2 : J(x) = J(x) + λ/m . w 187 * 188 * The gradient of the cost function to be used with gradient descent: 189 * ∇J(x) = -(y - h(x)) = (h(x) - y) 190 */ 191 return function ($weights, $sample, $y) use ($penalty) { 192 $this->weights = $weights; 193 $hX = $this->output($sample); 194 195 // In cases where $hX = 1 or $hX = 0, the log-likelihood 196 // value will give a NaN, so we fix these values 197 if ($hX == 1) { 198 $hX = 1 - 1e-10; 199 } 200 201 if ($hX == 0) { 202 $hX = 1e-10; 203 } 204 205 $y = $y < 0 ? 0 : 1; 206 207 $error = -$y * log($hX) - (1 - $y) * log(1 - $hX); 208 $gradient = $hX - $y; 209 210 return [$error, $gradient, $penalty]; 211 }; 212 case 'sse': 213 /* 214 * Sum of squared errors or least squared errors cost function: 215 * J(x) = ∑ (y - h(x))^2 216 * 217 * If regularization term is given, then it will be added to the cost: 218 * for L2 : J(x) = J(x) + λ/m . w 219 * 220 * The gradient of the cost function: 221 * ∇J(x) = -(h(x) - y) . h(x) . (1 - h(x)) 222 */ 223 return function ($weights, $sample, $y) use ($penalty) { 224 $this->weights = $weights; 225 $hX = $this->output($sample); 226 227 $y = $y < 0 ? 0 : 1; 228 229 $error = ($y - $hX) ** 2; 230 $gradient = -($y - $hX) * $hX * (1 - $hX); 231 232 return [$error, $gradient, $penalty]; 233 }; 234 default: 235 // Not reached 236 throw new Exception(sprintf('Logistic regression has invalid cost function: %s.', $this->costFunction)); 237 } 238 } 239 240 /** 241 * Returns the output of the network, a float value between 0.0 and 1.0 242 */ 243 protected function output(array $sample): float 244 { 245 $sum = parent::output($sample); 246 247 return 1.0 / (1.0 + exp(-$sum)); 248 } 249 250 /** 251 * Returns the class value (either -1 or 1) for the given input 252 */ 253 protected function outputClass(array $sample): int 254 { 255 $output = $this->output($sample); 256 257 if ($output > 0.5) { 258 return 1; 259 } 260 261 return -1; 262 } 263 264 /** 265 * Returns the probability of the sample of belonging to the given label. 266 * 267 * The probability is simply taken as the distance of the sample 268 * to the decision plane. 269 * 270 * @param mixed $label 271 */ 272 protected function predictProbability(array $sample, $label): float 273 { 274 $sample = $this->checkNormalizedSample($sample); 275 $probability = $this->output($sample); 276 277 if (array_search($label, $this->labels, true) > 0) { 278 return $probability; 279 } 280 281 return 1 - $probability; 282 } 283 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body