[GRASS-SVN] r71990 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Wed Dec 27 23:08:31 PST 2017


Author: spawley
Date: 2017-12-27 23:08:31 -0800 (Wed, 27 Dec 2017)
New Revision: 71990

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml removal of nested cross validation option

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-12-28 06:38:22 UTC (rev 71989)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-12-28 07:08:31 UTC (rev 71990)
@@ -253,13 +253,6 @@
 #%end
 
 #%flag
-#% key: n
-#% label: Use nested cross validation
-#% description: Use nested cross validation as part of hyperparameter tuning
-#% guisection: Cross validation
-#%end
-
-#%flag
 #% key: f
 #% label: Estimate permutation-based feature importances
 #% description: Estimate feature importance using a permutation-based method
@@ -667,20 +660,17 @@
 
 def save_model(estimator, X, y, sample_coords, groups, filename):
     from sklearn.externals import joblib
-
     joblib.dump((estimator, X, y, sample_coords, group_id), filename)
 
 
 def load_model(filename):
     from sklearn.externals import joblib
-    
     estimator, X, y, sample_coords, groups = joblib.load(filename)
 
     return (estimator, X, y, sample_coords, groups)
 
 def extract_pixels(response, predictors, lowmem=False, na_rm=False):
     """
-
     Samples a list of GRASS rasters using a labelled raster
     Per raster sampling
 
@@ -696,9 +686,8 @@
     training_data (2d numpy array): Extracted raster values
     training_labels (1d numpy array): Numpy array of labels
     is_train (2d numpy array): Row and Columns of label positions
-
     """
-    
+
     from grass.pygrass.utils import pixel2coor
 
     current = Region()
@@ -1479,7 +1468,6 @@
     tune_only = flags['t']
     predict_resamples = flags['r']
     importances = flags['f']
-    nested_cv = flags['n']
     n_permutations = int(options['n_permutations'])
     errors_file = options['errors_file']
     preds_file = options['preds_file']
@@ -1795,8 +1783,6 @@
             if param_file != '':
                 param_df = pd.DataFrame(clf.cv_results_)
                 param_df.to_csv(param_file)
-            if nested_cv is False:
-                clf = clf.best_estimator_
 
         # ---------------------------------------------------------------------
         # cross-validation
@@ -1805,7 +1791,7 @@
         # If cv > 1 then use cross-validation to generate performance measures
         if cv > 1 and tune_only is not True:
             if mode == 'classification' and cv > np.histogram(
-		    y, bins=np.unique(y))[0].min():
+                y, bins=np.unique(y))[0].min():
                 gs.message(os.linesep)
                 gs.message('Number of cv folds is greater than number of '
                             'samples in some classes. Cross-validation is being'



More information about the grass-commit mailing list