1 <?php 2 3 declare(strict_types=1); 4 5 namespace Phpml\CrossValidation; 6 7 use Phpml\Dataset\ArrayDataset; 8 use Phpml\Dataset\Dataset; 9 10 class StratifiedRandomSplit extends RandomSplit 11 { 12 protected function splitDataset(Dataset $dataset, float $testSize): void 13 { 14 $datasets = $this->splitByTarget($dataset); 15 16 foreach ($datasets as $targetSet) { 17 parent::splitDataset($targetSet, $testSize); 18 } 19 } 20 21 /** 22 * @return Dataset[] 23 */ 24 private function splitByTarget(Dataset $dataset): array 25 { 26 $targets = $dataset->getTargets(); 27 $samples = $dataset->getSamples(); 28 29 $uniqueTargets = array_unique($targets); 30 /** @var array $split */ 31 $split = array_combine($uniqueTargets, array_fill(0, count($uniqueTargets), [])); 32 33 foreach ($samples as $key => $sample) { 34 $split[$targets[$key]][] = $sample; 35 } 36 37 return $this->createDatasets($uniqueTargets, $split); 38 } 39 40 private function createDatasets(array $uniqueTargets, array $split): array 41 { 42 $datasets = []; 43 foreach ($uniqueTargets as $target) { 44 $datasets[$target] = new ArrayDataset($split[$target], array_fill(0, count($split[$target]), $target)); 45 } 46 47 return $datasets; 48 } 49 }
title
Description
Body
title
Description
Body
title
Description
Body
title
Body