2023-04-30
Here, we play with the attention mechanism and word embeddings to get make a KV store that fuzzily searches through the key space, and might also return candidate values that are nonexistent in the store.
Suppose we have the following key/value pairs:
{
'one': 'number',
'cat': 'animal',
'dog': 'animal',
'ball': 'toy',
'sphere': 'shape',
'male': 'gender',
}
Here are some example queries and their results:
rhombus ['shape' 'shaped' 'fit' 'fits' 'resembles']
parabola ['shape' 'shaped' 'fit' 'fits' 'resembles']
female ['gender' 'ethnicity' 'orientation' 'racial' 'mainstreaming']
seven ['number' 'numbers' 'ten' 'only' 'other']
seventy ['gender' 'ethnicity' 'orientation' 'regardless' 'defining']
cylinder ['toy' 'instance' 'unlike' 'besides' 'newest']
So as we can see, the top result for seven is
number, because seven is similar to
one. but there are also other results that are not in the
dataset.
Then there is seventy, which should also have returned
number as at least one result, but it didn’t. Somehow our
store thinks seventy is closer to male than to
one.
# Init with a dict containing they key/value pairs.
db = Database({
'key1': 'val1',
'key2': 'val2,
})
# Query
results = db.query('foo')
assert type(results) == listConsider a dataset D containing key-value pairs \(D = \{(\textbf{k}_1, \textbf{v}_1), ..., (\textbf{k}_m, \textbf{v}_m)\}\).
Imagine a query mechanism that, given a query \(\textbf{q}\), produces a result given by:
\[ Attention(\textbf{q}, D) = \sum_{i=1}^{m}\alpha(\textbf{q}, \textbf{k}_i) \textbf{v}_i \]
Note:
Our DB interface uses variable length words, but the attention mechanism wants fixed-length vectors. While we could use a simple one-hot encoding of words, it won’t capture the semantic closeness of words.
Instead, we’ll use the GloVe embeddings, that allow us to represent variable length words with their fixed size embeddings.
Sketch:
(len(records), embed_size)q_embed of shape
(embed_size,)w = softmax(keys.mul(q_embed)) to get a
(len(records),) shaped vector of attention weights.sum(w .* values) to get a
weighted sum of values in a (n_embed,) vector.Runnable code in https://colab.research.google.com/drive/1ReE-WIEzAGNr1a-6TUHsQucwtgWFhuTF?usp=sharing