[GRASS-SVN] r70466 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Tue Jan 31 22:48:13 PST 2017


Author: spawley
Date: 2017-01-31 22:48:13 -0800 (Tue, 31 Jan 2017)
New Revision: 70466

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'fixed bug with estimator being altered via cross-validation'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-01 05:16:32 UTC (rev 70465)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-01 06:48:13 UTC (rev 70466)
@@ -289,6 +289,7 @@
 #%rules
 #% exclusive: trainingmap,load_model
 #% exclusive: load_training,save_training
+#% exclusive: trainingmap,load_training
 #%end
 
 import atexit
@@ -297,6 +298,7 @@
 import copy
 import grass.script as grass
 import tempfile
+from copy import deepcopy
 from grass.pygrass.modules.shortcuts import imagery as im
 from grass.pygrass.modules.shortcuts import raster as r
 from subprocess import PIPE
@@ -624,6 +626,9 @@
         random_state: Seed to pass to the random number generator
         """
 
+        # create copy of fitting estimator for cross-val fitting
+        fit_train = deepcopy(self.estimator)
+
         # dictionary of lists to store metrics
         if scorers == 'binary':
             self.scores = {
@@ -693,11 +698,11 @@
                 param_search = False
 
             if self.groups is not None and param_search is True:
-                fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+                fit_train.fit(X_train, y_train, groups=groups_train)
             else:
-                fit = self.estimator.fit(X_train, y_train)
+                fit_train.fit(X_train, y_train)
 
-            y_pred = fit.predict(X_test)
+            y_pred = fit_train.predict(X_test)
 
             y_test_agg = np.append(y_test_agg, y_test)
             y_pred_agg = np.append(y_pred_agg, y_pred)
@@ -710,7 +715,7 @@
                     self.scores['accuracy'],
                     metrics.accuracy_score(y_test, y_pred))
 
-                y_pred_proba = fit.predict_proba(X_test)[:, 1]
+                y_pred_proba = fit_train.predict_proba(X_test)[:, 1]
                 self.scores['auc'] = np.append(
                     self.scores['auc'],
                     metrics.roc_auc_score(y_test, y_pred_proba))
@@ -753,12 +758,12 @@
             if feature_importances is True:
                 if (self.fimp==0).all() == True:
                     self.fimp = self.varImp_permutation(
-                        fit, X_test, y_test, n_permutations, scorers,
+                        fit_train, X_test, y_test, n_permutations, scorers,
                         random_state)
                 else:
                     self.fimp = np.row_stack(
                         (self.fimp, self.varImp_permutation(
-                            fit, X_test, y_test,
+                            fit_train, X_test, y_test,
                             n_permutations, scorers, random_state)))
 
         self.scores_cm = metrics.classification_report(y_test_agg, y_pred_agg)
@@ -766,8 +771,6 @@
         # convert onehot-encoded feature importances back to original vars
         if self.fimp is not None and self.enc is not None:
 
-            from copy import deepcopy
-
             # get start,end positions of each suite of onehot-encoded vars
             feature_ranges = deepcopy(self.enc.feature_indices_)
             for i in range(0, len(self.enc.feature_indices_)-1):
@@ -1313,7 +1316,7 @@
     # because cross-validation will be performed spatially
     # ---------------------------------------------------------------
     if group_raster != '':
-        maplist2 = copy.deepcopy(maplist)
+        maplist2 = deepcopy(maplist)
         maplist2.append(group_raster)
         X, y, sample_coords = sample_predictors(response=response,
                                                 predictors=maplist2,



More information about the grass-commit mailing list