[GRASS-SVN] r70490 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Feb 6 12:15:10 PST 2017


Author: spawley
Date: 2017-02-06 12:15:10 -0800 (Mon, 06 Feb 2017)
New Revision: 70490

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'added flag to impute missing training data to r.learn.ml'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-05 21:11:16 UTC (rev 70489)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-06 20:15:10 UTC (rev 70490)
@@ -141,6 +141,12 @@
 #% guisection: Optional
 #%end
 
+#%flag
+#% key: i
+#% label: Impute missing values in training data
+#% guisection: Optional
+#%end
+
 #%option string
 #% key: categorymaps
 #% required: no
@@ -295,7 +301,6 @@
 import atexit
 import os
 import numpy as np
-import copy
 import grass.script as grass
 import tempfile
 from copy import deepcopy
@@ -1172,10 +1177,11 @@
     return(X, y, groups)
 
 
-def sample_predictors(response, predictors, shuffle_data=True, lowmem=False,
-                      random_state=None):
+def sample_predictors(response, predictors, impute=False, shuffle_data=True,
+                      lowmem=False, random_state=None):
 
     from sklearn.utils import shuffle
+    from sklearn.preprocessing import Imputer
 
     """
     Samples a list of GRASS rasters using a labelled raster
@@ -1261,6 +1267,11 @@
     # convert indexes of training pixels from tuple to n*2 np array
     is_train = np.array(is_train).T
 
+    # impute missing values
+    if impute is True:
+        missing = Imputer(strategy='median')
+        training_data = missing.fit_transform(training_data)
+
     # Remove nan rows from training data
     X = training_data[~np.isnan(training_data).any(axis=1)]
     y = training_labels[~np.isnan(training_data).any(axis=1)]
@@ -1277,7 +1288,8 @@
 
 
 def sample_training_data(response, maplist, group_raster='', n_partitions=3,
-                         cvtype='', lowmem=False, random_state=None):
+                         cvtype='', impute=False, lowmem=False,
+                         random_state=None):
 
     """
     Samples predictor and optional group id raster for cross-val
@@ -1320,6 +1332,7 @@
         maplist2.append(group_raster)
         X, y, sample_coords = sample_predictors(response=response,
                                                 predictors=maplist2,
+                                                impute=impute,
                                                 shuffle_data=False,
                                                 lowmem=False,
                                                 random_state=random_state)
@@ -1341,6 +1354,7 @@
     else:
         X, y, sample_coords = sample_predictors(
             response=response, predictors=maplist,
+            impute=impute,
             shuffle_data=True,
             lowmem=lowmem,
             random_state=random_state)
@@ -1418,6 +1432,7 @@
     tune_cv = int(options['tune_cv'])
     n_permutations = int(options['n_permutations'])
     lowmem = flags['l']
+    impute = flags['i']
     errors_file = options['errors_file']
     fimp_file = options['fimp_file']
     balance = flags['b']
@@ -1534,7 +1549,7 @@
         else:
             X, y, group_id = sample_training_data(
                 response, maplist, group_raster, n_partitions, cvtype,
-                lowmem, random_state)
+                impute, lowmem, random_state)
 
         # option to save extracted data to .csv file
         if save_training != '':



More information about the grass-commit mailing list