xxxxxxxxxx
81
let model;
let targetLabel = 'C';
let letters = ['C', 'D', 'E'];
let state = 'collection';
function setup() {
createCanvas(400, 400);
const options = {
inputs: ['x', 'y'],
outputs: ['label'],
task: 'classification',
debug: 'true'
};
model = ml5.neuralNetwork(options);
}
function keyPressed() {
// to train the model, I'm checking if the key pressed is 't'
if (key === 't') {
state = 'training';
console.log('starting training 🏃♀️')
// function in ml5 that will normalize the data (between 0 and 1)
model.normalizeData();
let options = {
epochs: 100
}
model.train(options, whileTraining, finishTraining);
} else {
// check if the key pressed is one of the target letters in the letters array
if (letters.includes(key.toUpperCase())) {
targetLabel = key.toUpperCase();
}
}
}
function finishTraining() {
console.log(`I'm done training 🏃♂️`);
state = 'prediction';
}
function whileTraining(epoch, loss) {
console.log(epoch);
}
function mousePressed() {
let inputs = {
x: mouseX,
y: mouseY
}
if (state === 'collection') {
let target = {
label: targetLabel
}
// function in p5js that takes key pair/values matching the defined
// inputs and outputs defined in the options object
model.addData(inputs, target);
stroke(0);
noFill();
circle(mouseX, mouseY, 24);
fill(0);
textAlign(CENTER, CENTER);
text(targetLabel, mouseX, mouseY);
} else if(state === 'prediction') {
model.classify(inputs,gotResults);
}
}
function gotResults(error,results) {
if(error){
console.error(error);
return;
}
console.log(results);
}