#!/usr/bin/env python

# Jacob Joseph
# 1 December 2009

# Run and parse output from the MC-UPGMA average linkage clustering
# package.

import os, subprocess, sys, tempfile, copy
from shutil import rmtree

from JJcluster.agglomerative import agglomerative_base
from JJcluster import cluster_obj

class mcupgma(agglomerative_base):

    def __init__(self, br_id = None, nc_id = None,
                 stype='e_value', self_hits = False,
                 symmetric = True,
                 set_id_filter = None,
                 score_threshold = None,
                 param_list = None):
        agglomerative_base.__init__(self, br_id=br_id, nc_id=nc_id,
                                    stype=stype, self_hits=self_hits,
                                    symmetric = symmetric,
                                    score_threshold = score_threshold,
                                    set_id_filter=set_id_filter)

        self.param_list = copy.copy(param_list)
        if self.param_list is None: self.param_list = []

        self.mcupgma_path = os.path.expandvars("$HOME/Durand/external_software/mcupgma")
        self.tmp = os.path.expandvars("$HOME/tmp/")

    def write_edgefile(self, filename):

        print >>sys.stderr, "Writing edgefile: ", filename

        # always use symmetric scores.  Affects blast.
        hits = self.bq.fetch_hits_direct(stype = self.stype,
                                         br_id = self.br_id,
                                         nc_id = self.nc_id,
                                         self_hits = self.self_hits,
                                         symmetric = self.symmetric,
                                         set_id_filter = self.set_id_filter)
        fd = os.popen('gzip > %s' % filename, 'w')

        for seq_id_0, seq_id_1, score in hits:
            # MCUPGMA takes only one of the (symmetric) edges
            if seq_id_0 > seq_id_1:
                continue

            # Make a distance from the similarity measure
            if self.stype in ('nc_score', 'bit_score'):
                score = self.max_distance - score
            
            print >> fd, seq_id_0, seq_id_1, score

        fd.close()


    def build_tree(self, tree_fname):

        print >>sys.stderr, "Reading tree: ", tree_fname
        
        clust_id_object = {}

        # tree leaves
        for seq_id in self.all_seqs:
            c = cluster_obj.cobj(distance=0,
                                 cluster_id=seq_id,
                                 items=[seq_id])
            clust_id_object[ seq_id] = c
            
            # sequence to cluster map
            self.seq_to_clust[seq_id] = c

            # cluster to sequence set map
            self.clust_to_seq[c] = set([seq_id])


        fd = open(tree_fname)

        for l in fd:
            (id0, id1, dist, newid) = l.strip().split()

            c0 = clust_id_object[ int(id0)]
            c1 = clust_id_object[ int(id1)]
            
            c_new = cluster_obj.cobj(distance=float(dist),
                                     cluster_id=int(newid),
                                     items=[c0, c1])
            clust_id_object[ int(newid)] = c_new

            seq_set = set()
            for clust in (c0, c1):
                seq_set.update( self.clust_to_seq.pop(clust))
                
            self.clust_to_seq[c_new] = seq_set

            for seq_id in seq_set:
                self.seq_to_clust[ seq_id] = c_new

        # it is possible that nodes in out set were not present at all
        # in the clustering input.  Connect them now
        if len( self.clust_to_seq) > 1:
            clusters = self.clust_to_seq.keys()
            c_root = cluster_obj.cobj(
                distance = max( [c.distance() for c in clusters]) + 0.01,
                cluster_id=-1,
                items=clusters)
            
            seq_set = set()
            for clust in clusters:
                seq_set.update( self.clust_to_seq.pop(clust))
            self.clust_to_seq[c_root] = seq_set

            for seq_id in seq_set:
                self.seq_to_clust[ seq_id] = c_root
                
        return

    def run_mcupgma(self, edge_fname, tree_fname):

        print >> sys.stderr, "Running MCUPGMA"

        # be sure to not reuse any temporary state files from
        # clustering
        cwd = tempfile.mkdtemp(dir=self.tmp)

        if not "-max_distance" in self.param_list:
            self.param_list.append("-max_distance")
            self.param_list.append("%g" % self.max_distance)

        self.param_list.append("-otree")
        self.param_list.append(tree_fname)
        self.param_list.append(edge_fname)

        # find the maximum ID
        max_id = max(self.all_seqs)

        self.param_list.append('-max_singleton')
        self.param_list.append('%d' % max_id)

        subprocess.check_call(
            [os.path.join(self.mcupgma_path, 'scripts', 'cluster.pl')]
            + self.param_list,
            cwd=cwd)

        # remove the temporary cwd
        rmtree(cwd)

        return
    
    def cluster(self):
        """Write the input file, run mcupgma, and parse the output."""

        edge_fname = "edges_%s_%s_%s_%s_%s_%s.edges.gz" % (
            self.stype, self.br_id, self.nc_id,
            self.self_hits, self.symmetric, self.set_id_filter)
        edge_fname = os.path.join(self.tmp, edge_fname)

        tree_fname = "cluster_tree_%s_%s_%s_%s_%s_%s"  % (
            self.stype, self.br_id, self.nc_id,
            self.self_hits, self.symmetric, self.set_id_filter)
        tree_fname = os.path.join(self.tmp, tree_fname)

        self.write_edgefile( edge_fname)

        self.run_mcupgma(edge_fname, tree_fname)

        self.build_tree(tree_fname)
        
