[GRASS-SVN] r70490 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Feb 6 12:15:10 PST 2017
Author: spawley
Date: 2017-02-06 12:15:10 -0800 (Mon, 06 Feb 2017)
New Revision: 70490
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'added flag to impute missing training data to r.learn.ml'
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-05 21:11:16 UTC (rev 70489)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-06 20:15:10 UTC (rev 70490)
@@ -141,6 +141,12 @@
#% guisection: Optional
#%end
+#%flag
+#% key: i
+#% label: Impute missing values in training data
+#% guisection: Optional
+#%end
+
#%option string
#% key: categorymaps
#% required: no
@@ -295,7 +301,6 @@
import atexit
import os
import numpy as np
-import copy
import grass.script as grass
import tempfile
from copy import deepcopy
@@ -1172,10 +1177,11 @@
return(X, y, groups)
-def sample_predictors(response, predictors, shuffle_data=True, lowmem=False,
- random_state=None):
+def sample_predictors(response, predictors, impute=False, shuffle_data=True,
+ lowmem=False, random_state=None):
from sklearn.utils import shuffle
+ from sklearn.preprocessing import Imputer
"""
Samples a list of GRASS rasters using a labelled raster
@@ -1261,6 +1267,11 @@
# convert indexes of training pixels from tuple to n*2 np array
is_train = np.array(is_train).T
+ # impute missing values
+ if impute is True:
+ missing = Imputer(strategy='median')
+ training_data = missing.fit_transform(training_data)
+
# Remove nan rows from training data
X = training_data[~np.isnan(training_data).any(axis=1)]
y = training_labels[~np.isnan(training_data).any(axis=1)]
@@ -1277,7 +1288,8 @@
def sample_training_data(response, maplist, group_raster='', n_partitions=3,
- cvtype='', lowmem=False, random_state=None):
+ cvtype='', impute=False, lowmem=False,
+ random_state=None):
"""
Samples predictor and optional group id raster for cross-val
@@ -1320,6 +1332,7 @@
maplist2.append(group_raster)
X, y, sample_coords = sample_predictors(response=response,
predictors=maplist2,
+ impute=impute,
shuffle_data=False,
lowmem=False,
random_state=random_state)
@@ -1341,6 +1354,7 @@
else:
X, y, sample_coords = sample_predictors(
response=response, predictors=maplist,
+ impute=impute,
shuffle_data=True,
lowmem=lowmem,
random_state=random_state)
@@ -1418,6 +1432,7 @@
tune_cv = int(options['tune_cv'])
n_permutations = int(options['n_permutations'])
lowmem = flags['l']
+ impute = flags['i']
errors_file = options['errors_file']
fimp_file = options['fimp_file']
balance = flags['b']
@@ -1534,7 +1549,7 @@
else:
X, y, group_id = sample_training_data(
response, maplist, group_raster, n_partitions, cvtype,
- lowmem, random_state)
+ impute, lowmem, random_state)
# option to save extracted data to .csv file
if save_training != '':
More information about the grass-commit
mailing list