#!/usr/bin/env python

# Jacob Joseph
# 20 October 2009

# A quick plot to compare different scoring runs by a
# scatterplot. Fetch one query at a time from each axis, and run
# though them to align targets.  Draw a scatterplot.  Keep track of
# those missing from either. Draw a 2 histograms of missing scores.

# Actually, instead of scatterplots, use heatmaps with high
# granularity to avoid the memory of very many points, and better
# visualize the count a particular point.

import matplotlib
from matplotlib import pyplot
from DurandDB import blastq
import numpy, cProfile, time
import scipy.weave as weave

class run:
    def __init__(self, stype, run_id, set_id_filter, symmetric, descr=None):
        self.stype = stype
        self.run_id = run_id
        self.nc_id = run_id if stype=='nc_score' else None
        self.br_id = run_id if stype in ('bit_score','e_value') else None
        self.set_id_filter = set_id_filter
        self.symmetric= symmetric
        self.descr = descr

    def __repr__(self):
        s = "%s_%s_%s_%s_%s" % (self.stype, self.nc_id, self.br_id, self.set_id_filter,
                                self.symmetric)
        return s
    
    def __str__(self):
        s = "%s,%s,%s,%s,%s\n%s" % (self.stype, self.nc_id, self.br_id, self.set_id_filter,
                                    self.symmetric,
                                    self.descr if self.descr is not None else "")
        return s
        
        
class scatter:
    
    def __init__(self, run_0, run_1, hres=None, vres=None):
        self.runs = [run_0, run_1]
        self.bq = blastq.blastq(cacheq=False)

        self.hres = 1000 if hres is None else hres
        self.vres = 1000 if vres is None else vres

        self.unique = None

        self.fetch_scatter_weave()

    def fetch_scatter(self, hres=1000, vres=1000):
        r0 = self.runs[0]
        r1 = self.runs[1]

        # fetch the sequence sets.  We'll then take the union,
        # and iterate through
        seq_set = set(self.bq.fetch_seq_set(br_id = r0.br_id,
                                            nc_id = r0.nc_id,
                                            set_id = r0.set_id_filter))
        seq_set.update(self.bq.fetch_seq_set(br_id = r1.br_id,
                                             nc_id = r1.nc_id,
                                             set_id = r1.set_id_filter))

        map_both = numpy.zeros([hres, vres], dtype=numpy.int32)
        map_0 = numpy.zeros(hres, dtype=numpy.int32)
        map_1 = numpy.zeros(hres, dtype=numpy.int32)

        # FIXME: transformation is specific to NC.
        for i, query_seq_id in enumerate(seq_set):
            print i, query_seq_id
            if i > 50: return (1,2,3)
            hitdict_0 = self.bq.fetch_hits_dictarray(
                stype=r0.stype, br_id = r0.br_id, nc_id = r0.nc_id,
                symmetric = r0.symmetric, set_id_filter = r0.set_id_filter,
                query_seq_id = query_seq_id)
            if query_seq_id in hitdict_0:
                h0, s0 = hitdict_0[query_seq_id]
            else:
                h0 = []
                s0 = None

            hitdict_1 = self.bq.fetch_hits_dictarray(
                stype=r1.stype, br_id = r1.br_id, nc_id = r1.nc_id,
                symmetric = r1.symmetric, set_id_filter = r1.set_id_filter,
                query_seq_id = query_seq_id)
            if query_seq_id in hitdict_1:
                h1, s1 = hitdict_1[query_seq_id]
            else:
                h1 = []
                s1 = None
            
            i0 = 0
            i1 = 0
            len0 = len(h0)
            len1 = len(h1)
            # Hits are in sorted order
            while i0 < len0 and i1 < len1:
                if h0[i0] < h1[i1]:
                    map_0[ int(s0[i0] * hres)] += 1
                    i0 += 1
                    continue
                elif h0[i0] > h1[i1]:
                    map_1[ int(s1[i1] * vres)] += 1
                    i1 += 1
                    continue

                #print h0[i0], h1[i1], s0[i0], s1[i1]
                #FIXME: not robust, specific to NC
                map_both[ int(s0[i0] * hres), int(s1[i1] * vres)] += 1
                i0 += 1
                i1 += 1

            #print hitdict_0
            #print hitdict_1
            #print query_seq_id
            #return

        return map_0, map_1, map_both

    def fetch_scatter_weave(self, hres=1000, vres=1000):
        r0 = self.runs[0]
        r1 = self.runs[1]

        # fetch the sequence sets.  We'll then take the union,
        # and iterate through
        seq_set = set(self.bq.fetch_seq_set(br_id = r0.br_id,
                                            nc_id = r0.nc_id,
                                            set_id = r0.set_id_filter))
        seq_set.update(self.bq.fetch_seq_set(br_id = r1.br_id,
                                             nc_id = r1.nc_id,
                                             set_id = r1.set_id_filter))

        map_both = numpy.zeros([hres, vres], dtype=numpy.int32)
        map_0 = numpy.zeros(hres, dtype=numpy.int32)
        map_1 = numpy.zeros(hres, dtype=numpy.int32)

        self.map_both = map_both
        self.map_0 = map_0
        self.map_1 = map_1

        # FIXME: transformation is specific to NC.
        for i, query_seq_id in enumerate(seq_set):
            #print i, query_seq_id

            hitdict_0 = self.bq.fetch_hits_dictarray(
                stype=r0.stype, br_id = r0.br_id, nc_id = r0.nc_id,
                symmetric = r0.symmetric, set_id_filter = r0.set_id_filter,
                query_seq_id = query_seq_id)
            if query_seq_id in hitdict_0:
                hits_0, scores_0 = hitdict_0[query_seq_id]
            else:
                hits_0 = numpy.array([], dtype=numpy.int32)
                scores_0 = numpy.array([], dtype=numpy.int32)

            hitdict_1 = self.bq.fetch_hits_dictarray(
                stype=r1.stype, br_id = r1.br_id, nc_id = r1.nc_id,
                symmetric = r1.symmetric, set_id_filter = r1.set_id_filter,
                query_seq_id = query_seq_id)
            if query_seq_id in hitdict_1:
                hits_1, scores_1 = hitdict_1[query_seq_id]
            else:
                hits_1 = numpy.array([], dtype=numpy.int32)
                scores_1 = numpy.array([], dtype=numpy.int32)

            code = """
            int ind_0 = 0, ind_1 = 0;

            while( (ind_0 < Nhits_0[0]) && (ind_1 < Nhits_1[0]))
            {
                int h0 = hits_0[ind_0];
                int h1 = hits_1[ind_1];

                if (h0 < h1) {
                  map_0[ (int)(scores_0[ind_0] * hres)] += 1;
                  ind_0++;
                  continue;
                }
                else if (h0 > h1) {
                  map_1[ (int)(scores_1[ind_1] * vres)] += 1;
                  ind_1++;
                  continue;
                }

                /* if (scores_0[ind_0] - scores_1[ind_1] < -0.1) {
                         printf("%d, %d, %f, %f, %f\\n", query_seq_id, h1,
                         scores_0[ind_0], scores_1[ind_1],
                         scores_0[ind_0] - scores_1[ind_1]);
                } */
                MAP_BOTH2( (int)(scores_0[ind_0] * hres),
                           (int)(scores_1[ind_1] * vres)) += 1;
                ind_0 += 1;
                ind_1 += 1;
            }

            return_val = 0;
            """
            
            ret = weave.inline( code, ['hits_0', 'scores_0',
                                       'hits_1', 'scores_1',
                                       'map_0', 'map_1', 'map_both',
                                       'vres', 'hres', 'query_seq_id'])
        return

    def plotmap(self, title=None, fprefix=None):
        
        map_both1 = self.map_both + 1
        # pyplot.pcolor(map_both1, cmap=pyplot.cm.hot,
        pyplot.imshow(map_both1, cmap=pyplot.cm.hot,
                      norm = matplotlib.colors.LogNorm(),
                      origin='lower')
        pyplot.plot([0,1000],[0,1000], 'b--')
        pyplot.axis('image')
        pyplot.xlabel( str(self.runs[1]))
        pyplot.ylabel( str(self.runs[0]))

        pyplot.subplots_adjust(bottom=0.1, left=0.12, top=0.95, right=0.92)

        pyplot.colorbar(fraction=0.05,pad=0.01)

        if title is not None:
            pyplot.suptitle(title)


        #l 0.12 b .1 r .9 t .9 w .2 h .2
        
        if fprefix is not None:
            fname = (time.strftime('%Y%m%d')
                     + '_'
                     + fprefix + repr(self.runs[0])
                     + '_vs_'
                     + repr(self.runs[1]))
            pyplot.savefig( 'figures/%s' % fname)
            pyplot.close()
        else:
            pyplot.show()
        return
        
if __name__ == "__main__":
    # (score_type, run_id, set_id_filter, symmetric)
    #s = scatter( run('nc_score', 746, 105, True),
    #              run('nc_score', 750, 105, True))

    #s = scatter( run('nc_score', 777, None, False),
    #             run('nc_score', 779, None, True))

    #r_777 = run('nc_score', 777, None, False, 'Not Symmetric, Composition=2, 12 species')
    #r_777hm = run('nc_score', 777, 110, False, 'Not Symmetric, Composition=2, families')
    #r_779 = run('nc_score', 779, None, True, 'Symmetric, Composition=2, 12 species')
    #r_779hm = run('nc_score', 779, 110, True, 'Symmetric, Composition=2, families')
    #r_780 = run('nc_score', 780, None, True, 'Not Symmetric, Composition=f, 12 species')
    #r_780hm = run('nc_score', 780, 110, True, 'Not Symmetric, Composition=f, families')
    #r_782hmblast = run('nc_score', 782, 109, 'Symmetric, Composition=f, NC from blast of human, mouse only')
    #r_789 = run('nc_score', 789, None, True, "48 Genomes, no log")
    r_808 = run('nc_score', 808, None, True, "48 Genomes. Correct calculation")

    r_822 = run('nc_score', 822, None, True, "48 Genomes. Using blast_hit_limit >=50")
    r_823 = run('nc_score', 823, None, True, "48 Genomes. Using blast_hit_limit >=100")
    r_824 = run('nc_score', 824, None, True, "48 Genomes. Using blast_hit_limit >=500")

    #cProfile.runctx('map_0, map_1, map_both = s.fetch_scatter_weave()', globals(), locals())
    #cProfile.runctx('map_0, map_1, map_both = s.fetch_scatter()', globals(), locals())
    #map_0, map_1, map_both = s.fetch_scatter_weave()

    pyplot.ioff()
    pyplot.rc('figure', figsize=(9,8))
    
    #s = scatter( r_777, r_779)
    #s = scatter( r_777, r_780)
    #s = scatter( r_779, r_780)
    #s = scatter( r_777hm, r_779hm)
    #s = scatter( r_777hm, r_780hm)
    #s = scatter( r_779hm, r_780hm)
    #s = scatter( r_780, r_782hmblast)
    #s = scatter( r_789, r_808)
    #s = scatter( r_816, r_808)
    #s = scatter( r_822, r_808)
    #s = scatter( r_823, r_808)
    s = scatter( r_824, r_808)

    s.plotmap( fprefix='ncncscatter')


def foo():
    logcnt = [ math.log10(c) for c in cnt]

    pyplot.scatter( bit, e, c=logcnt, marker='o')
    pyplot.axis('tight')
    pyplot.grid(color='gray', alpha=0.5)
    pyplot.colorbar()
    pyplot.xlabel('Bit-score')
    pyplot.ylabel('-log10(e_value/dbsize + 1E-200)')
    pyplot.subplots_adjust(left=0.1, bottom=0.05, right=0.98, top=0.95)
    pyplot.title("E-value vs Bit score (br_id: 91)")


    #pyplot.pcolor(map_both, cmap=pyplot.cm.hot)

    # color based upon the log of the number
    # Take a log, then normalize to a 0-1 range?
    #logmap_both = numpy.log10(numpy.log10(map_both))
    # remove -inf, replacing

    #logmap_both = 

    fd = open('map_out.pickle','w')
    cPickle.dump((map_0, map_1, map_both), fd)
    fd.close()

    fd = open('map_out.pickle','r')
    (map_0, map_1, map_both) = cPickle.load(fd)
    f.close()

    map_both1 = map_both + 1
    #pyplot.pcolor(map_both1, cmap=pyplot.cm.hot,
    pyplot.imshow(map_both1, cmap=pyplot.cm.hot,
                  norm = matplotlib.colors.LogNorm(),
                  origin='lower')
    #pyplot.xlabel("xlabel")
