xxxxxxxxxx
101
/**
* - Mouse click to cycle between two digits
* from the MNIST dataset
* - Key press to pick a random filter (kernel)
*/
let imgs = [];
cur = 0;
let filter = [
[ -1, 1, 0 ],
[ -1, 1, 0 ],
[ -1, 1, 0 ]
];
function preload() {
imgs.push(loadImage('mnist1.png'));
imgs.push(loadImage('mnist2.png'));
}
function setup() {
createCanvas(400, 800);
}
function draw() {
background(255);
image(imgs[cur], 0, 0, 48, 48);
let convolved = convolve(imgs[cur]);
image(convolved, 96, 0, 48, 48);
text('conv2d', 96, 100);
let pooled = maxPool(convolved, 2);
image(pooled, 192, 0, 48, 48);
text('pooled', 192, 100);
text('...', 288, 100);
}
function mousePressed() {
cur = (cur + 1) % imgs.length;
}
function keyPressed() {
for (let y=0; y < 3; y++) {
for (let x=0; x < 3; x++) {
filter[y][x] = random(-1, 1);
}
}
}
function convolve(input) {
let output = createImage(input.width, input.height);
for (let y=1; y < input.height-1; y++) {
for (let x=1; x < input.width-1; x++ ) {
let sumR = 0;
let sumG = 0;
let sumB = 0;
for (let dy=-1; dy <= 1; dy++) {
for (let dx=-1; dx <= 1; dx++) {
let pixel = input.get(x+dx, y+dy);
sumR += red(pixel) * filter[dy+1][dx+1];
sumG += green(pixel) * filter[dy+1][dx+1];
sumB += blue(pixel) * filter[dy+1][dx+1];
}
}
output.set(x, y, color(sumR, sumG, sumB));
}
}
output.updatePixels();
return output;
}
function maxPool(input, poolSize) {
let output = createImage(input.width/poolSize, input.height/poolSize);
for (let y=0; y < input.height; y += poolSize) {
for (let x=0; x < input.width; x += poolSize) {
let maxR = 0;
let maxG = 0;
let maxB = 0;
for (let dy=0; dy < poolSize; dy++) {
for (let dx=0; dx < poolSize; dx++) {
let pixel = input.get(x+dx, y+dy);
maxR = max(maxR, red(pixel));
maxG = max(maxG, green(pixel));
maxB = max(maxB, blue(pixel));
}
}
output.set(x/poolSize, y/poolSize, color(maxR, maxG, maxB));
}
}
output.updatePixels();
return output;
}