[GRASS-SVN] r70972 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Thu Apr 27 20:55:23 PDT 2017


Author: spawley
Date: 2017-04-27 20:55:23 -0700 (Thu, 27 Apr 2017)
New Revision: 70972

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml fixed hyperparameter tuning with knn classifier

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-28 01:34:01 UTC (rev 70971)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-04-28 03:55:23 UTC (rev 70972)
@@ -20,6 +20,7 @@
 #% keyword: machine learning
 #% keyword: scikit-learn
 #%end
+
 #%option G_OPT_I_GROUP
 #% key: group
 #% label: Imagery group to be classified
@@ -27,6 +28,7 @@
 #% required: yes
 #% multiple: no
 #%end
+
 #%option G_OPT_R_INPUT
 #% key: trainingmap
 #% label: Labelled pixels
@@ -34,6 +36,7 @@
 #% required: no
 #% guisection: Required
 #%end
+
 #%option G_OPT_V_INPUT
 #% key: trainingpoints
 #% label: Training point vector
@@ -41,6 +44,7 @@
 #% required: no
 #% guisection: Required
 #%end
+
 #%option G_OPT_DB_COLUMN
 #% key: field
 #% label: Response attribute column
@@ -48,6 +52,7 @@
 #% required: no
 #% guisection: Required
 #%end
+
 #%option G_OPT_R_OUTPUT
 #% key: output
 #% label: Output Map
@@ -55,6 +60,7 @@
 #% guisection: Required
 #% required: no
 #%end
+
 #%option string
 #% key: classifier
 #% label: Classifier
@@ -64,6 +70,7 @@
 #% guisection: Classifier settings
 #% required: no
 #%end
+
 #%option
 #% key: c
 #% type: double
@@ -73,6 +80,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: max_features
 #% type: integer
@@ -82,6 +90,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: max_depth
 #% type: integer
@@ -91,6 +100,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: min_samples_split
 #% type: integer
@@ -100,6 +110,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: min_samples_leaf
 #% type: integer
@@ -109,6 +120,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: n_estimators
 #% type: integer
@@ -118,6 +130,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: learning_rate
 #% type: double
@@ -127,6 +140,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option
 #% key: subsample
 #% type: double
@@ -136,22 +150,27 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
-#%option integer
+
+#%option
 #% key: max_degree
+#% type: integer
 #% label: The maximum degree of terms in forward pass
 #% description: The maximum degree of terms in forward pass for Py-earth
 #% answer: 1
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
-#%option integer
+
+#%option
 #% key: n_neighbors
+#% type: integer
 #% label: Number of neighbors to use
 #% description: Number of neighbors to use
 #% answer: 5
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option string
 #% key: weights
 #% label: weight function
@@ -161,6 +180,7 @@
 #% multiple: yes
 #% guisection: Classifier settings
 #%end
+
 #%option string
 #% key: grid_search
 #% label: Resampling method to use for hyperparameter optimization
@@ -170,6 +190,7 @@
 #% multiple: no
 #% guisection: Classifier settings
 #%end
+
 #%option integer
 #% key: categorymaps
 #% multiple: yes
@@ -177,6 +198,7 @@
 #% description: Indices of categorical rasters within the imagery group (0..n) that will be one-hot encoded
 #% guisection: Optional
 #%end
+
 #%option string
 #% key: cvtype
 #% label: Non-spatial or spatial cross-validation
@@ -185,6 +207,7 @@
 #% options: non-spatial,clumped,kmeans
 #% guisection: Cross validation
 #%end
+
 #%option
 #% key: n_partitions
 #% type: integer
@@ -193,6 +216,7 @@
 #% answer: 10
 #% guisection: Cross validation
 #%end
+
 #%option G_OPT_R_INPUT
 #% key: group_raster
 #% label: Custom group ids for training samples from GRASS raster
@@ -200,6 +224,7 @@
 #% required: no
 #% guisection: Cross validation
 #%end
+
 #%option
 #% key: cv
 #% type: integer
@@ -207,6 +232,7 @@
 #% answer: 1
 #% guisection: Cross validation
 #%end
+
 #%option
 #% key: n_permutations
 #% type: integer
@@ -214,39 +240,46 @@
 #% answer: 50
 #% guisection: Cross validation
 #%end
+
 #%flag
 #% key: t
 #% description: Perform hyperparameter tuning only
 #% guisection: Cross validation
 #%end
+
 #%flag
 #% key: f
 #% description: Calculate permutation importances during cross validation
 #% guisection: Cross validation
 #%end
+
 #%flag
 #% key: r
 #% label: Make predictions for cross validation resamples
 #% guisection: Cross validation
 #%end
+
 #%option G_OPT_F_OUTPUT
 #% key: errors_file
 #% label: Save cross-validation global accuracy results to csv
 #% required: no
 #% guisection: Cross validation
 #%end
+
 #%option G_OPT_F_OUTPUT
 #% key: fimp_file
 #% label: Save feature importances to csv
 #% required: no
 #% guisection: Cross validation
 #%end
+
 #%option G_OPT_F_OUTPUT
 #% key: param_file
 #% label: Save hyperparameter search scores to csv
 #% required: no
 #% guisection: Cross validation
 #%end
+
 #%option
 #% key: random_state
 #% type: integer
@@ -254,6 +287,7 @@
 #% answer: 1
 #% guisection: Optional
 #%end
+
 #%option
 #% key: lines
 #% type: integer
@@ -261,6 +295,7 @@
 #% answer: 25
 #% guisection: Optional
 #%end
+
 #%option
 #% key: indexes
 #% type: integer
@@ -269,6 +304,7 @@
 #% guisection: Optional
 #% multiple: yes
 #%end
+
 #%option
 #% key: n_jobs
 #% type: integer
@@ -276,65 +312,77 @@
 #% answer: -2
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: s
 #% label: Standardization preprocessing
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: i
 #% label: Impute training data preprocessing
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: p
 #% label: Output class membership probabilities
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: z
 #% label: Only predict class probabilities
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: m
 #% description: Build model only - do not perform prediction
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: b
 #% description: Balance training data using class weights
 #% guisection: Optional
 #%end
+
 #%flag
 #% key: l
 #% label: Use memory swap
 #% guisection: Optional
 #%end
+
 #%option G_OPT_F_OUTPUT
 #% key: save_training
 #% label: Save training data to csv
 #% required: no
 #% guisection: Optional
 #%end
+
 #%option G_OPT_F_INPUT
 #% key: load_training
 #% label: Load training data from csv
 #% required: no
 #% guisection: Optional
 #%end
+
 #%option G_OPT_F_OUTPUT
 #% key: save_model
 #% label: Save model from file
 #% required: no
 #% guisection: Optional
 #%end
+
 #%option G_OPT_F_INPUT
 #% key: load_model
 #% label: Load model from file
 #% required: no
 #% guisection: Optional
 #%end
+
 #%rules
 #% exclusive: trainingmap,load_model
 #% exclusive: load_training,save_training
@@ -384,7 +432,7 @@
         warnings.filterwarnings('ignore')
     except:
         gscript.fatal("Scikit learn 0.18 or newer is not installed")
-    
+
     try:
         import pandas as pd
     except:
@@ -485,7 +533,8 @@
     for key, val in hyperparams.iteritems():
         # split any comma separated strings and add them to the param_grid
         if ',' in val:
-            param_grid[key] = [hyperparams_type[key](i) for i in val.split(',')]
+            param_grid[key] = [hyperparams_type[key](i) for i in val.split(',')] # add all vals to param_grid
+            hyperparams[key] = [hyperparams_type[key](i) for i in val.split(',')][0] # use first param for default
         # else convert the single strings to int or float
         else:
             hyperparams[key] = hyperparams_type[key](val)



More information about the grass-commit mailing list