xxxxxxxxxx
92
let catFile, rainbowFile, trainFile;
let len = 784; // 28 * 28;
let catData = [];
let rainbowData = [];
let trainData = [];
let myNN;
function preload() {
// load in the files
catFile = loadBytes("./data/cats1000.bin");
rainbowFile = loadBytes("./data/rainbows1000.bin");
trainFile = loadBytes("./data/trains1000.bin");
}
function prepareData() {
// we know we hive a binary file with 1000 images and each one has 784 points
for (let i = 0; i < 1000; i++) {
// to read, access all points inside
catData[i] = []; // to hold the 784 nums from each of the 1000 images into an array that will be inside the catDta array
rainbowData[i] = [];
trainData[i] = [];
for (let j = 0; j < len; j++) {
// divide each by 255 to normalize the data
const catValue = (catFile.bytes[i*len + j]) / 255;
const rainbowValue = (rainbowFile.bytes[i*len + j]) / 255;
const trainValue = (trainFile.bytes[i*len + j]) / 255;
catData[i].push(catValue);
rainbowData[i].push(rainbowValue);
trainData[i].push(trainValue);
}
}
console.log("catData", catData);
}
function addData() {
const options = {
task: "classification",
inputs: 784, // shape of our data
outputs: 3, // num of classes we have
debug: true, // make it show data
};
myNN = ml5.neuralNetwork(options);
// loop through and add them all as input to neural network
for (let i = 0; i < 1000; i++) {
myNN.addData(catData[i], ["cat"]);
myNN.addData(rainbowData[i], ["rainbow"]);
myNN.addData(trainData[i], ["train"]);
}
}
// mynn.addData([1, 0, 0, 0 ...], 'cat')
function setup() {
createCanvas(400, 400);
background(255);
stroke(0);
strokeWeight(30)
prepareData();
addData();
myNN.train(
{
epochs: 10,
},
doneTraining
);
}
function draw() {
if (mouseIsPressed){
line(pmouseX, pmouseY, mouseX, mouseY);
}
}
function mouseReleased(){
//if (isModelReady){
const myInput = []
let c = get();
c.resize(28, 28);
c.loadPixels();
for (let i = 0; i < len; i++){
const value = (255 - c.pixels[i * 4]) / 4
myInput.push(value)
}
console.log(c.pixels)
myNN.classify(myInput, gotResults);
}
function doneTraining() {}
function gotResults(error, results){
console.log(results)
}