[GRASS-SVN] r70327 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Jan 9 09:21:56 PST 2017
Author: spawley
Date: 2017-01-09 09:21:56 -0800 (Mon, 09 Jan 2017)
New Revision: 70327
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'bug fix to tuning parameters when using spatial cross-validation'
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-01-09 15:53:30 UTC (rev 70326)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-01-09 17:21:56 UTC (rev 70327)
@@ -225,6 +225,14 @@
#%end
#%option
+#% key: tune_cv
+#% type: integer
+#% description: Number of cross-validation folds used for parameter tuning
+#% answer: 3
+#% guisection: Optional
+#%end
+
+#%option
#% key: n_permutations
#% type: integer
#% description: Number of permutations to perform for feature importances
@@ -356,11 +364,9 @@
def fit(self, param_distribution=None, n_iter=3, scorers='multiclass',
- cv=3, feature_importances=False, n_permutations=1,
+ cv=3, tune_cv=3, feature_importances=False, n_permutations=1,
random_state=None):
- from sklearn.model_selection import RandomizedSearchCV
-
"""
Main fit method for the train object. Performs fitting, hyperparameter
search and cross-validation in one step (inspired from R's CARET)
@@ -372,21 +378,35 @@
n_iter: Number of randomized search iterations
scorers: Suite of metrics to obtain
cv: Number of cross-validation folds
+ tune_cv: Number of cross-validation folds for parameter tuning
feature_importances: Boolean to perform permuatation-based importances
during cross-validation
n_permutations: Number of random permutations during feature importance
-
random_state: seed to be used during random number generation
"""
+ from sklearn.model_selection import RandomizedSearchCV
+ from sklearn.model_selection import GroupKFold
if param_distribution is not None and n_iter > 1:
+
+ # use groupkfold for hyperparameter search if groups are present
+ if self.groups is not None:
+ cv_search = GroupKFold(n_splits=tune_cv)
+ else:
+ cv_search = tune_cv
+
self.estimator = RandomizedSearchCV(
estimator=self.estimator,
param_distributions=param_distribution,
- n_iter=n_iter, cv=n_iter, random_state=random_state)
+ n_iter=n_iter, cv=cv_search, random_state=random_state)
+
+ if self.groups is None:
+ self.estimator.fit(self.X, self.y)
+ else:
+ self.estimator.fit(self.X, self.y, groups=self.groups)
+ else:
+ self.estimator.fit(self.X, self.y)
- self.estimator.fit(self.X, self.y)
-
if cv > 1:
self.cross_val(
scorers, cv, feature_importances, n_permutations, random_state)
@@ -512,6 +532,7 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GroupKFold
+ from sklearn.model_selection import RandomizedSearchCV
from sklearn import metrics
"""
@@ -569,11 +590,20 @@
for train_indices, test_indices in k_fold:
+ # get indices for train and test partitions
X_train, X_test = self.X[train_indices], self.X[test_indices]
y_train, y_test = self.y[train_indices], self.y[test_indices]
+
+ # also get indices of groups for the training partition
+ # fit the model on the training data and predict the test data
+ # need the groups parameter because the estimator can be a
+ # RandomizedSearchCV estimator where cv=GroupKFold
+ if self.groups is not None and isinstance(self.estimator, RandomizedSearchCV):
+ groups_train = self.groups[train_indices]
+ fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+ else:
+ fit = self.estimator.fit(X_train, y_train)
- # fit the model on the training data and predict the test data
- fit = self.estimator.fit(X_train, y_train)
y_pred = fit.predict(X_test)
y_test_agg = np.append(y_test_agg, y_test)
@@ -1326,6 +1356,7 @@
save_training = options['save_training']
importances = flags['f']
n_iter = int(options['n_iter'])
+ tune_cv = int(options['tune_cv'])
n_permutations = int(options['n_permutations'])
lowmem = flags['l']
errors_file = options['errors_file']
@@ -1450,7 +1481,7 @@
"""
# fit, search and cross-validate the training object
- learn_m.fit(param_grid, n_iter, scorers, cv,
+ learn_m.fit(param_grid, n_iter, scorers, cv, tune_cv,
feature_importances=importances,
n_permutations=n_permutations,
random_state=random_state)
More information about the grass-commit
mailing list