#!/usr/bin/env python

# Jacob Joseph
# 16 May 2009

# Postgres database tools

import psycopg2, getpass, tempfile, os, sys, logging

# support SQL_IN, and other additions
import psycopg2.extensions

#def quote( s):
#    return PgSQL.PgQuoteString( s)

# http://www.python.org/dev/peps/pep-0342/
def consumer(func):
            def wrapper(*args,**kw):
                gen = func(*args, **kw)
                gen.next()
                return gen
            wrapper.__name__ = func.__name__
            wrapper.__dict__ = func.__dict__
            wrapper.__doc__  = func.__doc__
            return wrapper


class dbwrap:
    # the variables PGUSER, PGPASSWORD, PGDATABASE, and PGHOST should
    # be set if these are not supplied as arguments, here
    
    def __init__( self, dbhost=None, dbport=None, dbname=None, dbuser=None,
                  dbpasswd=None, debug=False, logger=None):


        self.log = logging.getLogger("%s%s" % (logger.name+':' if logger is not None else '',
                                               self.__module__))
        # FIXME: We probably shouldn't adjust the root logger
        if debug:
            rootlogger = logging.getLogger()
            rootlogger.setLevel(min(logging.INFO, rootlogger.level))

        if len(self.log.handlers) == 0:
            self.log.addHandler(logging.StreamHandler())

        self.dsn = {}
        self.dsn['host'] = dbhost
        self.dsn['port'] = dbport
        self.dsn['dbname'] = dbname
        self.dsn['user'] = dbuser
        self.dsn['password'] = dbpasswd

        self.log.debug("Database DSN: %s", self.dsn)

        self.conn = None
        self.curs = None
        self.ncurs = None
        
        self.ready = 0

    def open( self, get_pass=True):
        self.log.info("Opening PgSQL server connection (%s)", self.dsn['host'])
        
        dsn_string = ""
        for k,v in self.dsn.items():
            if v is None: continue

            dsn_string += "%s=%s " % (k,v)

        try:
            self.conn = psycopg2.connect( dsn_string)
        except psycopg2.OperationalError, m:
            if get_pass and str(m).find("password") >= 0:
                user = raw_input("Database username [%s]:" % getpass.getuser())
                if len(user) == 0: user = getpass.getuser()
                self.dsn['user'] = user
                self.dsn['password'] = getpass.getpass("Database password:")
                self.open(get_pass=False)
            else:
                raise

        #self.conn.autocommit = 0
        # a cursor must be named to support server-side queries
        # howerver, named cursors may only be used once
        self.curs = self.conn.cursor()
        self.ready = 1

    def named_curs_open(self):
        """Overwrite the existing named cursor.  Should be called
        before executing."""
        if self.ncurs is not None:
            try:
                self.ncurs.close()
            except psycopg2.ProgrammingError, e:
                if e.message.find("named cursor isn't valid anymore") >= 0:
                    pass
                else:
                    raise
                
            self.conn.commit()        # the current transaction must end
        self.ncurs = self.conn.cursor("named_cursor")

    def close( self):
        self.conn.close()
        self.ready = 0
        self.conn  = None
        self.curs  = None

    def commit( self):
        # this will fail if the connection is not open, but what else
        # can be done?
        if self.ready: 
            self.conn.commit()

    def execute( self, s, args=None, stream=False):
        """If stream is true, use a new named cursor"""
        if not self.ready: self.open()
        try:
            if stream:
                self.named_curs_open()
                self.ncurs.execute(s, args)
            else:
                self.curs.execute( s, args)
        except:
            self.log.error("Execute: Error with the follwing statement:\n\n%s",
                           self.curs.mogrify(s, args))
            raise
        return

    def fetchall_iter(self, q = None, args=None, batch_size=1000):
        if q is not None:
            self.execute( q, args, stream=True)

        if not self.ready:
            m = "fetchall: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m
            
        while True:
            results = self.ncurs.fetchmany(batch_size)
            if not results:
                break
            for result in results:
                yield result
        

    def fetchall( self, q = None, args=None):
        """Return an iterator of rows"""
        if q is not None: self.execute( q, args)
        
        #self.ping()
        if self.ready:
            return self.curs.fetchall()
        else:
            m = "fetchall: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def fetchone( self, q = None, args=None):
        "Fetch one row."
        if q is not None: self.execute( q, args)
        
        #self.ping()
        if self.ready:
            return self.curs.fetchone()
        else:
            m = "fetchone: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def fetchone_d( self, q = None, args=None):
        "Fetch one row, as a dictionary of field names."
        if q is not None: self.execute( q, args)        

        #self.ping()
        if self.ready:
            row = self.curs.fetchone()
            desc = self.curs.description
            field_names = [ field[0] for field in desc]
            if row is None:
                return None

            ret = {}
            for (field, val) in zip(field_names, row):
                ret[field] = val

            return ret
        else:
            m = "fetchone_d: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def fetchmany( self, q = None, args=None):
        if q is not None: self.execute( q, args)
        
        #self.ping()
        if self.ready:
            return self.curs.fetchmany()
        else:
            m = "fetchmany: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def fetchsingle( self, q = None, args = None):
        """Fetch a single value. Previous execute must only return a
single row,column"""
        if q is not None: self.execute( q, args)
        
        #self.ping()
        if self.ready:
            tup = self.curs.fetchone()
            if not tup: return None
            if len(tup) == 1: return tup[0]
            else:
                m = "fetchsingle: more than one column returned"
                self.log.error(m)
                raise m
        else:
            m = "fetchsingle: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def lastrowid(self):
        assert False, "lastrowid() Unimplemented.  Use e.g., INSERT INTO ... RETURNING seq_id"

    def fetchcolumn( self, q = None, args = None):
        """Fetch a single column, returned as an array"""
        if q is not None: self.execute( q, args)
        
        #self.ping()
        if self.ready:
            ret = self.curs.fetchall()
            return [ a[0] for a in ret ]
        else:
            m = "fetchcolumn: database is not open --\n" \
                + "  was last op a query as it should have been?"
            self.log.error(m)
            raise m

    def insert_dict( self, table, rows):
        """Insert a single row, from a dictionary with keys
        corresponding to column names.  If 'rows' is a dictionary,
        insert one row.  'rows' may also be a list of dictionaries,
        each of which will be inserted simultaneously (all rows must
        specify the same set of columns)."""

        if type(rows) is dict:
            cols = rows.keys()
        elif type(rows) is list:
            cols = rows[0].keys()
        else:
            assert False, "Unknown rows type: '%s'" % type(rows)

        i_base = "INSERT INTO %s (%s)\n VALUES (%s)"

        i_cols = reduce( lambda x, y: x + ', ' + y, cols)
        #i_val
        assert False, "Not fully implemented"


    # FIXME: Streaming queries were used with mysql.  Are these needed
    # with postgres?
    def ssexecute(self, s):
        self.execute(s)

    def ssfetchmany(self, count):
        return self.fetchmany()


    @consumer
    def copy_from(self, table, columns, batch_size=100000, format_str=None):
        """Returns a generator that is used to send data by COPY back
        to the database.  Repeatedly call .send() with a tuple of row
        data.  This will be cached to a temporary file, and executed
        on the server when batch_size rows are queued, or .close() is
        called.  BE SURE TO CALL .close()"""

        if not self.ready: self.open()

        if format_str is None:
            format_str = ""
            for col in columns: format_str += "%r\t"
            format_str = format_str[:-1]
            
        tmp = tempfile.TemporaryFile()

        lines = 0
        try:
            while True:
                row_tup = yield
                tmp.write(format_str % row_tup + '\n')
                lines += 1

                if lines >= batch_size:
                    tmp.seek(0)
                    self.curs.copy_from( tmp, table, columns=columns)
                    self.commit()
                    
                    lines = 0
                    tmp.close()
                    tmp = tempfile.TemporaryFile()

        except GeneratorExit:
            #print >> sys.stderr, "received close()"
            tmp.seek(0)
            self.curs.copy_from( tmp, table, columns=columns)
            self.commit()
        
            tmp.close()
            
        return

    @consumer
    def copy_from_innertry(self, table, columns, batch_size=100000, format_str=None):
        """Returns a generator that is used to send data by COPY back
        to the database.  Repeatedly call .send() with a tuple of row
        data.  This will be cached to a temporary file, and executed
        on the server when batch_size rows are queued, or .close() is
        called.  BE SURE TO CALL .close()"""

        # This is somewhat slower

        if not self.ready: self.open()

        if format_str is None:
            format_str = ""
            for col in columns: format_str += "%r\t"
            format_str = format_str[:-1]
            
        tmp = tempfile.TemporaryFile()

        lines = 0
        while True:
                try: 
                    row_tup = yield
                    tmp.write(format_str % row_tup + '\n')
                    lines += 1
                    
                except GeneratorExit:
                    #print "received close()"
                    tmp.seek(0)
                    self.curs.copy_from( tmp, table, columns=columns)
                    self.commit()
                    
                    tmp.close()

                    return
                except:
                    # finish the database transaction?

                    raise

                else:
                    if lines >= batch_size:
                        tmp.seek(0)
                        self.curs.copy_from( tmp, table, columns=columns)
                        self.commit()
                        
                        lines = 0
                        tmp.close()
                        tmp = tempfile.TemporaryFile()



if __name__ == "__main__":
    dbw = dbwrap()
    dbw.open()
    
                  
