from __future__ import nested_scopes
import re,string,os,sys,shelve,copy,math
from types import *
import Settings,MontyUtils
if Settings.Settings().JYTHON_P:
    import jarray
    import struct
else:
    import array
class OMCSNetFast:


    def __init__(self,suppress_init_p=0):

        self.java_p = Settings.Settings().JYTHON_P
        self.filename1 = 'semanticnet.txt'
        self.filename2 = 'predicates.txt'
        self.fast_omcsnet_filename = 'OMCSNET_DATA'
        self.semantic_net = {}
        self.nodes_string = ''
        self.sentences_string = ''
        self.preds = []
        if self.java_p:
            self.node_start_pos = jarray.array([],'l')
            self.node_end_pos = jarray.array([],'l')
            self.sentence_start_pos = jarray.array([],'l')
            self.sentence_end_pos = jarray.array([],'l')
            self.edge_origin_nodeid = jarray.array([],'l')
            self.edge_destination_nodeid = jarray.array([],'l')
            self.edge_predid = jarray.array([],'l')
            self.edge_sentenceid = jarray.array([],'l')
            self.backedge_destination_nodeid = jarray.array([],'l')
            self.backedge_edgeid = jarray.array([],'l')
        else:
            self.node_start_pos = array.array('L')
            self.node_end_pos = array.array('L')
            self.sentence_start_pos = array.array('L')
            self.sentence_end_pos = array.array('L')
            self.edge_origin_nodeid = array.array('L')
            self.edge_destination_nodeid = array.array('L')
            self.edge_predid = array.array('L')
            self.edge_sentenceid = array.array('L')
            self.backedge_destination_nodeid = array.array('L')
            self.backedge_edgeid = array.array('L')
        if suppress_init_p:
            return
        if MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.1') != '':
            print "Fast OMCSNet Found! Now Loading!"
            self.load_fast_omcsnet()
            return
        elif MontyUtils.MontyUtils().find_file(self.filename1) != '' or MontyUtils.MontyUtils().find_file(self.filename2) != '':
            print "A text-only OMCSNet Datafile was found."
            print "OMCSNet can build a 10x faster datafile."
            print "However, it could take a couple of hours to build."
            print "Press Y to build fast datafile,"
            print "or any other key to continue using text-only datafile..."
            res = ''
            res = raw_input('> ')
            if len(res) < 1 or res[0] not in ['Y','y']:
                print "Loading text-only datafile..."
                self.use_old_omcsnet()
                return
            else:
                print "Now building faster datafiles..."
                print "Please be patient..."
                self.use_old_omcsnet()
                self.make_fast_omcsnet()
                print "Finished building fast OMCSNet datafiles: 'OMCSNET_DATA.*'"
                print "Always include new datafiles in the working directory"
                print "Please restart OMCSNet.  Now exiting..."
                sys.exit(0)
        else:
            print "No OMCSNet datafiles found..."
            print "The text-only datafile look like: 'semanticnet.txt' or 'predicates.txt'"
            print "The fast datafiles looks like: 'OMCSNET_DATA.*'"
            print "Now exiting..."
            sys.exit(0)
        return
    def get_edges_by_origin_nodename(self,nodename,sentence_p=1,default=[]):

        edgeids = self.get_edgeids_by_origin_nodename(nodename,[])
        edges = []
        for edgeid in edgeids:
            edge = self.get_edge_by_edgeid(edgeid,[])
            if edge == []:
                print "internal consistency error"
                return default
            if sentence_p:
                sentence = self.get_sentence_by_edgeid(edgeid,'')
                edges.append(edge+[sentence])
            else:
                edges.append(edge)
        return edges
    def get_edges_by_destination_nodename(self,nodename,sentence_p=1,default=[]):

        edgeids = self.get_edgeids_by_destination_nodename(nodename,[])
        edges = []
        for edgeid in edgeids:
            edge = self.get_edge_by_edgeid(edgeid,[])
            if edge == []:
                print "internal consistency error"
                return default
            if sentence_p:
                sentence = self.get_sentence_by_edgeid(edgeid,'')
                edges.append(edge+[sentence])
            else:
                edges.append(edge)
        return edges
    def get_pred_by_edgeid(self,edgeid,default=''):

        predid = self.edge_predid[edgeid]
        try:
            pred = self.preds[predid]
            return pred
        except:
            return default
        return default
    def get_sentence_by_edgeid(self,edgeid,default=''):

        sentenceid = self.edge_sentenceid[edgeid]
        try:
            sentence = self.sentences_string[self.sentence_start_pos[sentenceid]:self.sentence_end_pos[sentenceid]]
            return sentence
        except:
            return default
        return default
    def get_edge_by_edgeid(self,edgeid,default=[]):

        origin_nodeid = self.edge_origin_nodeid[edgeid]
        destination_nodeid = self.edge_destination_nodeid[edgeid]
        triple = [self.get_pred_by_edgeid(edgeid)] + map(lambda x:self.get_nodename_by_nodeid(x),[origin_nodeid,destination_nodeid])
        return triple
    def get_edgeids_by_destination_nodename(self,nodename,default=[]):

        nodeid = self.get_nodeid_by_nodename(nodename,-1)
        if nodeid < 0:
            return default
        return self.get_edgeids_by_destination_nodeid(nodeid,default)
    def get_edgeids_by_origin_nodename(self,nodename,default=[]):

        nodeid = self.get_nodeid_by_nodename(nodename,-1)
        if nodeid < 0:
            return default
        return self.get_edgeids_by_origin_nodeid(nodeid,default)
    def get_edgeids_by_destination_nodeid(self,nodeid,default=[]):

        if nodeid < 0:
            return default
        startpoint = 0
        endpoint = len(self.backedge_destination_nodeid)
        real_endpoint = endpoint
        found_one = -1
        while endpoint>=startpoint and startpoint < real_endpoint:
            midpoint = (endpoint-startpoint)/2+startpoint
            cur_nodeid = self.backedge_destination_nodeid[midpoint]
            res = nodeid-cur_nodeid
            if res != 0 and startpoint==endpoint:
                return default
            elif res < 0:
                endpoint = midpoint
            elif res > 0 and endpoint-startpoint == 1:
                startpoint = endpoint
            elif res > 0:
                startpoint = midpoint
            elif res == 0:
                found_one = midpoint
                break
            else:
                return default
        if found_one == -1:
            return default
        low = found_one
        high = found_one
        position = found_one
        while position >= 0:
            position -= 1
            if self.backedge_destination_nodeid[position] != nodeid:
                break
            else:
                low -= 1
        position = found_one
        while position < len(self.backedge_destination_nodeid):
            position += 1
            if self.backedge_destination_nodeid[position] != nodeid:
                break
            else:
                high += 1
        backedgeids = range(low,high+1)
        edgeids = map(lambda x:self.backedge_edgeid[x],backedgeids)
        return edgeids
    def get_edgeids_by_origin_nodeid(self,nodeid,default=[]):

        if nodeid < 0:
            return default
        startpoint = 0
        endpoint = len(self.edge_origin_nodeid)
        real_endpoint = endpoint
        found_one = -1
        while endpoint>=startpoint and startpoint < real_endpoint:
            midpoint = (endpoint-startpoint)/2+startpoint
            cur_nodeid = self.edge_origin_nodeid[midpoint]
            res = nodeid-cur_nodeid
            if res != 0 and startpoint==endpoint:
                return default
            elif res < 0:
                endpoint = midpoint
            elif res > 0 and endpoint-startpoint == 1:
                startpoint = endpoint
            elif res > 0:
                startpoint = midpoint
            elif res == 0:
                found_one = midpoint
                break
            else:
                return default
        if found_one == -1:
            return default
        low = found_one
        high = found_one
        position = found_one
        while position >= 0:
            position -= 1
            if self.edge_origin_nodeid[position] != nodeid:
                break
            else:
                low -= 1
        position = found_one
        while position < len(self.edge_origin_nodeid):
            position += 1
            if self.edge_origin_nodeid[position] != nodeid:
                break
            else:
                high += 1
        return range(low,high+1)
    def get_nodename_by_nodeid(self,nodeid,default=''):

        try:
            nodename= self.nodes_string[self.node_start_pos[nodeid]:self.node_end_pos[nodeid]]
            return nodename
        except:
            return default
    def get_nodeid_by_nodename(self,nodename,default=-1):

        startpoint = 0
        endpoint = self.origin_node_count
        real_endpoint = endpoint
        while endpoint>=startpoint and startpoint < real_endpoint:
            midpoint = (endpoint-startpoint)/2+startpoint
            res = self.get_nodeids_by_nodename_helper(midpoint,nodename)
            if res != 0 and startpoint==endpoint:
                return default
            elif res < 0:
                endpoint = midpoint
            elif res > 0 and endpoint-startpoint == 1:
                startpoint = endpoint
            elif res > 0:
                startpoint = midpoint
            elif res == 0:
                return midpoint
                break
            else:
                return default
        return default
    def get_nodeids_by_nodename_helper(self,i,word):

        cur_word = self.nodes_string[self.node_start_pos[i]:self.node_end_pos[i]]
        if word<cur_word: return -1
        elif word>cur_word: return 1
        else: return 0
    def use_old_omcsnet(self):

        if MontyUtils.MontyUtils().find_file(self.filename2) != '':
           self.load_predicates_omcsnet()
        else:
            self.load_pp_omcsnet()
        self.populate_efficient_repr()
        return
    def load_predicates_omcsnet(self):

        print "Loading OMCSNet from Predicates file"
        f = open(MontyUtils.MontyUtils().find_file(self.filename2),'r')
        line = f.readline()
        lines = []
        while line:
            lines.append(line)
            line = f.readline()
        lines = map(lambda x:x.strip(),lines)
        lines = filter(lambda x:x!='' and x[0]=='(' and x[-1]==')',lines)
        lines = map(lambda x:x[1:-1].split('"'),lines)
        triplets = map(lambda x:[x[0].strip(),x[1].strip(),x[3].strip()],lines)
        print "Found",len(triplets),"pred-arg structures"
        for triplet in triplets:
            pred,origin,dest=triplet
            existing_edges = self.semantic_net.get(origin,[])
            new_edge = [pred,dest,'','fw',0.0]
            if new_edge not in existing_edges:
                existing_edges.append(new_edge)
            self.semantic_net[origin] = existing_edges
        print "Loaded",len(self.semantic_net.keys()),"keys"
        return
    def load_fast_omcsnet(self):

        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.11'),'r')
        segments = f.read().split('\n')
        preds = segments[-1]
        len_node_start_pos,len_node_end_pos,len_sentence_start_pos,len_sentence_end_pos,len_edge_origin_nodeid,len_edge_destination_nodeid,len_edge_predid,len_edge_sentenceid,origin_node_count= map(lambda x: int(x),segments[:-1])
        len_backedge_destination_nodeid = len_edge_destination_nodeid
        len_backedge_edgeid = len_edge_destination_nodeid
        f.close()
        self.origin_node_count = origin_node_count
        self.preds = preds.split()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.1'),'r')
        self.nodes_string = f.read()
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.2'),'r')
        self.sentences_string = f.read()
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.3'),'rb')
        self.array_fromfile(f,self.node_start_pos,len_node_start_pos,self.java_p,java_code='node_start_pos')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.4'),'rb')
        self.array_fromfile(f,self.node_end_pos,len_node_end_pos,self.java_p,java_code='node_end_pos')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.5'),'rb')
        self.array_fromfile(f,self.sentence_start_pos,len_sentence_start_pos,self.java_p,java_code='sentence_start_pos')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.6'),'rb')
        self.array_fromfile(f,self.sentence_end_pos,len_sentence_end_pos,self.java_p,java_code='sentence_end_pos')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.7'),'rb')
        self.array_fromfile(f,self.edge_origin_nodeid,len_edge_origin_nodeid,self.java_p,java_code='edge_origin_nodeid')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.8'),'rb')
        self.array_fromfile(f,self.edge_destination_nodeid,len_edge_destination_nodeid,self.java_p,java_code='edge_destination_nodeid')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.9'),'rb')
        self.array_fromfile(f,self.edge_predid,len_edge_predid,self.java_p,java_code='edge_predid')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.10'),'rb')
        self.array_fromfile(f,self.edge_sentenceid,len_edge_sentenceid,self.java_p,java_code='edge_sentenceid')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.12'),'rb')
        self.array_fromfile(f,self.backedge_destination_nodeid,len_backedge_destination_nodeid,self.java_p,java_code='backedge_destination_nodeid')
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.13'),'rb')
        self.array_fromfile(f,self.backedge_edgeid,len_backedge_edgeid,self.java_p,java_code='backedge_edgeid')
        f.close()
    def make_fast_omcsnet(self):

        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.1'),'w')
        f.write(self.nodes_string)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.2'),'w')
        f.write(self.sentences_string)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.3'),'wb')
        self.node_start_pos.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.4'),'wb')
        self.node_end_pos.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.5'),'wb')
        self.sentence_start_pos.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.6'),'wb')
        self.sentence_end_pos.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.7'),'wb')
        self.edge_origin_nodeid.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.8'),'wb')
        self.edge_destination_nodeid.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.9'),'wb')
        self.edge_predid.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.10'),'wb')
        self.edge_sentenceid.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.11'),'w')
        f.write(str(len(self.node_start_pos))+'\n')
        f.write(str(len(self.node_end_pos))+'\n')
        f.write(str(len(self.sentence_start_pos))+'\n')
        f.write(str(len(self.sentence_end_pos))+'\n')
        f.write(str(len(self.edge_origin_nodeid))+'\n')
        f.write(str(len(self.edge_destination_nodeid))+'\n')
        f.write(str(len(self.edge_predid))+'\n')
        f.write(str(len(self.edge_sentenceid))+'\n')
        f.write(str(self.origin_node_count)+'\n')
        f.write(string.join(self.preds,' '))
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.12'),'wb')
        self.backedge_destination_nodeid.tofile(f)
        f.close()
        f = open(MontyUtils.MontyUtils().find_file(self.fast_omcsnet_filename+'.13'),'wb')
        self.backedge_edgeid.tofile(f)
        f.close()
        return
    def array_fromfile(self,file_ptr,array_ptr,length,java_p=0,java_code='',endian_order='little'):

        item_length = 4
        if java_p:
            if endian_order == 'big':
                format_character = '>'
            else:
                format_character = '<'
            arr = struct.unpack(format_character+str(length)+'L',file_ptr.read())
            if java_code == 'node_start_pos':
                self.node_start_pos = jarray.array(arr,'l')
            elif java_code == 'node_end_pos':
                self.node_end_pos = jarray.array(arr,'l')
            elif java_code == 'sentence_start_pos':
                self.sentence_start_pos = jarray.array(arr,'l')
            elif java_code == 'sentence_end_pos':
                self.sentence_end_pos = jarray.array(arr,'l')
            elif java_code == 'edge_origin_nodeid':
                self.edge_origin_nodeid = jarray.array(arr,'l')
            elif java_code == 'edge_destination_nodeid':
                self.edge_destination_nodeid = jarray.array(arr,'l')
            elif java_code == 'edge_predid':
                self.edge_predid = jarray.array(arr,'l')
            elif java_code == 'edge_sentenceid':
                self.edge_sentenceid = jarray.array(arr,'l')
            elif java_code == 'backedge_destination_nodeid':
                self.backedge_destination_nodeid = jarray.array(arr,'l')
            elif java_code == 'backedge_edgeid':
                self.backedge_edgeid = jarray.array(arr,'l')
            else:
                print "error! java code invalid!"
                sys.exit(-1)
        else:
            array_ptr.fromfile(file_ptr,length)
    def populate_efficient_repr(self):

        keys = self.semantic_net.keys()
        keys.sort()
        terminal_nodes = []
        count = 0
        edge_count = 0
        cur_len_nodestr = 0
        cur_len_sentencestr = 0
        for i in range(len(keys)):
            count += 1
            if count % 1000 == 0:
                print count
            self.node_start_pos.append(cur_len_nodestr)
            self.nodes_string += keys[i]
            cur_len_nodestr += len(keys[i])
            self.node_end_pos.append(cur_len_nodestr)
            edges = self.semantic_net.get(keys[i],[])
            for edge in edges:
                self.edge_origin_nodeid.append(i)
                if edge[1] not in keys:
                    terminal_nodes.append(edge[1])
                    self.edge_destination_nodeid.append(len(keys) + terminal_nodes.index(edge[1]))
                else:
                    self.edge_destination_nodeid.append(keys.index(edge[1]))
                if edge[0] not in self.preds:
                    self.preds.append(edge[0])
                self.edge_predid.append(self.preds.index(edge[0]))
                self.sentence_start_pos.append(cur_len_sentencestr)
                self.sentences_string += edge[2]
                cur_len_sentencestr += len(edge[2])
                self.sentence_end_pos.append(cur_len_sentencestr)
                self.edge_sentenceid.append(edge_count)
                edge_count += 1
        self.origin_node_count = count
        for i in range(len(terminal_nodes)):
            count += 1
            if count % 1000 == 0:
                print count
            self.node_start_pos.append(cur_len_nodestr)
            self.nodes_string += terminal_nodes[i]
            cur_len_nodestr += len(terminal_nodes[i])
            self.node_end_pos.append(cur_len_nodestr)
        self.populate_backedge_arrays()
        return
    def populate_backedge_arrays(self):

        backedge_destination_nodeid = self.edge_destination_nodeid.tolist()
        backedge_pairs = map(lambda i:[backedge_destination_nodeid[i],i],range(len(backedge_destination_nodeid)))
        backedge_pairs.sort(lambda x,y:int(x[0]-y[0]))
        self.backedge_destination_nodeid = array.array('L',map(lambda x:x[0],backedge_pairs))
        self.backedge_edgeid = array.array('L',map(lambda x:x[1],backedge_pairs))
        return
    def load_pp_omcsnet(self):

        f = open(MontyUtils.MontyUtils().find_file(self.filename1),'r')
        line = f.readline()
        node_set = []
        while line:
            if line[0:5] == '*****':
                self.parse_nodeset(node_set)
                node_set = []
            else:
                node_set.append(line)
            line = f.readline()
        return
    def parse_nodeset(self,node_set):

        originnode = string.strip(string.split(node_set[0],'NODE:')[-1])
        edges = []
        edge = []
        for line in node_set[1:]:
            if line[:5] == 'EDGE:':
                if edge != []:
                    if edge[3] != 'bw':
                        edges.append(edge)
                edge = []
            else:
                stripped = string.strip(string.split(line,': ')[-1])
                edge.append(stripped)
        if edge != []:
            edges.append(edge)
        edges = map(lambda edge:edge[:4]+[float(edge[4])],edges)
        self.semantic_net[originnode] = edges
if __name__ == '__main__':
    s = OMCSNetFast()