[GRASS-SVN] r70492 - grass-addons/grass7/raster/r.learn.ml
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Feb 6 17:24:39 PST 2017
Author: spawley
Date: 2017-02-06 17:24:39 -0800 (Mon, 06 Feb 2017)
New Revision: 70492
Modified:
grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
Log:
'added XGBoost as an optional classifier to r.learn.ml'
Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-07 00:29:21 UTC (rev 70491)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py 2017-02-07 01:24:39 UTC (rev 70492)
@@ -48,7 +48,7 @@
#% label: Classifier
#% description: Supervised learning model to use
#% answer: RandomForestClassifier
-#% options: LogisticRegression,LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis,GaussianNB,DecisionTreeClassifier,DecisionTreeRegressor,RandomForestClassifier,RandomForestRegressor,GradientBoostingClassifier,GradientBoostingRegressor,SVC,EarthClassifier,EarthRegressor
+#% options: LogisticRegression,LinearDiscriminantAnalysis,QuadraticDiscriminantAnalysis,GaussianNB,DecisionTreeClassifier,DecisionTreeRegressor,RandomForestClassifier,RandomForestRegressor,GradientBoostingClassifier,GradientBoostingRegressor,SVC,EarthClassifier,EarthRegressor,XGBClassifier,XGBRegressor
#%end
#%option
@@ -1045,12 +1045,31 @@
'EarthRegressor': Earth(max_degree=max_degree)}
except:
grass.fatal('Py-earth package not installed')
+
+ elif estimator == 'XGBClassifier' or estimator == 'XGBRegressor':
+ try:
+ from xgboost import XGBClassifier, XGBRegressor
+
+ if max_depth is None:
+ max_depth = int(3)
+
+ classifiers = {'XGBClassifier': XGBClassifier(learning_rate=learning_rate,
+ n_estimators=n_estimators,
+ max_depth=max_depth,
+ subsample=subsample),
+ 'XGBRegressor': XGBRegressor(learning_rate=learning_rate,
+ n_estimators=n_estimators,
+ max_depth=max_depth,
+ subsample=subsample)}
+ except:
+ grass.fatal('Py-earth package not installed')
else:
# core sklearn classifiers go here
classifiers = {
'SVC': SVC(C=C, probability=True, random_state=random_state),
'LogisticRegression':
- LogisticRegression(C=C, random_state=random_state, n_jobs=-1),
+ LogisticRegression(C=C, random_state=random_state, n_jobs=-1,
+ fit_intercept=True),
'DecisionTreeClassifier':
DecisionTreeClassifier(max_depth=max_depth,
max_features=max_features,
@@ -1113,6 +1132,7 @@
or estimator == 'LinearDiscriminantAnalysis' \
or estimator == 'QuadraticDiscriminantAnalysis' \
or estimator == 'EarthClassifier' \
+ or estimator == 'XGBClassifier' \
or estimator == 'SVC':
mode = 'classification'
else:
@@ -1506,7 +1526,7 @@
param_grid['max_depth'] = [int(i) for i in max_depth.split(',')]
max_depth = None
else:
- max_depth = float(max_depth)
+ max_depth = int(max_depth)
max_features = options['max_features']
if max_features == '':
More information about the grass-commit
mailing list