#!/usr/bin/env python

# Jacob Joseph
# 19 June 2010

# Basis for exploring a clustering in varying ways, particularly as
# HTML.  For example, for hierarchical clustering, facilitate
# generation of a "cluster browser", annotated by network statistics,
# and family annotations.

from JJcluster.cluster_sql import flatcluster, hcluster
from DurandDB import familyq, blastq
from JJnetstat.stathelper import nxstat
from JJutil import pickler, rate
import logging, os, numpy
import networkx as NX

class describe(object):

    def __init__(self, cluster_run_id,
                 clustering_type,
                 cacheq = False,
                 pickledir = None,
                 family_set_name = None,
                 debug_level = None, # default is warning and above,
                                        # or whatever the root logger
                                        # was configured to have
                                        # elsewhere
                 ):
        self.cacheq=cacheq
        if pickledir != None: self.pickledir = pickledir
        else: self.pickledir = os.path.expandvars('$HOME/tmp/pickles')

        self.log = logging.getLogger('JJcluster.describe')

        if debug_level is not None:
            logging.getLogger('').setLevel(debug_level)

        if clustering_type == 'hierarchical':
            self.CR = hcluster(cluster_run_id = cluster_run_id,
                               cacheq = self.cacheq)
        else:
            self.CR = flatcluster(cluster_run_id = cluster_run_id,
                                  cacheq = self.cacheq)

        # reading queries from the disk is often slower than the database
        if family_set_name is not None:
            self.fq = familyq.familyq(family_set_name = family_set_name, cacheq=False)
        else:
            # FIXME: this will break the html building, where they depend on family functions
            self.fq = blastq.blastq(cacheq=False)


    def cluster_conductance(self, cluster_id):
        """Calculate the conductance of a cluster."""

        cluster_seq_set = self.CR.fetch_cluster(cluster_id)
        full_seq_set = self.fq.fetch_seq_set(br_id=self.CR.br_id,
                                             nc_id=self.CR.nc_id,
                                             set_id=self.CR.set_id_filter)

        self.log.debug("Calculating conductance: %d (size: %d)" % (
            cluster_id,len(cluster_seq_set)))

        boundary_sum = 0.0
        a_s = 0.0
        a_not_s = 0.0
        # count edge boundary
        for query_seq in cluster_seq_set:
            d = self.fq.fetch_hits_dictarray(
                stype=self.CR.stype, br_id=self.CR.br_id, nc_id=self.CR.nc_id,
                symmetric=self.CR.symmetric,
                query_seq_id = query_seq, 
                seq_id_0_set_id = self.CR.set_id_filter, # potentially faster if specified
                set_id_filter = self.CR.set_id_filter,
                self_hits=False)
            
            (hit_list, score_list) = d[query_seq]
            a_s += numpy.sum(score_list)
            
            for i,seq_id_1 in enumerate(hit_list):
                if seq_id_1 not in cluster_seq_set:
                    boundary_sum += score_list[i]

        # FIXME: what do we do about having a symmetric matrix?
        for query_seq in (full_seq_set - cluster_seq_set):
            d = self.fq.fetch_hits_dictarray(
                stype=self.CR.stype, br_id=self.CR.br_id, nc_id=self.CR.nc_id,
                symmetric=self.CR.symmetric,
                query_seq_id = query_seq, 
                seq_id_0_set_id = self.CR.set_id_filter, # potentially faster if specified
                set_id_filter = self.CR.set_id_filter,
                self_hits=False)

            # a sequence may have no hits
            if not query_seq in d: continue
            
            (hit_list, score_list) = d[query_seq]
            a_not_s += numpy.sum(score_list)
            if a_not_s > a_s: break

        conductance = (boundary_sum / a_s) if a_s > 0 else -1

        self.log.debug("Conductance: %g, boundary: %g, a_s: %g, minimum a_not_s: %g" % (
            conductance, boundary_sum, a_s, a_not_s))
        
        return conductance

    def fetch_cluster_hits(self, cluster_id, stype=None):
        """Fetch the sequence (sub)network for all nodes in cluster_id."""
        
        seq_set = self.CR.fetch_cluster(cluster_id)

        # Default to the stype of the clustering, but it may be useful
        # to be able to extract bit-scores after clustering with NC
        # scores
        if stype is None:
            stype = self.CR.stype
        
        self.log.debug("Fetching cluster: %d (stype: %s)" % (cluster_id, stype))
        
        hit_dict = {}

        qrate = rate.rate(totcount=len(seq_set))
        for query_seq in seq_set:
            qrate.increment()
            if qrate.count % 1000 == 0:
                self.log.info("%d %s", (cluster_id, qrate))

            d = self.fq.fetch_hits_dictarray(
                stype=stype, br_id=self.CR.br_id, nc_id=self.CR.nc_id,
                symmetric=self.CR.symmetric,
                query_seq_id = query_seq, 
                seq_id_0_set_id = self.CR.set_id_filter, # potentially faster if specified
                seq_set = seq_set,
                self_hits=False)

            assert len(d) <= 1, "d has more than one key for query_seq %s" % query_seq

            if query_seq in d:
                hit_dict[query_seq] = d[query_seq]

        return (hit_dict, seq_set)


    def network_stats(self, cluster_id, hit_dict=None):
        if self.cacheq:
            retval = pickler.cachefn(pickledir = self.pickledir,
                                     args = "%s_%s" % (self.CR.cr_id,
                                                       cluster_id))
            if retval is not None: return retval

        if hit_dict is None:
            hit_dict = self.fetch_cluster_hits( cluster_id)

        self.log.debug("Calculating network stats: %d (size %d)", (cluster_id,len(hit_dict)))

        # build a networkx graph of all hits in the cluster
        G = NX.Graph()

        # add all edges
        G.add_edges_from(
            ( (node, nbr)
              for (node,(hit_list, score_list)) in hit_dict.iteritems()
              for nbr in hit_list ))

        net = nxstat(G=G, cacheq=True, unique_parameters="%s_%s" % (self.CR.cr_id,
                                                                    cluster_id))
        net_stats = net.calc_statistics(omit_spl=True)

        if self.cacheq:
            retval = pickler.cachefn(pickledir = self.pickledir,
                                     args = "%s_%s" % (self.CR.cr_id,
                                                       cluster_id),
                                     retval = net_stats)
        return net_stats


    def cluster_stats(self, cluster_id, hit_dict=None, seq_set=None, stype=None):

        if self.cacheq:
            retval = pickler.cachefn(pickledir = self.pickledir,
                                     args = "%s_%s_%s_%s" % (self.CR.cr_id,
                                                             self.CR.set_id_filter,
                                                             cluster_id,
                                                             stype))
            if retval is not None: return retval

        if hit_dict is None:
            hit_dict, seq_set = self.fetch_cluster_hits( cluster_id, stype=stype)   

        # FIXME: Not currently symmetric.  Repetitive, and likely
        # quite slow for very many clusters.  Some values could be
        # calculated by traversing the tree.

        num_scores = 0
        all_scores = numpy.empty([10], dtype=numpy.float64)

        # build an array of all scores within this set
        for (hit_list, score_list) in hit_dict.values():

            upper_ind = num_scores + len(score_list)
            if len(all_scores) < upper_ind:
                all_scores.resize( [int(upper_ind*1.5)])
                
            all_scores[num_scores:upper_ind] = score_list
            num_scores = upper_ind

        all_scores.resize([num_scores])
        #self.log.debug("cluster_stats: %d %s (len %d)" % (cluster_id, all_scores, num_scores))
        stats = {}
        stats['num_nodes'] = len(seq_set)
        stats['num_edges'] = len(all_scores)
        stats['frac_edges'] = float(stats['num_edges']) / (
            stats['num_nodes'] * (stats['num_nodes'] -1)) if num_scores > 0 else -1
        stats['min'] = numpy.min( all_scores) if num_scores > 0 else -1
        stats['max'] = numpy.max( all_scores) if num_scores > 0 else -1
        stats['mean'] = numpy.mean( all_scores) if num_scores > 0 else -1
        stats['stdev'] = numpy.std( all_scores) if num_scores > 0 else -1
        stats['density'] = stats['mean'] * len(all_scores) / (
            len(seq_set) * (len(seq_set) -1)) if num_scores > 0 else -1

        if self.cacheq:
            retval = pickler.cachefn(pickledir = self.pickledir,
                                     args = "%s_%s_%s" % (self.CR.cr_id,
                                                          self.CR.set_id_filter,
                                                          cluster_id),
                                     retval = stats)
        return stats

    def html_run_header(self, descr="", orgarg=""):
        s_table = """<table>
<tr><td valign="top">Cluster Run:</td>
    <td>%(cr_id)d<br>
        %(cr_comment)s<br>
        %(cr_params)s</td></tr>
<tr><td valign="top">Score Type:</td>
    <td>%(stype)s</td></tr>
<tr><td valign="top">Database Set:</td>
    <td>%(set_id)d (%(num_seqs)d sequences)<br>
        %(set_name)s<br>
        %(set_descr)s</td></tr>
<tr><td valign="top">Clustering Query Set:</td>
    <td>%(q_set_id)d (%(q_num_seqs)d sequences)<br>
        %(q_set_name)s<br>
        %(q_set_descr)s</td></tr>
<tr><td valign="top">Blast Run:</td>
    <td>%(br_id)d<br>
        Query set %(blast_query_set)d</td></tr>
<tr><td valign="top">NC Run:</td>
    <td>%(nc_id)s</td></tr>
<tr><td valign="top">Background argument:</td>
    <td>%(orgarg)s</td></tr>
<tr><td valign="top" colspan="2">%(descr)s</td></tr>
</table>
<br><hr><br>\n"""

        blast_info = self.fq.fetch_blast_info_d( self.CR.br_id)

        (set_name, set_descr) = self.fq.fetch_set_info( blast_info['set_id'])

        if self.CR.set_id_filter is not None:
            clustering_set_id = self.CR.set_id_filter
        else:
            clustering_set_id = blast_info['query_set_id']
            
        (q_set_name, q_set_descr) = self.fq.fetch_set_info( clustering_set_id)

        s = s_table % {'cr_id': self.CR.cr_id,
                       'stype': self.CR.stype,
                       'cr_comment': self.CR.cluster_comment,
                       'cr_params': self.CR.cluster_params,
                       'set_id': blast_info['set_id'],
                       'num_seqs': blast_info['num_sequences'],
                       'set_name': set_name,
                       'set_descr': set_descr,
                       'q_set_id': clustering_set_id,
                       'q_num_seqs': self.fq.fetch_set_size( clustering_set_id),
                       'q_set_name': q_set_name,
                       'q_set_descr': q_set_descr,
                       'br_id': self.CR.br_id,
                       'blast_query_set': blast_info['query_set_id'],
                       'nc_id': self.CR.nc_id,
                       'orgarg': orgarg,
                       'descr': descr
                       }
        return s

    def cluster_info(self, cluster_id):
        """Return a dictionary of basic cluster information.  For
        speed, include cluster_id, size, and any additional keys
        relevant to the specific clustering method used (e.g, cluster
        similarity for hierarchical)."""

        info_d = {'cluster_id': cluster_id}

        return

    def cluster_info_extended(self, cluster_id):
        """Include all from cluster_info, but also calculate cluster
        statistics.  Slower, and best called only for smaller
        clusters."""
        
        info_d = self.cluster_info(cluster_id)

        edge_stats = self.cluster_stats(cluster_id)

        return
        

    def html_cluster(self, cluster_id, do_netstats=False, orgarg=""):
        s_cluster = """<a name="%(cluster_id)d" href="http://quantbio-tools.princeton.edu/cgi-bin/CE?url=http://diatom.compbio.cs.cmu.edu:8001/%(cr_id)d/%(cluster_id)d/%(query_set_id)d%(orgarg)s">
<b>Cluster %(cluster_id)d</a>:
Cluster Similarity: %(clustsim)0.4f,
Size: %(num_nodes)d</b>,
Density: %(density)0.4f,
J: %(J)0.4f,
Edges: %(num_edges)d,
Frac. Edges: %(num_edges)d,
Mean: %(mean)0.4f(%(stdev)0.4f)<br>\n"""
        
        s_cluster_large = """<a name="%(cluster_id)d" href="http://quantbio-tools.princeton.edu/cgi-bin/CE?url=http://diatom.compbio.cs.cmu.edu:8001/%(cr_id)d/%(cluster_id)d/%(query_set_id)d%(orgarg)s">
<b>Cluster %(cluster_id)d</a>: Cluster Similarity: %(clustsim)0.4f, Size: %(num_nodes)d</b>, J: %(J)0.4f<br>\n"""

        # work around bad emacs highlighting "
        if do_netstats:
            s = s_cluster_large % {'cluster_id': cluster_id,
                                   'clustsim': 1-clust.distance(),
                                   'num_nodes': cluster_size,
                                   'J': float(parent_size - cluster_size) / cluster_size,
                                   'cr_id': self.CR.cr_id,
                                   'query_set_id': self.CR.set_id_filter,
                                   'orgarg': orgarg
                                   }
        else:
            edge_stats = self.cluster_stats(cluster_id)
            
            s += s_cluster % {'level': level,
                              'space': level*4,
                              'cluster_id': cluster_id,
                              'clustsim': 1-clust.distance(),
                              'num_nodes': cluster_size,
                              'num_edges': edge_stats['num_edges'] / 2,
                              'frac_edges': float(edge_stats['num_edges']) / (cluster_size * (cluster_size -1)),
                              'density': edge_stats['density'],
                              'mean': edge_stats['mean'],
                              'stdev': edge_stats['stdev'],
                              'J': float(parent_size - cluster_size) / cluster_size,
                              'left_id': clust.items()[0].cluster_id(),
                              'right_id': clust.items()[1].cluster_id(),
                              'parent_id': parent_id,
                              'cr_id': self.cr_id,
                              'query_set_id': self.query_set_id,
                              'orgarg': orgarg
                              }
           
        return s


    def html_sequence(self, seq_id):
        s_seq = """<a name="%(seq_id)d">Seq: %(seq_id)d</a>: %(name)s
<a style="color:green;">%(species)s</a>
<a style="color:red;">%(family)s</a>
%(urls_str)s
<a style="color:purple;">%(synonyms)s %(descr)s</a><br>\n"""

        s_url = """<a href="%s">%s</a>"""
        urls = [ s_url % (url,db) for (db,url) in self.fq.build_urls(seq_id).items()]
        if len(urls) > 0:
            urls_str = reduce(lambda a,b:a+', '+b, urls)
        else:
            urls_str = ""
        
        (synonyms, descr) = self.fq.fetch_fa_annotate( seq_id)

        family = self.fq.fetch_seq_family( seq_id)
        species = self.fq.fetch_taxonomy(seq_id)
        
        s = s_seq % {'seq_id': seq_id,
                     'name': self.fq.fetch_seq_acc( seq_id),
                     'urls_str': urls_str,
                     'species': species[1] if species is not None else '',
                     'family': family if family is not None else '',
                     'synonyms': synonyms,
                     'descr': descr}
        return s
