__author__  = "Hugo Liu <hugo@media.mit.edu>"
__version__ = "1.3.1"
import sys,string,os.path
from types import *
import Settings,MontyUtils
if Settings.Settings().JYTHON_P:
    import jarray
    import struct
else:
    import array
class LexiconFast:


    java_p = Settings.Settings().JYTHON_P
    auto_load_lexicon = 0
    lexicon_filename = 'LEXICON'
    fast_lexicon_filename = 'FASTLEXICON'
    packed_words = ""
    packed_pos = ""
    if java_p:
        word_start_arr = jarray.array([],'l')
        word_end_arr = jarray.array([],'l')
        pos_start_arr = jarray.array([],'l')
        pos_end_arr = jarray.array([],'l')
    else:
        word_start_arr = array.array('L')
        word_end_arr = array.array('L')
        pos_start_arr = array.array('L')
        pos_end_arr = array.array('L')
    def __init__(self,notify=[]):

        if MontyUtils.MontyUtils().find_file(self.fast_lexicon_filename+'.1') != '':
            print "Fast Lexicon Found! Now Loading!"
            self.load_fastlexicon()
        elif self.auto_load_lexicon:
            print "No Fast Lexicon Detected...Now Building..."
            self.lexicon_filename = MontyUtils.MontyUtils().find_file(self.lexicon_filename)
            if self.lexicon_filename == '':
                print "ERROR: could not find LEXICON"
                print "in current dir, %MONTYLINGUA% or %PATH%"
            self.populate_lexicon_from_file(self.lexicon_filename)
            print "Fast Lexicon files will be made and put into"
            print "the current working directory"
            self.make_fastlexicon()
            print "Finished building.  Won't build again"
        else:
            print "No Fast Lexicon Detected. Standard Lexicon used."
            notify.append(-1)
            return
        print "Lexicon OK!"
        return
    def make_fastlexicon(self):

        f = open(self.fast_lexicon_filename+'.1','w')
        f.write(self.packed_words)
        f.close()
        f = open(self.fast_lexicon_filename+'.2','w')
        f.write(self.packed_pos)
        f.close()
        f = open(self.fast_lexicon_filename+'.3','wb')
        self.word_start_arr.tofile(f)
        f.close()
        f = open(self.fast_lexicon_filename+'.4','wb')
        self.word_end_arr.tofile(f)
        f.close()
        f = open(self.fast_lexicon_filename+'.5','wb')
        self.pos_start_arr.tofile(f)
        f.close()
        f = open(self.fast_lexicon_filename+'.6','wb')
        self.pos_end_arr.tofile(f)
        f.close()
        f = open(self.fast_lexicon_filename+'.7','w')
        f.write(str(len(self.word_start_arr))+'\n')
        f.write(str(len(self.word_end_arr))+'\n')
        f.write(str(len(self.pos_start_arr))+'\n')
        f.write(str(len(self.pos_end_arr))+'\n')
        f.close()
        return
    def load_fastlexicon(self):

        mu = MontyUtils.MontyUtils()
        flf1 = mu.find_file(self.fast_lexicon_filename+'.1')
        flf2 = mu.find_file(self.fast_lexicon_filename+'.2')
        flf3 = mu.find_file(self.fast_lexicon_filename+'.3')
        flf4 = mu.find_file(self.fast_lexicon_filename+'.4')
        flf5 = mu.find_file(self.fast_lexicon_filename+'.5')
        flf6 = mu.find_file(self.fast_lexicon_filename+'.6')
        flf7 = mu.find_file(self.fast_lexicon_filename+'.7')
        f = open(flf7,'r')
        len1,len2,len3,len4 = map(lambda x: int(x),f.read().split())
        f.close()
        f = open(flf1,'r')
        self.packed_words = f.read()
        f.close()
        f = open(flf2,'r')
        self.packed_pos = f.read()
        f.close()
        f = open(flf3,'rb')
        arr = self.array_fromfile(f,self.word_start_arr,len1,self.java_p,java_code='ws')
        f.close()
        f = open(flf4,'rb')
        self.array_fromfile(f,self.word_end_arr,len2,self.java_p,java_code='we')
        f.close()
        f = open(flf5,'rb')
        self.array_fromfile(f,self.pos_start_arr,len3,self.java_p,java_code='ps')
        f.close()
        f = open(flf6,'rb')
        self.array_fromfile(f,self.pos_end_arr,len4,self.java_p,java_code='pe')
        f.close()
    def compare(self,element1,element2):

        if element1[0]<element2[0]: return -1
        elif element1[0]>element2[0]: return 1
        else: return 0
    def get(self,word,default):

        startpoint = 0
        endpoint = len(self.word_start_arr)
        real_endpoint = len(self.word_start_arr)
        while endpoint>=startpoint and startpoint < real_endpoint:
            midpoint = (endpoint-startpoint)/2+startpoint
            res = self.get_helper(midpoint,word)
            if res != 0 and startpoint==endpoint:
                return default
            elif res < 0:
                endpoint = midpoint
            elif res > 0 and endpoint-startpoint == 1:
                startpoint = endpoint
            elif res > 0:
                startpoint = midpoint
            elif res == 0:
                poses = self.packed_pos[self.pos_start_arr[midpoint]:self.pos_end_arr[midpoint]]
                return poses.split()
            else:
                return default
        return default
    def get_helper(self,i,word):

        cur_word = self.packed_words[self.word_start_arr[i]:self.word_end_arr[i]]
        if word<cur_word: return -1
        elif word>cur_word: return 1
        else: return 0
    def primary_pos(self,word):

        pos_arr = self.get(word,[])
        if pos_arr == []:
            return ""
        else:
            return pos_arr[0]
    def all_pos(self,word):

        pos_arr = self.get(word,[])
        return pos_arr
    def has_pos(self,word,pos):

        return pos in self.get(word,[])
    def is_word(self,word,case_sensitivity=0):

        if case_sensitivity:
            return (self.get(word,[]) != [])
        else:
            word_initial_cap = word.lower()
            if len(word_initial_cap) > 1:
                word_initial_cap = word_initial_cap[0].upper() + word_initial_cap[1:]
            res = (self.get(word,[]) != []) or (self.get(word.lower(),[]) != []) or (self.get(word.upper(),[]) != []) or (self.get(word_initial_cap,[]) != [])
            return res
    def populate_lexicon_from_file(self,filename):

        temp_lex = []
        try:
            f = open(filename,'r')
            line = f.readline()
            while line:
                word_end_index = string.find(line,' ')
                word = line[:word_end_index]
                poses = line[word_end_index+1:]
                temp_lex.append([word,poses])
                line = f.readline()
            f.close()
        except:
            print "Error parsing Lexicon!"
            sys.exit(-1)
        temp_lex.sort(self.compare)
        pw = self.packed_words
        pp = self.packed_pos
        ws = self.word_start_arr
        we = self.word_end_arr
        ps = self.pos_start_arr
        pe = self.pos_end_arr
        count = 0
        cur_len_pw = 0
        cur_len_pp = 0
        for i in range(len(temp_lex)):
            count += 1
            if count % 1000 == 0:
                print count
            word,poses = temp_lex[i]
            ws.append(cur_len_pw)
            cur_len_pw += len(word)
            we.append(cur_len_pw)
            ps.append(cur_len_pp)
            cur_len_pp += len(poses)
            pe.append(cur_len_pp)
            pw += word
            pp += poses
        self.packed_words = pw
        self.packed_pos = pp
        self.word_start_arr = ws
        self.word_end_arr = we
        self.pos_start_arr = ps
        self.pos_end_arr = pe
        return
    def array_fromfile(self,file_ptr,array_ptr,length,java_p=0,java_code='',endian_order='little'):

        item_length = 4
        if java_p:
            if endian_order == 'big':
                format_character = '>'
            else:
                format_character = '<'
            arr = struct.unpack(format_character+str(length)+'L',file_ptr.read())
            if java_code == 'ws':
                self.word_start_arr = jarray.array(arr,'l')
            elif java_code == 'we':
                self.word_end_arr = jarray.array(arr,'l')
            elif java_code == 'ps':
                self.pos_start_arr = jarray.array(arr,'l')
            elif java_code == 'pe':
                self.pos_end_arr = jarray.array(arr,'l')
            else:
                print "error! java code invalid!"
                sys.exit(-1)
        else:
            array_ptr.fromfile(file_ptr,length)
if __name__ == "__main__":
    l = LexiconFast()
    a = "aberration bird ate an apple"
    for word in string.split(a):
        print l.get(word,'UNK')