[GRASS-SVN] r70972 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Thu Apr 27 20:55:23 PDT 2017
Author: spawley
Date: 2017-04-27 20:55:23 -0700 (Thu, 27 Apr 2017)
New Revision: 70972
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml fixed hyperparameter tuning with knn classifier
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-28 01:34:01 UTC (rev 70971)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-04-28 03:55:23 UTC (rev 70972)
@@ -20,6 +20,7 @@
#% keyword: machine learning
#% keyword: scikit-learn
#%end
+
#%option G_OPT_I_GROUP
#% key: group
#% label: Imagery group to be classified
@@ -27,6 +28,7 @@
#% required: yes
#% multiple: no
#%end
+
#%option G_OPT_R_INPUT
#% key: trainingmap
#% label: Labelled pixels
@@ -34,6 +36,7 @@
#% required: no
#% guisection: Required
#%end
+
#%option G_OPT_V_INPUT
#% key: trainingpoints
#% label: Training point vector
@@ -41,6 +44,7 @@
#% required: no
#% guisection: Required
#%end
+
#%option G_OPT_DB_COLUMN
#% key: field
#% label: Response attribute column
@@ -48,6 +52,7 @@
#% required: no
#% guisection: Required
#%end
+
#%option G_OPT_R_OUTPUT
#% key: output
#% label: Output Map
@@ -55,6 +60,7 @@
#% guisection: Required
#% required: no
#%end
+
#%option string
#% key: classifier
#% label: Classifier
@@ -64,6 +70,7 @@
#% guisection: Classifier settings
#% required: no
#%end
+
#%option
#% key: c
#% type: double
@@ -73,6 +80,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: max_features
#% type: integer
@@ -82,6 +90,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: max_depth
#% type: integer
@@ -91,6 +100,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: min_samples_split
#% type: integer
@@ -100,6 +110,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: min_samples_leaf
#% type: integer
@@ -109,6 +120,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: n_estimators
#% type: integer
@@ -118,6 +130,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: learning_rate
#% type: double
@@ -127,6 +140,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option
#% key: subsample
#% type: double
@@ -136,22 +150,27 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
-#%option integer
+
+#%option
#% key: max_degree
+#% type: integer
#% label: The maximum degree of terms in forward pass
#% description: The maximum degree of terms in forward pass for Py-earth
#% answer: 1
#% multiple: yes
#% guisection: Classifier settings
#%end
-#%option integer
+
+#%option
#% key: n_neighbors
+#% type: integer
#% label: Number of neighbors to use
#% description: Number of neighbors to use
#% answer: 5
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option string
#% key: weights
#% label: weight function
@@ -161,6 +180,7 @@
#% multiple: yes
#% guisection: Classifier settings
#%end
+
#%option string
#% key: grid_search
#% label: Resampling method to use for hyperparameter optimization
@@ -170,6 +190,7 @@
#% multiple: no
#% guisection: Classifier settings
#%end
+
#%option integer
#% key: categorymaps
#% multiple: yes
@@ -177,6 +198,7 @@
#% description: Indices of categorical rasters within the imagery group (0..n) that will be one-hot encoded
#% guisection: Optional
#%end
+
#%option string
#% key: cvtype
#% label: Non-spatial or spatial cross-validation
@@ -185,6 +207,7 @@
#% options: non-spatial,clumped,kmeans
#% guisection: Cross validation
#%end
+
#%option
#% key: n_partitions
#% type: integer
@@ -193,6 +216,7 @@
#% answer: 10
#% guisection: Cross validation
#%end
+
#%option G_OPT_R_INPUT
#% key: group_raster
#% label: Custom group ids for training samples from GRASS raster
@@ -200,6 +224,7 @@
#% required: no
#% guisection: Cross validation
#%end
+
#%option
#% key: cv
#% type: integer
@@ -207,6 +232,7 @@
#% answer: 1
#% guisection: Cross validation
#%end
+
#%option
#% key: n_permutations
#% type: integer
@@ -214,39 +240,46 @@
#% answer: 50
#% guisection: Cross validation
#%end
+
#%flag
#% key: t
#% description: Perform hyperparameter tuning only
#% guisection: Cross validation
#%end
+
#%flag
#% key: f
#% description: Calculate permutation importances during cross validation
#% guisection: Cross validation
#%end
+
#%flag
#% key: r
#% label: Make predictions for cross validation resamples
#% guisection: Cross validation
#%end
+
#%option G_OPT_F_OUTPUT
#% key: errors_file
#% label: Save cross-validation global accuracy results to csv
#% required: no
#% guisection: Cross validation
#%end
+
#%option G_OPT_F_OUTPUT
#% key: fimp_file
#% label: Save feature importances to csv
#% required: no
#% guisection: Cross validation
#%end
+
#%option G_OPT_F_OUTPUT
#% key: param_file
#% label: Save hyperparameter search scores to csv
#% required: no
#% guisection: Cross validation
#%end
+
#%option
#% key: random_state
#% type: integer
@@ -254,6 +287,7 @@
#% answer: 1
#% guisection: Optional
#%end
+
#%option
#% key: lines
#% type: integer
@@ -261,6 +295,7 @@
#% answer: 25
#% guisection: Optional
#%end
+
#%option
#% key: indexes
#% type: integer
@@ -269,6 +304,7 @@
#% guisection: Optional
#% multiple: yes
#%end
+
#%option
#% key: n_jobs
#% type: integer
@@ -276,65 +312,77 @@
#% answer: -2
#% guisection: Optional
#%end
+
#%flag
#% key: s
#% label: Standardization preprocessing
#% guisection: Optional
#%end
+
#%flag
#% key: i
#% label: Impute training data preprocessing
#% guisection: Optional
#%end
+
#%flag
#% key: p
#% label: Output class membership probabilities
#% guisection: Optional
#%end
+
#%flag
#% key: z
#% label: Only predict class probabilities
#% guisection: Optional
#%end
+
#%flag
#% key: m
#% description: Build model only - do not perform prediction
#% guisection: Optional
#%end
+
#%flag
#% key: b
#% description: Balance training data using class weights
#% guisection: Optional
#%end
+
#%flag
#% key: l
#% label: Use memory swap
#% guisection: Optional
#%end
+
#%option G_OPT_F_OUTPUT
#% key: save_training
#% label: Save training data to csv
#% required: no
#% guisection: Optional
#%end
+
#%option G_OPT_F_INPUT
#% key: load_training
#% label: Load training data from csv
#% required: no
#% guisection: Optional
#%end
+
#%option G_OPT_F_OUTPUT
#% key: save_model
#% label: Save model from file
#% required: no
#% guisection: Optional
#%end
+
#%option G_OPT_F_INPUT
#% key: load_model
#% label: Load model from file
#% required: no
#% guisection: Optional
#%end
+
#%rules
#% exclusive: trainingmap,load_model
#% exclusive: load_training,save_training
@@ -384,7 +432,7 @@
warnings.filterwarnings('ignore')
except:
gscript.fatal("Scikit learn 0.18 or newer is not installed")
-
+
try:
import pandas as pd
except:
@@ -485,7 +533,8 @@
for key, val in hyperparams.iteritems():
# split any comma separated strings and add them to the param_grid
if ',' in val:
- param_grid[key] = [hyperparams_type[key](i) for i in val.split(',')]
+ param_grid[key] = [hyperparams_type[key](i) for i in val.split(',')] # add all vals to param_grid
+ hyperparams[key] = [hyperparams_type[key](i) for i in val.split(',')][0] # use first param for default
# else convert the single strings to int or float
else:
hyperparams[key] = hyperparams_type[key](val)
More information about the grass-commit
mailing list