[GRASS-SVN] r70945 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Apr 24 16:20:24 PDT 2017


Author: spawley
Date: 2017-04-24 16:20:24 -0700 (Mon, 24 Apr 2017)
New Revision: 70945

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml fixed issue with using holdout method for hyperparameter tuning

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-24 20:38:06 UTC (rev 70944)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-24 23:20:24 UTC (rev 70945)
@@ -664,8 +664,8 @@
         gscript.message(os.linesep)
         gscript.message(('Fitting model using ' + classifier))
 
-        # pass groups to fit parameter GroupKFold and param_grid are present
-        if isinstance(inner, GroupKFold) and any(param_grid) is True:
+        # pass groups to fit parameter GroupKFold/GroupShuffleSplit and param_grid are present
+        if isinstance(inner, (GroupKFold, GroupShuffleSplit)) and any(param_grid) is True:
             if balance is True and classifier in (
                     'GradientBoostingClassifier', 'XGBClassifier'):
                 clf.fit(X=X, y=y, groups=group_id, sample_weight=class_weights)



More information about the grass-commit mailing list