#! /usr/bin/python

from Numeric import *

#                                        David MacKay  Dec 2005
#  - writes a huffman code (derived from huffmano.p)
#
# usage:
# $        huffman.py
# (the counts vector is hard-coded in the program)

# define objects that contain count information (inputs)
class node:
    def __init__(self, count, index ):
        self.count = count
        self.index = index
    def __cmp__(self, other):
        return cmp(self.count, other.count)

# define objects that contain codewords (outputs)
class nodeword:
    def __init__(self, index , word):
        self.index = index
        self.word = word
    def __cmp__(self, other):
        return cmp(self.index, other.index)
    def report(self):
        print self.index,self.word

## recursive function that reads in a count list and returns the codeword list
def iterate (c , m) : 
## c is the list of values and associated numbers; m is the number of 
## iterations remaining.

    c.sort() ## sort the count nodes, using the __cmp__ function defined in the node class

    first  = c[0].index ## the index of the smallest node
    second = c[1].index ## and second smallest
    bot = c[ 0 ].count 

    # MERGE THE BOTTOM TWO
    c[1].count += bot 
    del c[0]

    # report what has been done
    print "combining",first,second
    for v in c :
        print v.count,v.index
        
    if ( m > 2 ) :
	codewords = iterate ( c , m-1 )
        ## find the codeword that has been split/joined at this step
        ## (can this be done more elegantly?)
        for co in codewords :
            if co.index ==  second :
                thestring = co.word
                co.word += '1'
                codewords.append( nodeword( first, thestring+'0'  ) )  
                break
            
    else :
        codewords = [ nodeword( first , "0" ) ,  nodeword( second , "1" ) ]


    return codewords
## that's all for iterate

counts = [ 10, 5, 2, 3 ]
counts = [ 1,2,3,4,5,6,7,8,9,10 ]

c=[]
m=0
for co in counts :
    m += 1
    c.append( node( co, m ) )

## c = [ node( 10 , 1 ) , node( 5 , 2 ) , node( 2 , 3 ) , node( 3 , 4 ) ] 

c.sort()
        
## main program starts here
M = len(c)
print  M

if (M <= 1):
    print "too few symbols\n"
    exit(0)

#
# make huffman code
#

answer = iterate ( c , M ) 
# write the answer

answer.sort()

print "Done ========= "
for co in answer :
    co.report()
# end


