#!/usr/bin/env python
# Jacob Joseph

from JJutil import pgutils, pickler
import os, math, itertools, time, numpy, logging

class seqq(object):
    """Base class for queries about sequences in the DurandLab2
database.  Includes methods for querying sequence sets (See seq_set.py
for building sequence sets.)."""

    dbw = None
    cacheq = None
    pickledir = None
    
    def __init__(self, cacheq=False, pickledir=None, debug=True):
        self.log = logging.getLogger("%s" % self.__module__)

        #if debug is not None:
        #    self.log.warning("debug specified.  Perhaps remove it from the calling function, as it is unused.")
            
        self.cacheq = cacheq
        
        if pickledir != None: self.pickledir = pickledir
        else: self.pickledir = os.path.expandvars('$HOME/tmp/pickles')

        self.dbw = pgutils.dbwrap( logger = self.log)

        self.memory_cache = {}

    def close(self):
        self.dbw.close()

    def fetch_seq( self, seq_id):
        """Return the sequence string"""
        q = """SELECT sequence
FROM prot_seq
JOIN prot_seq_str USING (seq_str_id)
WHERE seq_id=%(seq_id)s"""

        return self.dbw.fetchsingle(q, locals())

    def fetch_seq_crc(self, seq_id):
        """Return the stored CRC of a sequence."""

        q = """SELECT crc
FROM prot_seq
JOIN prot_seq_str USING (seq_str_id)
WHERE seq_id=%(seq_id)s"""

        return self.dbw.fetchsingle(q, locals())

        
    def fetch_seq_acc( self, seq_id):
        """Return the (Uniprot) primary accession for a given sequence id."""
        q = """SELECT primary_acc
        FROM prot_seq_version
        WHERE seq_id=%(seq_id)s"""
        return self.dbw.fetchsingle(q, locals())

    def fetch_seq_by_acc(self, primary_acc, source_name, source_version=None):
        """Return the seq_id associated with a particular primary
        accession from a source_name.  If source_version is not
        specified, the most recent seq_id is returned."""

        q ="""SELECT seq_id FROM prot_seq_source
        JOIN prot_seq_source_ver USING (source_id)
        JOIN prot_seq_version USING (source_ver_id)
        WHERE source_name = %(source_name)s
        AND primary_acc = %(primary_acc)s"""

        if source_version is not None:
            q += "\nAND version = %(source_version)s"

        q += "\nORDER BY seq_id DESC LIMIT 1"""
        
        return self.dbw.fetchsingle( q, locals())

    def fetch_seq_id( self, br_id, primary_acc):
        """Return the sequence id for a given (Uniprot) primary
accession, from the set of sequences belonging to a blast run."""
        q = """SELECT seq_id
        FROM prot_seq_version
        JOIN prot_seq_set_member USING (seq_id)
        JOIN blast_run USING (set_id)
        WHERE br_id=%(br_id)s
        AND primary_acc=%(primary_acc)s"""

        return self.dbw.fetchsingle(q, locals())

    def fetch_seq_in_set( self, set_id, primary_acc, source_name=None,
                          source_version=None):
        """Return the sequence id for a given (Uniprot) primary
accession, from the set of sequences belonging to a blast run.
Specifying the source greatly increases speed, but is optional."""

        assert not ((source_name is None) ^ (source_version is None)), "Neither, or both of source_name and source_version should be specified."

        q = """SELECT seq_id
        FROM prot_seq_version
        JOIN prot_seq_set_member USING (seq_id)"""

        if source_name is not None:
            q += """\nJOIN prot_seq_source_ver using (source_ver_id)
            JOIN prot_seq_source using (source_id)"""

        q += """\nWHERE set_id=%(set_id)s
        AND primary_acc=%(primary_acc)s"""

        if source_name is not None:
            q += """\nAND source_name = %(source_name)s
            AND version = %(source_version)s"""

        return self.dbw.fetchsingle(q, locals())

    def fetch_seq_sp_xtra( self, seq_id):
        """Return (entry_name, gene_name, description) from the
sp_xtra Uniprot annotation table."""
        q = """SELECT entry_name, gene_name, description
        FROM sp_xtra
        WHERE seq_id=%(seq_id)s"""

        return self.dbw.fetchone(q, locals())


    def fetch_fa_descr( self, seq_id):
        """Return descr from fa_descr table"""

        q = """SELECT descr
        FROM fa_descr
        WHERE seq_id=%(seq_id)s"""

        return self.dbw.fetchsingle(q, locals())

    def build_database_urls(self, seq_id):
        """Return a list of URLs for as many databases as specified in
        fa_descr.  This only works if there is a string like
        'HUMAN|ENSEMBL:ENSG00000131068|UniProtKB:Q96PH6' in fa_descr."""
        assert False, "Unimplemented"
        
    def build_urls(self, seq_id):
        """Return a dictionary of urls for as many databases as
        specified in fa_descr"""

        db_urls = {
            'UniProtKB': "http://www.uniprot.org/uniprot/%s",
            'NCBI': "http://www.ncbi.nlm.nih.gov/sites/entrez?Db=protein&amp;Cmd=DetailsSearch&amp;Term=%s",
            'TAIR': "http://www.arabidopsis.org/servlets/TairObject?accession=%s",
            'ZFIN': "http://zfin.org/cgi-bin/webdriver?MIval=aa-markerview.apg&amp;OID=%s",
            'WB': "http://www.wormbase.org/db/seq/protein?name=%s",
            'dictyBase': "http://dictybase.org/db/cgi-bin/gene_page.pl?dictybaseid=%s",
            'FB': "http://flybase.org/reports/%s",
            'ECOLI': "http://biocyc.org/ECOLI/NEW-IMAGE?object=%s",
            #'ENTREZ': "http://www.ncbi.nlm.nih.gov/entrez/viewer.fcgi?db=protein&amp;id=%s",
            'ENTREZ': "http://www.ncbi.nlm.nih.gov/gene/%s",
            #'ENSEMBL': "http://www.ensembl.org/Homo_sapiens/Gene/Summary?g=%s",
            'ENSEMBL': "http://www.ensembl.org/Multi/Search/Results?species=all;idx=;q=%s",
            'MGI': "http://www.informatics.jax.org/searches/accession_report.cgi?id=%s",
            'RGD': "http://rgd.mcw.edu/tools/genes/genes_view.cgi?id=%s",
            'SGD': "http://www.yeastgenome.org/cgi-bin/locus.fpl?dbid=%s",
            'GeneDB_Spombe': "http://www.genedb.org/genedb/Search?submit=Search+for&amp;organism=pombe&amp;name=%s"
            }

        descr = self.fetch_fa_descr(seq_id)

        urls = {}
        if descr is not None:
            for identifier in descr.strip().split('|'):
                idarr = identifier.split(':')
            
                if len(idarr) != 2: continue

                db, acc = idarr

                if not db in db_urls: continue
                
                urls[db] = db_urls[ db] % acc

        if len(urls) == 0:
            self.log.warning("No URLs for seq_id %d", seq_id)            
        return urls

    def fetch_fa_annotate(self, seq_id):
        """Return (synonyms, descr) from fa_annotate table"""

        q = """SELECT synonyms, descr
        FROM fa_annotate
        WHERE seq_id=%(seq_id)s"""

        row = self.dbw.fetchone(q, locals())
        if row is None: row = ('', '')

        return row
        

    def fetch_set_info(self, set_id=None):
        """Fetch (name, description) for a sequence set"""
        q = """SELECT name, description
        FROM prot_seq_set
        WHERE set_id = %(set_id)s"""

        self.dbw.execute( q, locals())
        return self.dbw.fetchone()

    def fetch_set_size(self, set_id):
        """Fetch the number of sequences in a sequence set"""
        q = """SELECT count(*)
        FROM prot_seq_set_member
        WHERE set_id = %(set_id)s"""

        self.dbw.execute( q, locals())
        return self.dbw.fetchsingle()

    def fetch_seq_set( self, set_id=None, br_id=None, nc_id=None,
                       limit=None, offset=None):
        """Return a set of seq_ids for a particular blast run id,
        nc run id, or seq set id."""

        # fetch from an in-memory cache if possible
        cache_args = ('fetch_seq_set', set_id, br_id, nc_id, limit, offset)
        if cache_args in self.memory_cache:
            return self.memory_cache[ cache_args]

        # fetch from the pickle cache if possible
        if self.cacheq:
            
            retval = pickler.cachefn(pickledir=self.pickledir)
            if retval: return retval

        q_range = ""
        if limit is not None:
            q_range += " LIMIT %(limit)s"
            if offset is not None:
                q_range += " OFFSET %(offset)s"

        if set_id is not None:
            q = """SELECT seq_id
            FROM prot_seq_set_member
            WHERE set_id=%(set_id)s"""
            q += q_range
            self.dbw.execute( q, locals())
            
        elif br_id is not None:
            q = """SELECT seq_id
            FROM prot_seq_set_member
            JOIN blast_run USING (set_id)
            WHERE br_id=%(br_id)s"""
            q += q_range
            self.dbw.execute( q, locals())

        elif nc_id is not None:
            q = """SELECT seq_id
            FROM nc_run
            JOIN blast_run USING (br_id)
            JOIN prot_seq_set_member USING (set_id)
            WHERE nc_id=%(nc_id)s"""
            q += q_range
            self.dbw.execute( q, locals())

        else:
            assert False, "fetch_seq_set: must specify either set_id, br_id, or nc_id"
        
        seqs = self.dbw.fetchall()
        seqs = set([ int(a[0]) for a in seqs])

        self.memory_cache[ cache_args] = seqs

        if self.cacheq:
            pickler.cachefn( retval = seqs,
                             pickledir=self.pickledir)
        return seqs

    def fetch_taxonomy( self, seq_id):
        """Return (gb_tax_id, scientific name, [common names])"""

        q = """SELECT tax_id, name, common_names
        FROM prot_seq_taxonomy
        JOIN taxon USING (tax_id)
        WHERE seq_id=%(seq_id)s"""
        
        return self.dbw.fetchone(q, locals())
