[GRASS-SVN] r70267 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Thu Jan 5 12:40:24 PST 2017


Author: spawley
Date: 2017-01-05 12:40:24 -0800 (Thu, 05 Jan 2017)
New Revision: 70267

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'allow saving of feature importances per cross-validation fold'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-05 18:11:51 UTC (rev 70266)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-01-05 20:40:24 UTC (rev 70267)
@@ -431,7 +431,7 @@
         best_score = self.pred_func(estimator, X_test, y_true, scorers)
 
         np.random.seed(seed=random_state)
-        scores = np.zeros((X_test.shape[1], n_permutations))
+        scores = np.zeros((n_permutations, X_test.shape[1]))
 
         # outer loop to repeat the pemutation rep times
         for rep in range(n_permutations):
@@ -443,12 +443,12 @@
                 Xscram[:, i] = np.random.choice(X_test[:, i], X_test.shape[0])
 
                 # fit the model on the training data and predict the test data
-                scores[i, rep] = best_score-self.pred_func(
+                scores[rep, i] = best_score-self.pred_func(
                     estimator, Xscram, y_true, scorers)
-                if scores[i, rep] < 0: scores[i, rep] = 0
+                if scores[rep, i] < 0: scores[rep, i] = 0
 
         # average the repetitions
-        scores = scores.mean(axis=1)
+        scores = scores.mean(axis=0)
 
         return(scores)
 
@@ -522,7 +522,7 @@
 
         y_test_agg = []
         y_pred_agg = []
-        self.fimp = np.zeros((self.X.shape[1], cv))
+        self.fimp = np.zeros((cv, self.X.shape[1]))
 
         # generate Kfold indices
         if self.groups is None:
@@ -612,14 +612,11 @@
                         fit, X_test, y_test, n_permutations, scorers,
                         random_state)
                 else:
-                    self.fimp = np.column_stack(
+                    self.fimp = np.row_stack(
                         (self.fimp, self.varImp_permutation(
                             fit, X_test, y_test,
                             n_permutations, scorers, random_state)))
 
-        if feature_importances == True:
-            self.fimp = self.fimp.mean(axis=1)
-
         self.scores_cm = metrics.classification_report(y_test_agg, y_pred_agg)
 
 
@@ -1467,21 +1464,26 @@
 
             # feature importances
             if importances is True:
+                import pandas as pd
                 grass.message("\r\n")
                 grass.message("Feature importances")
                 grass.message("id" + "\t" + "Raster" + "\t" + "Importance")
 
                 for i in range(len(learn_m.fimp)):
+                    # mean of cross-validation feature importances
                     grass.message(
                         str(i) + "\t" + maplist[i] +
-                        "\t" + str(round(learn_m.fimp[i], 4)))
+                        "\t" + str(round(learn_m.fimp[i].mean(axis=0), 4)))
 
                 if fimp_file != '':
-                    fimp_output = pd.DataFrame(
-                        {'grass raster': maplist, 'importance': learn_m.fimp})
-                    fimp_output.to_csv(
-                        path_or_buf=fimp_file,
-                        header=['grass raster', 'importance'])
+#                    fimp_output = pd.DataFrame(
+#                        {'grass raster': maplist, 'importance': learn_m.fimp})
+#                    fimp_output = pd.DataFrame(learn_m.fimp)
+#                    fimp_output.to_csv(
+#                        path_or_buf=fimp_file,
+#                        header=['grass raster', 'importance'])
+                    np.savetxt(fname=fimp_file, X=learn_m.fimp, delimiter=',',
+                               header=','.join(maplist), comments='')
     else:
         # load a previously fitted train object
         # -------------------------------------



More information about the grass-commit mailing list