xxxxxxxxxx
184
// Machine Learning for Artists and Designers
// NYUSH F24 - gohai
let w2v;
let inputElem;
let buttonElem;
let word = "";
let neighbors = [];
function setup() {
createCanvas(600, 600);
w2v = new Word2Vec();
inputElem = createInput("");
buttonElem = createButton("Get");
buttonElem.mousePressed(getVector);
}
function draw() {
background(220);
textSize(18);
text(word, width/2, height/2);
for (let i=1; i < neighbors.length; i++) {
push();
translate(width/2, height/2);
rotate(radians(i * (360/(neighbors.length-1))));
let d = map(neighbors[i].dist, 0, neighbors[neighbors.length-1].dist, 0, width/2-50);
translate(d, 0);
text(neighbors[i].word, 0, 0);
pop();
}
}
function getVector() {
word = inputElem.value();
let vector = w2v.get(word);
neighbors = w2v.nearest(vector, 10);
console.log(neighbors);
}
// ---
// The code below implements a class Word2Vec, that
// is being used above to explore word2vec.
// The dataset consists of vectors for the the 25k
// most common English words, trained on a part of
// the Google News dataset (about 100 billion words).
// source https://github.com/turbomaze/word2vecjson
class Word2Vec {
constructor(dataset = wordVecs) {
this.dataset = dataset;
this.dims = dataset[Object.keys(this.dataset)[0]].length;
console.log(
"Using a dataset with " +
Object.keys(dataset).length +
" words and " +
this.dims +
" dimensions per vector"
);
}
words() {
return Object.keys(this.dataset);
}
get(word) {
word = word.trim().toLowerCase();
if (this.dataset[word]) {
return this.dataset[word];
} else {
console.warn("The word " + word + " is not in the dataset");
return new Array(this.dims).fill(0);
}
}
dist(arr1, arr2) {
arr1 = this.ensure_vector(arr1);
arr2 = this.ensure_vector(arr2);
return 1 - this.cosine_similarity(arr1, arr2);
// this returns a number between 0 and 2
}
nearest(arr, top = 10) {
arr = this.ensure_vector(arr);
const nearest = [];
for (let word in this.dataset) {
const dist = this.dist(arr, this.dataset[word]);
if (nearest.length < top || dist < nearest[top - 1].dist) {
nearest.push({ word: word, dist: dist });
nearest.sort((a, b) => a.dist - b.dist);
}
}
nearest.splice(top);
return nearest;
}
add(arr1, arr2) {
arr1 = this.ensure_vector(arr1);
arr2 = this.ensure_vector(arr2);
const result = new Array(arr1.length);
for (let i = 0; i < arr1.length; i++) {
result[i] = arr1[i] + arr2[i];
}
return result;
}
sub(arr1, arr2) {
arr1 = this.ensure_vector(arr1);
arr2 = this.ensure_vector(arr2);
const result = new Array(arr1.length);
for (let i = 0; i < arr1.length; i++) {
result[i] = arr1[i] - arr2[i];
}
return result;
}
avg(arr1, arr2) {
arr1 = this.ensure_vector(arr1);
arr2 = this.ensure_vector(arr2);
const result = new Array(arr1.length);
for (let i = 0; i < arr1.length; i++) {
result[i] = (arr1[i] + arr2[i]) / 2;
}
return result;
}
lerp(arr1, arr2, t) {
arr1 = this.ensure_vector(arr1);
arr2 = this.ensure_vector(arr2);
// clamp t to 0-1
t = Math.min(Math.max(t, 0), 1);
const result = new Array(arr1.length);
for (let i = 0; i < arr1.length; i++) {
result[i] = arr1[i] * (1 - t) + arr2[i] * t;
}
return result;
}
ensure_vector(param) {
if (typeof param == "string") {
return this.get(param);
} else if (Array.isArray(param) && param.length == this.dims) {
return param;
} else {
throw (
"The parameter needs to be an array of " +
this.dims +
" numbers (a vector), or a string (a word)."
);
}
}
cosine_similarity(arr1, arr2) {
// calculate dot product of the two arrays
const dotProduct = this.dot(arr1, arr2);
// calculate the magnitude of the first array
const magnitudeA = this.magnitude(arr1);
// calculate the magnitude of the second array
const magnitudeB = this.magnitude(arr2);
// calculate the cosine similarity
const cosineSimilarity = dotProduct / (magnitudeA * magnitudeB);
return cosineSimilarity;
// this returns a number between -1 and 1
}
dot(arr1, arr2) {
let result = 0;
for (let i = 0; i < arr1.length; ++i) {
result += arr1[i] * arr2[i];
}
return result;
}
magnitude(arr) {
return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0));
}
}