#!/usr/bin/env python

# Jacob Joseph
# 2008 July 1
# Class to perform clustering using the MCL algorithm

from DurandDB import blastq
from IPython.Shell import IPShellEmbed
import os, math, copy

class mcl(object):

    def __init__(self, br_id = None, nc_id = None,
                 stype='e_value', self_hits = True,
                 score_threshold = None,
                 set_id_filter = None,
                 param_list = None):

        self.br_id = br_id
        self.nc_id = nc_id
        self.stype = stype
        self.score_threshold = score_threshold
        self.set_id_filter = set_id_filter
        self.self_hits = self_hits
        self.param_list = copy.copy(param_list)
        if self.param_list is None: self.param_list = []

        self.bq = blastq.blastq()

        self.clusters = {}

    def cluster(self, verbose=True):
        """Perform clustering with a call to the 'mcl' program"""

        if verbose: print "*** Running MCL ***"

        (stdin, stdout) = os.popen2("mcl - --abc -o - " + self.param_list, 'w')

        # similarities are needed.  Take a -log of e_value
        # correct_e_value takes a -log10(e_value)
        hits = self.bq.fetch_hits_iter( br_id = self.br_id,
                                        nc_id = self.nc_id,
                                        stype = self.stype,
                                        set_id_filter = self.set_id_filter,
                                        thresh = self.score_threshold,
                                        correct_e_value = True,
                                        self_hits = self.self_hits)
        
        for (seq_id_0, seq_id_1, score) in hits:
            print >> stdin, "%d\t%d\t%f" % (seq_id_0, seq_id_1, score)

        if verbose: print "\n*** All pairwise scores piped ***\n"

        retval = stdin.close()
        assert retval is None, "MCL retval: %d" % retval
        
        for (i, l) in enumerate(stdout.readlines()):
            seq_ids = l.strip().split()
            seq_ids = map(int, seq_ids)
            self.clusters[i] = seq_ids

        stdout.close()
         
        return
        

    def param_cl(self):
        """Build command line from the given parameter dictionary"""
        short_cl = { 'main_inflation': '-I %f',
                     'initial_inflation': '-i %f',
                     'main_loop_length': '-L %d',
                     'initial_loop_length': '-l %d',
                     'warn_factor': '-warn-factor %d',
                     'warn_pct': '-warn-pct %d',
                     'adapt_factor': '-af %f',
                     'adapt_exponent': '-ae %f',
                     'jury_window_index': '-nj %d',
                     'y_window_index': '-ny %d',
                     'x_window_index': '-nx %d',
                     'recovery_pct': '-pct %d',
                     'recovery_num': '-R %d',
                     'selection_num': '-S %d',
                     'prune_num': '-P %d',
                     'threads': '-te %d',
                     'scheme': '-scheme %d'}

        cl_str = ""
        for (key, val) in self.param_dict.items():
            assert key in short_cl, "MCL parameter (%s) unknown" % key

            cl_str += short_cl[key] % val
            cl_str += " "

        return cl_str
    
    def get_cluster_map(self):
        """Return a dictionary of cluster assignments"""
        return self.clusters
