#!/usr/bin/env python

# Jacob Joseph
# 4 June 2010
# Perform clustering using SPICI, and insert into the database


import os, tempfile, subprocess, copy, math
from DurandDB import blastq

class spici(object):

    def __init__(self, stype, run_id, set_id_filter=None,
                 score_threshold = None,
                 param_list = None, verbose=True,
                 spici_path=None):
        
        self.stype = stype
        self.run_id = run_id
        self.set_id_filter = set_id_filter
        self.thresh = score_threshold
        self.verbose = verbose
        self.spici_path = spici_path

        self.param_list = copy.copy(param_list)
        if self.param_list is None: self.param_list = []

        self.bq = blastq.blastq()

        self.clusters = {}

    def write_scores(self):
        fname = os.path.expandvars("$HOME/tmp/spici_input_%s_%s_%s_%s" % (
            self.stype, self.run_id, self.set_id_filter, self.thresh))

        if os.path.exists(fname):
            return fname

        fd = open(fname, 'w')
        
        scores = self.bq.fetch_hits_direct(
            stype=self.stype,
            br_id = self.run_id,
            nc_id = self.run_id,
            symmetric = True,
            self_hits = False,
            set_id_filter = self.set_id_filter,
            thresh = self.thresh,
            correct_e_value = True) # gives similarities

        if self.stype == 'bit_score':
            max_bit_score = math.log10(self.bq.dbw.fetchsingle("""select max(bit_score)
            from (select explode(bit_score) as bit_score
                  from blast_hit_symmetric_arr
                  where br_id=%s) as foo""", (self.run_id,)))

        if self.verbose: print "*** Writing score file: %s ***" % fd.name
        
        for seq_id_0, seq_id_1, score in scores:
            # Just use half the matrix, checking that seq_id_0 < seq_id_1?
            if seq_id_0 > seq_id_1:
                continue

            # ignore old NC negative scores.  They were spurious
            if score < 0:
                continue

            if self.stype == 'bit_score':
                score = math.log10(score) / max_bit_score
            
            fd.write("%d\t%d\t%g\n" % (seq_id_0, seq_id_1, score))

        fd.close()

        return fname

    def run_spici(self, input_fname, output_fname):

        self.param_list.append("-i")
        self.param_list.append(input_fname)
        self.param_list.append("-o")
        self.param_list.append(output_fname)

        if self.verbose: print "*** Running SPICI ***"
        
        subprocess.check_call([os.path.join(self.spici_path, 'spici')]
                              + self.param_list)
        return

    def parse_output(self, output_fname):
        # one line per cluster, listing the constituent sequence IDs
        if self.verbose: print "*** Parsing clusters: %s ***" % output_fname

        fd = open(output_fname)

        i = 0 # arbitrary cluster identifier, by line number
        for l in fd:
            self.clusters[i] = [int(seq_id) for seq_id in l.strip().split()]
            i += 1
        return

    def cluster(self):
        """Perform clustering with a call to the spici program. Populates self.clusters"""


        input_fname = self.write_scores()

        output_tmpfile = tempfile.NamedTemporaryFile(dir=os.path.expandvars('$HOME/tmp'))
        # close the file device.  We just want a safe filename to use.
        output_tmpfile.file.close()  

        self.run_spici(input_fname, output_tmpfile.name)

        self.parse_output(output_tmpfile.name)

        # remove temporary files
        output_tmpfile.close()

        return

    def get_cluster_map(self):
        """Return a dictionary of cluster assignments"""
        return self.clusters
