1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\Classification; 6 7 use Phpml\Exception\InvalidArgumentException; 8 use Phpml\Helper\Predictable; 9 use Phpml\Helper\Trainable; 10 use Phpml\Math\Statistic\Mean; 11 use Phpml\Math\Statistic\StandardDeviation; 12 13 class NaiveBayes implements Classifier 14 { 15 use Trainable; 16 use Predictable; 17 18 public const CONTINUOS = 1; 19 20 public const NOMINAL = 2; 21 22 public const EPSILON = 1e-10; 23 24 /** 25 * @var array 26 */ 27 private $std = []; 28 29 /** 30 * @var array 31 */ 32 private $mean = []; 33 34 /** 35 * @var array 36 */ 37 private $discreteProb = []; 38 39 /** 40 * @var array 41 */ 42 private $dataType = []; 43 44 /** 45 * @var array 46 */ 47 private $p = []; 48 49 /** 50 * @var int 51 */ 52 private $sampleCount = 0; 53 54 /** 55 * @var int 56 */ 57 private $featureCount = 0; 58 59 /** 60 * @var array 61 */ 62 private $labels = []; 63 64 public function train(array $samples, array $targets): void 65 { 66 $this->samples = array_merge($this->samples, $samples); 67 $this->targets = array_merge($this->targets, $targets); 68 $this->sampleCount = count($this->samples); 69 $this->featureCount = count($this->samples[0]); 70 71 $this->labels = array_map('strval', array_flip(array_flip($this->targets))); 72 foreach ($this->labels as $label) { 73 $samples = $this->getSamplesByLabel($label); 74 $this->p[$label] = count($samples) / $this->sampleCount; 75 $this->calculateStatistics($label, $samples); 76 } 77 } 78 79 /** 80 * @return mixed 81 */ 82 protected function predictSample(array $sample) 83 { 84 // Use NaiveBayes assumption for each label using: 85 // P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label) 86 // Then compare probability for each class to determine which label is most likely 87 $predictions = []; 88 foreach ($this->labels as $label) { 89 $p = $this->p[$label]; 90 for ($i = 0; $i < $this->featureCount; ++$i) { 91 $Plf = $this->sampleProbability($sample, $i, $label); 92 $p += $Plf; 93 } 94 95 $predictions[$label] = $p; 96 } 97 98 arsort($predictions, SORT_NUMERIC); 99 reset($predictions); 100 101 return key($predictions); 102 } 103 104 /** 105 * Calculates vital statistics for each label & feature. Stores these 106 * values in private array in order to avoid repeated calculation 107 */ 108 private function calculateStatistics(string $label, array $samples): void 109 { 110 $this->std[$label] = array_fill(0, $this->featureCount, 0); 111 $this->mean[$label] = array_fill(0, $this->featureCount, 0); 112 $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); 113 $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS); 114 for ($i = 0; $i < $this->featureCount; ++$i) { 115 // Get the values of nth column in the samples array 116 // Mean::arithmetic is called twice, can be optimized 117 $values = array_column($samples, $i); 118 $numValues = count($values); 119 // if the values contain non-numeric data, 120 // then it should be treated as nominal/categorical/discrete column 121 if ($values !== array_filter($values, 'is_numeric')) { 122 $this->dataType[$label][$i] = self::NOMINAL; 123 $this->discreteProb[$label][$i] = array_count_values($values); 124 $db = &$this->discreteProb[$label][$i]; 125 $db = array_map(function ($el) use ($numValues) { 126 return $el / $numValues; 127 }, $db); 128 } else { 129 $this->mean[$label][$i] = Mean::arithmetic($values); 130 // Add epsilon in order to avoid zero stdev 131 $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false); 132 } 133 } 134 } 135 136 /** 137 * Calculates the probability P(label|sample_n) 138 */ 139 private function sampleProbability(array $sample, int $feature, string $label): float 140 { 141 if (!isset($sample[$feature])) { 142 throw new InvalidArgumentException('Missing feature. All samples must have equal number of features'); 143 } 144 145 $value = $sample[$feature]; 146 if ($this->dataType[$label][$feature] == self::NOMINAL) { 147 if (!isset($this->discreteProb[$label][$feature][$value]) || 148 $this->discreteProb[$label][$feature][$value] == 0) { 149 return self::EPSILON; 150 } 151 152 return $this->discreteProb[$label][$feature][$value]; 153 } 154 155 $std = $this->std[$label][$feature]; 156 $mean = $this->mean[$label][$feature]; 157 // Calculate the probability density by use of normal/Gaussian distribution 158 // Ref: https://en.wikipedia.org/wiki/Normal_distribution 159 // 160 // In order to avoid numerical errors because of small or zero values, 161 // some libraries adopt taking log of calculations such as 162 // scikit-learn did. 163 // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py) 164 $pdf = -0.5 * log(2.0 * M_PI * $std * $std); 165 $pdf -= 0.5 * (($value - $mean) ** 2) / ($std * $std); 166 167 return $pdf; 168 } 169 170 /** 171 * Return samples belonging to specific label 172 */ 173 private function getSamplesByLabel(string $label): array 174 { 175 $samples = []; 176 for ($i = 0; $i < $this->sampleCount; ++$i) { 177 if ($this->targets[$i] == $label) { 178 $samples[] = $this->samples[$i]; 179 } 180 } 181 182 return $samples; 183 } 184 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body