[GRASS-SVN] r73914 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Sat Jan 5 23:32:02 PST 2019


Author: spawley
Date: 2019-01-05 23:32:02 -0800 (Sat, 05 Jan 2019)
New Revision: 73914

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
r.learn.ml bug fix for handling multiple categorical rasters; changed dict iteritems method to items for python3 compatibility

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2019-01-05 00:01:42 UTC (rev 73913)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2019-01-06 07:32:02 UTC (rev 73914)
@@ -1254,7 +1254,7 @@
     if isinstance(scoring, basestring):
         scoring = [scoring]
     scores = dict.fromkeys(scoring)
-    scores = {key: [] for key, value in scores.iteritems()}
+    scores = {key: [] for key, value in scores.items()}
     scoring_methods = {'accuracy': metrics.accuracy_score,
                        'balanced_accuracy': metrics.recall_score,
                        'average_precision': metrics.average_precision_score,
@@ -1287,10 +1287,10 @@
     n_classes = len(np.unique(y))
     labels = np.unique(y)
     byclass_scores = dict.fromkeys(byclass_methods)
-    byclass_scores = {key: np.zeros((0, n_classes)) for key, value in byclass_scores.iteritems()}
+    byclass_scores = {key: np.zeros((0, n_classes)) for key, value in byclass_scores.items()}
 
     # remove any byclass_scorers that are not in the scoring list
-    byclass_scores = {key: value for key, value in byclass_scores.iteritems() if key in scores}
+    byclass_scores = {key: value for key, value in byclass_scores.items() if key in scores}
 
     # check if remaining scorers are valid sklearn metrics
     for i in scores.keys():
@@ -1494,15 +1494,21 @@
     balance = flags['b']
 
     # fetch individual raster names from group
-    maplist, mapnames = maps_from_group(group)
+    maplist, _ = maps_from_group(group)
 
     # extract indices of category maps
     if categorymaps.strip() == '':
         categorymaps = None
     else:
-        if isinstance(categorymaps, str):
+        # split string into list
+        if ',' in categorymaps is False:
             categorymaps = [categorymaps]
+        else:
+            categorymaps = categorymaps.split(',')
+        
         cat_indexes = []
+
+        # check that each category map is also in the imagery group
         for cat in categorymaps:
             try:
                 cat_indexes.append(maplist.index(cat))
@@ -1553,7 +1559,7 @@
     param_grid = deepcopy(hyperparams_type)
     param_grid = dict.fromkeys(param_grid, None)
 
-    for key, val in hyperparams.iteritems():
+    for key, val in hyperparams.items():
         # split any comma separated strings and add them to the param_grid
         if ',' in val:
             param_grid[key] = [hyperparams_type[key](i) for i in val.split(',')] # add all vals to param_grid
@@ -1564,7 +1570,7 @@
 
     if hyperparams['max_depth'] == 0: hyperparams['max_depth'] = None
     if hyperparams['max_features'] == 0: hyperparams['max_features'] = 'auto'
-    param_grid = {k: v for k, v in param_grid.iteritems() if v is not None}
+    param_grid = {k: v for k, v in param_grid.items() if v is not None}
 
     # retrieve sklearn classifier object and parameters
     clf, mode = model_classifiers(
@@ -1573,7 +1579,7 @@
     # remove dict keys that are incompatible for the selected classifier
     clf_params = clf.get_params()
     param_grid = {
-        key: value for key, value in param_grid.iteritems()
+        key: value for key, value in param_grid.items()
         if key in clf_params}
 
     # scoring metrics
@@ -1828,7 +1834,7 @@
 
                 preds = np.hstack((preds, sample_coords))
 
-                for method, val in scores.iteritems():
+                for method, val in scores.items():
                     gs.message(
                         method+":\t%0.3f\t+/-SD\t%0.3f" %
                         (val.mean(), val.std()))
@@ -1840,7 +1846,7 @@
                         'Cross validation class performance measures......:')
                     gs.message('Class \t' + '\t'.join(map(str, np.unique(y))))
 
-                    for method, val in cscores.iteritems():
+                    for method, val in cscores.items():
                         mat_cscores = np.matrix(val)
                         gs.message(
                             method+':\t' + '\t'.join(



More information about the grass-commit mailing list