App下載

pytorch怎么實(shí)現(xiàn)beam search?

猿友 2021-08-06 14:45:30 瀏覽數(shù) (3016)
反饋

在機(jī)器學(xué)習(xí)中經(jīng)常用到一種搜索算法——束搜索算法,又叫beam search 算法,他是貪心算法的一種優(yōu)化實(shí)現(xiàn)。在機(jī)器學(xué)習(xí)中我們需要自行實(shí)現(xiàn)這種算法,接下來(lái)這篇文章主要記錄兩種不同的beam search版本,小伙伴可以進(jìn)行對(duì)比和學(xué)習(xí)。

版本一

使用類(lèi)似層次遍歷的方式進(jìn)行搜索,用隊(duì)列進(jìn)行維護(hù),每次循環(huán)對(duì)當(dāng)前層的所有節(jié)點(diǎn)進(jìn)行搜索,這些節(jié)點(diǎn)每個(gè)分別對(duì)應(yīng)topk個(gè)節(jié)點(diǎn)作為下一層候選節(jié)點(diǎn),取所有候選節(jié)點(diǎn)的前tok個(gè)作為下一層節(jié)點(diǎn)加入隊(duì)列

bfs with width constraint. 啟發(fā)式搜索的一種. 屬于貪心算法. 如果k -> inf,那么等價(jià)于bfs.

從根節(jié)點(diǎn)開(kāi)始(),選取所有可能(大概幾萬(wàn)個(gè))里面概率最大的k個(gè),拓展為下一層節(jié)點(diǎn).

然后在這k個(gè)節(jié)點(diǎn)里面,其可能拓展的所有節(jié)點(diǎn)中(一般是k * 幾萬(wàn)個(gè)),再選取概率最大的k個(gè)(注意這里的概率是累乘,即從根節(jié)點(diǎn)到該節(jié)點(diǎn)的概率乘積)拓展. 這里拓展的k個(gè)子節(jié)點(diǎn),其父節(jié)點(diǎn)可以是上一層的k個(gè),也可以只是其中一部分,甚至全部出自其中一個(gè)節(jié)點(diǎn). 以此類(lèi)推.

這樣形成的是一棵每層都是k個(gè)節(jié)點(diǎn)樹(shù)(除了根節(jié)點(diǎn)、末尾,和候選者不足k個(gè)的情況).

一般概率取log,避免值過(guò)小.

舉個(gè)例子:k=2

<sos> 選取概率最大的三個(gè), “i”: 0.6, “he”: 0.4. 其他單詞忽略不計(jì)

拓展一共有4個(gè) (1)“i"后面接,假設(shè)概率最大的是"love”: 0.7, “l(fā)ike”: 0.3 其他單詞忽略不計(jì)(2)“he"后面接:假設(shè)概率最大的是"hates”: 0.9, “l(fā)oves”: 0.1 其他單詞忽略不計(jì); 這樣4種可能中,到這里 "i love"概率是0.6 * 0.7 = 0.42, "i like"概率是0.6 * 0.3 = 0.18, "he hates"概率是0.4 * 0.9 = 0.36, "he loves"概率是0.4 * 0.1 = 0.04; 選取概率最大的兩個(gè),“i love"和"he hates”.

下一層拓展仍為4個(gè) (1) "i love"后面接 ,假設(shè)概率最大是 “you”:0.9, 其他單詞加起來(lái)0.1;(2)“he hates"后面接,假設(shè)概率最大的是"her”:0.8, “himself”:0.1, 其他單詞加起來(lái)0.1; 那么"i love you"概率為 0.42 * 0.9 = 0.378; "he hates her"概率為0.36*0.8 = 0.228,其他不用算了都小于這個(gè)值. 最后也選取2個(gè)概率最大的: "i love you"和 “he hates her”

下一層拓展, “i love you"應(yīng)該拓展兩個(gè)子節(jié)點(diǎn),發(fā)現(xiàn)”"概率0.99,其他單詞加起來(lái)0.01;“he hates her"應(yīng)該拓展兩個(gè)子節(jié)點(diǎn),發(fā)現(xiàn)”"概率0.99,其他單詞加起來(lái)0.01;所以概率最大的是"i love you "和"he hates you ". 因兩個(gè)分支均遇到,均結(jié)束搜索.

最后在兩個(gè)當(dāng)中選擇概率最大的 "i love you ". 結(jié)束

代碼是從一個(gè)項(xiàng)目中截取的,只選取了關(guān)鍵內(nèi)容,pytorch實(shí)現(xiàn):

class Node(object):
    def __init__(self, hidden, previous_node, decoder_input, attn, log_prob, length):
        self.hidden = hidden
        self.previous_node = previous_node
        self.decoder_input = decoder_input
        self.attn = attn
        self.log_prob = log_prob
        self.length = length        
def beam_search(beam_width):
    ...
    root = Node(hidden, None, decoder_input, None, 0, 1)
    q = Queue()
    q.put(root)
    
    end_nodes = [] #最終節(jié)點(diǎn)的位置,用于回溯
    while not q.empty():
        candidates = []  #每一層的可能被拓展的節(jié)點(diǎn),只需選取每個(gè)父節(jié)點(diǎn)的兒子節(jié)點(diǎn)中概率最大的k個(gè)即可
    
        for _ in range(q.qsize()):
            node = q.get()
            decoder_input = node.decoder_input
            hidden = node.hidden
            
            # 搜索終止條件
            if decoder_input.item() == EOS or node.length >= 50:
                end_nodes.append(node)
                continue
              
            log_prob, hidden, attn = decoder(
                 decoder_input, hidden, encoder_input
             )
             
             log_prob, indices = log_prob.topk(beam_width) #選取某個(gè)父節(jié)點(diǎn)的兒子節(jié)點(diǎn)概率最大的k個(gè)
             
             for k in range(beam_width):
                  index = indices[k].unsqueeze(0)
                  log_p = log_prob[k].item()
                  child = Node(hidden, node, index, attn, node.log_prob + log_p, node.length + 1)
                  candidates.append((node.log_prob + log_p, child))  #建立候選兒子節(jié)點(diǎn),注意這里概率需要累計(jì)
           
         candidates = sorted(candidates, key=lambda x:x[0], reverse=True) #候選節(jié)點(diǎn)排序
         length = min(len(candidates), beam_width)  #取前k個(gè),如果不足k個(gè),則全部入選
         for i in range(length):
             q.put(candidates[i][1])  
    # 后面是回溯, 省略
    ...

版本二

不進(jìn)行層次遍歷,而是每次從整個(gè)隊(duì)列中拿出概率最大的節(jié)點(diǎn)出隊(duì)(優(yōu)先隊(duì)列)進(jìn)行搜索,將該節(jié)點(diǎn)的topk加入優(yōu)先隊(duì)列,循環(huán)終止的條件是節(jié)點(diǎn)所在位置對(duì)應(yīng)長(zhǎng)度達(dá)到限制或隊(duì)列節(jié)點(diǎn)個(gè)數(shù)超過(guò)限制

import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 50
class DecoderRNN(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1):
        '''
        Illustrative decoder
        '''
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        self.embedding = nn.Embedding(num_embeddings=output_size,
                                      embedding_dim=embedding_size,
                                      )
        self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True, dropout=dropout, batch_first=False)
        self.dropout_rate = dropout
        self.out = nn.Linear(hidden_size, output_size)
    def forward(self, input, hidden, not_used):
        embedded = self.embedding(input).transpose(0, 1)  # [B,1] -> [ 1, B, D]
        embedded = F.dropout(embedded, self.dropout_rate)
        output = embedded
        # batch_first=False, output維度為 (seq_len, batch_size, num_directions * hidden_size) = [1, batch_size, 2*hidden_size]
        output, hidden = self.rnn(output, hidden)
        out = self.out(output.squeeze(0))
        # output維度為 [batch_size, vocab_size]
        # hidden維度為 [num_layers * num_directions, batch_size, hidden_size]
        output = F.log_softmax(out, dim=1)
        return output, hidden
class BeamSearchNode(object):
    def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
        '''
        :param hiddenstate:
        :param previousNode:
        :param wordId:
        :param logProb:
        :param length:
        '''
        self.h = hiddenstate
        self.prevNode = previousNode
        self.wordid = wordId
        self.logp = logProb
        self.leng = length
    def eval(self, alpha=1.0):
        reward = 0
        # Add here a function for shaping a reward
        return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
decoder = DecoderRNN()
def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''
    beam_width = 10
    topk = 1  # how many sentence do you want to generate
    decoded_batch = []
    # decoding goes sentence by sentence
    for idx in range(target_tensor.size(0)):
        if isinstance(decoder_hiddens, tuple):  # LSTM case
            decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
        else:
            decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
        encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)
        # Start with the start of the sentence token
        decoder_input = torch.LongTensor([[SOS_token]], device=device)
        # Number of sentence to generate
        endnodes = []
        number_required = min((topk + 1), topk - len(endnodes))
        # starting node -  hidden vector, previous node, word id, logp, length
        node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
        nodes = PriorityQueue()
        # start the queue
        nodes.put((-node.eval(), node))
        qsize = 1
        # start beam search
        while True:
            # give up when decoding takes too long
            if qsize > 2000: break
            # fetch the best node
            score, n = nodes.get()
            decoder_input = n.wordid
            decoder_hidden = n.h
            if n.wordid.item() == EOS_token and n.prevNode != None:
                endnodes.append((score, n))
                # if we reached maximum # of sentences required
                if len(endnodes) >= number_required:
                    break
                else:
                    continue
            # output維度為 [batch_size, vocab_size]
            # hidden維度為 [num_layers * num_directions, batch_size, hidden_size]
            # decode for one step using decoder
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
            # PUT HERE REAL BEAM SEARCH OF TOP
            # log_prov, indexes維度為 [batch_size, beam_width] = [1, beam_width]
            log_prob, indexes = torch.topk(decoder_output, beam_width, dim=1)
            nextnodes = []
            for new_k in range(beam_width):
                # decoded_t: [1,1],通過(guò)view(1,-1)將數(shù)字tensor變?yōu)榫S度為[1,1]的tensor
                decoded_t = indexes[0][new_k].view(1, -1)
                # log_p, int
                log_p = log_prob[0][new_k].item() # item()將tensor數(shù)字變?yōu)閕nt
                node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
                score = -node.eval()
                nextnodes.append((score, node))
            # put them into queue
            for i in range(len(nextnodes)):
                score, nn = nextnodes[i]
                nodes.put((score, nn))
                # increase qsize
            qsize += len(nextnodes) - 1
        # choose nbest paths, back trace them
        if len(endnodes) == 0:
            endnodes = [nodes.get() for _ in range(topk)]
        utterances = []
        for score, n in sorted(endnodes, key=operator.itemgetter(0)):
            utterance = []
            utterance.append(n.wordid)
            # back trace
            while n.prevNode != None:
                n = n.prevNode
                utterance.append(n.wordid)
            utterance = utterance[::-1]
            utterances.append(utterance)
        decoded_batch.append(utterances)
    return decoded_batch
def greedy_decode(decoder_hidden, encoder_outputs, target_tensor):
    '''
    :param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
    :param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
    :param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
    :return: decoded_batch
    '''
    batch_size, seq_len = target_tensor.size()
    decoded_batch = torch.zeros((batch_size, MAX_LENGTH))
    decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=device)
    for t in range(MAX_LENGTH):
        decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
        topv, topi = decoder_output.data.topk(1)  # get candidates
        topi = topi.view(-1)
        decoded_batch[:, t] = topi
        decoder_input = topi.detach().view(-1, 1)
    return decoded_batch

補(bǔ)充:beam search 簡(jiǎn)單例子實(shí)現(xiàn)及講解

看代碼吧~

from math import log
from numpy import array
from numpy import argmax
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    # walk over each step in sequence
    for row in data:
        all_candidates = list()
        # expand each current candidate
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -log(row[j])]
                all_candidates.append(candidate)
        # order all candidates by score
        ordered = sorted(all_candidates, key=lambda tup :tup[1])
        # select k best
        sequences = ordered[:k]
    return sequences
def greedy_decoder(data):
    # index for largest probability each row
    return [argmax(s) for s in data]
# define a sequence of 10 words over a vocab of 5 words
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = array(data)
# decode sequence
result = beam_search_decoder(data, 3)
# print result
for seq in result:
    print(seq)

每次循環(huán)sequences的值

[[[4], 0.6931471805599453], [[3], 0.916290731874155], [[2], 1.2039728043259361]]

[[[4, 0], 0.4804530139182014], [[4, 1], 0.6351243373717793], [[3, 0], 0.6351243373717793]]

[[[4, 0, 4], 0.33302465198892944], [[4, 0, 3], 0.4402346437542523], [[4, 1, 4], 0.4402346437542523]]

最終print的結(jié)果

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108]

[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397]

[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]

以上就是束搜索算法的全部介紹,希望能給大家一個(gè)參考,也希望大家多多支持W3Cschool。



0 人點(diǎn)贊