import math
import numpy as np
import torch
import torch.nn as nn

Generate synthesis data

copus_a = ["one is one", "two is two", "three is three", "four is four", "five is five",
           "six is six", "seven is seven", "eight is eight", "nine is nine"]
copus_b = ["1 = 1", "2 = 2", "3 = 3", "4 = 4", "5 = 5",
           "6 = 6", "7 = 7", "8 = 8", "9 = 9"]
embed_a = {"one":  [1.0,0,0,0,0,0,0,0,0,0,0,0],
           "two":  [0,1.0,0,0,0,0,0,0,0,0,0,0],
           "three":[0,0,1.0,0,0,0,0,0,0,0,0,0],
           "four": [0,0,0,1.0,0,0,0,0,0,0,0,0],
           "five": [0,0,0,0,1.0,0,0,0,0,0,0,0],
           "six":  [0,0,0,0,0,1.0,0,0,0,0,0,0],
           "seven":[0,0,0,0,0,0,1.0,0,0,0,0,0],
           "eight":[0,0,0,0,0,0,0,1.0,0,0,0,0],
           "nine": [0,0,0,0,0,0,0,0,1.0,0,0,0],
           "is":   [0,0,0,0,0,0,0,0,0,1.0,0,0],
           "less": [0,0,0,0,0,0,0,0,0,0,1.0,0],
           "more": [0,0,0,0,0,0,0,0,0,0,0,1.0]
          }

embed_b = {"9": [1.0,0,0,0,0,0,0,0,0,0,0,0],
           "8": [0,1.0,0,0,0,0,0,0,0,0,0,0],
           "7": [0,0,1.0,0,0,0,0,0,0,0,0,0],
           "6": [0,0,0,1.0,0,0,0,0,0,0,0,0],
           "5": [0,0,0,0,1.0,0,0,0,0,0,0,0],
           "4": [0,0,0,0,0,1.0,0,0,0,0,0,0],
           "3": [0,0,0,0,0,0,1.0,0,0,0,0,0],
           "2": [0,0,0,0,0,0,0,1.0,0,0,0,0],
           "1": [0,0,0,0,0,0,0,0,1.0,0,0,0],
           "=": [0,0,0,0,0,0,0,0,0,1.0,0,0],
           "<": [0,0,0,0,0,0,0,0,0,1.0,0,0],
           ">": [0,0,0,0,0,0,0,0,0,1.0,0,0],
          }
def sentence_embed(sentence, embed_dict):
    """Generate an embedding for a sentence"""
    res = []
    for word in sentence.split():
        res.append(embed_dict[word])
    return res  
inp = sentence_embed("one is one", embed_a) 
out = sentence_embed("1 = 1", embed_b)
inp = torch.tensor(inp, dtype=torch.float32)
out = torch.tensor(out, dtype=torch.float32)
inp.shape, out.shape
(torch.Size([3, 12]), torch.Size([3, 12]))

Scaled dot product attention

def dot_attention(q, k, v):
    """inp: input sentence, dk: keyword dimension"""
    # Initiate weight matrix for Query, Key and Value
    dk = k.size(-1)
    logit = (q @ k.transpose(0, -1)) / math.sqrt(dk)
    weights = torch.softmax(logit, dim=-1)
    res = weights @ v
    return res
q, k, v = inp, inp, inp
dot_attention(q, k, v)
tensor([[0.7275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2725, 0.0000, 0.0000],
        [0.5998, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.4002, 0.0000, 0.0000],
        [0.7275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2725, 0.0000, 0.0000]])

Multi-head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, dm, nh):
        """
        dm: model dimenstion
        nh: number of heads
        """
        super().__init__()
        self.dm, self.nh = dm, nh
        self.dk = dm // nh
        self.heads = [{"wq":nn.Linear(self.dm, self.dk),
                      "wk":nn.Linear(self.dm, self.dk),
                      "wv":nn.Linear(self.dm, self.dk)} for h in range(nh)
                     ]        
        self.out = nn.Linear(dm, dm)
        
    def forward(self, inp):
        res = []
        for head in self.heads:
            q, k, v = head["wq"](inp), head["wk"](inp), head["wv"](inp)
            print(q.shape, k.shape, v.shape)
            res.append(dot_attention(q, k, v))
        concat = torch.cat(res, 1)
        res = self.out(concat)
        print(concat.shape, res.shape)
        return res        
dm = 12
nh = 3
# dk = 12/3 = 4
mul_head = MultiHeadAttention(dm, nh)
mul_head(inp)
torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])
torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])
torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])
torch.Size([3, 12]) torch.Size([3, 12])
tensor([[-1.4953e-01,  6.1958e-02, -9.2505e-02,  1.4574e-01,  1.0211e-01,
         -1.9842e-03,  8.9212e-02,  9.2313e-02, -2.3563e-01, -5.9226e-02,
         -2.6632e-01, -1.9141e-01],
        [-1.5497e-01,  6.4145e-02, -9.3638e-02,  1.4637e-01,  1.0329e-01,
         -6.2585e-07,  9.0302e-02,  9.7207e-02, -2.3449e-01, -5.7356e-02,
         -2.6794e-01, -1.9412e-01],
        [-1.4953e-01,  6.1958e-02, -9.2505e-02,  1.4574e-01,  1.0211e-01,
         -1.9842e-03,  8.9212e-02,  9.2313e-02, -2.3563e-01, -5.9226e-02,
         -2.6632e-01, -1.9141e-01]], grad_fn=<AddmmBackward>)