from __future__ import nested_scopes
import re,string,os,sys,shelve,copy,math
from types import *
import MontyTagger
import OMCSNetFast
class OMCSNetTools:


    theMontyTagger = None
    theOMCSNetFast = None
    def __init__(self,theOMCSNetFastHandle):

        self.theMontyTagger = MontyTagger.MontyTagger()
        self.theOMCSNetFast=theOMCSNetFastHandle
        return
    def tokenize_text(self,text):

        return self.theMontyTagger.tokenize(text)
    def tag_text(self,text,tokenized_p=0):

        if not tokenized_p:
            text = self.tokenize_text()
        return self.theMontyTagger.tag_tokenized(text)
    def find_analogous_nodes(self,nodename):

        candidates_fw = {}
        edges1 = self.theOMCSNetFast.get_edges_by_origin_nodename(nodename,sentence_p=0)
        for edge1 in edges1:
            common_pred = edge1[0]
            common_node = edge1[2]
            edges2 = self.theOMCSNetFast.get_edges_by_destination_nodename(common_node,sentence_p=0)
            for edge2 in filter(lambda x:x[1]!=nodename and x[0]==common_pred,edges2):
                candidate_node = edge2[1]
                candidates_fw[candidate_node] = candidates_fw.get(candidate_node,[]) + [[common_pred,common_node]]
        for key in candidates_fw.keys():
            if len(candidates_fw[key]) < 2:
                del candidates_fw[key]
        candidates_fw_list = map(lambda key:[key,candidates_fw[key],0],candidates_fw.keys())
        for i in range(len(candidates_fw_list)):
            relation_types = []
            for j in range(len(candidates_fw_list[i][1])):
                entry = candidates_fw_list[i][1][j]
                entry_pred = entry[0]
                entry_pred_numberless = entry_pred
                if entry_pred[-1] in map(str,range(10)):
                    entry_pred_numberless = entry_pred[:-1]
                if entry_pred_numberless not in relation_types and entry_pred_numberless != 'OftenNear':
                    relation_types.append(entry_pred_numberless)
                if 'OftenNear' ==entry_pred_numberless:
                    candidates_fw_list[i][2] += 1
                elif 'LocationOf' == entry_pred_numberless:
                    candidates_fw_list[i][2] += 5
                elif entry_pred_numberless in ['UsedFor','PropertyOf']:
                    candidates_fw_list[i][2] += 12
                else:
                    candidates_fw_list[i][2] += 8
            candidates_fw_list[i][2] += 25*len(relation_types)
        candidates_fw_list.sort(lambda x,y:y[2]-x[2])
        candidates_fw_list = map(lambda x:x[:2]+[min(0.9,x[2]/200.0)],candidates_fw_list)
        return candidates_fw_list
    def find_paths_from_a_to_b(self,node1name,node2name,sentences_p=0,max_node_visits=10000,max_number_of_results = 200,restrict_predicates=''):

        restrict_predicates_p = 0
        allowable_predids = []
        if restrict_predicates != '':
            restrict_predicates_p = 1
            preds = restrict_predicates.split()
            allowable_predids = map(lambda x:self.theOMCSNetFast.preds.index(x),preds)
        origin_nodeid = self.theOMCSNetFast.get_nodeid_by_nodename(node1name,-1)
        destination_nodeid = self.theOMCSNetFast.get_nodeid_by_nodename(node2name,-1)
        if origin_nodeid < 0 or destination_nodeid < 0:
            return []
        paths = [[(0,origin_nodeid)]]
        output_paths = []
        i = 0
        nodes_seen = 1
        while i < len(paths) and nodes_seen < max_node_visits:
            cur_origin_nodeid = paths[i][-1][1]
            next_edgeids = self.theOMCSNetFast.get_edgeids_by_origin_nodeid(cur_origin_nodeid)
            if restrict_predicates_p:
                next_edgeids = filter(lambda x:self.theOMCSNetFast.edge_predid[x] in allowable_predids,next_edgeids)
            next_origin_nodeids = map(lambda x:self.theOMCSNetFast.edge_destination_nodeid[x],next_edgeids)
            for j in range(len(next_origin_nodeids)):
                if next_origin_nodeids[j] == destination_nodeid:
                    output_paths.append(paths[i]+[(next_edgeids[j],next_origin_nodeids[j])])
                else:
                    paths.append(paths[i]+[(next_edgeids[j],next_origin_nodeids[j])])
            nodes_seen += len(next_edgeids)
            i += 1
        output_paths = output_paths[:max_number_of_results]
        for i in range(len(output_paths)):
            cur_path = output_paths[i]
            if sentences_p:
                cur_path = map(lambda x:[self.theOMCSNetFast.get_pred_by_edgeid(x[0]),self.theOMCSNetFast.get_nodename_by_nodeid(x[1]),self.theOMCSNetFast.get_sentence_by_edgeid(x[0])],cur_path)
                cur_path[0][0] = 'start'
                cur_path[0][2] = ''
            else:
                cur_path = map(lambda x:[self.theOMCSNetFast.get_pred_by_edgeid(x[0]),self.theOMCSNetFast.get_nodename_by_nodeid(x[1])],cur_path)
                cur_path[0][0] = 'start'
            output_paths[i] = cur_path
        print output_paths
        return output_paths
    def get_context(self,nodelist,max_node_visits=1000,max_number_of_results=200):

        origin_nodeids = []
        queue = []
        for nodename in nodelist:
            origin_nodeid = self.theOMCSNetFast.get_nodeid_by_nodename(nodename,-1)
            if origin_nodeid < 0:
                return []
            origin_nodeids.append(origin_nodeid)
            queue.append([origin_nodeid,1.0])
        discount_factor = 0.5
        nodes_seen = 1
        i = 0
        while i<len(queue):
            if nodes_seen > max_node_visits:
                break
            cur_node,cur_score = queue[i]
            edgeids = self.theOMCSNetFast.get_edgeids_by_origin_nodeid(queue[i][0])
            next_score=cur_score*discount_factor
            next_nodes = map(lambda x:[self.theOMCSNetFast.edge_destination_nodeid[x],next_score],edgeids)
            queue += next_nodes
            nodes_seen += len(edgeids)
            i += 1
        node_dict = {}
        for item in queue:
            node,score = item
            cur_score = node_dict.get(node,0.0)
            new_score = max(score,cur_score)+(1.0-max(score,cur_score))*min(score,cur_score)
            node_dict[node] = new_score
        for origin_nodeid in origin_nodeids:
            del node_dict[origin_nodeid]
        output = map(lambda x:[x,node_dict.get(x,0.0)],node_dict.keys())
        output.sort(lambda x,y:int((y[1]-x[1])*100))
        output = output[:max_number_of_results]
        output = map(lambda x:[self.theOMCSNetFast.get_nodename_by_nodeid(x[0]),x[1]],output)
        return output
if __name__ == '__main__':
    import OMCSNetFast
    o = OMCSNetFast.OMCSNetFast()
    t = OMCSNetTools(o)
    print t.find_paths_from_a_to_b('professor','student',sentences_p=1,max_node_visits=1000,max_number_of_results = 200)