2023-04-30
Here, we play with the attention mechanism and word embeddings to get make a KV store that fuzzily searches through the key space, and might also return candidate values that are nonexistent in the store.
Suppose we have the following key/value pairs:
{
'one': 'number',
'cat': 'animal',
'dog': 'animal',
'ball': 'toy',
'sphere': 'shape',
'male': 'gender',
}
Here are some example queries and their results:
rhombus ['shape' 'shaped' 'fit' 'fits' 'resembles']
parabola ['shape' 'shaped' 'fit' 'fits' 'resembles']
female ['gender' 'ethnicity' 'orientation' 'racial' 'mainstreaming']
seven ['number' 'numbers' 'ten' 'only' 'other']
seventy ['gender' 'ethnicity' 'orientation' 'regardless' 'defining']
cylinder ['toy' 'instance' 'unlike' 'besides' 'newest']
So as we can see, the top result for seven
is
number
, because seven
is similar to
one
. but there are also other results that are not in the
dataset.
Then there is seventy
, which should also have returned
number
as at least one result, but it didn’t. Somehow our
store thinks seventy
is closer to male
than to
one
.
# Init with a dict containing they key/value pairs.
= Database({
db 'key1': 'val1',
'key2': 'val2,
})
# Query
= db.query('foo')
results assert type(results) == list
Consider a dataset D containing key-value pairs \(D = \{(\textbf{k}_1, \textbf{v}_1), ..., (\textbf{k}_m, \textbf{v}_m)\}\).
Imagine a query mechanism that, given a query \(\textbf{q}\), produces a result given by:
\[ Attention(\textbf{q}, D) = \sum_{i=1}^{m}\alpha(\textbf{q}, \textbf{k}_i) \textbf{v}_i \]
Note:
Our DB interface uses variable length words, but the attention mechanism wants fixed-length vectors. While we could use a simple one-hot encoding of words, it won’t capture the semantic closeness of words.
Instead, we’ll use the GloVe embeddings, that allow us to represent variable length words with their fixed size embeddings.
Sketch:
(len(records), embed_size)
q_embed
of shape
(embed_size,)
w = softmax(keys.mul(q_embed))
to get a
(len(records),)
shaped vector of attention weights.sum(w .* values)
to get a
weighted sum of values in a (n_embed,)
vector.Runnable code in https://colab.research.google.com/drive/1ReE-WIEzAGNr1a-6TUHsQucwtgWFhuTF?usp=sharing