[GRASS-SVN] r71003 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Tue May 2 15:34:41 PDT 2017
Author: spawley
Date: 2017-05-02 15:34:41 -0700 (Tue, 02 May 2017)
New Revision: 71003
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
Log:
r.learn.ml added toggle for nested cross validation
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-05-02 21:58:07 UTC (rev 71002)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-05-02 22:34:41 UTC (rev 71003)
@@ -248,8 +248,15 @@
#%end
#%flag
+#% key: n
+#% label: Use nested cross validation
+#% description: Use nested cross validation as part of hyperparameter tuning
+#% guisection: Cross validation
+#%end
+
+#%flag
#% key: f
-#% label: Estimator permutation-based feature importances
+#% label: Estimate permutation-based feature importances
#% description: Estimate feature importance using a permutation-based method
#% guisection: Cross validation
#%end
@@ -423,7 +430,12 @@
for rast in tmp_rast:
gscript.run_command("g.remove", name=rast, type='raster', flags='f', quiet=True)
+def warn(*args, **kwargs):
+ pass
+import warnings
+warnings.warn = warn
+
def main():
try:
from sklearn.externals import joblib
@@ -436,9 +448,7 @@
from sklearn.pipeline import Pipeline
from sklearn.utils import shuffle
from sklearn import metrics
- from sklearn.metrics import make_scorer, confusion_matrix
- import warnings
- warnings.filterwarnings('ignore')
+ from sklearn.metrics import make_scorer
except:
gscript.fatal("Scikit learn 0.18 or newer is not installed")
@@ -479,6 +489,7 @@
tune_only = flags['t']
predict_resamples = flags['r']
importances = flags['f']
+ nested_cv = flags['n']
n_permutations = int(options['n_permutations'])
errors_file = options['errors_file']
preds_file = options['preds_file']
@@ -776,6 +787,8 @@
if param_file != '':
param_df = pd.DataFrame(clf.cv_results_)
param_df.to_csv(param_file)
+ if nested_cv is False:
+ clf = clf.best_estimator_
# ---------------------------------------------------------------------
# cross-validation
Modified: grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py 2017-05-02 21:58:07 UTC (rev 71002)
+++ grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py 2017-05-02 22:34:41 UTC (rev 71003)
@@ -243,18 +243,13 @@
trains, tests = [], []
for train_indices, test_indices in k_fold:
- trains.append(deepcopy(train_indices))
- tests.append(deepcopy(test_indices))
+ trains.append(train_indices)
+ tests.append(test_indices)
# -------------------------------------------------------------------------
# Perform multiprocessing fitting of clf on each fold
# -------------------------------------------------------------------------
- # Multiprocessing-backed parallel loops cannot be nested, setting n_jobs=1
- if isinstance(clf, (GridSearchCV, RandomizedSearchCV)):
- n_jobs = 1
- print(n_jobs)
-
clf_resamples = Parallel(n_jobs=n_jobs)(
delayed(parallel_fit)(clf, X, y, groups, train_indices,
test_indices, sample_weight)
More information about the grass-commit
mailing list