Search moodle.org's
Developer Documentation

See Release Notes

  • Bug fixes for general core bugs in 4.2.x will end 22 April 2024 (12 months).
  • Bug fixes for security issues in 4.2.x will end 7 October 2024 (18 months).
  • PHP version: minimum PHP 8.0.0 Note: minimum PHP version has increased since Moodle 4.1. PHP 8.1.x is supported too.
   1  <?php
   2  
   3  declare(strict_types=1);
   4  
   5  namespace Phpml\SupportVectorMachine;
   6  
   7  use Phpml\Exception\InvalidArgumentException;
   8  use Phpml\Exception\InvalidOperationException;
   9  use Phpml\Exception\LibsvmCommandException;
  10  use Phpml\Helper\Trainable;
  11  
  12  class SupportVectorMachine
  13  {
  14      use Trainable;
  15  
  16      /**
  17       * @var int
  18       */
  19      private $type;
  20  
  21      /**
  22       * @var int
  23       */
  24      private $kernel;
  25  
  26      /**
  27       * @var float
  28       */
  29      private $cost;
  30  
  31      /**
  32       * @var float
  33       */
  34      private $nu;
  35  
  36      /**
  37       * @var int
  38       */
  39      private $degree;
  40  
  41      /**
  42       * @var float|null
  43       */
  44      private $gamma;
  45  
  46      /**
  47       * @var float
  48       */
  49      private $coef0;
  50  
  51      /**
  52       * @var float
  53       */
  54      private $epsilon;
  55  
  56      /**
  57       * @var float
  58       */
  59      private $tolerance;
  60  
  61      /**
  62       * @var int
  63       */
  64      private $cacheSize;
  65  
  66      /**
  67       * @var bool
  68       */
  69      private $shrinking;
  70  
  71      /**
  72       * @var bool
  73       */
  74      private $probabilityEstimates;
  75  
  76      /**
  77       * @var string
  78       */
  79      private $binPath;
  80  
  81      /**
  82       * @var string
  83       */
  84      private $varPath;
  85  
  86      /**
  87       * @var string
  88       */
  89      private $model;
  90  
  91      /**
  92       * @var array
  93       */
  94      private $targets = [];
  95  
  96      public function __construct(
  97          int $type,
  98          int $kernel,
  99          float $cost = 1.0,
 100          float $nu = 0.5,
 101          int $degree = 3,
 102          ?float $gamma = null,
 103          float $coef0 = 0.0,
 104          float $epsilon = 0.1,
 105          float $tolerance = 0.001,
 106          int $cacheSize = 100,
 107          bool $shrinking = true,
 108          bool $probabilityEstimates = false
 109      ) {
 110          $this->type = $type;
 111          $this->kernel = $kernel;
 112          $this->cost = $cost;
 113          $this->nu = $nu;
 114          $this->degree = $degree;
 115          $this->gamma = $gamma;
 116          $this->coef0 = $coef0;
 117          $this->epsilon = $epsilon;
 118          $this->tolerance = $tolerance;
 119          $this->cacheSize = $cacheSize;
 120          $this->shrinking = $shrinking;
 121          $this->probabilityEstimates = $probabilityEstimates;
 122  
 123          $rootPath = realpath(implode(DIRECTORY_SEPARATOR, [__DIR__, '..', '..'])).DIRECTORY_SEPARATOR;
 124  
 125          $this->binPath = $rootPath.'bin'.DIRECTORY_SEPARATOR.'libsvm'.DIRECTORY_SEPARATOR;
 126          $this->varPath = $rootPath.'var'.DIRECTORY_SEPARATOR;
 127      }
 128  
 129      public function setBinPath(string $binPath): void
 130      {
 131          $this->ensureDirectorySeparator($binPath);
 132          $this->verifyBinPath($binPath);
 133  
 134          $this->binPath = $binPath;
 135      }
 136  
 137      public function setVarPath(string $varPath): void
 138      {
 139          if (!is_writable($varPath)) {
 140              throw new InvalidArgumentException(sprintf('The specified path "%s" is not writable', $varPath));
 141          }
 142  
 143          $this->ensureDirectorySeparator($varPath);
 144          $this->varPath = $varPath;
 145      }
 146  
 147      public function train(array $samples, array $targets): void
 148      {
 149          $this->samples = array_merge($this->samples, $samples);
 150          $this->targets = array_merge($this->targets, $targets);
 151  
 152          $trainingSet = DataTransformer::trainingSet($this->samples, $this->targets, in_array($this->type, [Type::EPSILON_SVR, Type::NU_SVR], true));
 153          file_put_contents($trainingSetFileName = $this->varPath.uniqid('phpml', true), $trainingSet);
 154          $modelFileName = $trainingSetFileName.'-model';
 155  
 156          $command = $this->buildTrainCommand($trainingSetFileName, $modelFileName);
 157          $output = [];
 158          exec(escapeshellcmd($command).' 2>&1', $output, $return);
 159  
 160          unlink($trainingSetFileName);
 161  
 162          if ($return !== 0) {
 163              throw new LibsvmCommandException(
 164                  sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
 165              );
 166          }
 167  
 168          $this->model = (string) file_get_contents($modelFileName);
 169  
 170          unlink($modelFileName);
 171      }
 172  
 173      public function getModel(): string
 174      {
 175          return $this->model;
 176      }
 177  
 178      /**
 179       * @return array|string
 180       *
 181       * @throws LibsvmCommandException
 182       */
 183      public function predict(array $samples)
 184      {
 185          $predictions = $this->runSvmPredict($samples, false);
 186  
 187          if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
 188              $predictions = DataTransformer::predictions($predictions, $this->targets);
 189          } else {
 190              $predictions = explode(PHP_EOL, trim($predictions));
 191          }
 192  
 193          if (!is_array($samples[0])) {
 194              return $predictions[0];
 195          }
 196  
 197          return $predictions;
 198      }
 199  
 200      /**
 201       * @return array|string
 202       *
 203       * @throws LibsvmCommandException
 204       */
 205      public function predictProbability(array $samples)
 206      {
 207          if (!$this->probabilityEstimates) {
 208              throw new InvalidOperationException('Model does not support probabiliy estimates');
 209          }
 210  
 211          $predictions = $this->runSvmPredict($samples, true);
 212  
 213          if (in_array($this->type, [Type::C_SVC, Type::NU_SVC], true)) {
 214              $predictions = DataTransformer::probabilities($predictions, $this->targets);
 215          } else {
 216              $predictions = explode(PHP_EOL, trim($predictions));
 217          }
 218  
 219          if (!is_array($samples[0])) {
 220              return $predictions[0];
 221          }
 222  
 223          return $predictions;
 224      }
 225  
 226      private function runSvmPredict(array $samples, bool $probabilityEstimates): string
 227      {
 228          $testSet = DataTransformer::testSet($samples);
 229          file_put_contents($testSetFileName = $this->varPath.uniqid('phpml', true), $testSet);
 230          file_put_contents($modelFileName = $testSetFileName.'-model', $this->model);
 231          $outputFileName = $testSetFileName.'-output';
 232  
 233          $command = $this->buildPredictCommand(
 234              $testSetFileName,
 235              $modelFileName,
 236              $outputFileName,
 237              $probabilityEstimates
 238          );
 239          $output = [];
 240          exec(escapeshellcmd($command).' 2>&1', $output, $return);
 241  
 242          unlink($testSetFileName);
 243          unlink($modelFileName);
 244          $predictions = (string) file_get_contents($outputFileName);
 245  
 246          unlink($outputFileName);
 247  
 248          if ($return !== 0) {
 249              throw new LibsvmCommandException(
 250                  sprintf('Failed running libsvm command: "%s" with reason: "%s"', $command, array_pop($output))
 251              );
 252          }
 253  
 254          return $predictions;
 255      }
 256  
 257      private function getOSExtension(): string
 258      {
 259          $os = strtoupper(substr(PHP_OS, 0, 3));
 260          if ($os === 'WIN') {
 261              return '.exe';
 262          } elseif ($os === 'DAR') {
 263              return '-osx';
 264          }
 265  
 266          return '';
 267      }
 268  
 269      private function buildTrainCommand(string $trainingSetFileName, string $modelFileName): string
 270      {
 271          return sprintf(
 272              '%ssvm-train%s -s %s -t %s -c %s -n %F -d %s%s -r %s -p %F -m %F -e %F -h %d -b %d %s %s',
 273              $this->binPath,
 274              $this->getOSExtension(),
 275              $this->type,
 276              $this->kernel,
 277              $this->cost,
 278              $this->nu,
 279              $this->degree,
 280              $this->gamma !== null ? ' -g '.$this->gamma : '',
 281              $this->coef0,
 282              $this->epsilon,
 283              $this->cacheSize,
 284              $this->tolerance,
 285              $this->shrinking,
 286              $this->probabilityEstimates,
 287              escapeshellarg($trainingSetFileName),
 288              escapeshellarg($modelFileName)
 289          );
 290      }
 291  
 292      private function buildPredictCommand(
 293          string $testSetFileName,
 294          string $modelFileName,
 295          string $outputFileName,
 296          bool $probabilityEstimates
 297      ): string {
 298          return sprintf(
 299              '%ssvm-predict%s -b %d %s %s %s',
 300              $this->binPath,
 301              $this->getOSExtension(),
 302              $probabilityEstimates ? 1 : 0,
 303              escapeshellarg($testSetFileName),
 304              escapeshellarg($modelFileName),
 305              escapeshellarg($outputFileName)
 306          );
 307      }
 308  
 309      private function ensureDirectorySeparator(string &$path): void
 310      {
 311          if (substr($path, -1) !== DIRECTORY_SEPARATOR) {
 312              $path .= DIRECTORY_SEPARATOR;
 313          }
 314      }
 315  
 316      private function verifyBinPath(string $path): void
 317      {
 318          if (!is_dir($path)) {
 319              throw new InvalidArgumentException(sprintf('The specified path "%s" does not exist', $path));
 320          }
 321  
 322          $osExtension = $this->getOSExtension();
 323          foreach (['svm-predict', 'svm-scale', 'svm-train'] as $filename) {
 324              $filePath = $path.$filename.$osExtension;
 325              if (!file_exists($filePath)) {
 326                  throw new InvalidArgumentException(sprintf('File "%s" not found', $filePath));
 327              }
 328  
 329              if (!is_executable($filePath)) {
 330                  throw new InvalidArgumentException(sprintf('File "%s" is not executable', $filePath));
 331              }
 332          }
 333      }
 334  }