Applying A* path finding to latent word vectors

Path finding algorithms

There are many algorithms for finding the shortest path between two points, but few are better known than Dijkstra’s algorithm and the A* algorithm.

Dijkstra’s algorithm

Dijkstra’s algorithm is often described as being “shortest-path first”. This is a greedy approach that isn’t suitable for large graphs.


A* search algorithm

A*, Weighted A*. Source:

A* is a best-first search algorithm. In cases when it is available, it uses a heuristic to prioritize search. The resulting speed-up has made the A* search algorithm very popular.

Applying A* to word vectors

In my last post, I discussed how word2vec can be used to qualitatively fingerprint sets of items. Since word vectors can be thought of as spatial coordinates, then a shortest path of word vectors can be found between any two words.

After training on the IMDB Movie Reviews Dataset, this is the result:

animated -> carrying
Explored: 68, Frontier: 169, Cost: 3.356


know -> tasty
Explored: 11166, Frontier: 2618, Cost: 7.393


tasty -> boxing
Explored: 461, Frontier: 1305, Cost: 4.298


threaten -> restore
Explored: 4820, Frontier: 2992, Cost: 6.216




The following is a snippet from Word segmentation in Python and Building a recommendation system using product embeddings.

name = 'reviews'
vocab_size = 12000
epochs = 100

if not os.path.isfile('{0}.model'.format(name)):
    spm.SentencePieceTrainer.Train('--input={0}.txt --model_prefix={0} --vocab_size={1} --split_by_whitespace=True'.format(name,vocab_size))

sp = spm.SentencePieceProcessor()

if not os.path.isfile('./checkpoints/{0}-{1}.w2v'.format(name,epochs-1)):
    if not os.path.isfile('{0}_tokenized.tsv'.format(name)):
        with open('{0}_tokenized.tsv'.format(name),'w+') as f:
            for i,line in enumerate(open('{0}.txt'.format(name))):
                ids = sp.EncodeAsIds(line)
                if len(ids) > 10:
                    f.write('{0}\t{1}\n'.format(i,' '.join(str(x) for x in ids)))

    model = Doc2Vec(vector_size=300, window=8, min_count=3, negative=5, workers=16, epochs=1)

    documents = get_documents('{0}_tokenized.tsv'.format(name),name)
    model = train_model(model,documents,name,epochs)

model = Doc2Vec.load('./checkpoints/{0}-{1}.w2v'.format(name,epochs-1))

A* search algorithm

Adapted from WikiClassify, a project that would later become the Chrome extension WikiPopularity.

import heapq

class elem_t:
    def __init__(self,value,parent=None,cost=None):
        self.value = value
        self.cost = cost
        self.column_offset = 0
        self.parent = parent

class PriorityQueue:
    def __init__(self):
        self._queue = []
        self._index = 0

    def push(self,item):
        heapq.heappush(self._queue, (item.cost,self._index,item) )
        self._index += 1

    def pop(self):
        index,item = heapq.heappop(self._queue)[1:]
        return item

    def length(self):
        return len(self._queue)

def get_transition_cost(word1,word2,doc2vec_model):
    return 1.0-float(doc2vec_model.similarity(word1,word2))

def a_star_search(start_word, end_word, doc2vec_model, branching_factor=60, weight=4.):

    cost_list = {start_word:0}

    frontier = PriorityQueue()
    start_elem = elem_t(start_word,parent=None,cost=get_transition_cost(start_word,end_word,doc2vec_model))

    path_end = start_elem
    explored = []
    while True:

        if frontier.length() == 0:

        current_node = frontier.pop()
        current_word = current_node.value

        if current_word == end_word:
            path_end = current_node

        neighbors = [x[0] for x in doc2vec_model.most_similar(current_word,topn=branching_factor) if x!=current_word]
        if neighbors == None:

        base_cost = cost_list[current_word]

        for neighbor_word in neighbors:
            if current_word == neighbor_word:
            cost = base_cost + get_transition_cost(current_word,neighbor_word,doc2vec_model)
            new_elem = elem_t(neighbor_word,parent=current_node,cost=cost)
            new_elem.column_offset = neighbors.index(neighbor_word)
            if (neighbor_word not in cost_list or cost<cost_list[neighbor_word]) and neighbor_word not in explored:
                cost_list[neighbor_word] = cost
                new_elem.cost = cost + weight*get_transition_cost(neighbor_word,end_word,doc2vec_model)
        print("Explored: "+str(len(explored))+", Frontier: "+str(frontier.length())+", Cost: "+str(base_cost)[:5],end='\r')
    path = [path_end.value]
    cur = path_end
    while cur:
        cur = cur.parent
        if cur:
    return path[::-1]

About the author

Hi, I'm Nathan. I'm a software engineer in the Los Angeles area. Keep an eye out for more content being posted soon.

Leave a Reply

Your email address will not be published.