#!/usr/bin/env python

# Jacob Joseph
# 30 Dec 2009

# Wrapper to run PFAM models against sequences in the DurandLab2
# database

import getopt, os, sys, subprocess, tempfile, urllib
from collections import namedtuple
from DurandDB import seqq, pfamq
from JJutil import rate

class pfamscan:
    
    def __init__(self, set_id, tmpdir = None, datadir = None):
        self.set_id = set_id

        if tmpdir is None:
            self.tmpdir = os.path.expandvars("$HOME/tmp/pfam_fasta")
        else:
            self.tmpdir = tmpdir

        if not os.path.exists(self.tmpdir):
            os.mkdir( self.tmpdir)

        if datadir is None:
            self.datadir = os.path.expandvars("$HOME/tmp/pfam")
        else:
            self.datadir = datadir

        if not os.path.exists(self.datadir):
            os.mkdir( self.datadir)

        self.db = seqq.seqq()

        self.download_models()
        self.generate_hmm_binaries()

    def download_models(self):
        url_base = "ftp://ftp.sanger.ac.uk/pub/databases/Pfam/current_release/"
        files = ["Pfam-A.hmm.gz", "Pfam-A.hmm.dat.gz", "active_site.dat.gz"]

        for file in files:
            # check that the extracted file exists.  It should only if
            # the archive was successfully downloaded and extracted.
            if os.path.exists(os.path.join(self.datadir,file[:-3])): continue
            
            print "fetching ", url_base+file
            urllib.urlretrieve(url_base+file, filename=os.path.join(self.datadir,file))

            ret = subprocess.call(['gzip', '-d', os.path.join(self.datadir,file)])
            assert ret==0, "gzip decompress returned error '%s' for file '%s'" % (
                ret, os.path.join(self.datadir,file))

    def generate_hmm_binaries(self):
        files = ["Pfam-A.hmm"]
        for file in files:
            print "running hmmpress ", os.path.join(self.datadir,file)
            ret = subprocess.call(['hmmpress', os.path.join(self.datadir,file)])
            assert ret>0, "hmmpress returned error '%s' for file '%s'" % (
                ret, os.path.join(self.datadir,file))

    def write_fasta(self, seqs):
        """Write one or more sequence strings to a fasta file"""
        
        if type(seqs) is int: seqs = [seqs]

        assert type(seqs) is list, "seqs should be a list, not %s" % type(seqs)

        fd = tempfile.NamedTemporaryFile(dir=self.tmpdir, delete=False)
        fname = fd.name

        for seq_id in seqs:
            print >> fd, ">%d foo" % seq_id

            seq_str = self.db.fetch_seq(seq_id)
            print >> fd, seq_str

            print >> fd, ""

        fd.close()

        return fname

    def find_domains(self, fasta):
        fd = subprocess.Popen(
            "./pfam_scan.pl -fasta %(fasta)s -dir %(datadir)s | egrep -v \"^#\"" % {
            'fasta': fasta,
            'datadir': self.datadir},
            stdout=subprocess.PIPE,
            shell=True,
            cwd="PfamScan").stdout

        # <seq id> <alignment start> <alignment end> <envelope start> <envelope end> <hmm acc> <hmm name> <type> <hmm start> <hmm end> <hmm length> <bit score> <E-value> <significance> <clan>
        # 21212773      4    237      4    239 PF00244.13  14-3-3            Domain     1   234   236    391.6  6.6e-118   1 No_clan

        pfam_record = namedtuple( "pfam_record",
                                  ['seq_id', 'align_start', 'align_end',
                                   'envelope_start', 'envelope_end',
                                   'hmm_acc', 'hmm_name', 'hmm_type',
                                   'hmm_start', 'hmm_end', 'hmm_length',
                                   'bit_score', 'e_value', 'significance',
                                   'clan'])

        hits = []
        for l in fd.readlines():
            l = l.strip()
            if len(l) == 0: continue
            larr = l.strip().split()

            hits.append( pfam_record( *larr))
            
        fd.close()
        return hits

    def insert_domains(self, pfam_records):
        """Insert a list of pfam records into the database"""
        
        i = """INSERT into prot_seq_pfam
        (seq_id, align_start, align_end, envelope_start, envelope_end, hmm_acc,
        hmm_name, hmm_type, hmm_start, hmm_end, hmm_length, bit_score, e_value,
        significance, clan)
        VALUES
        (%(seq_id)s, %(align_start)s, %(align_end)s, %(envelope_start)s, %(envelope_end)s,
        %(hmm_acc)s, %(hmm_name)s, %(hmm_type)s, %(hmm_start)s, %(hmm_end)s, %(hmm_length)s,
        %(bit_score)s, %(e_value)s, %(significance)s, %(clan)s)"""

        for rec in pfam_records:
            self.db.dbw.execute(i, rec._asdict())

        self.db.dbw.commit()

    def fetch_complete_seqs(self):
        q = """SELECT DISTINCT seq_id
        FROM prot_seq_pfam
        JOIN prot_seq_set_member USING (seq_id)
        WHERE set_id = %(set_id)s"""

        return set( self.db.dbw.fetchcolumn( q, {'set_id': self.set_id}))


    def fetch_seqs_hack(self):
        q = """SELECT seq_id
        FROM prot_seq_set_member
        WHERE set_id=106
        AND seq_id NOT IN (
          SELECT seq_id
          FROM prot_seq_set_member
          WHERE set_id=108)"""

        return set(self.db.dbw.fetchcolumn(q))
          
    def process_set(self, chunksize=100):
        seqs = set(self.db.fetch_seq_set( set_id = self.set_id))
        #seqs = self.fetch_seqs_hack()

        seqs -= self.fetch_complete_seqs()

        seqrate = rate.rate( totcount=len(seqs))

        print >> sys.stderr, "Processing %d sequences..." % len(seqs)
        
        while len(seqs) > 0:
            seq_chunk = [ seqs.pop() for i in range(chunksize) if len(seqs) > 0]
            lasti = seqrate.count
            seqrate.increment( len(seq_chunk))

            fname = self.write_fasta( seq_chunk)
            pfam_records = self.find_domains( fname)
            self.insert_domains( pfam_records)
            
            if seqrate.count % 100 <= lasti % 100 or seqrate.count - lasti >= 100:
                print seqrate

            os.remove(fname)

        return
        
def usage():
    s = """pfam_run.py

Usage:

    pfam_run.py -s <set_id> [options]

    -s, --set_id <integer>
          Calculate pfam domains for all sequences in this DurandDB
          seq_set.

     -h, --help
          Print this help message.
"""
    return s

def main():
    try:
        opts, args = getopt.getopt(sys.argv[1:], "s:h", ["set_id=", "help"])

    except getopt.GetoptError:
        print usage()
        raise

    printhelp = False
    set_id = None
    for o, a in opts:
        if o in ("-h", "--help"):
            printhelp=True
        if o in ('-s', '--set_id'):
            set_id = int(a)

    if printhelp or set_id is None:
        print usage()
        sys.exit()

    pf = pfamscan( set_id)
    pf.process_set()

    return pf

if __name__ == "__main__":
    pf = main()
    
    
