xxxxxxxxxx
169
// 2D regression
// using MobileNet, ml5.js, p5js. Based on:
// https://editor.p5js.org/ml5/sketches/FeatureExtractor_Image_Regression
// Useful for training a 2D tracker.
// Keep in mind you'll need to add lots of training samples!
let video;
let statusString = "STATUS_ADD_SAMPLES";
let nSamples = 0;
let bFlipHorizontal = true;
let featureExtractorX;
let featureExtractorY;
let regressorX;
let regressorY;
let trainingLossX;
let trainingLossY;
let predictionValueX01;
let predictionValueY01;
//--------------------------------------
function setup() {
createCanvas(320, 240);
video = createCapture(VIDEO);
video.hide();
featureExtractorX = ml5.featureExtractor('MobileNet');
featureExtractorY = ml5.featureExtractor('MobileNet');
regressorX = featureExtractorX.regression(video); //, options);
regressorY = featureExtractorY.regression(video); //, options);
}
//--------------------------------------
function draw() {
background('white');
tint(255, 255, 255, 90);
push();
if (bFlipHorizontal){
translate(width,0);
scale(-1,1);
}
image(video, 0, 0, 320, 240);
pop();
// draw simple "sliders"
noFill();
stroke(0);
rect(0, 0, width, 10);
rect(0, 0, 10, height);
var mx = mouseX;
var my = mouseY;
if (predictionValueX01) {
mx = map(predictionValueX01, 0, 1, 0, width);
}
if (predictionValueY01) {
my = map(predictionValueY01, 0, 1, 0, height);
}
rect(mx, 0, 1, 10); // ticks
rect(0, my, 10, 1);
// draw a red circle whose X location is
// proportional to our predicted value
var positionX = width / 2;
var positionY = height / 2;
if (predictionValueX01) {
positionX = map(predictionValueX01, 0, 1, 0, width);
}
if (predictionValueY01) {
positionY = map(predictionValueY01, 0, 1, 0, height);
}
noStroke();
fill(255, 0, 0);
ellipse(positionX, positionY, 50, 50);
// draw diagnostic/debug information
fill('black');
var instructions = "Press s and adjust mouseX to add samples. \n";
instructions += "Press t to train model. \n";
instructions += "Press p to start predicting. \n";
text(instructions, 15, 30);
text("status: " + statusString, 15, 75);
text("nSamples: " + nSamples, 15, 90);
text("trainingLossX: " + trainingLossX, 15, 105);
text("trainingLossY: " + trainingLossY, 15, 120);
var pStrX = (predictionValueX01) ? nf(predictionValueX01,1,3):"undefined";
var pStrY = (predictionValueY01) ? nf(predictionValueY01,1,3):"undefined";
text("predictionX: " + pStrX, 15, 135);
text("predictionY: " + pStrY, 15, 150);
}
function mousePressed(){
keyPressed('s');
}
//--------------------------------------
function keyPressed() {
if (key == 's') { // add sample
var xValue = constrain(map(mouseX, 0, width, 0, 1), 0, 1);
var yValue = constrain(map(mouseY, 0, height, 0, 1), 0, 1);
regressorX.addImage(xValue);
regressorY.addImage(yValue);
nSamples++;
} else if (key == 't') { // train the regressor
var xtraining = false;
var ytraining = false;
var xtrainingDone = false;
var ytrainingDone = false;
regressorX.train(function(lossValueX) {
if (lossValueX) {
trainingLossX = lossValueX;
xtraining = true;
} else {
xtrainingDone = true;
}
});
regressorY.train(function(lossValueY) {
if (lossValueY) {
trainingLossY = lossValueY;
ytraining = true;
} else {
ytrainingDone = true;
}
});
if (xtraining || ytraining){
statusString = "STATUS_TRAINING_MODEL(S)";
} else if (xtrainingDone && ytrainingDone){
statusString = "STATUS_DONE_TRAINING";
}
} else if (key == 'p') { // initiate prediction
statusString = "STATUS_PREDICTING";
regressorX.predict(gotResultsCallbackX);
regressorY.predict(gotResultsCallbackY);
}
}
//--------------------------------------
// Store the results, and restart the process.
function gotResultsCallbackX(err, result) {
if (err) {
console.error(err);
}
if (result && result.value) {
predictionValueX01 = result.value;
regressorX.predict(gotResultsCallbackX);
}
}
function gotResultsCallbackY(err, result) {
if (err) {
console.error(err);
}
if (result && result.value) {
predictionValueY01 = result.value;
regressorY.predict(gotResultsCallbackY);
}
}