Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.3.x will end 7 October 2024 (12 months).
  • Bug fixes for security issues in 4.3.x will end 21 April 2025 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.2.x is supported too.
   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\Classification;
   6  
   7  use Phpml\Exception\InvalidArgumentException;
   8  use Phpml\Helper\Predictable;
   9  use Phpml\Helper\Trainable;
  10  use Phpml\Math\Statistic\Mean;
  11  use Phpml\Math\Statistic\StandardDeviation;
  12  
  13  class NaiveBayes implements Classifier
  14  {
  15      use Trainable;
  16      use Predictable;
  17  
  18      public const CONTINUOS = 1;
  19  
  20      public const NOMINAL = 2;
  21  
  22      public const EPSILON = 1e-10;
  23  
  24      /**
  25       * @var array
  26       */
  27      private $std = [];
  28  
  29      /**
  30       * @var array
  31       */
  32      private $mean = [];
  33  
  34      /**
  35       * @var array
  36       */
  37      private $discreteProb = [];
  38  
  39      /**
  40       * @var array
  41       */
  42      private $dataType = [];
  43  
  44      /**
  45       * @var array
  46       */
  47      private $p = [];
  48  
  49      /**
  50       * @var int
  51       */
  52      private $sampleCount = 0;
  53  
  54      /**
  55       * @var int
  56       */
  57      private $featureCount = 0;
  58  
  59      /**
  60       * @var array
  61       */
  62      private $labels = [];
  63  
  64      public function train(array $samples, array $targets): void
  65      {
  66          $this->samples = array_merge($this->samples, $samples);
  67          $this->targets = array_merge($this->targets, $targets);
  68          $this->sampleCount = count($this->samples);
  69          $this->featureCount = count($this->samples[0]);
  70  
  71          $this->labels = array_map('strval', array_flip(array_flip($this->targets)));
  72          foreach ($this->labels as $label) {
  73              $samples = $this->getSamplesByLabel($label);
  74              $this->p[$label] = count($samples) / $this->sampleCount;
  75              $this->calculateStatistics($label, $samples);
  76          }
  77      }
  78  
  79      /**
  80       * @return mixed
  81       */
  82      protected function predictSample(array $sample)
  83      {
  84          // Use NaiveBayes assumption for each label using:
  85          //	 P(label|features) = P(label) * P(feature0|label) * P(feature1|label) .... P(featureN|label)
  86          // Then compare probability for each class to determine which label is most likely
  87          $predictions = [];
  88          foreach ($this->labels as $label) {
  89              $p = $this->p[$label];
  90              for ($i = 0; $i < $this->featureCount; ++$i) {
  91                  $Plf = $this->sampleProbability($sample, $i, $label);
  92                  $p += $Plf;
  93              }
  94  
  95              $predictions[$label] = $p;
  96          }
  97  
  98          arsort($predictions, SORT_NUMERIC);
  99          reset($predictions);
 100  
 101          return key($predictions);
 102      }
 103  
 104      /**
 105       * Calculates vital statistics for each label & feature. Stores these
 106       * values in private array in order to avoid repeated calculation
 107       */
 108      private function calculateStatistics(string $label, array $samples): void
 109      {
 110          $this->std[$label] = array_fill(0, $this->featureCount, 0);
 111          $this->mean[$label] = array_fill(0, $this->featureCount, 0);
 112          $this->dataType[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
 113          $this->discreteProb[$label] = array_fill(0, $this->featureCount, self::CONTINUOS);
 114          for ($i = 0; $i < $this->featureCount; ++$i) {
 115              // Get the values of nth column in the samples array
 116              // Mean::arithmetic is called twice, can be optimized
 117              $values = array_column($samples, $i);
 118              $numValues = count($values);
 119              // if the values contain non-numeric data,
 120              // then it should be treated as nominal/categorical/discrete column
 121              if ($values !== array_filter($values, 'is_numeric')) {
 122                  $this->dataType[$label][$i] = self::NOMINAL;
 123                  $this->discreteProb[$label][$i] = array_count_values($values);
 124                  $db = &$this->discreteProb[$label][$i];
 125                  $db = array_map(function ($el) use ($numValues) {
 126                      return $el / $numValues;
 127                  }, $db);
 128              } else {
 129                  $this->mean[$label][$i] = Mean::arithmetic($values);
 130                  // Add epsilon in order to avoid zero stdev
 131                  $this->std[$label][$i] = 1e-10 + StandardDeviation::population($values, false);
 132              }
 133          }
 134      }
 135  
 136      /**
 137       * Calculates the probability P(label|sample_n)
 138       */
 139      private function sampleProbability(array $sample, int $feature, string $label): float
 140      {
 141          if (!isset($sample[$feature])) {
 142              throw new InvalidArgumentException('Missing feature. All samples must have equal number of features');
 143          }
 144  
 145          $value = $sample[$feature];
 146          if ($this->dataType[$label][$feature] == self::NOMINAL) {
 147              if (!isset($this->discreteProb[$label][$feature][$value]) ||
 148                  $this->discreteProb[$label][$feature][$value] == 0) {
 149                  return self::EPSILON;
 150              }
 151  
 152              return $this->discreteProb[$label][$feature][$value];
 153          }
 154  
 155          $std = $this->std[$label][$feature];
 156          $mean = $this->mean[$label][$feature];
 157          // Calculate the probability density by use of normal/Gaussian distribution
 158          // Ref: https://en.wikipedia.org/wiki/Normal_distribution
 159          //
 160          // In order to avoid numerical errors because of small or zero values,
 161          // some libraries adopt taking log of calculations such as
 162          // scikit-learn did.
 163          // (See : https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/naive_bayes.py)
 164          $pdf = -0.5 * log(2.0 * M_PI * $std * $std);
 165          $pdf -= 0.5 * (($value - $mean) ** 2) / ($std * $std);
 166  
 167          return $pdf;
 168      }
 169  
 170      /**
 171       * Return samples belonging to specific label
 172       */
 173      private function getSamplesByLabel(string $label): array
 174      {
 175          $samples = [];
 176          for ($i = 0; $i < $this->sampleCount; ++$i) {
 177              if ($this->targets[$i] == $label) {
 178                  $samples[] = $this->samples[$i];
 179              }
 180          }
 181  
 182          return $samples;
 183      }
 184  }