xxxxxxxxxx
220
// A regression using MobileNet, ml5.js, p5js. Based on:
// https://editor.p5js.org/ml5/sketches/FeatureExtractor_Image_Regression
let featureExtractorX, featureExtractorY;
let regressorX, regressorY;
let video;
let statusString = "STATUS_ADD_SAMPLES";
let trainingLossX, trainingLossY;
let nSamples = 0;
let predictionValue01, predictionValue02;
let trained = false;
var lines = [];
var points = [];
var thisPoint = 0;
var count = 0;
let oldPositionX, oldPositionY;
let currentX, currentY;
//--------------------------------------
function setup() {
// createCanvas(windowWidth, windowHeight);
createCanvas(2400, 1350);
fullscreen(true);
video = createCapture(VIDEO);
video.hide();
featureExtractorX = ml5.featureExtractor('MobileNet');
regressorX = featureExtractorX.regression(video);
featureExtractorY = ml5.featureExtractor('MobileNet');
regressorY = featureExtractorY.regression(video);
}
//--------------------------------------
function draw() {
//print(windowWidth + " " + windowHeight);
background('white');
tint(255, 255, 255, 84);
push();
translate(width, 0);
scale(-1, 1);
image(video, 0, 0, width, height);
pop();
if (keyIsDown(13)) addSample();
var positionX = width / 2;
var positionY = height / 2;
if (!trained) {
showInstructions();
for (var py = 0; py < 4; py++) {
var pointY = map(py, 0, 3, 25, height - 25);
for (var px = 0; px < 4; px++) {
var pointX = map(px, 0, 3, 25, width - 25);
points.push([pointX, pointY]);
if (py * 4 + px == thisPoint) {
fill("red");
if (count > 30) fill("lightgreen");
}
else {
fill(0);
noStroke();
}
ellipse(pointX, pointY, 50, 50);
}
}
}
drawSlider();
strokeWeight(1);
if (predictionValue01) positionX = map(predictionValue01, 0, 1, 25, width - 25);
if (predictionValue02) positionY = map(predictionValue02, 0, 1, 25, height - 25);
if (trained) {
if (lines.length == 0) {
oldPositionX = positionX;
oldPositionY = positionY;
}
strokeWeight(4);
stroke(0);
lines.push([oldPositionX, oldPositionY, positionX, positionY]);
for (var i = 0; i < lines.length; i++) {
var l = lines[i];
line(l[0], l[1], l[2], l[3]);
}
oldPositionX = positionX;
oldPositionY = positionY;
}
}
function showInstructions() {
// draw diagnostic/debug information
noStroke();
fill('black');
var instructions = "Press enter to add samples. ";
instructions += "Press + or - to move between training points. \n";
instructions += "Press t to train model. \n";
instructions += "Press p to start predicting. \n";
text(instructions, 15, 30);
text("status: " + statusString, 15, 75);
text("nSamples: " + nSamples, 15, 90);
text("trainingLoss: " + trainingLossX, 15, 105);
text("trainingLoss: " + trainingLossY, 15, 120);
var pStrX = (predictionValue01) ? nf(predictionValue01, 1, 3) : "undefined";
var pStrY = (predictionValue02) ? nf(predictionValue02, 1, 3) : "undefined";
text("predictionX: " + pStrX, 15, 135);
text("predictionY: " + pStrY, 15, 150);
}
function drawSlider() {
noFill();
stroke(0);
rect(0, 0, width, 10);
var mx = points[thisPoint][0];
if (predictionValue01) {
mx = map(predictionValue01, 0, 1, 25, width - 25);
}
rect(mx, 0, 1, 10);
rect(0, 0, 10, height);
var my = points[thisPoint][1];
if (predictionValue02) {
my = map(predictionValue02, 0, 1, 25, height - 25);
}
rect(0, my, 10, 1);
}
//--------------------------------------
function addSample() {
currentX = points[thisPoint][0];
currentY = points[thisPoint][1];
var xValue = constrain(map(currentX, 0, width, 0, 1), 0, 1);
var yValue = constrain(map(currentY, 0, height, 0, 1), 0, 1);
regressorX.addImage(xValue);
regressorY.addImage(yValue);
nSamples++;
count++;
}
function keyPressed() {
if (key == 't' || key == "5") { // train the regressor
regressorX.train(function(lossValueX) {
if (lossValueX) {
trainingLossX = lossValueX;
statusString = "STATUS_TRAINING_MODEL_X";
} else {
statusString = "STATUS_DONE_TRAINING_X";
}
});
regressorY.train(function(lossValueY) {
if (lossValueY) {
trainingLossY = lossValueY;
statusString = "STATUS_TRAINING_MODEL_Y";
} else {
statusString = "STATUS_DONE_TRAINING_Y";
}
});
} else if (key == 'p' || key == "2") { // initiate prediction
statusString = "STATUS_PREDICTING";
regressorX.predict(gotResultsCallbackX);
regressorY.predict(gotResultsCallbackY);
trained = true;
} else if (key == 'r') {
lines = [];
}
else if (key == "+" || key =="6" || keyCode == RIGHT_ARROW) {
if (thisPoint > 14) thisPoint = 0;
else thisPoint++;
count = 0;
}
else if (key == "-" || key == "4" || keyCode == LEFT_ARROW) {
if (thisPoint == 0) thisPoint = 15;
else thisPoint--;
count = 0;
}
}
var starter = 0;
// function mousePressed() {
// if (starter > 0) {
// let fs = fullscreen();
// fullscreen(true);
// }
// starter++;
// }
function windowResized() {
resizeCanvas(windowWidth, windowHeight);
}
//--------------------------------------
// Store the results, and restart the process.
function gotResultsCallbackX(err, result) {
if (err) {
console.error(err);
}
if (result && result.value) {
predictionValue01 = result.value;
regressorX.predict(gotResultsCallbackX);
}
}
function gotResultsCallbackY(err, result) {
if (err) {
console.error(err);
}
if (result && result.value) {
predictionValue02 = result.value;
regressorY.predict(gotResultsCallbackY);
}
}