[GRASS-SVN] r71003 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Tue May 2 15:34:41 PDT 2017


Author: spawley
Date: 2017-05-02 15:34:41 -0700 (Tue, 02 May 2017)
New Revision: 71003

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
   grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
Log:
r.learn.ml added toggle for nested cross validation

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-05-02 21:58:07 UTC (rev 71002)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-05-02 22:34:41 UTC (rev 71003)
@@ -248,8 +248,15 @@
 #%end
 
 #%flag
+#% key: n
+#% label: Use nested cross validation
+#% description: Use nested cross validation as part of hyperparameter tuning
+#% guisection: Cross validation
+#%end
+
+#%flag
 #% key: f
-#% label: Estimator permutation-based feature importances
+#% label: Estimate permutation-based feature importances
 #% description: Estimate feature importance using a permutation-based method
 #% guisection: Cross validation
 #%end
@@ -423,7 +430,12 @@
     for rast in tmp_rast:
         gscript.run_command("g.remove", name=rast, type='raster', flags='f', quiet=True)
 
+def warn(*args, **kwargs):
+    pass
 
+import warnings
+warnings.warn = warn
+
 def main():
     try:
         from sklearn.externals import joblib
@@ -436,9 +448,7 @@
         from sklearn.pipeline import Pipeline
         from sklearn.utils import shuffle
         from sklearn import metrics
-        from sklearn.metrics import make_scorer, confusion_matrix
-        import warnings
-        warnings.filterwarnings('ignore')
+        from sklearn.metrics import make_scorer
     except:
         gscript.fatal("Scikit learn 0.18 or newer is not installed")
 
@@ -479,6 +489,7 @@
     tune_only = flags['t']
     predict_resamples = flags['r']
     importances = flags['f']
+    nested_cv = flags['n']
     n_permutations = int(options['n_permutations'])
     errors_file = options['errors_file']
     preds_file = options['preds_file']
@@ -776,6 +787,8 @@
             if param_file != '':
                 param_df = pd.DataFrame(clf.cv_results_)
                 param_df.to_csv(param_file)
+            if nested_cv is False:
+                clf = clf.best_estimator_
 
         # ---------------------------------------------------------------------
         # cross-validation

Modified: grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py	2017-05-02 21:58:07 UTC (rev 71002)
+++ grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py	2017-05-02 22:34:41 UTC (rev 71003)
@@ -243,18 +243,13 @@
 
     trains, tests = [], []
     for train_indices, test_indices in k_fold:
-        trains.append(deepcopy(train_indices))
-        tests.append(deepcopy(test_indices))
+        trains.append(train_indices)
+        tests.append(test_indices)
 
     # -------------------------------------------------------------------------
     # Perform multiprocessing fitting of clf on each fold
     # -------------------------------------------------------------------------
 
-    # Multiprocessing-backed parallel loops cannot be nested, setting n_jobs=1
-    if isinstance(clf, (GridSearchCV, RandomizedSearchCV)):
-        n_jobs = 1
-        print(n_jobs)
-
     clf_resamples = Parallel(n_jobs=n_jobs)(
         delayed(parallel_fit)(clf, X, y, groups, train_indices,
                               test_indices, sample_weight)



More information about the grass-commit mailing list