[GRASS-SVN] r70944 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Apr 24 13:38:07 PDT 2017


Author: spawley
Date: 2017-04-24 13:38:06 -0700 (Mon, 24 Apr 2017)
New Revision: 70944

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
   grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
Log:
r.learn.ml added option to choose inner search method. Fixed bugs after moving functions into separate modules

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-24 05:03:47 UTC (rev 70943)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-24 20:38:06 UTC (rev 70944)
@@ -144,6 +144,15 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+#%option string
+#% key: grid_search
+#% label: Resampling method to use for hyperparameter optimization
+#% description: Resampling method to use for hyperparameter optimization
+#% options: cross-validation,holdout
+#% answer: cross-validation
+#% multiple: no
+#% guisection: Classifier settings
+#%end
 #%option integer
 #% key: categorymaps
 #% multiple: yes
@@ -319,7 +328,6 @@
 
 import atexit
 import os
-import tempfile
 import itertools
 from copy import deepcopy
 import numpy as np
@@ -345,9 +353,10 @@
     try:
         from sklearn.externals import joblib
         from sklearn.cluster import KMeans
-        from sklearn.model_selection import StratifiedKFold, GroupKFold
         from sklearn.preprocessing import StandardScaler, Imputer
-        from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
+        from sklearn.model_selection import (
+            GridSearchCV, RandomizedSearchCV, GroupShuffleSplit, ShuffleSplit,
+            StratifiedKFold, GroupKFold)
         from sklearn.preprocessing import OneHotEncoder
         from sklearn.pipeline import Pipeline
         from sklearn.utils import shuffle
@@ -367,6 +376,7 @@
 
     # classifier gui section
     classifier = options['classifier']
+    grid_search = options['grid_search']
     hyperparams = {
         'C': options['c'],
         'min_samples_split': options['min_samples_split'],
@@ -558,21 +568,38 @@
                 save_training_data(X, y, group_id, save_training)
 
         # ---------------------------------------------------------------------
-        # define the hyperparameter inner search cross validation method
+        # define the inner search resampling method
         # ---------------------------------------------------------------------
 
-        # define model selection cross-validation method
-        if any(param_grid) is True and cv == 1:
-            gscript.fatal('Hyperparameter search requires cv > 1')
-        if any(param_grid) is True or cv > 1:
+        if any(param_grid) is True and cv == 1 and grid_search == 'cross-validation':
+            gscript.fatal('Hyperparameter search using cross validation requires cv > 1')
+
+        # define inner resampling using cross-validation method
+        if any(param_grid) is True and grid_search == 'cross-validation':
             if group_id is None:
-                resampling = StratifiedKFold(
-                    n_splits=cv, random_state=random_state)
+                inner = StratifiedKFold(n_splits=cv, random_state=random_state)
             else:
-                resampling = GroupKFold(n_splits=cv)
+                inner = GroupKFold(n_splits=cv)
+
+        # define inner resampling using the holdout method
+        if any(param_grid) is True and grid_search == 'holdout':
+            if group_id is None:
+                inner = ShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+            else:
+                inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
+
+        # ---------------------------------------------------------------------
+        # define the outer search resampling method
+        # ---------------------------------------------------------------------
+        if group_id is None:
+            outer = StratifiedKFold(n_splits=cv, random_state=random_state)
         else:
-            resampling = None
+            outer = GroupKFold(n_splits=cv)
 
+        # ---------------------------------------------------------------------
+        # define sample weights for gradient boosting classifiers
+        # ---------------------------------------------------------------------
+
         # sample weights for GradientBoosting or XGBClassifier
         if balance is True and mode == 'classification' and classifier in (
                 'GradientBoostingClassifier', 'XGBClassifier'):
@@ -628,7 +655,7 @@
             # create grid search method
             clf = GridSearchCV(
                 estimator=clf, param_grid=param_grid, scoring=search_scorer,
-                n_jobs=n_jobs, cv=resampling)
+                n_jobs=n_jobs, cv=inner)
 
         # ---------------------------------------------------------------------
         # classifier training
@@ -638,7 +665,7 @@
         gscript.message(('Fitting model using ' + classifier))
 
         # pass groups to fit parameter GroupKFold and param_grid are present
-        if isinstance(resampling, GroupKFold) and any(param_grid) is True:
+        if isinstance(inner, GroupKFold) and any(param_grid) is True:
             if balance is True and classifier in (
                     'GradientBoostingClassifier', 'XGBClassifier'):
                 clf.fit(X=X, y=y, groups=group_id, sample_weight=class_weights)
@@ -691,7 +718,7 @@
 
                 # perform the cross-validatation
                 scores, cscores, fimp, models = cross_val_scores(
-                    clf, X, y, group_id, class_weights, resampling, scoring,
+                    clf, X, y, group_id, class_weights, outer, scoring,
                     importances, n_permutations, predict_resamples, random_state)
 
                 # global scores

Modified: grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py	2017-04-24 05:03:47 UTC (rev 70943)
+++ grass-addons/grass7/raster/r.learn.ml/r_learn_utils.py	2017-04-24 20:38:06 UTC (rev 70944)
@@ -2,6 +2,9 @@
 # -- coding: utf-8 --
 
 import numpy as np
+import os
+import tempfile
+from copy import deepcopy
 from numpy.random import RandomState
 from grass.pygrass.modules.shortcuts import raster as r
 from grass.pygrass.raster import RasterRow
@@ -11,6 +14,7 @@
 from grass.pygrass.vector import VectorTopo
 from grass.pygrass.vector.table import Link
 from grass.pygrass.utils import get_raster_for_points
+import grass.script as gscript
 from subprocess import PIPE
 
 def specificity_score(y_true, y_pred):



More information about the grass-commit mailing list