xxxxxxxxxx
102
let nn;
let training = true;
let bgColor = "white";
let predictLabel = "";
let drawPoints = [];
let inputs = [];
function setup() {
createCanvas(400, 400);
background(bgColor);
let options = {
task: "imageClassification", // GH: changed
inputs: [ 64, 64, 4 ], // GH: added
debug: true,
};
nn = ml5.neuralNetwork(options);
}
function draw() {
if (mouseIsPressed) {
strokeWeight(2);
noFill();
let points = createVector(mouseX, mouseY);
drawPoints.push(points);
line(pmouseX,pmouseY,mouseX,mouseY);
beginShape();
for (let i = 0; i < drawPoints.length; i++) {
vertex(drawPoints[i].x, drawPoints[i].y);
}
} else {
if (drawPoints.length > 2) {
let d = dist(
drawPoints[0].x,
drawPoints[0].y,
drawPoints[drawPoints.length - 1].x,
drawPoints[drawPoints.length - 1].y
);
console.log(d);
line(
drawPoints[0].x,
drawPoints[0].y,
drawPoints[drawPoints.length - 1].x,
drawPoints[drawPoints.length - 1].y
);
endShape();
fill(0);
}
}
}
function keyPressed() {
let img = get();
img.resize(64, 64);
if (training) {
if (key == "s") {
//press "s" to input square;
nn.addData({ image: img }, ["square"]);
clearCanvas();
} else if (key == "c") {
//press "c" to input circle;
nn.addData({ image: img }, ["circle"]);
clearCanvas();
} else if (key == "t") {
//press "t" to train data;
nn.normalizeData();
clearCanvas();
let options = {
epochs: 64,
};
nn.train(options, doneTraining);
}
} else {
if (keyCode == 13) {
nn.classify({ image: img }, doneClassifying);
clearCanvas();
}//press enter to see the classification results;
}
}
function doneTraining() {
console.log("Done training!");
training = false;
clearCanvas();
}
function clearCanvas() {
background(bgColor);
drawPoints=[];
}
function doneClassifying(error, results) {
console.log(results[0].label);
predictLabel = results[0].label;
}