[GRASS-SVN] r70327 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Jan 9 09:21:56 PST 2017


Author: spawley
Date: 2017-01-09 09:21:56 -0800 (Mon, 09 Jan 2017)
New Revision: 70327

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'bug fix to tuning parameters when using spatial cross-validation'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-09 15:53:30 UTC (rev 70326)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-09 17:21:56 UTC (rev 70327)
@@ -225,6 +225,14 @@
 #%end
 
 #%option
+#% key: tune_cv
+#% type: integer
+#% description: Number of cross-validation folds used for parameter tuning
+#% answer: 3
+#% guisection: Optional
+#%end
+
+#%option
 #% key: n_permutations
 #% type: integer
 #% description: Number of permutations to perform for feature importances
@@ -356,11 +364,9 @@
 
 
     def fit(self, param_distribution=None, n_iter=3, scorers='multiclass',
-            cv=3, feature_importances=False, n_permutations=1,
+            cv=3, tune_cv=3, feature_importances=False, n_permutations=1,
             random_state=None):
 
-        from sklearn.model_selection import RandomizedSearchCV
-
         """
         Main fit method for the train object. Performs fitting, hyperparameter
         search and cross-validation in one step (inspired from R's CARET)
@@ -372,21 +378,35 @@
         n_iter: Number of randomized search iterations
         scorers: Suite of metrics to obtain
         cv: Number of cross-validation folds
+        tune_cv: Number of cross-validation folds for parameter tuning
         feature_importances: Boolean to perform permuatation-based importances
         during cross-validation
         n_permutations: Number of random permutations during feature importance
-
         random_state: seed to be used during random number generation
         """
+        from sklearn.model_selection import RandomizedSearchCV
+        from sklearn.model_selection import GroupKFold
 
         if param_distribution is not None and n_iter > 1:
+            
+            # use groupkfold for hyperparameter search if groups are present
+            if self.groups is not None:
+                cv_search = GroupKFold(n_splits=tune_cv)
+            else:
+                cv_search = tune_cv
+                
             self.estimator = RandomizedSearchCV(
                 estimator=self.estimator,
                 param_distributions=param_distribution,
-                n_iter=n_iter, cv=n_iter, random_state=random_state)
+                n_iter=n_iter, cv=cv_search, random_state=random_state)
+        
+            if self.groups is None:
+                self.estimator.fit(self.X, self.y)
+            else:
+                self.estimator.fit(self.X, self.y, groups=self.groups)
+        else:
+            self.estimator.fit(self.X, self.y)
 
-        self.estimator.fit(self.X, self.y)
-
         if cv > 1:
             self.cross_val(
                 scorers, cv, feature_importances, n_permutations, random_state)
@@ -512,6 +532,7 @@
 
         from sklearn.model_selection import StratifiedKFold
         from sklearn.model_selection import GroupKFold
+        from sklearn.model_selection import RandomizedSearchCV
         from sklearn import metrics
 
         """
@@ -569,11 +590,20 @@
 
         for train_indices, test_indices in k_fold:
 
+            # get indices for train and test partitions
             X_train, X_test = self.X[train_indices], self.X[test_indices]
             y_train, y_test = self.y[train_indices], self.y[test_indices]
+            
+            # also get indices of groups for the training partition
+            # fit the model on the training data and predict the test data
+            # need the groups parameter because the estimator can be a 
+            # RandomizedSearchCV estimator where cv=GroupKFold
+            if self.groups is not None and isinstance(self.estimator, RandomizedSearchCV):
+                groups_train = self.groups[train_indices]
+                fit = self.estimator.fit(X_train, y_train, groups=groups_train)            
+            else:
+                fit = self.estimator.fit(X_train, y_train)   
 
-            # fit the model on the training data and predict the test data
-            fit = self.estimator.fit(X_train, y_train)
             y_pred = fit.predict(X_test)
 
             y_test_agg = np.append(y_test_agg, y_test)
@@ -1326,6 +1356,7 @@
     save_training = options['save_training']
     importances = flags['f']
     n_iter = int(options['n_iter'])
+    tune_cv = int(options['tune_cv'])
     n_permutations = int(options['n_permutations'])
     lowmem = flags['l']
     errors_file = options['errors_file']
@@ -1450,7 +1481,7 @@
         """
 
         # fit, search and cross-validate the training object
-        learn_m.fit(param_grid, n_iter, scorers, cv,
+        learn_m.fit(param_grid, n_iter, scorers, cv, tune_cv,
                     feature_importances=importances,
                     n_permutations=n_permutations,
                     random_state=random_state)



More information about the grass-commit mailing list