[GRASS-SVN] r70946 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Apr 24 21:55:40 PDT 2017


Author: spawley
Date: 2017-04-24 21:55:40 -0700 (Mon, 24 Apr 2017)
New Revision: 70946

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml fixed issue with no inner search

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-24 23:20:24 UTC (rev 70945)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-25 04:55:40 UTC (rev 70946)
@@ -366,6 +366,11 @@
         warnings.filterwarnings('ignore')
     except:
         gscript.fatal("Scikit learn 0.18 or newer is not installed")
+    
+    try:
+        import pandas as pd
+    except:
+        gscript.fatal("Pandas is not installed ")
 
     # required gui section
     group = options['group']
@@ -575,18 +580,21 @@
             gscript.fatal('Hyperparameter search using cross validation requires cv > 1')
 
         # define inner resampling using cross-validation method
-        if any(param_grid) is True and grid_search == 'cross-validation':
+        elif any(param_grid) is True and grid_search == 'cross-validation':
             if group_id is None:
                 inner = StratifiedKFold(n_splits=cv, random_state=random_state)
             else:
                 inner = GroupKFold(n_splits=cv)
 
         # define inner resampling using the holdout method
-        if any(param_grid) is True and grid_search == 'holdout':
+        elif any(param_grid) is True and grid_search == 'holdout':
             if group_id is None:
                 inner = ShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
             else:
                 inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+        
+        else:
+            inner = None
 
         # ---------------------------------------------------------------------
         # define the outer search resampling method
@@ -684,14 +692,8 @@
             gscript.message('Best parameters:')
             gscript.message(str(clf.best_params_))
             if param_file != '':
-                try:
-                    import pandas as pd
-                    param_df = pd.DataFrame(clf.cv_results_)
-                    param_df.to_csv(param_file)
-                except:
-                    gscript.message((
-                        "Pandas is not installed ",
-                        "cannot export hyperparameter search results to csv"))
+                param_df = pd.DataFrame(clf.cv_results_)
+                param_df.to_csv(param_file)
 
         # ---------------------------------------------------------------------
         # cross-validation



More information about the grass-commit mailing list