[GRASS-SVN] r73911 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Fri Jan 4 14:30:22 PST 2019


Author: spawley
Date: 2019-01-04 14:30:21 -0800 (Fri, 04 Jan 2019)
New Revision: 73911

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
Bug fix with cross validation attempting to use StratifiedKFold method when using regressors, switched to KFold for continuous target variables

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2019-01-04 19:39:53 UTC (rev 73910)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2019-01-04 22:30:21 UTC (rev 73911)
@@ -1424,7 +1424,7 @@
         from sklearn.preprocessing import StandardScaler
         from sklearn.model_selection import (
             GridSearchCV, GroupShuffleSplit, ShuffleSplit,
-            StratifiedKFold, GroupKFold)
+            StratifiedKFold, GroupKFold, KFold)
         from sklearn.preprocessing import OneHotEncoder
         from sklearn.pipeline import Pipeline
         from sklearn.utils import shuffle
@@ -1671,8 +1671,10 @@
 
         # define inner resampling using cross-validation method
         elif any(param_grid) is True and grid_search == 'cross-validation':
-            if group_id is None:
+            if group_id is None and mode == 'classification':
                 inner = StratifiedKFold(n_splits=cv, random_state=random_state)
+            elif group_id is None and mode == 'regression':
+                inner = KFold(n_splits=cv, random_state=random_state)
             else:
                 inner = GroupKFold(n_splits=cv)
 
@@ -1691,8 +1693,10 @@
         # define the outer search resampling method
         # ---------------------------------------------------------------------
         if cv > 1:
-            if group_id is None:
+            if group_id is None and mode == 'classification':
                 outer = StratifiedKFold(n_splits=cv, random_state=random_state)
+            elif group_id is None and mode == 'regression':
+                outer = KFold(n_splits=cv, random_state=random_state)
             else:
                 outer = GroupKFold(n_splits=cv)
 



More information about the grass-commit mailing list