[GRASS-SVN] r73911 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Fri Jan 4 14:30:22 PST 2019
Author: spawley
Date: 2019-01-04 14:30:21 -0800 (Fri, 04 Jan 2019)
New Revision: 73911
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
Bug fix with cross validation attempting to use StratifiedKFold method when using regressors, switched to KFold for continuous target variables
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2019-01-04 19:39:53 UTC (rev 73910)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2019-01-04 22:30:21 UTC (rev 73911)
@@ -1424,7 +1424,7 @@
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import (
GridSearchCV, GroupShuffleSplit, ShuffleSplit,
- StratifiedKFold, GroupKFold)
+ StratifiedKFold, GroupKFold, KFold)
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.utils import shuffle
@@ -1671,8 +1671,10 @@
# define inner resampling using cross-validation method
elif any(param_grid) is True and grid_search == 'cross-validation':
- if group_id is None:
+ if group_id is None and mode == 'classification':
inner = StratifiedKFold(n_splits=cv, random_state=random_state)
+ elif group_id is None and mode == 'regression':
+ inner = KFold(n_splits=cv, random_state=random_state)
else:
inner = GroupKFold(n_splits=cv)
@@ -1691,8 +1693,10 @@
# define the outer search resampling method
# ---------------------------------------------------------------------
if cv > 1:
- if group_id is None:
+ if group_id is None and mode == 'classification':
outer = StratifiedKFold(n_splits=cv, random_state=random_state)
+ elif group_id is None and mode == 'regression':
+ outer = KFold(n_splits=cv, random_state=random_state)
else:
outer = GroupKFold(n_splits=cv)
More information about the grass-commit
mailing list