[GRASS-SVN] r70466 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Tue Jan 31 22:48:13 PST 2017
Author: spawley
Date: 2017-01-31 22:48:13 -0800 (Tue, 31 Jan 2017)
New Revision: 70466
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'fixed bug with estimator being altered via cross-validation'
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-01 05:16:32 UTC (rev 70465)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-01 06:48:13 UTC (rev 70466)
@@ -289,6 +289,7 @@
#%rules
#% exclusive: trainingmap,load_model
#% exclusive: load_training,save_training
+#% exclusive: trainingmap,load_training
#%end
import atexit
@@ -297,6 +298,7 @@
import copy
import grass.script as grass
import tempfile
+from copy import deepcopy
from grass.pygrass.modules.shortcuts import imagery as im
from grass.pygrass.modules.shortcuts import raster as r
from subprocess import PIPE
@@ -624,6 +626,9 @@
random_state: Seed to pass to the random number generator
"""
+ # create copy of fitting estimator for cross-val fitting
+ fit_train = deepcopy(self.estimator)
+
# dictionary of lists to store metrics
if scorers == 'binary':
self.scores = {
@@ -693,11 +698,11 @@
param_search = False
if self.groups is not None and param_search is True:
- fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+ fit_train.fit(X_train, y_train, groups=groups_train)
else:
- fit = self.estimator.fit(X_train, y_train)
+ fit_train.fit(X_train, y_train)
- y_pred = fit.predict(X_test)
+ y_pred = fit_train.predict(X_test)
y_test_agg = np.append(y_test_agg, y_test)
y_pred_agg = np.append(y_pred_agg, y_pred)
@@ -710,7 +715,7 @@
self.scores['accuracy'],
metrics.accuracy_score(y_test, y_pred))
- y_pred_proba = fit.predict_proba(X_test)[:, 1]
+ y_pred_proba = fit_train.predict_proba(X_test)[:, 1]
self.scores['auc'] = np.append(
self.scores['auc'],
metrics.roc_auc_score(y_test, y_pred_proba))
@@ -753,12 +758,12 @@
if feature_importances is True:
if (self.fimp==0).all() == True:
self.fimp = self.varImp_permutation(
- fit, X_test, y_test, n_permutations, scorers,
+ fit_train, X_test, y_test, n_permutations, scorers,
random_state)
else:
self.fimp = np.row_stack(
(self.fimp, self.varImp_permutation(
- fit, X_test, y_test,
+ fit_train, X_test, y_test,
n_permutations, scorers, random_state)))
self.scores_cm = metrics.classification_report(y_test_agg, y_pred_agg)
@@ -766,8 +771,6 @@
# convert onehot-encoded feature importances back to original vars
if self.fimp is not None and self.enc is not None:
- from copy import deepcopy
-
# get start,end positions of each suite of onehot-encoded vars
feature_ranges = deepcopy(self.enc.feature_indices_)
for i in range(0, len(self.enc.feature_indices_)-1):
@@ -1313,7 +1316,7 @@
# because cross-validation will be performed spatially
# ---------------------------------------------------------------
if group_raster != '':
- maplist2 = copy.deepcopy(maplist)
+ maplist2 = deepcopy(maplist)
maplist2.append(group_raster)
X, y, sample_coords = sample_predictors(response=response,
predictors=maplist2,
More information about the grass-commit
mailing list