[GRASS-SVN] r70551 - grass-addons/grass7/raster/r.learn.ml

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Feb 13 08:43:14 PST 2017


Author: spawley
Date: 2017-02-13 08:43:14 -0800 (Mon, 13 Feb 2017)
New Revision: 70551

Modified:
   grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
   grass-addons/grass7/raster/r.learn.ml/raster_learning.py
Log:
r.learn.ml refactoring of OOP to be more consistent with scikit-learn philisophy; cross-validation command output also reformatted to tab-delineated for easy cut and paste

Modified: grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-13 15:31:59 UTC (rev 70550)
+++ grass-addons/grass7/raster/r.learn.ml/r.learn.ml.py	2017-02-13 16:43:14 UTC (rev 70551)
@@ -65,7 +65,7 @@
 #% type: integer
 #% description: Number of features to consider during splitting for tree based classifiers. Default is sqrt(n_features) for classification, and n_features for regression
 #% required: no
-#% answer:
+#% answer:0
 #% multiple: yes
 #% guisection: Classifier Parameters
 #%end
@@ -75,7 +75,7 @@
 #% type: integer
 #% description: Optionally specifiy maximum tree depth. Otherwise full-growing occurs for decision trees and random forests, and max_depth=3 for gradient boosting
 #% required: no
-#% answer:
+#% answer:0
 #% multiple: yes
 #% guisection: Classifier Parameters
 #%end
@@ -230,14 +230,6 @@
 #%end
 
 #%option
-#% key: tune_cv
-#% type: integer
-#% description: Number of cross-validation folds used for parameter tuning
-#% answer: 3
-#% guisection: Optional
-#%end
-
-#%option
 #% key: n_permutations
 #% type: integer
 #% description: Number of permutations to perform for feature importances
@@ -299,6 +291,7 @@
 #%end
 
 import atexit
+import os
 import numpy as np
 import grass.script as grass
 from copy import deepcopy
@@ -308,31 +301,39 @@
 set_path('r.learn.ml')
 from raster_learning import train, model_classifiers
 from raster_learning import save_training_data, load_training_data
-from raster_learning import extract, maps_from_group
+from raster_learning import extract, maps_from_group, random_oversampling
 
+tmp_rast = []
 
+
 def cleanup():
-    grass.run_command("g.remove", name='tmp_clfmask',
-                      flags="f", type="raster", quiet=True)
-    grass.run_command("g.remove", name='tmp_roi_clumped',
-                      flags="f", type="raster", quiet=True)
+    for rast in tmp_rast:
+        grass.run_command("g.remove", rast=rast, quiet=True)
 
 
 def main():
 
+    """
+    Lazy imports for main------------------------------------------------------
+    """
+
     try:
         from sklearn.externals import joblib
         from sklearn.cluster import KMeans
         from sklearn.metrics import make_scorer, cohen_kappa_score
+        from sklearn.model_selection import StratifiedKFold, GroupKFold
+        from sklearn.preprocessing import StandardScaler
+        from sklearn.model_selection import GridSearchCV
+        import warnings
+        warnings.filterwarnings('ignore')  # turn off UndefinedMetricWarning
     except:
         grass.fatal("Scikit learn 0.18 or newer is not installed")
 
     """
-    GRASS options and flags
-    -----------------------
+    GRASS options and flags----------------------------------------------------
     """
 
-    # General options and flags
+    # General options and flags -----------------------------------------------
     group = options['group']
     response = options['trainingmap']
     output = options['output']
@@ -352,7 +353,6 @@
     load_training = options['load_training']
     save_training = options['save_training']
     importances = flags['f']
-    tune_cv = int(options['tune_cv'])
     n_permutations = int(options['n_permutations'])
     lowmem = flags['l']
     impute = flags['i']
@@ -365,6 +365,7 @@
     else:
         categorymaps = None
 
+    # classifier options and parameter grid settings --------------------------
     param_grid = {'C': None,
                   'min_samples_split': None,
                   'min_samples_leaf': None,
@@ -375,7 +376,6 @@
                   'max_features': None,
                   'max_degree': None}
 
-    # classifier options
     C = options['c']
     if ',' in C:
         param_grid['C'] = [float(i) for i in C.split(',')]
@@ -422,7 +422,7 @@
         subsample = float(subsample)
 
     max_depth = options['max_depth']
-    if max_depth == '':
+    if max_depth == '0':
         max_depth = None
     else:
         if ',' in max_depth:
@@ -432,7 +432,7 @@
             max_depth = int(max_depth)
 
     max_features = options['max_features']
-    if max_features == '':
+    if max_features == '0':
         max_features = 'auto'
     else:
         if ',' in max_features:
@@ -449,15 +449,17 @@
     else:
         max_degree = int(max_degree)
 
+    # remove empty items from the param_grid dict
+    param_grid = {k: v for k, v in param_grid.iteritems() if v is not None}
+
     if importances is True and cv == 1:
         grass.fatal('Feature importances require cross-validation cv > 1')
 
-    # fetch individual raster names from group
+    # fetch individual raster names from group --------------------------------
     maplist, map_names = maps_from_group(group)
 
     """
-    Sample training data and group ids
-    --------------------
+    Sample training data and group ids ----------------------------------------
     """
 
     if model_load == '':
@@ -466,13 +468,17 @@
         if load_training != '':
             X, y, group_id = load_training_data(load_training)
         else:
+            grass.message('Extracting training data')
+
             # clump the labelled pixel raster if labels represent polygons
             # then set the group_raster to the clumped raster to extract the
             # group_ids used in the GroupKFold cross-validation
             if cvtype == 'clumped' and group_raster == '':
-                r.clump(input=response, output='tmp_roi_clumped',
+                clumped_trainingmap = 'tmp_clumped_trainingmap'
+                tmp_rast.append(clumped_trainingmap)
+                r.clump(input=response, output=clumped_trainingmap,
                         overwrite=True, quiet=True)
-                group_raster = 'tmp_roi_clumped'
+                group_raster = clumped_trainingmap
 
             # extract training data from maplist and take group ids from
             # group_raster. Shuffle=False so that group ids and labels align
@@ -488,15 +494,6 @@
                 # take group id from last column and remove from predictors
                 group_id = X[:, -1]
                 X = np.delete(X, -1, axis=1)
-
-                # remove the clumped raster
-                try:
-                    grass.run_command(
-                        "g.remove", name='tmp_roi_clumped', flags="f",
-                        type="raster", quiet=True)
-                except:
-                    pass
-
             else:
                 # extract training data from maplist without group Ids
                 # shuffle this data by default
@@ -527,26 +524,46 @@
             save_training_data(X, y, group_id, save_training)
 
         """
-        Train the classifier
-        --------------------
+        Train the classifier --------------------------------------------------
         """
 
-        # retrieve sklearn classifier object and parameters
-        grass.message("Classifier = " + classifier)
-
+        # retrieve sklearn classifier object and parameters -------------------
         clf, mode = \
             model_classifiers(classifier, random_state,
                               C, max_depth, max_features, min_samples_split,
                               min_samples_leaf, n_estimators,
                               subsample, learning_rate, max_degree)
 
-        # turn off balancing if mode = regression
-        if mode == 'regression' and balance is True:
-            balance = False
+        # set other parameters based on classification or regression ----------
+        if mode == 'classification':
+            if len(np.unique(y)) == 2 and all([0, 1] == np.unique(y)):
+                scorers = 'binary'
+            else:
+                scorers = 'multiclass'
+            search_scorer = make_scorer(cohen_kappa_score)
+            labels = np.unique(y)
+        else:
+            scorers = 'regression'
+            search_scorer = 'r2'
+            labels = None  # no classes
+            balance = False  # no balancing for regression
+            if probability is True:
+                grass.warning(
+                        'Class probabilities only valid for classifications...'
+                        'ignoring')
+                probability = False
 
-        # remove empty items from the param_grid dict
-        param_grid = {k: v for k, v in param_grid.iteritems() if v is not None}
+        # setup model selection model -----------------------------------------
+        if any(param_grid) is True and cv == 1:
+            grass.fatal('Hyperparameter search requires cv > 1')
+        if any(param_grid) is True or cv > 1:
+            if group_id is None:
+                search_cv_method = StratifiedKFold(
+                        n_splits=cv, random_state=random_state)
+            else:
+                search_cv_method = GroupKFold(n_splits=cv)
 
+        # set-up parameter grid for hyperparameter search ---------------------
         # check that dict keys are compatible for the selected classifier
         clf_params = clf.get_params()
         param_grid = { key: value for key, value in param_grid.iteritems() if key in clf_params}
@@ -555,130 +572,172 @@
         # so that the train object will not perform GridSearchCV
         if any(param_grid) is not True:
             param_grid = None
+        else:
+            clf = GridSearchCV(estimator=clf, param_grid=param_grid,
+                               scoring=search_scorer, n_jobs=-1,
+                               cv=search_cv_method)
 
-        # Decide on scoring metric scheme and scorer to for grid search
-        if mode == 'classification':
-            if len(np.unique(y)) == 2 and all([0, 1] == np.unique(y)):
-                scorers = 'binary'
-            else:
-                scorers = 'multiclass'
-            search_scorer = make_scorer(cohen_kappa_score)
+        # preprocessing options -----------------------------------------------
+        if balance is True:
+            sampling = random_oversampling(random_state)
         else:
-            scorers = 'regression'
-            search_scorer = 'r2'
+            sampling = None
 
-        if mode == 'regression' and probability is True:
-            grass.warning(
-                'Class probabilities only valid for classifications...'
-                'ignoring')
-            probability = False
+        if norm_data is True:
+            scaler = StandardScaler()
+        else:
+            scaler = None
 
-        # create training object - onehot-encoded on-the-fly
-        learn_m = train(clf, X, y, group_id, categorical_var=categorymaps,
-                        standardize=norm_data, balance=balance)
+        # create training object ----------------------------------------------
+        learn_m = train(clf, categorical_var=categorymaps,
+                        preprocessing=scaler, sampling=sampling)
 
         """
-        Fitting, parameter search and cross-validation
-        ----------------
+        Fitting, parameter search and cross-validation ------------------------
         """
 
         # fit and parameter search
-        learn_m.fit(param_grid=param_grid, cv=tune_cv, scoring=search_scorer,
-                    random_state=random_state)
+        grass.message(os.linesep)
+        grass.message(('Fitting model using ' + classifier))
+        learn_m.fit(X, y, group_id)
 
         if param_grid is not None:
-            grass.message('\n')
+            grass.message(os.linesep)
             grass.message('Best parameters:')
             grass.message(str(learn_m.estimator.best_params_))
-
+            
         # If cv > 1 then use cross-validation to generate performance measures
         if cv > 1:
-            grass.message('\r\n')
-            grass.message(
-                "Cross validation global performance measures......:")
+            # check that a sufficient number of samples are present per class
+            if cv > np.histogram(y, bins=len(labels))[0].min():
+                grass.message(os.linesep)
+                grass.message('Number of cv folds is greater than number of '
+                              'samples in some classes. Cross-validation is being'
+                              ' skipped')
+            else:
+                grass.message(os.linesep)
+                grass.message(
+                    "Cross validation global performance measures......:")
 
-            # cross-validate the training object
-            learn_m.cross_val(scorers, cv, importances,
-                              n_permutations=n_permutations,
-                              random_state=random_state)
+                # cross-validate the training object
+                learn_m.cross_val(search_cv_method, X, y, group_id,
+                                  scorers, importances,
+                                  n_permutations=n_permutations,
+                                  random_state=random_state)
 
-            if mode == 'classification':
-                if scorers == 'binary':
-                    grass.message(
-                        "Accuracy   :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['accuracy'].mean(),
-                         learn_m.scores['accuracy'].std()))
-                    grass.message(
-                        "AUC        :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['auc'].mean(),
-                         learn_m.scores['auc'].std()))
-                    grass.message(
-                        "Kappa      :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['kappa'].mean(),
-                         learn_m.scores['kappa'].std()))
-                    grass.message(
-                        "Precision  :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['precision'].mean(),
-                         learn_m.scores['precision'].std()))
-                    grass.message(
-                        "Recall     :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['recall'].mean(),
-                         learn_m.scores['recall'].std()))
-                    grass.message(
-                        "Specificity:\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['specificity'].mean(),
-                         learn_m.scores['specificity'].std()))
-                    grass.message(
-                        "F1         :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['f1'].mean(),
-                         learn_m.scores['f1'].std()))
+                scores = learn_m.get_cross_val_scores()
 
-                if scorers == 'multiclass':
-                    grass.message(
-                        "Accuracy:\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['accuracy'].mean(),
-                         learn_m.scores['accuracy'].std()))
-                    grass.message(
-                        "Kappa   :\t%0.3f\t+/-SD\t%0.3f" %
-                        (learn_m.scores['kappa'].mean(),
-                         learn_m.scores['kappa'].std()))
+                if mode == 'classification':
+                    if scorers == 'binary':
+                        grass.message(
+                            "Accuracy   :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['accuracy'].mean(),
+                             scores['accuracy'].std()))
+                        grass.message(
+                            "AUC        :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['auc'].mean(),
+                             scores['auc'].std()))
+                        grass.message(
+                            "Kappa      :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['kappa'].mean(),
+                             scores['kappa'].std()))
+                        grass.message(
+                            "Precision  :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['precision'].mean(),
+                             scores['precision'].std()))
+                        grass.message(
+                            "Recall     :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['recall'].mean(),
+                             scores['recall'].std()))
+                        grass.message(
+                            "Specificity:\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['specificity'].mean(),
+                             scores['specificity'].std()))
+                        grass.message(
+                            "F1         :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['f1'].mean(),
+                             scores['f1'].std()))
 
-                # classification report
-                grass.message("\n")
-                grass.message("Classification report:")
-                grass.message(learn_m.scores_cm)
+                    if scorers == 'multiclass':
+                        # global scores
+                        grass.message(
+                            "Accuracy:\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['accuracy'].mean(),
+                             scores['accuracy'].std()))
+                        grass.message(
+                            "Kappa   :\t%0.3f\t+/-SD\t%0.3f" %
+                            (scores['kappa'].mean(),
+                             scores['kappa'].std()))
 
-            else:
-                grass.message("R2:\t%0.3f\t+/-\t%0.3f" %
-                              (learn_m.scores['r2'].mean(),
-                               learn_m.scores['r2'].std()))
+                        # per class scores
+                        grass.message(os.linesep)
+                        grass.message('Cross validation class performance measures......:')
+                        mat_precision = np.matrix(scores['precision'])
+                        mat_recall = np.matrix(scores['recall'])
+                        mat_f1 = np.matrix(scores['f1'])
 
-            # write cross-validation results for csv file
-            if errors_file != '':
-                try:
-                    import pandas as pd
-                    errors = pd.DataFrame(learn_m.scores)
-                    errors.to_csv(errors_file, mode='w')
-                except:
-                    grass.warning('Pandas is not installed. Pandas is '
-                                  'required to write the cross-validation '
-                                  'results to file')
+                        grass.message('Class \t' + '\t'.join(map(str, labels)))
+                        grass.message(
+                            'Precision mean \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_precision.mean(axis=0), 2)[0])))
+                        grass.message(
+                            'Precision std \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_precision.std(axis=0), 2)[0])))
+                        grass.message(
+                            'Recall mean \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_recall.mean(axis=0), 2)[0])))
+                        grass.message(
+                            'Recall std \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_recall.std(axis=0), 2)[0])))
+                        grass.message(
+                            'F1 score mean \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_f1.mean(axis=0), 2)[0])))
+                        grass.message(
+                            'F1 score std \t' + '\t'.join(
+                                    map(str, np.round(
+                                            mat_f1.std(axis=0), 2)[0])))
 
-            # feature importances
-            if importances is True:
-                grass.message("\r\n")
-                grass.message("Feature importances")
-                grass.message("id" + "\t" + "Raster" + "\t" + "Importance")
+                        # remove perclass scores from dict
+                        del scores['precision']
+                        del scores['recall']
+                        del scores['f1']
 
-                # mean of cross-validation feature importances
-                for i in range(len(learn_m.fimp.mean(axis=0))):
-                    grass.message(
-                        str(i) + "\t" + maplist[i] +
-                        "\t" + str(round(learn_m.fimp.mean(axis=0)[i], 4)))
+                else:
+                    grass.message("R2:\t%0.3f\t+/-\t%0.3f" %
+                                  (scores['r2'].mean(),
+                                   scores['r2'].std()))
 
-                if fimp_file != '':
-                    np.savetxt(fname=fimp_file, X=learn_m.fimp, delimiter=',',
-                               header=','.join(maplist), comments='')
+                # write cross-validation results for csv file
+                if errors_file != '':
+                    try:
+                        import pandas as pd
+                        errors = pd.DataFrame(scores)
+                        errors.to_csv(errors_file, mode='w')
+                    except:
+                        grass.warning('Pandas is not installed. Pandas is '
+                                      'required to write the cross-validation '
+                                      'results to file')
+
+                # feature importances
+                if importances is True:
+                    grass.message(os.linesep)
+                    grass.message("Feature importances")
+                    grass.message("id" + "\t" + "Raster" + "\t" + "Importance")
+
+                    # mean of cross-validation feature importances
+                    for i in range(len(learn_m.fimp.mean(axis=0))):
+                        grass.message(
+                            str(i) + "\t" + maplist[i] +
+                            "\t" + str(round(learn_m.fimp.mean(axis=0)[i], 4)))
+
+                    if fimp_file != '':
+                        np.savetxt(fname=fimp_file, X=learn_m.fimp, delimiter=',',
+                                   header=','.join(maplist), comments='')
     else:
         # load a previously fitted train object
         # -------------------------------------
@@ -686,20 +745,18 @@
             # load a previously fitted model
             learn_m = joblib.load(model_load)
 
-    """
-    Optionally save the fitted model
-    ---------------------
-    """
-
+    # Optionally save the fitted model
     if model_save != '':
         joblib.dump(learn_m, model_save)
 
     """
-    Prediction on the rest of the GRASS rasters in the imagery group
-    ----------------------------------------------------------------
+    Prediction on the rest of the GRASS rasters in the imagery group ----------
     """
+
     if modelonly is not True:
-        learn_m.predict(maplist, output, probability, rowincr)
+        grass.message(os.linesep)
+        grass.message('Predicting raster...')
+        learn_m.predict(maplist, output, labels, probability, rowincr)
     else:
         grass.message("Model built and now exiting")
 

Modified: grass-addons/grass7/raster/r.learn.ml/raster_learning.py
===================================================================
--- grass-addons/grass7/raster/r.learn.ml/raster_learning.py	2017-02-13 15:31:59 UTC (rev 70550)
+++ grass-addons/grass7/raster/r.learn.ml/raster_learning.py	2017-02-13 16:43:14 UTC (rev 70551)
@@ -1,12 +1,9 @@
 # -*- coding: utf-8 -*-
-"""
-Created on Tue Feb  7 09:03:10 2017
 
- at author: steve
-"""
-
 import os
+import scipy
 import numpy as np
+from numpy.random import RandomState
 from copy import deepcopy
 import tempfile
 import grass.script as grass
@@ -17,59 +14,27 @@
 from subprocess import PIPE
 
 
-class train():
+class random_oversampling():
 
-    def __init__(self, estimator, X, y, groups=None, categorical_var=None,
-                 standardize=False, balance=False):
+    def __init__(self, random_state):
         """
-        Train class to perform preprocessing, fitting, parameter search and
-        cross-validation in a single step
+        Balances X, y observations using simple oversampling
 
         Args
         ----
-        estimator: Scikit-learn compatible estimator object
-        X, y: training data and labels as numpy arrays
-        groups: groups to be used for cross-validation
-        categorical_var: 1D list containing indices of categorical predictors
-        standardize: Transform predictors
-        balance: boolean to balance number of classes
+        random_state: Seed to pass onto random number generator
         """
 
-        # fitting data
-        self.estimator = estimator
-        self.X = X
-        self.y = y
-        self.groups = groups
-        self.balance = balance
+        self.random_state = random_state
 
-        # for onehot-encoding
-        self.enc = None
-        self.categorical_var = categorical_var
-        self.category_values = None
-
-        if self.categorical_var:
-            self.__onehotencode()
-
-        # for standardization
-        if standardize is True:
-            self.standardization()
-        else:
-            self.scaler = None
-
-        # for cross-validation scores
-        self.scores = None
-        self.scores_cm = None
-        self.fimp = None
-
-    def __random_oversampling(self, X, y, random_state=None):
+    def fit_sample(self, X, y):
         """
-        Balances X, y observations using simple oversampling
+        Performs equal balancing of response and explanatory variances
 
         Args
         ----
         X: numpy array of training data
         y: 1D numpy array of response data
-        random_state: Seed to pass onto random number generator
 
         Returns
         -------
@@ -77,7 +42,7 @@
         y_resampled: Numpy array of resampled response data
         """
 
-        np.random.seed(seed=random_state)
+        np.random.seed(seed=self.random_state)
 
         # count the number of observations per class
         y_classes = np.unique(y)
@@ -103,10 +68,54 @@
 
         return (X_resampled, y_resampled)
 
-    def __onehotencode(self):
+
+class train():
+
+    def __init__(self, estimator, categorical_var=None,
+                 preprocessing=None, sampling=None):
         """
+        Train class to perform preprocessing, fitting, parameter search and
+        cross-validation in a single step
+
+        Args
+        ----
+        estimator: Scikit-learn compatible estimator object
+        categorical_var: 1D list containing indices of categorical predictors
+        preprocessing: Sklearn preprocessing scaler
+        sampling: Balancing object e.g. from imbalance-learn
+        """
+
+        # fitting data
+        self.estimator = estimator
+
+        # for onehot-encoding
+        self.enc = None
+        self.categorical_var = categorical_var
+        self.category_values = None
+
+        if self.categorical_var:
+            self.__onehotencode()
+
+        # for preprocessing of data
+        self.sampling = sampling
+        self.preprocessing = preprocessing
+
+        # for cross-validation scores
+        self.scores = None
+        self.scores_cm = None
+        self.fimp = None
+        self.mean_tpr = None
+        self.mean_fpr = None
+
+    def __onehotencode(self, X):
+
+        """
         Method to convert a list of categorical arrays in X into a suite of
-        binary predictors which are added to the left of the array
+        binary predictors which are added to the end of the array
+
+        Args
+        ----
+        X: 2D numpy array containing training data
         """
 
         from sklearn.preprocessing import OneHotEncoder
@@ -114,140 +123,89 @@
         # store original range of values
         self.category_values = [0] * len(self.categorical_var)
         for i, cat in enumerate(self.categorical_var):
-            self.category_values[i] = np.unique(self.X[:, cat])
+            self.category_values[i] = np.unique(X[:, cat])
 
         # fit and transform categorical grids to a suite of binary features
         self.enc = OneHotEncoder(categorical_features=self.categorical_var,
                                  sparse=False)
-        self.enc.fit(self.X)
-        self.X = self.enc.transform(self.X)
+        self.enc.fit(X)
+        X = self.enc.transform(X)
 
-    def fit(self, param_distributions=None, param_grid=None,
-            scoring=None, n_iter=3, cv=3, random_state=None):
+        return(X)
 
+    def fit(self, X, y, groups=None):
+
         """
-        Main fit method for the train object. Performs fitting, hyperparameter
-        search and cross-validation in one step (inspired from R's CARET)
+        Main fit method for the train object
 
         Args
         ----
-        param_distributions: continuous parameter distribution to be used in a
-        randomizedCVsearch
-        param_grid: Dist of non-continuous parameters to grid search
-        n_iter: Number of randomized search iterations
-        cv: Number of cross-validation folds for parameter tuning
-        random_state: seed to be used during random number generation
+        X, y: training data and labels as numpy arrays
+        groups: groups to be used for cross-validation
         """
 
         from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
-        from sklearn.model_selection import GroupKFold
 
-        # Balance classes
-        if self.balance is True:
-            X, y = self.__random_oversampling(
-                    self.X, self.y, random_state=random_state)
+        # Balance classes prior to fitting
+        if self.sampling is not None:
+            # balance samples
+            y_original = deepcopy(y)
+            X, y = self.sampling.fit_sample(X, y)
 
-            if self.groups is not None:
-                groups, _ = self.__random_oversampling(
-                    self.groups, self.y, random_state=random_state)
-            else:
-                groups = None
-        else:
-            X = self.X
-            y = self.y
-            groups = self.groups
+            # balance groups if present
+            if groups is not None:
+                groups, _ = self.sampling.fit_sample(
+                        groups.reshape(-1, 1), y_original)
 
-        # Randomized or grid search
-        if param_distributions is not None or param_grid is not None:
+        if self.preprocessing is not None:
+            X = self.__preprocessor(X)
 
-            # use groupkfold for hyperparameter search if groups are present
-            if self.groups is not None:
-                cv_search = GroupKFold(n_splits=cv)
-            else:
-                cv_search = cv
+        if self.categorical_var is not None:
+            X = self.__onehotencode(X)
 
-            # Randomized search
-            if param_distributions is not None:
-                self.estimator = RandomizedSearchCV(
-                    estimator=self.estimator,
-                    param_distributions=param_distributions,
-                    n_iter=n_iter, scoring=scoring,
-                    cv=cv_search)
+        # fit the model on the training data and predict the test data
+        # need the groups parameter because the estimator can be a
+        # RandomizedSearchCV or GridSearchCV estimator where cv=GroupKFold
+        if isinstance(self.estimator, RandomizedSearchCV) \
+                or isinstance(self.estimator, GridSearchCV):
+            param_search = True
+        else:
+            param_search = False
 
-            # Grid Search
-            if param_grid is not None:
-                self.estimator = GridSearchCV(self.estimator,
-                                              param_grid,
-                                              n_jobs=-1, cv=cv_search,
-                                              scoring=scoring)
-
-            # if groups then fit RandomizedSearchCV.fit requires groups param
-            if self.groups is None:
-                self.estimator.fit(X, y)
-            else:
-                self.estimator.fit(X, y, groups=groups)
-
-        # Fitting without parameter search
+        if groups is not None and param_search is True:
+            self.estimator.fit(X, y, groups=groups)
         else:
             self.estimator.fit(X, y)
 
-    def standardization(self):
+    def __preprocessor(self, X):
         """
         Transforms the non-categorical X
+
+        Args
+        ----
+        X; 2D numpy array to transform
         """
 
-        from sklearn.preprocessing import StandardScaler
-
         # create mask so that indices that represent categorical
         # predictors are not selected
         if self.categorical_var is not None:
-            idx = np.arange(self.X.shape[1])
+            idx = np.arange(X.shape[1])
             mask = np.ones(len(idx), dtype=bool)
             mask[self.categorical_var] = False
         else:
-            mask = np.arange(self.X.shape[1])
+            mask = np.arange(X.shape[1])
 
-        X_continuous = self.X[:, mask]
-        self.scaler = StandardScaler()
-        self.scaler.fit(X_continuous)
-        self.X[:, mask] = self.scaler.transform(X_continuous)
+        X_continuous = X[:, mask]
+        self.preprocessing.fit(X=X_continuous)
+        X[:, mask] = self.preprocessing.transform(X_continuous)
 
-    def __pred_func(self, estimator, X_test, y_true, scorers):
-        """
-        Calculates a single performance metric depending on if scorer type
-        is binary, multiclass or regression
+        return(X)
 
-        To be called from the varImp_permutation
+    def varImp_permutation(self, estimator, X_test, y_true,
+                           n_permutations, scorers,
+                           random_state):
 
-        Args
-        ----
-        estimator: fitted estimator on training set
-        X_test: Test training data
-        y_true: Test labelled data
-        scorers: String indicating which type of scorer to be used
         """
-
-        from sklearn import metrics
-
-        if scorers == 'binary':
-            scorer = metrics.roc_auc_score
-            y_pred = estimator.predict_proba(X_test)[:, 1]
-        if scorers == 'multiclass':
-            scorer = metrics.accuracy_score
-            y_pred = estimator.predict(X_test)
-        if scorers == 'regression':
-            scorer = metrics.r2_score
-            y_pred = estimator.predict(X_test)
-
-        score = scorer(y_true, y_pred)
-
-        return (score)
-
-    def __varImp_permutation(self, estimator, X_test, y_true,
-                             n_permutations, scorers,
-                             random_state):
-
-        """
         Method to perform permutation-based feature importance during
         cross-validation (cross-validation is applied externally to this
         method)
@@ -272,11 +230,19 @@
         scores: AUC scores for each predictor following permutation
         """
 
+        from sklearn import metrics
+        if scorers == 'binary' or scorers == 'multiclass':
+            scorer = metrics.accuracy_score
+        if scorers == 'regression':
+            scorer = metrics.r2_score
+
         # calculate score on original variables without permutation
         # determine best metric type for binary/multiclass/regression scenarios
-        best_score = self.__pred_func(estimator, X_test, y_true, scorers)
+        y_pred = estimator.predict(X_test)
+        best_score = scorer(y_true, y_pred)
 
         np.random.seed(seed=random_state)
+        rstate = RandomState(random_state)
         scores = np.zeros((n_permutations, X_test.shape[1]))
 
         # outer loop to repeat the pemutation rep times
@@ -286,11 +252,11 @@
             # difference in auc
             for i in range(X_test.shape[1]):
                 Xscram = np.copy(X_test)
-                Xscram[:, i] = np.random.choice(X_test[:, i], X_test.shape[0])
+                Xscram[:, i] = rstate.choice(X_test[:, i], X_test.shape[0])
 
                 # fit the model on the training data and predict the test data
-                scores[rep, i] = best_score-self.__pred_func(
-                    estimator, Xscram, y_true, scorers)
+                y_pred = estimator.predict(Xscram)
+                scores[rep, i] = best_score-scorer(y_true, y_pred)
                 if scores[rep, i] < 0:
                     scores[rep, i] = 0
 
@@ -301,6 +267,19 @@
 
     def specificity_score(self, y_true, y_pred):
 
+        """
+        Simple method to calculate specificity score
+
+        Args
+        ----
+        y_true: 1D numpy array of truth values
+        y_pred: 1D numpy array of predicted classes
+
+        Returns
+        -------
+        specificity: specificity score
+        """
+
         from sklearn.metrics import confusion_matrix
 
         cm = confusion_matrix(y_true, y_pred)
@@ -314,11 +293,10 @@
 
         return (specificity)
 
-    def cross_val(self, scorers='binary', cv=3, feature_importances=False,
-                  n_permutations=25, random_state=None):
+    def cross_val(self, splitter, X, y, groups=None, scorers='binary',
+                  feature_importances=False, n_permutations=25,
+                  random_state=None):
 
-        from sklearn.model_selection import StratifiedKFold
-        from sklearn.model_selection import GroupKFold
         from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
         from sklearn import metrics
 
@@ -330,17 +308,33 @@
 
         Args
         ----
-        scorers: Suite of performance metrics to use
-        cv: Integer of cross-validation folds
+        splitter: Scikit learn model_selection object, e.g. StratifiedKFold
+        X, y: 2D numpy array of training data and 1D array of labels
+        groups: 1D numpy array of groups to be used for cross-validation
+        scorers: String specifying suite of performance metrics to use
         feature_importances: Boolean to perform permutation-based importances
         n_permutations: Number of permutations during feature importance
         random_state: Seed to pass to the random number generator
         """
 
+        # preprocessing -------------------------------------------------------
+        if self.preprocessing is not None:
+            X = self.__preprocessor(X)
+
+        if self.categorical_var is not None:
+            X = self.__onehotencode(X)
+
         # create copy of fitting estimator for cross-val fitting
         fit_train = deepcopy(self.estimator)
 
-        # dictionary of lists to store metrics
+        # create dictionary of lists to store metrics -------------------------
+        n_classes = len(np.unique(y))
+
+        if scorers == 'accuracy':
+            self.scores = {
+                'accuracy': []
+            }
+
         if scorers == 'binary':
             self.scores = {
                 'accuracy': [],
@@ -355,77 +349,89 @@
         if scorers == 'multiclass':
             self.scores = {
                 'accuracy': [],
-                'f1': [],
-                'kappa': []
-            }
+                'kappa': [],
+                'precision': np.zeros((0, n_classes)),  # scores per sample
+                'recall': np.zeros((0, n_classes)),
+                'f1': np.zeros((0, n_classes))
+                }
 
         if scorers == 'regression':
             self.scores = {
                 'r2': []
             }
 
-        y_test_agg = []
-        y_pred_agg = []
-        self.fimp = np.zeros((cv, self.X.shape[1]))
+        self.mean_tpr = 0
+        self.mean_fpr = np.linspace(0, 1, 100)
 
-        # generate Kfold indices
-        if self.groups is None:
-            k_fold = StratifiedKFold(
-                n_splits=cv,
-                shuffle=False,
-                random_state=random_state).split(self.X, self.y)
+        # create np array to store feature importance scores
+        # for each predictor per fold
+        if feature_importances is True:
+            self.fimp = np.zeros((splitter.get_n_splits(), X.shape[1]))
+            self.fimp[:] = np.nan
+
+        # generate Kfold indices ----------------------------------------------
+
+        if groups is None:
+            k_fold = splitter.split(X, y)
         else:
-            k_fold = GroupKFold(n_splits=cv).split(
-                self.X, self.y, groups=self.groups)
+            k_fold = splitter.split(
+                X, y, groups=groups)
 
+        # train on k-1 folds and test of k folds ------------------------------
+
         for train_indices, test_indices in k_fold:
 
             # get indices for train and test partitions
-            X_train, X_test = self.X[train_indices], self.X[test_indices]
-            y_train, y_test = self.y[train_indices], self.y[test_indices]
+            X_train, X_test = X[train_indices], X[test_indices]
+            y_train, y_test = y[train_indices], y[test_indices]
+            if groups is not None:
+                groups_train = groups[train_indices]
 
             # balance the fold
-            if self.balance is True:
-                X_train, y_train = self.__random_oversampling(
-                        X_train, y_train, random_state=random_state)
-                if self.groups is not None:
-                    groups_train = self.groups[train_indices]
-                    groups_train, _ = self.__random_oversampling(
-                        groups_train, self.y[train_indices],
-                        random_state=random_state)
+            if self.sampling is not None:
+                y_train_original = deepcopy(y_train)
+                X_train, y_train = self.sampling.fit_sample(
+                    X_train, y_train)
 
+                if groups is not None:
+                    groups_train, _ = self.sampling.fit_sample(
+                        groups_train.reshape(-1, 1), y_train_original)
+
             else:
                 # also get indices of groups for the training partition
-                if self.groups is not None:
-                    groups_train = self.groups[train_indices]
+                if groups is not None:
+                    groups_train = groups[train_indices]
 
             # fit the model on the training data and predict the test data
             # need the groups parameter because the estimator can be a
-            # RandomizedSearchCV estimator where cv=GroupKFold
-            if isinstance(self.estimator, RandomizedSearchCV) is True \
-                    or isinstance(self.estimator, GridSearchCV):
+            # RandomizedSearchCV or GridSearchCV estimator where cv=GroupKFold
+            if isinstance(fit_train, RandomizedSearchCV) is True \
+                    or isinstance(fit_train, GridSearchCV):
                 param_search = True
             else:
                 param_search = False
 
-            if self.groups is not None and param_search is True:
+            # train fit_train on training fold
+            if groups is not None and param_search is True:
                 fit_train.fit(X_train, y_train, groups=groups_train)
             else:
                 fit_train.fit(X_train, y_train)
 
+            # prediction of test fold
             y_pred = fit_train.predict(X_test)
-
-            y_test_agg = np.append(y_test_agg, y_test)
-            y_pred_agg = np.append(y_pred_agg, y_pred)
-
             labels = np.unique(y_pred)
 
             # calculate metrics
-            if scorers == 'binary':
+            if scorers == 'accuracy':
                 self.scores['accuracy'] = np.append(
                     self.scores['accuracy'],
                     metrics.accuracy_score(y_test, y_pred))
 
+            elif scorers == 'binary':
+                self.scores['accuracy'] = np.append(
+                    self.scores['accuracy'],
+                    metrics.accuracy_score(y_test, y_pred))
+
                 y_pred_proba = fit_train.predict_proba(X_test)[:, 1]
                 self.scores['auc'] = np.append(
                     self.scores['auc'],
@@ -451,8 +457,9 @@
                     self.scores['kappa'],
                     metrics.cohen_kappa_score(y_test, y_pred))
 
-                self.scores_cm = metrics.classification_report(
-                        y_test_agg, y_pred_agg)
+                fpr, tpr, thresholds = metrics.roc_curve(y_test, y_pred_proba)
+                self.mean_tpr += scipy.interp(self.mean_fpr, fpr, tpr)
+                self.mean_tpr[0] = 0.0
 
             elif scorers == 'multiclass':
 
@@ -464,33 +471,47 @@
                     self.scores['kappa'],
                     metrics.cohen_kappa_score(y_test, y_pred))
 
-                self.scores_cm = metrics.classification_report(
-                        y_test_agg, y_pred_agg)
+                self.scores['precision'] = np.vstack((
+                    self.scores['precision'],
+                    np.array(metrics.precision_score(
+                        y_test, y_pred, average=None))))
 
+                self.scores['recall'] = np.vstack((
+                    self.scores['recall'],
+                    np.array(metrics.recall_score(
+                        y_test, y_pred, average=None))))
+
+                self.scores['f1'] = np.vstack((
+                    self.scores['f1'],
+                    np.array(metrics.f1_score(
+                        y_test, y_pred, average=None))))
+
             elif scorers == 'regression':
                 self.scores['r2'] = np.append(
                     self.scores['r2'], metrics.r2_score(y_test, y_pred))
 
             # feature importances using permutation
             if feature_importances is True:
-                if bool((self.fimp == 0).all()) is True:
-                    self.fimp = self.__varImp_permutation(
+                if bool((np.isnan(self.fimp)).all()) is True:
+                    self.fimp = self.varImp_permutation(
                         fit_train, X_test, y_test, n_permutations, scorers,
                         random_state)
                 else:
                     self.fimp = np.row_stack(
-                        (self.fimp, self.__varImp_permutation(
+                        (self.fimp, self.varImp_permutation(
                             fit_train, X_test, y_test,
                             n_permutations, scorers, random_state)))
 
+        # summarize data ------------------------------------------------------
+
         # convert onehot-encoded feature importances back to original vars
         if self.fimp is not None and self.enc is not None:
 
             # get start,end positions of each suite of onehot-encoded vars
             feature_ranges = deepcopy(self.enc.feature_indices_)
             for i in range(0, len(self.enc.feature_indices_)-1):
-                feature_ranges[i+1] =\
-                    feature_ranges[i] + len(self.category_values[i])
+                feature_ranges[i+1] = feature_ranges[i] + \
+                              len(self.category_values[i])
 
             # take sum of each onehot-encoded feature
             ohe_feature = [0] * len(self.categorical_var)
@@ -498,7 +519,7 @@
 
             for i in range(len(self.categorical_var)):
                 ohe_feature[i] = \
-                    self.fimp[:, feature_ranges[i]:feature_ranges[i+1]]
+                           self.fimp[:, feature_ranges[i]:feature_ranges[i+1]]
                 ohe_sum[i] = ohe_feature[i].sum(axis=1)
 
             # remove onehot-encoded features from the importances array
@@ -507,12 +528,22 @@
 
             # insert summed importances into original positions
             for index in self.categorical_var:
-                self.fimp = np.insert(
-                        self.fimp, np.array(index), ohe_sum[0], axis=1)
+                self.fimp = np.insert(self.fimp, np.array(index),
+                                      ohe_sum[0], axis=1)
 
-    def predict(self, predictors, output, class_probabilities=False,
-                rowincr=25):
+        if scorers == 'binary':
+            self.mean_tpr /= splitter.get_n_splits(X, y)
+            self.mean_tpr[-1] = 1.0
 
+    def get_roc_curve(self):
+        return (self.mean_tpr, self.mean_fpr)
+
+    def get_cross_val_scores(self):
+        return (self.scores)
+
+    def predict(self, predictors, output, labels=None,
+                class_probabilities=False, rowincr=25):
+
         """
         Prediction on list of GRASS rasters using a fitted scikit learn model
 
@@ -527,9 +558,7 @@
         """
 
         # determine output data type and nodata
-        predicted = self.estimator.predict(self.X)
-
-        if bool((predicted % 1 == 0).all()) is True:
+        if labels is not None:
             ftype = 'CELL'
             nodata = -2147483648
         else:
@@ -568,7 +597,6 @@
         if class_probabilities is True:
 
             # determine number of classes
-            labels = np.unique(self.y)
             nclasses = len(labels)
 
             prob_out_raster = [0] * nclasses
@@ -583,118 +611,122 @@
         """
         Prediction using row blocks
         """
+        try:
+            for rowblock in range(0, current.rows, rowincr):
+                grass.percent(rowblock, current.rows, rowincr)
+                # check that the row increment does not exceed the number of rows
+                if rowblock+rowincr > current.rows:
+                    rowincr = current.rows - rowblock
+                img_np_row = np.zeros((rowincr, current.cols, n_features))
+                mask_np_row = np.zeros((rowincr, current.cols))
 
-        for rowblock in range(0, current.rows, rowincr):
+                # loop through each row, and each band
+                # and add these values to the 2D array img_np_row
+                for row in range(rowblock, rowblock+rowincr, 1):
+                    mask_np_row[row-rowblock, :] = np.array(mask_raster[row])
 
-            # check that the row increment does not exceed the number of rows
-            if rowblock+rowincr > current.rows:
-                rowincr = current.rows - rowblock
-            img_np_row = np.zeros((rowincr, current.cols, n_features))
-            mask_np_row = np.zeros((rowincr, current.cols))
+                    for band in range(n_features):
+                        img_np_row[row-rowblock, :, band] = \
+                            np.array(rasstack[band][row])
 
-            # loop through each row, and each band
-            # and add these values to the 2D array img_np_row
-            for row in range(rowblock, rowblock+rowincr, 1):
-                mask_np_row[row-rowblock, :] = np.array(mask_raster[row])
+                mask_np_row[mask_np_row == -2147483648] = np.nan
+                nanmask = np.isnan(mask_np_row)  # True in mask means invalid data
 
-                for band in range(n_features):
-                    img_np_row[row-rowblock, :, band] = \
-                        np.array(rasstack[band][row])
+                # reshape each row-band matrix into a n*m array
+                nsamples = rowincr * current.cols
+                flat_pixels = img_np_row.reshape((nsamples, n_features))
 
-            mask_np_row[mask_np_row == -2147483648] = np.nan
-            nanmask = np.isnan(mask_np_row)  # True in mask means invalid data
+                # remove NaN values and GRASS CELL nodata vals
+                flat_pixels[flat_pixels == -2147483648] = np.nan
+                flat_pixels = np.nan_to_num(flat_pixels)
 
-            # reshape each row-band matrix into a n*m array
-            nsamples = rowincr * current.cols
-            flat_pixels = img_np_row.reshape((nsamples, n_features))
+                # rescale
+                if self.preprocessing is not None:
+                    # create mask so that indices that represent categorical
+                    # predictors are not selected
+                    if self.categorical_var is not None:
+                        idx = np.arange(n_features)
+                        mask = np.ones(len(idx), dtype=bool)
+                        mask[self.categorical_var] = False
+                    else:
+                        mask = np.arange(n_features)
+                    flat_pixels_continuous = flat_pixels[:, mask]
+                    flat_pixels[:, mask] = self.preprocessing.transform(
+                            flat_pixels_continuous)
 
-            # remove NaN values and GRASS CELL nodata vals
-            flat_pixels[flat_pixels == -2147483648] = np.nan
-            flat_pixels = np.nan_to_num(flat_pixels)
+                # onehot-encoding
+                if self.enc is not None:
+                    try:
+                        flat_pixels = self.enc.transform(flat_pixels)
+                    except:
+                        # if this fails it is because the onehot-encoder was fitted
+                        # on the training samples, but the prediction data contains
+                        # new values, i.e. the training data has not sampled all of
+                        # categories
+                        grass.fatal('There are values in the categorical rasters '
+                                    'that are not present in the training data '
+                                    'set, i.e. the training data has not sampled '
+                                    'all of the categories')
 
-            # onehot-encoding
-            if self.enc is not None:
-                try:
-                    flat_pixels = self.enc.transform(flat_pixels)
-                except:
-                    # if this fails it is because the onehot-encoder was fitted
-                    # on the training samples, but the prediction data contains
-                    # new values, i.e. the training data has not sampled all of
-                    # categories
-                    grass.fatal('There are values in the categorical rasters '
-                                'that are not present in the training data '
-                                'set, i.e. the training data has not sampled '
-                                'all of the categories')
+                # perform prediction
+                result = self.estimator.predict(flat_pixels)
+                result = result.reshape((rowincr, current.cols))
 
-            # rescale
-            if self.scaler is not None:
-                # create mask so that indices that represent categorical
-                # predictors are not selected
-                if self.categorical_var is not None:
-                    idx = np.arange(self.X.shape[1])
-                    mask = np.ones(len(idx), dtype=bool)
-                    mask[self.categorical_var] = False
-                else:
-                    mask = np.arange(self.X.shape[1])
-                flat_pixels_continuous = flat_pixels[:, mask]
-                flat_pixels[:, mask] = self.scaler.transform(
-                        flat_pixels_continuous)
+                # replace NaN values so that the prediction does not have a border
+                result = np.ma.masked_array(
+                    result, mask=nanmask, fill_value=-99999)
 
-            # perform prediction
-            result = self.estimator.predict(flat_pixels)
-            result = result.reshape((rowincr, current.cols))
+                # return a copy of result, with masked values filled with a value
+                result = result.filled([nodata])
 
-            # replace NaN values so that the prediction does not have a border
-            result = np.ma.masked_array(
-                result, mask=nanmask, fill_value=-99999)
+                # for each row we can perform computation, and write the result
+                for row in range(rowincr):
+                    newrow = Buffer((result.shape[1],), mtype=ftype)
+                    newrow[:] = result[row, :]
+                    classification.put_row(newrow)
 
-            # return a copy of result, with masked values filled with a value
-            result = result.filled([nodata])
+                # same for probabilities
+                if class_probabilities is True:
+                    result_proba = self.estimator.predict_proba(flat_pixels)
 
-            # for each row we can perform computation, and write the result
-            for row in range(rowincr):
-                newrow = Buffer((result.shape[1],), mtype=ftype)
-                newrow[:] = result[row, :]
-                classification.put_row(newrow)
+                    for iclass in range(result_proba.shape[1]):
 
-            # same for probabilities
-            if class_probabilities is True:
-                result_proba = self.estimator.predict_proba(flat_pixels)
+                        result_proba_class = result_proba[:, iclass]
+                        result_proba_class = result_proba_class.reshape(
+                                                (rowincr, current.cols))
 
-                for iclass in range(result_proba.shape[1]):
+                        result_proba_class = np.ma.masked_array(
+                            result_proba_class, mask=nanmask, fill_value=np.nan)
 
-                    result_proba_class = result_proba[:, iclass]
-                    result_proba_class = result_proba_class.reshape(
-                                            (rowincr, current.cols))
+                        result_proba_class = result_proba_class.filled([np.nan])
 
-                    result_proba_class = np.ma.masked_array(
-                        result_proba_class, mask=nanmask, fill_value=np.nan)
+                        for row in range(rowincr):
+                            newrow = Buffer((
+                                        result_proba_class.shape[1],),
+                                        mtype='FCELL')
 
-                    result_proba_class = result_proba_class.filled([np.nan])
+                            newrow[:] = result_proba_class[row, :]
+                            prob[iclass].put_row(newrow)
+        finally:
+            # close all predictors
+            for i in range(n_features):
+                rasstack[i].close()
 
-                    for row in range(rowincr):
+            # close classification and mask maps
+            classification.close()
+            mask_raster.close()
 
-                        newrow = Buffer((
-                                    result_proba_class.shape[1],),
-                                    mtype='FCELL')
+            grass.run_command("g.remove", name='tmp_clfmask',
+                              flags="f", type="raster", quiet=True)
 
-                        newrow[:] = result_proba_class[row, :]
-                        prob[iclass].put_row(newrow)
+            # close all class probability maps
+            try:
+                for iclass in range(nclasses):
+                    prob[iclass].close()
+            except:
+                pass
 
-        # close all maps
-        for i in range(n_features):
-            rasstack[i].close()
 
-        classification.close()
-        mask_raster.close()
-
-        try:
-            for iclass in range(nclasses):
-                prob[iclass].close()
-        except:
-            pass
-
-
 def model_classifiers(estimator='LogisticRegression', random_state=None,
                       C=1, max_depth=None, max_features='auto',
                       min_samples_split=2, min_samples_leaf=1,



More information about the grass-commit mailing list