Self Attention from scratch
Implement Self Attention from scratch by using torch tensor, and the goal is to understand how MultiHead Attention work
import math
import numpy as np
import torch
import torch.nn as nn
copus_a = ["one is one", "two is two", "three is three", "four is four", "five is five",
"six is six", "seven is seven", "eight is eight", "nine is nine"]
copus_b = ["1 = 1", "2 = 2", "3 = 3", "4 = 4", "5 = 5",
"6 = 6", "7 = 7", "8 = 8", "9 = 9"]
embed_a = {"one": [1.0,0,0,0,0,0,0,0,0,0,0,0],
"two": [0,1.0,0,0,0,0,0,0,0,0,0,0],
"three":[0,0,1.0,0,0,0,0,0,0,0,0,0],
"four": [0,0,0,1.0,0,0,0,0,0,0,0,0],
"five": [0,0,0,0,1.0,0,0,0,0,0,0,0],
"six": [0,0,0,0,0,1.0,0,0,0,0,0,0],
"seven":[0,0,0,0,0,0,1.0,0,0,0,0,0],
"eight":[0,0,0,0,0,0,0,1.0,0,0,0,0],
"nine": [0,0,0,0,0,0,0,0,1.0,0,0,0],
"is": [0,0,0,0,0,0,0,0,0,1.0,0,0],
"less": [0,0,0,0,0,0,0,0,0,0,1.0,0],
"more": [0,0,0,0,0,0,0,0,0,0,0,1.0]
}
embed_b = {"9": [1.0,0,0,0,0,0,0,0,0,0,0,0],
"8": [0,1.0,0,0,0,0,0,0,0,0,0,0],
"7": [0,0,1.0,0,0,0,0,0,0,0,0,0],
"6": [0,0,0,1.0,0,0,0,0,0,0,0,0],
"5": [0,0,0,0,1.0,0,0,0,0,0,0,0],
"4": [0,0,0,0,0,1.0,0,0,0,0,0,0],
"3": [0,0,0,0,0,0,1.0,0,0,0,0,0],
"2": [0,0,0,0,0,0,0,1.0,0,0,0,0],
"1": [0,0,0,0,0,0,0,0,1.0,0,0,0],
"=": [0,0,0,0,0,0,0,0,0,1.0,0,0],
"<": [0,0,0,0,0,0,0,0,0,1.0,0,0],
">": [0,0,0,0,0,0,0,0,0,1.0,0,0],
}
def sentence_embed(sentence, embed_dict):
"""Generate an embedding for a sentence"""
res = []
for word in sentence.split():
res.append(embed_dict[word])
return res
inp = sentence_embed("one is one", embed_a)
out = sentence_embed("1 = 1", embed_b)
inp = torch.tensor(inp, dtype=torch.float32)
out = torch.tensor(out, dtype=torch.float32)
inp.shape, out.shape
def dot_attention(q, k, v):
"""inp: input sentence, dk: keyword dimension"""
# Initiate weight matrix for Query, Key and Value
dk = k.size(-1)
logit = (q @ k.transpose(0, -1)) / math.sqrt(dk)
weights = torch.softmax(logit, dim=-1)
res = weights @ v
return res
q, k, v = inp, inp, inp
dot_attention(q, k, v)
class MultiHeadAttention(nn.Module):
def __init__(self, dm, nh):
"""
dm: model dimenstion
nh: number of heads
"""
super().__init__()
self.dm, self.nh = dm, nh
self.dk = dm // nh
self.heads = [{"wq":nn.Linear(self.dm, self.dk),
"wk":nn.Linear(self.dm, self.dk),
"wv":nn.Linear(self.dm, self.dk)} for h in range(nh)
]
self.out = nn.Linear(dm, dm)
def forward(self, inp):
res = []
for head in self.heads:
q, k, v = head["wq"](inp), head["wk"](inp), head["wv"](inp)
print(q.shape, k.shape, v.shape)
res.append(dot_attention(q, k, v))
concat = torch.cat(res, 1)
res = self.out(concat)
print(concat.shape, res.shape)
return res
dm = 12
nh = 3
# dk = 12/3 = 4
mul_head = MultiHeadAttention(dm, nh)
mul_head(inp)