xxxxxxxxxx
224
//Warning, pretty messy!
let featureExtractor;
let classifier;
let video;
let loss;
let imagesOfA = 0;
let imagesOfB = 0;
let classificationResult;
let image1;
let image2;
let aImages = [];
let bImages = [];
let nImages = 10;
let canvas;
let whichImage = false;
let showImg = true;
let allowedToTrain = true;
function preload() {
image1 = loadImage('A.png');
image2 = loadImage('B.png');
for (let i = 0; i < nImages; i++) {
aImages[i] = loadImage('A' + i + '.png');
bImages[i] = loadImage('B' + i + '.png');
}
}
function setup() {
canvas = createCanvas(100, 100);
pixelDensity(2);
// Create a video element
video = createCapture(VIDEO);
video.size(width, height);
video.hide();
//videoHD = createCapture(VIDEO);
//videoHD.size(displayWidth,displayHeight);
//videoHD.hide();
const options = {
epochs: 500
};
// Append it to the videoContainer DOM element
//video.parent('videoContainer');
// Extract the already learned features from MobileNet
featureExtractor = ml5.featureExtractor('MobileNet', options, modelReady);
// Create a new classifier using those features and give the video we want to use
classifier = featureExtractor.classification(video, videoReady);
//classifier = featureExtractor.classification();
// Set up the UI buttons
setupButtons();
for (let i = 0; i < nImages; i++) {
aImages[i].resize(224, 224);
bImages[i].resize(224, 224);
}
}
function draw() {
if (classifier.isPredicting) {
background(122);
image(video, 0, 0, width, height);
if (classificationResult == 'A') {
//ellipse(100, 100, 100, 100);
} else if (classificationResult == 'B') {
//rect(100, 100, 100, 100);
}
}
if (frameCount == 25 && allowedToTrain) {
allowedToTrain = false;
console.log("now!");
doTheTraining();
}
}
// A function to be called when the model has been loaded
function modelReady() {
select('#modelStatus').html('Base Model (MobileNet) loaded!');
trainOnLoadedImages();
}
function trainOnLoadedImages() {
for (let i = 0; i < nImages; i++) {
image(aImages[i], 0, 0);
const src = canvas.elt.toDataURL();
let tmpImg = createImg(src, "Failed to create", () => {
tmpImg.class("image").size(224, 224);
classifier.addImage(tmpImg.elt, 'A');
});
tmpImg.remove();
select('#amountOfAImages').html(imagesOfA++);
}
for (let i = 0; i < nImages; i++) {
image(bImages[i], 0, 0);
const src = canvas.elt.toDataURL();
let tmpImg = createImg(src, "Failed to create", () => {
tmpImg.class("image").size(224, 224);
classifier.addImage(tmpImg.elt, 'B');
});
tmpImg.remove();
select('#amountOfBImages').html(imagesOfB++);
}
classifier.train(function(lossValue) {
if (lossValue) {
loss = lossValue;
select('#loss').html('Loss: ' + loss);
} else {
select('#loss').html('Done Training! Final Loss: ' + loss);
showImg = false;
}
});
}
function doTheTraining() {
classifier.train(function(lossValue) {
if (lossValue) {
loss = lossValue;
select('#loss').html('Loss: ' + loss);
} else {
select('#loss').html('Done Training! Final Loss: ' + loss);
showImg = false;
classify();
}
});
}
// A function to be called when the video has loaded
function videoReady() {
select('#videoStatus').html('Video ready!');
classifier.train(function(lossValue) {
if (lossValue) {
loss = lossValue;
select('#loss').html('Loss: ' + loss);
} else {
select('#loss').html('Done Training! Final Loss: ' + loss);
showImg = false;
classify();
}
});
}
// Classify the current frame.
function classify() {
classifier.classify(gotResults);
}
// A util function to create UI buttons
function setupButtons() {
// When the A button is pressed, add the current frame
// from the video with a label of "A" to the classifier
buttonA = select('#ButtonA');
buttonA.mousePressed(function() {
classifier.addImage('A');
select('#amountOfAImages').html(imagesOfA++);
});
// When the B button is pressed, add the current frame
// from the video with a label of "B" to the classifier
buttonB = select('#ButtonB');
buttonB.mousePressed(function() {
classifier.addImage('B');
select('#amountOfBImages').html(imagesOfB++);
});
// Train Button
train = select('#train');
train.mousePressed(function() {
classifier.train(function(lossValue) {
if (lossValue) {
loss = lossValue;
select('#loss').html('Loss: ' + loss);
} else {
select('#loss').html('Done Training! Final Loss: ' + loss);
showImg = false;
classify();
}
});
});
// Predict Button
buttonPredict = select('#buttonPredict');
buttonPredict.mousePressed(classify);
}
// Show the results
function gotResults(err, result) {
// Display any error
if (err) {
console.error(err);
}
select('#result').html(result);
classificationResult = result;
classify();
}