[GRASS-SVN] r70100 - grass-addons/grass7/raster/r.randomforest

svn_grass at osgeo.org svn_grass at osgeo.org
Sun Dec 18 21:46:02 PST 2016


Author: spawley
Date: 2016-12-18 21:46:02 -0800 (Sun, 18 Dec 2016)
New Revision: 70100

Modified:
   grass-addons/grass7/raster/r.randomforest/r.randomforest.py
Log:
'lazy import of Pandas'

Modified: grass-addons/grass7/raster/r.randomforest/r.randomforest.py
===================================================================
--- grass-addons/grass7/raster/r.randomforest/r.randomforest.py	2016-12-18 22:10:49 UTC (rev 70099)
+++ grass-addons/grass7/raster/r.randomforest/r.randomforest.py	2016-12-19 05:46:02 UTC (rev 70100)
@@ -7,7 +7,6 @@
 #
 # COPYRIGHT: (c) 2016 Steven Pawley, and the GRASS Development Team
 #                This program is free software under the GNU General Public
-#                License (>=v2). Read the file COPYING that comes with GRASS
 #                for details.
 #
 #############################################################################
@@ -261,11 +260,6 @@
 from grass.pygrass.modules.shortcuts import raster as r
 
 try:
-    import pandas as pd
-except:
-    grass.fatal("Pandas not installed")
-
-try:
     import sklearn
     from sklearn.externals import joblib
     from sklearn import metrics
@@ -1001,7 +995,8 @@
                 metric = 'r2'
             
             X, X_devel, y, y_devel, Id, Id_devel, clf = \
-                tune_split(X, y, Id, clf, metric, param_grid, ratio, random_state)
+                tune_split(X, y, Id, clf, metric, param_grid,
+                           ratio, random_state)
 
             grass.message('\n')
             grass.message('Searched parameters:')
@@ -1043,19 +1038,25 @@
                 # classification report
                 grass.message("\n")
                 grass.message("Classification report:")
-                grass.message(metrics.classification_report(y_test, y_pred))
+                grass.message(metrics.classification_report(y_test, y_pred))                   
 
-                if errors_file != '':
-                    errors = pd.DataFrame({'accuracy': scores['accuracy']})
-                    errors.to_csv(errors_file, mode='w')
-
             else:
                 grass.message("R2:\t%0.2f\t+/-\t%0.2f" %
                               (scores['r2'].mean(), scores['r2'].std()))
-
-                if errors_file != '':
-                    errors = pd.DataFrame({'r2': scores['accuracy']})
+            
+            # write cross-validation results for csv file
+            if errors_file != '':
+                try:
+                    import pandas as pd
+                    
+                    if mode == 'classification':
+                        errors = pd.DataFrame({'accuracy': scores['accuracy'],
+                                               'auc': scores['auc']})
+                    else:
+                        errors = pd.DataFrame({'r2': scores['r2']})
                     errors.to_csv(errors_file, mode='w')
+                except:
+                    grass.warning("Pandas is not installed. Pandas is required to write the cross-validation results to file")
 
         # train classifier
         clf.fit(X, y)



More information about the grass-commit mailing list