[GRASS-SVN] r73783 - grass-addons/grass7/raster/r.mregression.series
svn_grass at osgeo.org
svn_grass at osgeo.org
Mon Dec 10 07:45:05 PST 2018
Author: DmitryKolesov
Date: 2018-12-10 07:45:05 -0800 (Mon, 10 Dec 2018)
New Revision: 73783
Modified:
grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py
Log:
Add robust linear model
Modified: grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py
===================================================================
--- grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py 2018-12-09 17:31:10 UTC (rev 73782)
+++ grass-addons/grass7/raster/r.mregression.series/r.mregression.series.py 2018-12-10 15:45:05 UTC (rev 73783)
@@ -18,7 +18,7 @@
#############################################################################
#%Module
-#% description: Calculates multiple regression between time series: Y = b1*X1 + ... + bn*Xn.
+#% description: Calculates multiple regression between time series: Y(t) = b1*X1(t) + ... + bn*Xn(t).
#% overwrite: yes
#% keyword: raster
#% keyword: statistics
@@ -40,8 +40,18 @@
#% required : yes
#% multiple: no
#%end
+#%option
+#% key: model
+#% type: string
+#% gisprompt: model
+#% description: model type: ols (ordinary least squares), rlm (robust linear model)
+#% required: no
+#% answer: ols
+#% multiple: no
+#%end
+
import os
import sys
@@ -73,7 +83,7 @@
return value
-def ols(y, x):
+def fit(y, x, model='ols'):
"""Ordinary least squares.
:param x: MxN matrix of data points
@@ -94,7 +104,13 @@
# The system can't be solved
return [FNULL for i in range(factor_count)]
- model = sm.OLS(y, x)
+ if model == 'ols':
+ model = sm.OLS(y, x)
+ elif model == 'rlm':
+ model = sm.robust.robust_linear_model.RLM(y, x)
+ else:
+ raise NotImplementedError("Model %s doesn't implemented" % (model, ))
+
try:
results = model.fit()
coefs = results.params
@@ -227,7 +243,7 @@
return Y, X
- def ols(self, overwrite=None):
+ def fit(self, model='ols', overwrite=None):
try:
reg = Region()
self.open_rasters(overwrite=overwrite)
@@ -235,7 +251,7 @@
for r in range(rows):
for c in range(cols):
Y, X = self.get_sample(r, c)
- coefs = ols(Y, X)
+ coefs = fit(Y, X, model)
for i in range(self.factor_count):
b = self.b(i)
b.put(r, c, coefs[i])
@@ -246,6 +262,7 @@
def main(options, flags):
samples = options['samples']
res_pref = options['result_prefix']
+ model_type = options['model']
if not os.path.isfile(samples):
sys.stderr.write("File '%s' doesn't exist.\n" % (samples, ))
sys.exit(1)
@@ -253,7 +270,7 @@
headers, outputs, inputs = get_sample_names(samples)
model = DataModel(headers, outputs, inputs, res_pref)
- model.ols(overwrite=grass.overwrite())
+ model.fit(model=model_type, overwrite=grass.overwrite())
sys.exit(0)
if __name__ == "__main__":
More information about the grass-commit
mailing list