[GRASS-SVN] r73783 - grass-addons/grass7/raster/r.mregression.series

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Dec 10 07:45:05 PST 2018


Author: DmitryKolesov
Date: 2018-12-10 07:45:05 -0800 (Mon, 10 Dec 2018)
New Revision: 73783

Modified:
   grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py
Log:
Add robust linear model

Modified: grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py
===================================================================
--- grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py	2018-12-09 17:31:10 UTC (rev 73782)
+++ grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py	2018-12-10 15:45:05 UTC (rev 73783)
@@ -18,7 +18,7 @@
 #############################################################################
 
 #%Module
-#% description: Calculates multiple regression between time series: Y = b1*X1 + ... + bn*Xn.
+#% description: Calculates multiple regression between time series: Y(t) = b1*X1(t) + ... + bn*Xn(t).
 #% overwrite: yes
 #% keyword: raster
 #% keyword: statistics
@@ -40,8 +40,18 @@
 #% required : yes
 #% multiple: no
 #%end
+#%option
+#% key: model
+#% type: string
+#% gisprompt: model 
+#% description: model type: ols (ordinary least squares), rlm (robust linear model)
+#% required: no
+#% answer: ols
+#% multiple: no
+#%end
 
 
+
 import os
 import sys
 
@@ -73,7 +83,7 @@
     return value
 
 
-def ols(y, x):
+def fit(y, x, model='ols'):
     """Ordinary least squares.
 
     :param x:   MxN matrix of data points
@@ -94,7 +104,13 @@
         # The system can't be solved
         return [FNULL for i in range(factor_count)]
 
-    model = sm.OLS(y, x)
+    if model == 'ols':
+        model = sm.OLS(y, x)
+    elif model == 'rlm':
+        model = sm.robust.robust_linear_model.RLM(y, x)
+    else:
+        raise NotImplementedError("Model %s doesn't implemented" % (model, ))
+
     try:
         results = model.fit()
         coefs = results.params
@@ -227,7 +243,7 @@
 
         return Y, X
 
-    def ols(self, overwrite=None):
+    def fit(self, model='ols', overwrite=None):
         try:
             reg = Region()
             self.open_rasters(overwrite=overwrite)
@@ -235,7 +251,7 @@
             for r in range(rows):
                 for c in range(cols):
                     Y, X = self.get_sample(r, c)
-                    coefs = ols(Y, X)
+                    coefs = fit(Y, X, model)
                     for i in range(self.factor_count):
                         b = self.b(i)
                         b.put(r, c, coefs[i])
@@ -246,6 +262,7 @@
 def main(options, flags):
     samples = options['samples']
     res_pref = options['result_prefix']
+    model_type = options['model']
     if not os.path.isfile(samples):
         sys.stderr.write("File '%s' doesn't exist.\n" % (samples, ))
         sys.exit(1)
@@ -253,7 +270,7 @@
     headers, outputs, inputs = get_sample_names(samples)
 
     model = DataModel(headers, outputs, inputs, res_pref)
-    model.ols(overwrite=grass.overwrite())
+    model.fit(model=model_type, overwrite=grass.overwrite())
     sys.exit(0)
 
 if __name__ == "__main__":



More information about the grass-commit mailing list