[GRASS-SVN] r71358 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Wed Aug 9 10:15:22 PDT 2017


Author: spawley
Date: 2017-08-09 10:15:22 -0700 (Wed, 09 Aug 2017)
New Revision: 71358

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
bug fixes to r.learn.ml for cross validation and categorical layers from diff by Jaan Janno and Mait Long

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-08-09 15:03:33 UTC (rev 71357)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-08-09 17:15:22 UTC (rev 71358)
@@ -12,6 +12,9 @@
 #                for details.
 #
 #############################################################################
+# July, 2017. Jaan Janno, Mait Lang. Bugfixes concerning crossvalidation failure
+# when class numeric ID-s were not continous increasing +1 each.
+# Bugfix for processing index list of nominal layers.  
 
 #%module
 #% description: Supervised classification and regression of GRASS rasters using the python scikit-learn package
@@ -427,7 +430,8 @@
 
 def cleanup():
     for rast in tmp_rast:
-        gscript.run_command("g.remove", name=rast, type='raster', flags='f', quiet=True)
+        gscript.run_command(
+            "g.remove", name=rast, type='raster', flags='f', quiet=True)
 
 def warn(*args, **kwargs):
     pass
@@ -514,10 +518,29 @@
     balance = flags['b']
 
     # convert to lists
+    """
     if ',' in categorymaps:
         categorymaps = [int(i) for i in categorymaps.split(',')]
+	print(categorymaps)
     else: categorymaps = None
-
+    """
+    
+    if categorymaps.strip() == '':
+        categorymaps = None
+    else:
+        try:
+            categorymaps = [int(i.strip()) for i in categorymaps.split(',')]
+            # negatiivse ja maplist, _ = maps_from_group(group) suurima indeksi kontroll, dublikaatide kontroll (unique)
+            nCategories = len(maps_from_group(group)[0])
+            if min(categorymaps) < 0:
+                gscript.fatal('Category map index can not be negative.')
+            if max(categorymaps) > nCategories - 1:
+                gscript.fatal('Category map index input can not exceed ' + str(nCategories - 1))
+            if not len(np.unique(categorymaps)) == len(categorymaps):
+                gscript.fatal('Duplicate indices in category map index list.')            
+        except:
+            gscript.fatal('Error in category map list input.')
+    
     if ',' in indexes:
         indexes = [int(i) for i in indexes.split(',')]
     else:
@@ -550,6 +573,8 @@
     param_grid = deepcopy(hyperparams_type)
     param_grid = dict.fromkeys(param_grid, None)
 
+   
+
     for key, val in hyperparams.iteritems():
         # split any comma separated strings and add them to the param_grid
         if ',' in val:
@@ -562,6 +587,7 @@
     if hyperparams['max_depth'] == 0: hyperparams['max_depth'] = None
     if hyperparams['max_features'] == 0: hyperparams['max_features'] = 'auto'
     param_grid = {k: v for k, v in param_grid.iteritems() if v is not None}
+ 
 
     # retrieve sklearn classifier object and parameters
     clf, mode = model_classifiers(
@@ -573,6 +599,7 @@
         key: value for key, value in param_grid.iteritems()
         if key in clf_params}
 
+
     # scoring metrics
     if mode == 'classification':
         scoring = ['accuracy', 'precision', 'recall', 'f1', 'kappa',\
@@ -588,7 +615,7 @@
 
     # fetch individual raster names from group
     maplist, _ = maps_from_group(group)
-
+    
     if model_load == '':
 
         # Sample training data and group id
@@ -686,6 +713,7 @@
                 inner = GroupShuffleSplit(n_splits=1, test_size=0.33, random_state=random_state)
         else:
             inner = None
+     
 
         # ---------------------------------------------------------------------
         # define the outer search resampling method
@@ -714,21 +742,26 @@
         # ---------------------------------------------------------------------
         # standardization
         if norm_data is True and categorymaps is None:
+	    gscript.message('norm_data is True and categorymaps is None:')
             clf = Pipeline([('scaling', StandardScaler()),
                             ('classifier', clf)])
 
         # onehot encoding
         if categorymaps is not None and norm_data is False:
+	    gscript.message('categorymaps is not None and norm_data is False:')		
             enc = OneHotEncoder(categorical_features=categorymaps)
+	    # print(enc.n_values)
             enc.fit(X)
             clf = Pipeline([('onehot', OneHotEncoder(
                 categorical_features=categorymaps,
                 n_values=enc.n_values_, handle_unknown='ignore',
+	        # handle_unknown='ignore',
                 sparse=False)),  # dense because not all clf can use sparse
                             ('classifier', clf)])
 
         # standardization and onehot encoding
         if norm_data is True and categorymaps is not None:
+	    gscript.message('norm_data is True and categorymaps is not None:')
             enc = OneHotEncoder(categorical_features=categorymaps)
             enc.fit(X)
             clf = Pipeline([('onehot', OneHotEncoder(
@@ -737,7 +770,7 @@
                 sparse=False)),
                             ('scaling', StandardScaler()),
                             ('classifier', clf)])
-
+        # print(clf)
         # ---------------------------------------------------------------------
         # create the hyperparameter grid search method
         # ---------------------------------------------------------------------
@@ -751,6 +784,10 @@
                     newkey = 'classifier__' + key
                     param_grid[newkey] = param_grid.pop(key)
 
+            # print(param_grid)
+            # print(inner)
+            # gscript.fatal('Põmm!')
+
             # create grid search method
             clf = GridSearchCV(
                 estimator=clf, param_grid=param_grid, scoring=search_scorer,
@@ -793,7 +830,14 @@
         # If cv > 1 then use cross-validation to generate performance measures
         if cv > 1 and tune_only is not True:
             if mode == 'classification' and cv > np.histogram(
-                    y, bins=len(np.unique(y)))[0].min():
+		    # See oli algselt ja ajas jama, kui klassikoodides olid  +2 suuremad augud
+                    # y, bins=len(np.unique(y)))[0].min():
+		    y, bins=np.unique(y))[0].min():
+		# print(np.unique(y))
+		# print(len(np.unique(y)))
+		# print(np.histogram(y, bins=len(np.unique(y)))[0].min())
+		print(np.histogram(y, bins=len(np.unique(y))))
+		print(np.histogram(y, bins=np.unique(y)))
                 gscript.message(os.linesep)
                 gscript.message('Number of cv folds is greater than number of '
                                 'samples in some classes. Cross-validation is being'



More information about the grass-commit mailing list