[GRASS-SVN] r71990 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Wed Dec 27 23:08:31 PST 2017
Author: spawley
Date: 2017-12-27 23:08:31 -0800 (Wed, 27 Dec 2017)
New Revision: 71990
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml removal of nested cross validation option
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-12-28 06:38:22 UTC (rev 71989)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-12-28 07:08:31 UTC (rev 71990)
@@ -253,13 +253,6 @@
#%end
#%flag
-#% key: n
-#% label: Use nested cross validation
-#% description: Use nested cross validation as part of hyperparameter tuning
-#% guisection: Cross validation
-#%end
-
-#%flag
#% key: f
#% label: Estimate permutation-based feature importances
#% description: Estimate feature importance using a permutation-based method
@@ -667,20 +660,17 @@
def save_model(estimator, X, y, sample_coords, groups, filename):
from sklearn.externals import joblib
-
joblib.dump((estimator, X, y, sample_coords, group_id), filename)
def load_model(filename):
from sklearn.externals import joblib
-
estimator, X, y, sample_coords, groups = joblib.load(filename)
return (estimator, X, y, sample_coords, groups)
def extract_pixels(response, predictors, lowmem=False, na_rm=False):
"""
-
Samples a list of GRASS rasters using a labelled raster
Per raster sampling
@@ -696,9 +686,8 @@
training_data (2d numpy array): Extracted raster values
training_labels (1d numpy array): Numpy array of labels
is_train (2d numpy array): Row and Columns of label positions
-
"""
-
+
from grass.pygrass.utils import pixel2coor
current = Region()
@@ -1479,7 +1468,6 @@
tune_only = flags['t']
predict_resamples = flags['r']
importances = flags['f']
- nested_cv = flags['n']
n_permutations = int(options['n_permutations'])
errors_file = options['errors_file']
preds_file = options['preds_file']
@@ -1795,8 +1783,6 @@
if param_file != '':
param_df = pd.DataFrame(clf.cv_results_)
param_df.to_csv(param_file)
- if nested_cv is False:
- clf = clf.best_estimator_
# ---------------------------------------------------------------------
# cross-validation
@@ -1805,7 +1791,7 @@
# If cv > 1 then use cross-validation to generate performance measures
if cv > 1 and tune_only is not True:
if mode == 'classification' and cv > np.histogram(
- y, bins=np.unique(y))[0].min():
+ y, bins=np.unique(y))[0].min():
gs.message(os.linesep)
gs.message('Number of cv folds is greater than number of '
'samples in some classes. Cross-validation is being'
More information about the grass-commit
mailing list