[GRASS-SVN] r70946 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Apr 24 21:55:40 PDT 2017
Author: spawley
Date: 2017-04-24 21:55:40 -0700 (Mon, 24 Apr 2017)
New Revision: 70946
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml fixed issue with no inner search
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-24 23:20:24 UTC (rev 70945)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-25 04:55:40 UTC (rev 70946)
@@ -366,6 +366,11 @@
warnings.filterwarnings('ignore')
except:
gscript.fatal("Scikit learn 0.18 or newer is not installed")
+
+ try:
+ import pandas as pd
+ except:
+ gscript.fatal("Pandas is not installed ")
# required gui section
group = options['group']
@@ -575,18 +580,21 @@
gscript.fatal('Hyperparameter search using cross validation requires cv > 1')
# define inner resampling using cross-validation method
- if any(param_grid) is True and grid_search == 'cross-validation':
+ elif any(param_grid) is True and grid_search == 'cross-validation':
if group_id is None:
inner = StratifiedKFold(n_splits=cv, random_state=random_state)
else:
inner = GroupKFold(n_splits=cv)
# define inner resampling using the holdout method
- if any(param_grid) is True and grid_search == 'holdout':
+ elif any(param_grid) is True and grid_search == 'holdout':
if group_id is None:
inner = ShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
else:
inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+
+ else:
+ inner = None
# ---------------------------------------------------------------------
# define the outer search resampling method
@@ -684,14 +692,8 @@
gscript.message('Best parameters:')
gscript.message(str(clf.best_params_))
if param_file != '':
- try:
- import pandas as pd
- param_df = pd.DataFrame(clf.cv_results_)
- param_df.to_csv(param_file)
- except:
- gscript.message((
- "Pandas is not installed ",
- "cannot export hyperparameter search results to csv"))
+ param_df = pd.DataFrame(clf.cv_results_)
+ param_df.to_csv(param_file)
# ---------------------------------------------------------------------
# cross-validation
More information about the grass-commit
mailing list