xxxxxxxxxx
108
// This is an attempt to add a 3rd prediction class (not yet working)
// Original sketch https://is.gd/ml5_itp
let featureExtractor;
let classifier;
let video;
let loss;
let humanImages = 0; // XXX this was added
let dogImages = 0;
let catImages = 0;
const options = {
version: 1,
alpha: 1.0,
topk: 3,
learningRate: 0.00005,
hiddenUnits: 100,
epochs: 60, // random attempt - not sure
numClasses: 3, // the default was 2
batchSize: 0.4,
}
function setup() {
noCanvas();
// Create a video element
video = createCapture(VIDEO);
// Append it to the videoContainer DOM element
video.parent('videoContainer'); // optional
// Extract the already learned features from MobileNet
featureExtractor = ml5.featureExtractor('MobileNet', options, modelReady);
// Create a new classifier using those features and give the video we want to use
classifier = featureExtractor.classification(video);
// Create the UI buttons
createButtons();
}
// A function to be called when the model has been loaded
function modelReady() {
console.log("modelReady()");
select('#loading').html('Base Model (MobileNet) loaded!');
}
// Add the current frame from the video to the classifier
function addImage(label) {
console.log("addImage()");
classifier.addImage(label);
}
// Classify the current frame.
function classify() {
console.log("classify()");
classifier.classify(gotResults);
}
// A util function to create UI buttons
function createButtons() {
console.log("createButtons()");
// When the Cat button is pressed, add the current frame
// from the video with a label of "cat" to the classifier
buttonA = select('#catButton');
buttonA.mousePressed(function() {
addImage('cat');
select('#amountOfCatImages').html(catImages++);
});
// When the Dog button is pressed, add the current frame
// from the video with a label of "dog" to the classifier
buttonB = select('#dogButton');
buttonB.mousePressed(function() {
addImage('dog');
select('#amountOfDogImages').html(dogImages++);
});
buttonC = select('#humanButton'); // this was added
buttonC.mousePressed(function() { // this was added
addImage('human'); // this was added
select('#amountOfHumanImages').html(humanImages++); // this was added
});
// Train Button
train = select('#train');
train.mousePressed(function() {
classifier.train(function(lossValue) {
if (lossValue) {
loss = lossValue;
select('#loss').html('Loss: ' + loss);
} else {
select('#loss').html('Done Training! Final Loss: ' + loss);
}
});
});
// Predict Button
buttonPredict = select('#buttonPredict');
buttonPredict.mousePressed(classify);
}
// Show the results
function gotResults(result) {
console.log("gotResults()");
select('#result').html(result);
classify();
}