[GRASS-SVN] r70212 - grass-addons/grass7/raster/r.randomforest

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Jan 2 14:14:28 PST 2017


Author: spawley
Date: 2017-01-02 14:14:28 -0800 (Mon, 02 Jan 2017)
New Revision: 70212

Modified:
   grass-addons/grass7/raster/r.randomforest/r.randomforest.py
Log:
'lazy import of all external dependencies in r.randomforest to avoid problems when compiling'

Modified: grass-addons/grass7/raster/r.randomforest/r.randomforest.py
===================================================================
--- grass-addons/grass7/raster/r.randomforest/r.randomforest.py	2017-01-02 20:17:12 UTC (rev 70211)
+++ grass-addons/grass7/raster/r.randomforest/r.randomforest.py	2017-01-02 22:14:28 UTC (rev 70212)
@@ -259,37 +259,6 @@
 from grass.pygrass.raster.buffer import Buffer
 from grass.pygrass.modules.shortcuts import raster as r
 
-try:
-    import sklearn
-    from sklearn.externals import joblib
-    from sklearn import metrics
-    from sklearn import preprocessing
-    from sklearn.model_selection import StratifiedKFold
-    from sklearn.model_selection import GroupKFold
-    from sklearn.model_selection import train_test_split
-    from sklearn.model_selection import GridSearchCV
-    from sklearn.feature_selection import SelectKBest
-    from sklearn.feature_selection import f_classif
-    from sklearn.utils import shuffle
-    from sklearn.cluster import KMeans
-    
-    from sklearn.linear_model import LogisticRegression
-    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
-    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
-    from sklearn.naive_bayes import GaussianNB
-    from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
-    from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
-    from sklearn.ensemble import GradientBoostingClassifier
-    from sklearn.ensemble import GradientBoostingRegressor
-    from sklearn.svm import SVC
-
-except:
-    grass.fatal("Scikit learn is not installed")
-
-if (sklearn.__version__) < 0.18:
-    grass.fatal("Scikit learn 0.18 or newer is required")
-
-
 def cleanup():
 
     grass.run_command("g.remove", name='tmp_clfmask',
@@ -303,6 +272,16 @@
                       min_samples_split, min_samples_leaf,
                       n_estimators, subsample, learning_rate):
     
+    from sklearn.linear_model import LogisticRegression
+    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
+    from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
+    from sklearn.naive_bayes import GaussianNB
+    from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
+    from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
+    from sklearn.ensemble import GradientBoostingClassifier
+    from sklearn.ensemble import GradientBoostingRegressor
+    from sklearn.svm import SVC
+    
     classifiers = {
         'SVC': SVC(C=C, probability=True, random_state=random_state),
         'LogisticRegression': LogisticRegression(C=C, class_weight=class_weight,
@@ -460,6 +439,8 @@
 
 
 def sample_predictors(response, predictors, shuffle_data, lowmem, random_state):
+    
+    from sklearn.utils import shuffle
 
     """
     Samples a list of GRASS rasters using a labelled raster
@@ -705,7 +686,11 @@
 
 
 def cross_val_classification(clf, X, y, group_ids, cv, rstate):
-
+    
+    from sklearn.model_selection import StratifiedKFold
+    from sklearn.model_selection import GroupKFold
+    from sklearn import metrics
+    
     """
     Stratified Kfold cross-validation
     Generates several scoring_metrics
@@ -788,6 +773,9 @@
 
 
 def tune_split(X, y, Id, estimator, metric, params, test_size, random_state):
+    
+    from sklearn.model_selection import train_test_split
+    from sklearn.model_selection import GridSearchCV    
 
     if Id is None:
         X, X_devel, y, y_devel = train_test_split(X, y, test_size=test_size,
@@ -806,6 +794,9 @@
 
 
 def feature_importances(clf, X, y):
+    
+    from sklearn.feature_selection import SelectKBest
+    from sklearn.feature_selection import f_classif
 
     try:
         clfimp = clf.feature_importances_
@@ -820,6 +811,9 @@
 def sample_training_data(roi, maplist, cv, cvtype, model_load,
                          load_training, save_training, lowmem, random_state):
     
+    from sklearn.externals import joblib
+    from sklearn.cluster import KMeans
+    
     # load the model or training data
     if model_load != '':
         clf = joblib.load(model_load)
@@ -867,6 +861,13 @@
 
 
 def main():
+    
+    try:
+        from sklearn import preprocessing
+        from sklearn import metrics
+        from sklearn.externals import joblib
+    except:
+        grass.fatal("Scikit learn 0.18 or newer is not installed")
 
     """
     GRASS options and flags



More information about the grass-commit mailing list