[GRASS-SVN] r70212 - grass-addons/grass7/raster/r.randomforest
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Jan 2 14:14:28 PST 2017
Author: spawley
Date: 2017-01-02 14:14:28 -0800 (Mon, 02 Jan 2017)
New Revision: 70212
Modified:
grass-addons/grass7/raster/r.randomforest/r.randomforest.py
Log:
'lazy import of all external dependencies in r.randomforest to avoid problems when compiling'
Modified: grass-addons/grass7/raster/r.randomforest/r.randomforest.py
===================================================================
--- grass-addons/grass7/raster/r.randomforest/r.randomforest.py 2017-01-02 20:17:12 UTC (rev 70211)
+++ grass-addons/grass7/raster/r.randomforest/r.randomforest.py 2017-01-02 22:14:28 UTC (rev 70212)
@@ -259,37 +259,6 @@
from grass.pygrass.raster.buffer import Buffer
from grass.pygrass.modules.shortcuts import raster as r
-try:
- import sklearn
- from sklearn.externals import joblib
- from sklearn import metrics
- from sklearn import preprocessing
- from sklearn.model_selection import StratifiedKFold
- from sklearn.model_selection import GroupKFold
- from sklearn.model_selection import train_test_split
- from sklearn.model_selection import GridSearchCV
- from sklearn.feature_selection import SelectKBest
- from sklearn.feature_selection import f_classif
- from sklearn.utils import shuffle
- from sklearn.cluster import KMeans
-
- from sklearn.linear_model import LogisticRegression
- from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
- from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
- from sklearn.naive_bayes import GaussianNB
- from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
- from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
- from sklearn.ensemble import GradientBoostingClassifier
- from sklearn.ensemble import GradientBoostingRegressor
- from sklearn.svm import SVC
-
-except:
- grass.fatal("Scikit learn is not installed")
-
-if (sklearn.__version__) < 0.18:
- grass.fatal("Scikit learn 0.18 or newer is required")
-
-
def cleanup():
grass.run_command("g.remove", name='tmp_clfmask',
@@ -303,6 +272,16 @@
min_samples_split, min_samples_leaf,
n_estimators, subsample, learning_rate):
+ from sklearn.linear_model import LogisticRegression
+ from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
+ from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
+ from sklearn.naive_bayes import GaussianNB
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
+ from sklearn.ensemble import GradientBoostingClassifier
+ from sklearn.ensemble import GradientBoostingRegressor
+ from sklearn.svm import SVC
+
classifiers = {
'SVC': SVC(C=C, probability=True, random_state=random_state),
'LogisticRegression': LogisticRegression(C=C, class_weight=class_weight,
@@ -460,6 +439,8 @@
def sample_predictors(response, predictors, shuffle_data, lowmem, random_state):
+
+ from sklearn.utils import shuffle
"""
Samples a list of GRASS rasters using a labelled raster
@@ -705,7 +686,11 @@
def cross_val_classification(clf, X, y, group_ids, cv, rstate):
-
+
+ from sklearn.model_selection import StratifiedKFold
+ from sklearn.model_selection import GroupKFold
+ from sklearn import metrics
+
"""
Stratified Kfold cross-validation
Generates several scoring_metrics
@@ -788,6 +773,9 @@
def tune_split(X, y, Id, estimator, metric, params, test_size, random_state):
+
+ from sklearn.model_selection import train_test_split
+ from sklearn.model_selection import GridSearchCV
if Id is None:
X, X_devel, y, y_devel = train_test_split(X, y, test_size=test_size,
@@ -806,6 +794,9 @@
def feature_importances(clf, X, y):
+
+ from sklearn.feature_selection import SelectKBest
+ from sklearn.feature_selection import f_classif
try:
clfimp = clf.feature_importances_
@@ -820,6 +811,9 @@
def sample_training_data(roi, maplist, cv, cvtype, model_load,
load_training, save_training, lowmem, random_state):
+ from sklearn.externals import joblib
+ from sklearn.cluster import KMeans
+
# load the model or training data
if model_load != '':
clf = joblib.load(model_load)
@@ -867,6 +861,13 @@
def main():
+
+ try:
+ from sklearn import preprocessing
+ from sklearn import metrics
+ from sklearn.externals import joblib
+ except:
+ grass.fatal("Scikit learn 0.18 or newer is not installed")
"""
GRASS options and flags
More information about the grass-commit
mailing list