[GRASS-SVN] r58628 - grass/trunk/lib/python/pygrass/vector

svn_grass at osgeo.org svn_grass at osgeo.org
Mon Jan 6 10:09:29 PST 2014


Author: zarch
Date: 2014-01-06 10:09:28 -0800 (Mon, 06 Jan 2014)
New Revision: 58628

Modified:
   grass/trunk/lib/python/pygrass/vector/table.py
Log:
Add additional parameters to the Table methods

Modified: grass/trunk/lib/python/pygrass/vector/table.py
===================================================================
--- grass/trunk/lib/python/pygrass/vector/table.py	2014-01-06 17:54:24 UTC (rev 58627)
+++ grass/trunk/lib/python/pygrass/vector/table.py	2014-01-06 18:09:28 UTC (rev 58628)
@@ -7,8 +7,10 @@
 
 
 """
+import os
 import ctypes
 import numpy as np
+from sqlite3 import OperationalError
 
 try:
     from collections import OrderedDict
@@ -18,11 +20,11 @@
 import grass.lib.vector as libvect
 from grass.pygrass.gis import Mapset
 from grass.pygrass.errors import DBError
+from grass.pygrass.functions import table_exist
 from grass.script.db import db_table_in_vector
 from grass.script.core import warning
-import sys
+
 import sql
-import os
 
 
 DRIVERS = ('sqlite', 'pg')
@@ -360,10 +362,26 @@
 
         ..
         """
-        valid_type = ('DOUBLE PRECISION', 'INT', 'DATE')
-        if 'VARCHAR' in col_type or col_type.upper() not in valid_type:
-            str_err = "Type is not supported, supported types are: %s"
-            raise TypeError(str_err % ", ".join(valid_type))
+        def check_col(col_type):
+            valid_type = ('DOUBLE PRECISION', 'DOUBLE', 'INT', 'INTEGER',
+                          'DATE')
+            if 'VARCHAR' in col_type or col_type.upper() not in valid_type:
+                str_err = "Type is not supported, supported types are: %s"
+                raise TypeError(str_err % ", ".join(valid_type))
+
+        if isinstance(col_name, str):
+            check_col(col_type)
+        else:
+            if len(col_name) == len(col_type):
+                cvars = []
+                for name, ctype in zip(col_name, col_type):
+                    check_col(ctype)
+                    cvars.append('%s %s' % (name, ctype))
+                col_name = ''
+                col_type = ','.join(cvars)
+            else:
+                str_err = "The lenghts of the columns are different:\n%r\n%r"
+                raise TypeError(str_err % (col_name, col_type))
         cur = self.conn.cursor()
         cur.execute(sql.ADD_COL.format(tname=self.tname,
                                        cname=col_name,
@@ -604,7 +622,8 @@
     driver = property(fget=_get_driver, fset=_set_driver)
 
     def __init__(self, layer=1, name=None, table=None, key='cat',
-                 database='$GISDBASE/$LOCATION_NAME/PERMANENT/sqlite/sqlite.db',
+                 database='$GISDBASE/$LOCATION_NAME/'
+                          '$MAPSET/sqlite/sqlite.db',
                  driver='sqlite', c_fieldinfo=None):
         if c_fieldinfo is not None:
             self.c_fieldinfo = c_fieldinfo
@@ -620,6 +639,13 @@
     def __repr__(self):
         return "Link(%d, %s, %s)" % (self.layer, self.name, self.driver)
 
+    def __eq__(self, link):
+        attrs = ['layer', 'name', 'table_name', 'key', 'driver']
+        for attr in attrs:
+            if getattr(self, attr) != getattr(link, attr):
+                return False
+        return True
+
     def connection(self):
         """Return a connection object. ::
 
@@ -654,6 +680,7 @@
         elif self.driver == 'pg':
             try:
                 import psycopg2
+                psycopg2.paramstyle = 'qmark'
                 db = ' '.join(self.database.split(','))
                 return psycopg2.connect(db)
             except ImportError:
@@ -744,16 +771,24 @@
         return "DBlinks(%r)" % [link for link in self.__iter__()]
 
     def by_index(self, indx):
+        """Return a Link object by index"""
+        nlinks = self.num_dblinks()
+        if nlinks == 0:
+            raise IndexError
+        if indx < 0:
+            indx += nlinks
+        if indx > nlinks:
+            raise IndexError
         c_fieldinfo = libvect.Vect_get_dblink(self.c_mapinfo, indx)
-        return Link(c_fieldinfo=c_fieldinfo)
+        return Link(c_fieldinfo=c_fieldinfo) if c_fieldinfo else None
 
     def by_layer(self, layer):
         c_fieldinfo = libvect.Vect_get_field(self.c_mapinfo, layer)
-        return Link(c_fieldinfo=c_fieldinfo)
+        return Link(c_fieldinfo=c_fieldinfo) if c_fieldinfo else None
 
     def by_name(self, name):
         c_fieldinfo = libvect.Vect_get_field_by_name(self.c_mapinfo, name)
-        return Link(c_fieldinfo=c_fieldinfo)
+        return Link(c_fieldinfo=c_fieldinfo) if c_fieldinfo else None
 
     def num_dblinks(self):
         return libvect.Vect_get_num_dblinks(self.c_mapinfo)
@@ -839,7 +874,7 @@
         cur = self.conn.cursor()
         cur.execute(sql.RENAME_TAB.format(old_name=old_name,
                                           new_name=new_name))
-        cur.commit()
+        self.conn.commit()
         cur.close()
 
     name = property(fget=_get_name, fset=_set_name)
@@ -875,26 +910,21 @@
         """Return the number of rows"""
         return self.n_rows()
 
-    def drop(self, name=None, force=False):
+    def drop(self, cursor=None, force=False):
         """Private method to drop table from database"""
-        if name:
-            name = name
-        else:
-            name = self.name
-        used = db_table_in_vector(name)
-        if len(used) > 0:
-            warning(_("Deleting table <%s> which is attached to following map(s):") % name)
-            for vect in used:
-                warning("%s" % vect)
-            if not force:
-                warning(_("You must use the force flag to actually remove it. Exiting."))
-                sys.exit(0)
+
+        cur = cursor if cursor else self.conn.cursor()
+        if self.exist(cursor=cur):
+            used = db_table_in_vector(self.name)
+            if len(used) > 0 and not force:
+                print(_("Deleting table <%s> which is attached"
+                        " to following map(s):") % self.name)
+                for vect in used:
+                    warning("%s" % vect)
+                print(_("You must use the force flag to actually"
+                        " remove it. Exiting."))
             else:
-                cur = self.conn.cursor()
-                cur.execute(sql.DROP_TAB.format(tname=name))
-        else:
-            cur = self.conn.cursor()
-            cur.execute(sql.DROP_TAB.format(tname=name))
+                cur.execute(sql.DROP_TAB.format(tname=self.name))
 
     def n_rows(self):
         """Return the number of rows
@@ -912,7 +942,7 @@
         cur.close()
         return number
 
-    def execute(self, sql_code=None):
+    def execute(self, sql_code=None, cursor=None, many=False, values=None):
         """Execute SQL code from a given string or build with filters and
         return a cursor object. ::
 
@@ -930,36 +960,50 @@
         """
         try:
             sqlc = sql_code if sql_code else self.filters.get_sql()
-            cur = self.conn.cursor()
-            #if hasattr(self.cur, 'executescript'):
-            #    return cur.executescript(sqlc)
+            cur = cursor if cursor else self.conn.cursor()
+            if many and values:
+                return cur.executemany(sqlc, values)
             return cur.execute(sqlc)
         except:
             raise ValueError("The SQL is not correct:\n%r" % sqlc)
 
-    def insert(self, values, many=False):
+    def exist(self, cursor=None):
+        """Return True if the table already exist in the DB, False otherwise"""
+        cur = cursor if cursor else self.conn.cursor()
+        return table_exist(cur, self.name)
+
+    def insert(self, values, cursor=None, many=False):
         """Insert a new row"""
-        cur = self.conn.cursor()
+        cur = cursor if cursor else self.conn.cursor()
         if many:
             return cur.executemany(self.columns.insert_str, values)
         return cur.execute(self.columns.insert_str, values)
 
-    def update(self, key, values):
+    def update(self, key, values, cursor=None, many=False):
         """Update a column for each row"""
-        cur = self.conn.cursor()
+        cur = cursor if cursor else self.conn.cursor()
         vals = list(values)
         vals.append(key)
         return cur.execute(self.columns.update_str, vals)
 
-    def create(self, cols, name=None):
+    def create(self, cols, name=None, overwrite=False, cursor=None):
         """Create a new table"""
-        cur = self.conn.cursor()
+        cur = cursor if cursor else self.conn.cursor()
         coldef = ',\n'.join(['%s %s' % col for col in cols])
         if name:
             newname = name
         else:
             newname = self.name
-        cur.execute(sql.CREATE_TAB.format(tname=newname, coldef=coldef))
-        self.conn.commit()
+        try:
+            cur.execute(sql.CREATE_TAB.format(tname=newname, coldef=coldef))
+            self.conn.commit()
+        except OperationalError:  # OperationalError
+            if overwrite:
+                self.drop(force=True)
+                cur.execute(sql.CREATE_TAB.format(tname=newname,
+                                                  coldef=coldef))
+                self.conn.commit()
+            else:
+                print "The table: %s already exist." % self.name
         cur.close()
         self.columns.update_odict()



More information about the grass-commit mailing list