[GRASS-SVN] r70944 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Apr 24 13:38:07 PDT 2017
Author: spawley
Date: 2017-04-24 13:38:06 -0700 (Mon, 24 Apr 2017)
New Revision: 70944
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
Log:
r.learn.ml added option to choose inner search method. Fixed bugs after moving functions into separate modules
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-24 05:03:47 UTC (rev 70943)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-24 20:38:06 UTC (rev 70944)
@@ -144,6 +144,15 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+#%option string
+#% key: grid_search
+#% label: Resampling method to use for hyperparameter optimization
+#% description: Resampling method to use for hyperparameter optimization
+#% options: cross-validation,holdout
+#% answer: cross-validation
+#% multiple: no
+#% guisection: Classifier settings
+#%end
#%option integer
#% key: categorymaps
#% multiple: yes
@@ -319,7 +328,6 @@
import atexit
import os
-import tempfile
import itertools
from copy import deepcopy
import numpy as np
@@ -345,9 +353,10 @@
try:
from sklearn.externals import joblib
from sklearn.cluster import KMeans
- from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.preprocessing import StandardScaler, Imputer
- from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
+ from sklearn.model_selection import (
+ GridSearchCV, RandomizedSearchCV, GroupShuffleSplit, ShuffleSplit,
+ StratifiedKFold, GroupKFold)
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import Pipeline
from sklearn.utils import shuffle
@@ -367,6 +376,7 @@
# classifier gui section
classifier = options['classifier']
+ grid_search = options['grid_search']
hyperparams = {
'C': options['c'],
'min_samples_split': options['min_samples_split'],
@@ -558,21 +568,38 @@
save_training_data(X, y, group_id, save_training)
# ---------------------------------------------------------------------
- # define the hyperparameter inner search cross validation method
+ # define the inner search resampling method
# ---------------------------------------------------------------------
- # define model selection cross-validation method
- if any(param_grid) is True and cv == 1:
- gscript.fatal('Hyperparameter search requires cv > 1')
- if any(param_grid) is True or cv > 1:
+ if any(param_grid) is True and cv == 1 and grid_search == 'cross-validation':
+ gscript.fatal('Hyperparameter search using cross validation requires cv > 1')
+
+ # define inner resampling using cross-validation method
+ if any(param_grid) is True and grid_search == 'cross-validation':
if group_id is None:
- resampling = StratifiedKFold(
- n_splits=cv, random_state=random_state)
+ inner = StratifiedKFold(n_splits=cv, random_state=random_state)
else:
- resampling = GroupKFold(n_splits=cv)
+ inner = GroupKFold(n_splits=cv)
+
+ # define inner resampling using the holdout method
+ if any(param_grid) is True and grid_search == 'holdout':
+ if group_id is None:
+ inner = ShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+ else:
+ inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+
+ # ---------------------------------------------------------------------
+ # define the outer search resampling method
+ # ---------------------------------------------------------------------
+ if group_id is None:
+ outer = StratifiedKFold(n_splits=cv, random_state=random_state)
else:
- resampling = None
+ outer = GroupKFold(n_splits=cv)
+ # ---------------------------------------------------------------------
+ # define sample weights for gradient boosting classifiers
+ # ---------------------------------------------------------------------
+
# sample weights for GradientBoosting or XGBClassifier
if balance is True and mode == 'classification' and classifier in (
'GradientBoostingClassifier', 'XGBClassifier'):
@@ -628,7 +655,7 @@
# create grid search method
clf = GridSearchCV(
estimator=clf, param_grid=param_grid, scoring=search_scorer,
- n_jobs=n_jobs, cv=resampling)
+ n_jobs=n_jobs, cv=inner)
# ---------------------------------------------------------------------
# classifier training
@@ -638,7 +665,7 @@
gscript.message(('Fitting model using ' + classifier))
# pass groups to fit parameter GroupKFold and param_grid are present
- if isinstance(resampling, GroupKFold) and any(param_grid) is True:
+ if isinstance(inner, GroupKFold) and any(param_grid) is True:
if balance is True and classifier in (
'GradientBoostingClassifier', 'XGBClassifier'):
clf.fit(X=X, y=y, groups=group_id, sample_weight=class_weights)
@@ -691,7 +718,7 @@
# perform the cross-validatation
scores, cscores, fimp, models = cross_val_scores(
- clf, X, y, group_id, class_weights, resampling, scoring,
+ clf, X, y, group_id, class_weights, outer, scoring,
importances, n_permutations, predict_resamples, random_state)
# global scores
Modified: grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py 2017-04-24 05:03:47 UTC (rev 70943)
+++ grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py 2017-04-24 20:38:06 UTC (rev 70944)
@@ -2,6 +2,9 @@
# -- coding: utf-8 --
import numpy as np
+import os
+import tempfile
+from copy import deepcopy
from numpy.random import RandomState
from grass.pygrass.modules.shortcuts import raster as r
from grass.pygrass.raster import RasterRow
@@ -11,6 +14,7 @@
from grass.pygrass.vector import VectorTopo
from grass.pygrass.vector.table import Link
from grass.pygrass.utils import get_raster_for_points
+import grass.script as gscript
from subprocess import PIPE
def specificity_score(y_true, y_pred):
More information about the grass-commit
mailing list