[GRASS-SVN] r70361 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Thu Jan 12 21:56:55 PST 2017
Author: spawley
Date: 2017-01-12 21:56:55 -0800 (Thu, 12 Jan 2017)
New Revision: 70361
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'fix bug with balancing during cross-validation'
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-01-13 04:51:43 UTC (rev 70360)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-01-13 05:56:55 UTC (rev 70361)
@@ -224,14 +224,6 @@
#%end
#%option
-#% key: n_iter
-#% type: integer
-#% description: Number of randomized parameter tuning steps
-#% answer: 1
-#% guisection: Optional
-#%end
-
-#%option
#% key: tune_cv
#% type: integer
#% description: Number of cross-validation folds used for parameter tuning
@@ -621,7 +613,7 @@
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import GroupKFold
- from sklearn.model_selection import RandomizedSearchCV
+ from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn import metrics
"""
@@ -683,23 +675,31 @@
X_train, X_test = self.X[train_indices], self.X[test_indices]
y_train, y_test = self.y[train_indices], self.y[test_indices]
- # also get indices of groups for the training partition
- if self.groups is not None:
- groups_train = self.groups[train_indices]
-
# balance the fold
if self.balance == True:
X_train, y_train = self.random_oversampling(X_train, y_train, random_state=random_state)
if self.groups is not None:
+ groups_train = self.groups[train_indices]
groups_train, _ = self.random_oversampling(
- groups_train, y_train, random_state=random_state)
-
+ groups_train, self.y[train_indices], random_state=random_state)
+
+ else:
+ # also get indices of groups for the training partition
+ if self.groups is not None:
+ groups_train = self.groups[train_indices]
+
# fit the model on the training data and predict the test data
# need the groups parameter because the estimator can be a
# RandomizedSearchCV estimator where cv=GroupKFold
- if self.groups is not None and isinstance(self.estimator, RandomizedSearchCV):
- fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+ if isinstance(self.estimator, RandomizedSearchCV) == True \
+ or isinstance(self.estimator, GridSearchCV):
+ param_search = True
else:
+ param_search = False
+
+ if self.groups is not None and param_search == True:
+ fit = self.estimator.fit(X_train, y_train, groups=groups_train)
+ else:
fit = self.estimator.fit(X_train, y_train)
y_pred = fit.predict(X_test)
More information about the grass-commit
mailing list