[GRASS-SVN] r70492 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Feb 6 17:24:39 PST 2017


Author: spawley
Date: 2017-02-06 17:24:39 -0800 (Mon, 06 Feb 2017)
New Revision: 70492

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'added XGBoost as an optional classifier to r.learn.ml'

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-07 00:29:21 UTC (rev 70491)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-07 01:24:39 UTC (rev 70492)
@@ -48,7 +48,7 @@
 #% label: Classifier
 #% description: Supervised learning model to use
 #% answer: RandomForestClassifier
-#% options: LogisticRegression,LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis,GaussianNB,DecisionTreeClassifier,DecisionTreeRegressor,RandomForestClassifier,RandomForestRegressor,GradientBoostingClassifier,GradientBoostingRegressor,SVC,EarthClassifier,EarthRegressor
+#% options: LogisticRegression,LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis,GaussianNB,DecisionTreeClassifier,DecisionTreeRegressor,RandomForestClassifier,RandomForestRegressor,GradientBoostingClassifier,GradientBoostingRegressor,SVC,EarthClassifier,EarthRegressor,XGBClassifier,XGBRegressor
 #%end
 
 #%option
@@ -1045,12 +1045,31 @@
                            'EarthRegressor': Earth(max_degree=max_degree)}
         except:
             grass.fatal('Py-earth package not installed')
+            
+    elif estimator == 'XGBClassifier' or estimator == 'XGBRegressor':
+        try:
+            from xgboost import XGBClassifier, XGBRegressor
+
+            if max_depth is None:
+                max_depth = int(3)
+
+            classifiers = {'XGBClassifier': XGBClassifier(learning_rate=learning_rate,
+                                                          n_estimators=n_estimators,
+                                                          max_depth=max_depth,
+                                                          subsample=subsample),
+                           'XGBRegressor': XGBRegressor(learning_rate=learning_rate,
+                                                        n_estimators=n_estimators,
+                                                        max_depth=max_depth,
+                                                        subsample=subsample)}
+        except:
+            grass.fatal('Py-earth package not installed')
     else:
         # core sklearn classifiers go here
         classifiers = {
             'SVC': SVC(C=C, probability=True, random_state=random_state),
             'LogisticRegression':
-                LogisticRegression(C=C, random_state=random_state, n_jobs=-1),
+                LogisticRegression(C=C, random_state=random_state, n_jobs=-1,
+                                   fit_intercept=True),
             'DecisionTreeClassifier':
                 DecisionTreeClassifier(max_depth=max_depth,
                                        max_features=max_features,
@@ -1113,6 +1132,7 @@
         or estimator == 'LinearDiscriminantAnalysis' \
         or estimator == 'QuadraticDiscriminantAnalysis' \
         or estimator == 'EarthClassifier' \
+        or estimator == 'XGBClassifier' \
             or estimator == 'SVC':
             mode = 'classification'
     else:
@@ -1506,7 +1526,7 @@
             param_grid['max_depth'] = [int(i) for i in max_depth.split(',')]
             max_depth = None
         else:
-            max_depth = float(max_depth)
+            max_depth = int(max_depth)
 
     max_features = options['max_features']
     if max_features == '':



More information about the grass-commit mailing list