[GRASS-SVN] r70361 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Thu Jan 12 21:56:55 PST 2017


Author: spawley
Date: 2017-01-12 21:56:55 -0800 (Thu, 12 Jan 2017)
New Revision: 70361

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'fix bug with balancing during cross-validation'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-13 04:51:43 UTC (rev 70360)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-13 05:56:55 UTC (rev 70361)
@@ -224,14 +224,6 @@
 #%end
 
 #%option
-#% key: n_iter
-#% type: integer
-#% description: Number of randomized parameter tuning steps
-#% answer: 1
-#% guisection: Optional
-#%end
-
-#%option
 #% key: tune_cv
 #% type: integer
 #% description: Number of cross-validation folds used for parameter tuning
@@ -621,7 +613,7 @@
 
         from sklearn.model_selection import StratifiedKFold
         from sklearn.model_selection import GroupKFold
-        from sklearn.model_selection import RandomizedSearchCV
+        from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
         from sklearn import metrics
 
         """
@@ -683,23 +675,31 @@
             X_train, X_test = self.X[train_indices], self.X[test_indices]
             y_train, y_test = self.y[train_indices], self.y[test_indices]         
             
-            # also get indices of groups for the training partition
-            if self.groups is not None:
-                groups_train = self.groups[train_indices]
-
             # balance the fold
             if self.balance == True:
                 X_train, y_train = self.random_oversampling(X_train, y_train, random_state=random_state)                
                 if self.groups is not None:
+                    groups_train = self.groups[train_indices]
                     groups_train, _ = self.random_oversampling(
-                        groups_train, y_train, random_state=random_state) 
-                
+                        groups_train, self.y[train_indices], random_state=random_state) 
+
+            else:
+                # also get indices of groups for the training partition
+                if self.groups is not None:
+                    groups_train = self.groups[train_indices]
+                    
             # fit the model on the training data and predict the test data
             # need the groups parameter because the estimator can be a 
             # RandomizedSearchCV estimator where cv=GroupKFold
-            if self.groups is not None and isinstance(self.estimator, RandomizedSearchCV):
-                fit = self.estimator.fit(X_train, y_train, groups=groups_train)            
+            if isinstance(self.estimator, RandomizedSearchCV) == True \
+            or isinstance(self.estimator, GridSearchCV):
+                param_search = True
             else:
+                param_search = False
+            
+            if self.groups is not None and param_search == True:
+                fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+            else:
                 fit = self.estimator.fit(X_train, y_train)   
 
             y_pred = fit.predict(X_test)



More information about the grass-commit mailing list