[GRASS-SVN] r71358 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Wed Aug 9 10:15:22 PDT 2017
Author: spawley
Date: 2017-08-09 10:15:22 -0700 (Wed, 09 Aug 2017)
New Revision: 71358
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
bug fixes to r.learn.ml for cross validation and categorical layers from diff by Jaan Janno and Mait Long
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-08-09 15:03:33 UTC (rev 71357)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-08-09 17:15:22 UTC (rev 71358)
@@ -12,6 +12,9 @@
# for details.
#
#############################################################################
+# July, 2017. Jaan Janno, Mait Lang. Bugfixes concerning crossvalidation failure
+# when class numeric ID-s were not continous increasing +1 each.
+# Bugfix for processing index list of nominal layers.
#%module
#% description: Supervised classification and regression of GRASS rasters using the python scikit-learn package
@@ -427,7 +430,8 @@
def cleanup():
for rast in tmp_rast:
- gscript.run_command("g.remove", name=rast, type='raster', flags='f', quiet=True)
+ gscript.run_command(
+ "g.remove", name=rast, type='raster', flags='f', quiet=True)
def warn(*args, **kwargs):
pass
@@ -514,10 +518,29 @@
balance = flags['b']
# convert to lists
+ """
if ',' in categorymaps:
categorymaps = [int(i) for i in categorymaps.split(',')]
+ print(categorymaps)
else: categorymaps = None
-
+ """
+
+ if categorymaps.strip() == '':
+ categorymaps = None
+ else:
+ try:
+ categorymaps = [int(i.strip()) for i in categorymaps.split(',')]
+ # negatiivse ja maplist, _ = maps_from_group(group) suurima indeksi kontroll, dublikaatide kontroll (unique)
+ nCategories = len(maps_from_group(group)[0])
+ if min(categorymaps) < 0:
+ gscript.fatal('Category map index can not be negative.')
+ if max(categorymaps) > nCategories - 1:
+ gscript.fatal('Category map index input can not exceed ' + str(nCategories - 1))
+ if not len(np.unique(categorymaps)) == len(categorymaps):
+ gscript.fatal('Duplicate indices in category map index list.')
+ except:
+ gscript.fatal('Error in category map list input.')
+
if ',' in indexes:
indexes = [int(i) for i in indexes.split(',')]
else:
@@ -550,6 +573,8 @@
param_grid = deepcopy(hyperparams_type)
param_grid = dict.fromkeys(param_grid, None)
+
+
for key, val in hyperparams.iteritems():
# split any comma separated strings and add them to the param_grid
if ',' in val:
@@ -562,6 +587,7 @@
if hyperparams['max_depth'] == 0: hyperparams['max_depth'] = None
if hyperparams['max_features'] == 0: hyperparams['max_features'] = 'auto'
param_grid = {k: v for k, v in param_grid.iteritems() if v is not None}
+
# retrieve sklearn classifier object and parameters
clf, mode = model_classifiers(
@@ -573,6 +599,7 @@
key: value for key, value in param_grid.iteritems()
if key in clf_params}
+
# scoring metrics
if mode == 'classification':
scoring = ['accuracy', 'precision', 'recall', 'f1', 'kappa',\
@@ -588,7 +615,7 @@
# fetch individual raster names from group
maplist, _ = maps_from_group(group)
-
+
if model_load == '':
# Sample training data and group id
@@ -686,6 +713,7 @@
inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
else:
inner = None
+
# ---------------------------------------------------------------------
# define the outer search resampling method
@@ -714,21 +742,26 @@
# ---------------------------------------------------------------------
# standardization
if norm_data is True and categorymaps is None:
+ gscript.message('norm_data is True and categorymaps is None:')
clf = Pipeline([('scaling', StandardScaler()),
('classifier', clf)])
# onehot encoding
if categorymaps is not None and norm_data is False:
+ gscript.message('categorymaps is not None and norm_data is False:')
enc = OneHotEncoder(categorical_features=categorymaps)
+ # print(enc.n_values)
enc.fit(X)
clf = Pipeline([('onehot', OneHotEncoder(
categorical_features=categorymaps,
n_values=enc.n_values_, handle_unknown='ignore',
+ # handle_unknown='ignore',
sparse=False)), # dense because not all clf can use sparse
('classifier', clf)])
# standardization and onehot encoding
if norm_data is True and categorymaps is not None:
+ gscript.message('norm_data is True and categorymaps is not None:')
enc = OneHotEncoder(categorical_features=categorymaps)
enc.fit(X)
clf = Pipeline([('onehot', OneHotEncoder(
@@ -737,7 +770,7 @@
sparse=False)),
('scaling', StandardScaler()),
('classifier', clf)])
-
+ # print(clf)
# ---------------------------------------------------------------------
# create the hyperparameter grid search method
# ---------------------------------------------------------------------
@@ -751,6 +784,10 @@
newkey = 'classifier__' + key
param_grid[newkey] = param_grid.pop(key)
+ # print(param_grid)
+ # print(inner)
+ # gscript.fatal('Põmm!')
+
# create grid search method
clf = GridSearchCV(
estimator=clf, param_grid=param_grid, scoring=search_scorer,
@@ -793,7 +830,14 @@
# If cv > 1 then use cross-validation to generate performance measures
if cv > 1 and tune_only is not True:
if mode == 'classification' and cv > np.histogram(
- y, bins=len(np.unique(y)))[0].min():
+ # See oli algselt ja ajas jama, kui klassikoodides olid +2 suuremad augud
+ # y, bins=len(np.unique(y)))[0].min():
+ y, bins=np.unique(y))[0].min():
+ # print(np.unique(y))
+ # print(len(np.unique(y)))
+ # print(np.histogram(y, bins=len(np.unique(y)))[0].min())
+ print(np.histogram(y, bins=len(np.unique(y))))
+ print(np.histogram(y, bins=np.unique(y)))
gscript.message(os.linesep)
gscript.message('Number of cv folds is greater than number of '
'samples in some classes. Cross-validation is being'
More information about the grass-commit
mailing list