xxxxxxxxxx
173
class SwNetWide4 {
// vecLen must be 4,8,16,32.....
constructor(vecLen, depth) {
this.vecLen = vecLen;
this.depth = depth;
this.scale = 1 / sqrt(4 * vecLen);
this.work = new Float32Array(4 * vecLen);
this.flips = new Float32Array(4 * vecLen);
for (let i = 0; i < 4 * vecLen; i++) {
this.flips[i] = random(-1, 1) < 0 ? -this.scale : this.scale;
}
this.params = new Float32Array(8 * vecLen * depth);
for (let i = 0; i < this.params.length; i++) {
this.params[i] = this.scale;
}
}
recall(result, inVec) {
const n = 4 * this.vecLen;
const m = this.vecLen - 1;
for (let i = 0; i < n; i++) {
this.work[i] = inVec[i & m] * this.flips[i];
}
wht(this.work);
let paramIdx = 0; // parameter index
for (let i = 0; i < this.depth; i++) {
for (let j = 0; j < n; j++) {
const signBit = this.work[j] < 0 ? 0 : 1;
this.work[j] *= this.params[paramIdx + signBit];
paramIdx += 2;
}
wht(this.work);
}
for (let i = 0; i <= m; i++) {
result[i] = this.work[i];
}
}
}
// Fast Walsh Hadamard Transform
function wht(vec) {
const n = vec.length;
let hs = 1;
while (hs < n) {
let i = 0;
while (i < n) {
const j = i + hs;
while (i < j) {
var a = vec[i];
var b = vec[i + hs];
vec[i] = a + b;
vec[i + hs] = a - b;
i += 1;
}
i += hs;
}
hs += hs;
}
}
// Sum of squared difference cost
function costL2(vec, tar) {
var cost = 0;
for (var i = 0; i < vec.length; i++) {
var e = vec[i] - tar[i];
cost += e * e;
}
return cost;
}
class Mutator {
constructor(size, precision, limit) {
this.previous = new Float32Array(size);
this.pIdx = new Int32Array(size);
this.precision = precision;
this.limit = limit;
}
mutate(vec) {
for (let i = 0; i < this.previous.length; i++) {
let rpos = int(random(vec.length));
let v = vec[rpos];
this.pIdx[i] = rpos;
this.previous[i] = v;
let m = 2 * this.limit * exp(random(-this.precision, 0));
if (random() < 0.5) m = -m;
let vm = v + m;
if (vm >= this.limit) vm = v;
if (vm <= -this.limit) vm = v;
vec[rpos] = vm;
}
}
undo(vec) {
for (let i = this.previous.length - 1; i >= 0; i--) {
vec[this.pIdx[i]] = this.previous[i];
}
}
}
// Test with Lissajous curves
let c1;
let c2;
let ex = [];
let work = new Float32Array(256);
let parentCost = Number.POSITIVE_INFINITY;
let parentNet;
let mut;
function setup() {
createCanvas(400, 400);
parentNet = new SwNetWide4(256, 2);
mut = new Mutator(8, 35, 2 * parentNet.scale);
c1 = color("gold");
for (let i = 0; i < 8; i++) {
ex[i] = new Float32Array(256);
}
for (let i = 0; i < 127; i++) {
// Training data
let t = (i * 2 * PI) / 127;
ex[0][2 * i] = sin(t);
ex[0][2 * i + 1] = sin(2 * t);
ex[1][2 * i] = sin(2 * t);
ex[1][2 * i + 1] = sin(t);
ex[2][2 * i] = sin(2 * t);
ex[2][2 * i + 1] = sin(3 * t);
ex[3][2 * i] = sin(3 * t);
ex[3][2 * i + 1] = sin(2 * t);
ex[4][2 * i] = sin(3 * t);
ex[4][2 * i + 1] = sin(4 * t);
ex[5][2 * i] = sin(4 * t);
ex[5][2 * i + 1] = sin(3 * t);
ex[6][2 * i] = sin(2 * t);
ex[6][2 * i + 1] = sin(5 * t);
ex[7][2 * i] = sin(5 * t);
ex[7][2 * i + 1] = sin(2 * t);
}
textSize(16);
}
function draw() {
background(0);
loadPixels();
for (let i = 0; i < 100; i++) {
mut.mutate(parentNet.params);
let cost = 0;
for (let j = 0; j < 8; j++) {
parentNet.recall(work, ex[j]);
cost += costL2(work, ex[j]);
}
if (cost < parentCost) {
parentCost = cost;
} else {
mut.undo(parentNet.params);
}
}
fill(c1);
for (let i = 0; i < 8; i++) {
for (let j = 0; j < 255; j += 2) {
set(25 + i * 40 + 18 * ex[i][j], 44 + 18 * ex[i][j + 1], c1);
}
}
for (let i = 0; i < 8; i++) {
parentNet.recall(work, ex[i]);
for (let j = 0; j < 255; j += 2) {
set(25 + i * 40 + 18 * work[j], 104 + 18 * work[j + 1], c2);
}
}
updatePixels();
text("Training Data", 5, 20);
text("Recall", 5, 80);
text("Iterations: " + frameCount, 5, 150);
text("Cost: " + parentCost.toFixed(3), 5, 170);
}