import java.awt.*; import java.applet.*; import java.util.*; public class Perceptron extends Applet{ double inputs[]; double weights[]; double value; double eta; int state; int count; int misclassified; Random r; CheckboxGroup inout; Checkbox in, out; Button cmdButton, clearButton; TextField etaField; InputSet showit; int currentClassification; ClassifiedNode theNodes[]; public void init(){ setBackground(Color.cyan); weights = new double[3]; // Current input weights inputs = new double[3]; misclassified = 0; // Count # misclassified inputs[0] = 1.0; // Held at 1.0 eta = 0.1; // Learning rate // Various input widgets inout = new CheckboxGroup(); in = new Checkbox("1",inout,true); out = new Checkbox("0",inout,false); add(in); add(out); cmdButton = new Button("End input"); add(cmdButton); etaField = new TextField(5); etaField.setBackground(Color.white); add(etaField); clearButton = new Button("Clear"); add(clearButton); // showit is where we get the input points and display the // line given by the weights showit = new InputSet(this); showit.resize(201,201); showit.setBackground(Color.white); add(showit); currentClassification = 1; r = new Random(); state = 0; } public boolean action(Event e, Object o){ if( e.target == in) currentClassification = 1; else if(e.target == out) currentClassification = 0; // In state 0, the command button is the "End input" // button. When it is selected, we switch to training. else if(e.target == cmdButton && state == 0){ state = 1; showit.setState(1); cmdButton.setLabel("Train"); // Select weights randomly to start training. setWeights(); // Tell the canvas so it can display the line. showit.setWeights(weights); // Get and convert the selected points. theNodes = showit.getNodes(); count = showit.count; } // When training, each time we select the command button // we get one round of training. else if(e.target == cmdButton && state > 0){ state = 2; showit.setState(2); oneTrainingRound(); } // We get to state -1 if no points are misclassified. In that // case, we return to the initial state. We also allow the // user to start over again at any time. else if((e.target == cmdButton && state == -1) || e.target == clearButton){ showit.clear(); state = 0; cmdButton.setLabel("End input"); showit.setState(0); showit.repaint(); repaint(); return true; } // Let the user change the training rate, else if(e.target == etaField){ String etaInput = e.arg.toString(); Double maybe; double preeta; try{ maybe = new Double(etaInput); } catch(NumberFormatException n){ etaField.setText("" + eta); repaint(); } preeta = maybe.doubleValue(); if(preeta > 0.0 && preeta <= 1.0){ eta = preeta; } else { etaField.setText(""+eta); repaint(); } } // Check how many points are misclassified. If none, // switch to a state from which we will start over. if(state > 0){ countMisclassified(); if(misclassified == 0){ state = -1; cmdButton.setLabel("Restart"); } } showit.repaint(); repaint(); return true; } void setWeights(){ // Select random weights between -1 and 1 int i; for(i = 0; i < 3; i++){ weights[i] = 2.0 * r.nextDouble() - 1.0; } } int classify(){ // Decide whether the input[] point belongs to class 0 or class 1. value = weights[0]*inputs[0] + weights[1]*inputs[1] + weights[2]*inputs[2]; return value > 0 ? 1 : 0; } public void paint(Graphics g){ if(state != 0){ g.drawString("weights:",1,250); g.drawString(" " + weights[0] + " " + weights[1] + " " + weights[2], 1, 275); g.drawString("misclassified: " + misclassified, 1, 300); } } void countMisclassified(){ // Return how many points are misclassified, int j, whatClass; misclassified = 0; for(j = 0; j < count; j++) { inputs[1] = theNodes[j].x; inputs[2] = theNodes[j].y; whatClass = classify(); if(whatClass == 0 && theNodes[j].classification == 1){ misclassified++; } else if(whatClass == 1 && theNodes[j].classification == 0){ misclassified++; } } } void oneTrainingRound(){ int i, j, whatClass; for(j = 0; j < count; j++) { inputs[1] = theNodes[j].x; inputs[2] = theNodes[j].y; whatClass = classify(); // If point is in class 1 but misclassified by perceptron, // move the line toward the point. if(whatClass == 0 && theNodes[j].classification == 1){ for(i = 0; i < 3; i++) weights[i] += eta*inputs[i]; } // If misclassified point is really in class 0, move // line away from point. else if(whatClass == 1 && theNodes[j].classification == 0){ for(i = 0; i < 3; i++) weights[i] -= eta*inputs[i]; } } showit.setWeights(weights); } } class InputSet extends Canvas{ // Gets and displays points. int count; InputNode first; // first of a list of input points Perceptron theApplet; int state; double weights[]; // local copy of the perceptron's weights for drawing InputSet(Perceptron a){ count = 0; state = 0; first = null; theApplet = a; weights = new double[3]; } public void clear(){ count = 0; state = 0; first = null; } public boolean mouseUp(Event e, int x, int y){ // Get and attach one input node. if( state == 0){ first = new InputNode(first, x, y, theApplet.currentClassification); count++; repaint(); } return true; } public void paint(Graphics g){ int i; int val1,val2; InputNode n = first; g.setColor(Color.black); // Draw the axes. g.drawLine(1,100,201,100); g.drawLine(100,1,100,201); // Draw the points. for(i = 0; i < count; i++){ g.setColor(n.classification == 0 ? Color.red : Color.blue); g.drawOval(n.x-3,n.y-3,6,6); n = n.next; } // Draw the line if appropriate, if(state > 0){ g.setColor(Color.black); // Horizontal case, if(weights[2] == 0.0){ val1 = (int)Math.round(-weights[0]/weights[1]*100); g.drawLine(val1+100,-200,val1+100,200); } // The input rectangle (1,1) -> (201,201) has to be mapped // to (-1.0, 1.0) -> (1.0, -1.0). else{ val1 = (int)Math.round((100*weights[0]-101*weights[1]+101*weights[2]+weights[1]*1)/weights[2]); val2 = (int)Math.round((100*weights[0]-101*weights[1]+101*weights[2]+weights[1]*201)/weights[2]); g.drawLine(1,val1,201,val2); // g.drawString(" " + weights[0] + " " + weights[1] + " " + weights[2],1, 221); } } } public void setState(int newstate){ state = newstate; } public void setWeights(double w[]){ int i; for(i = 0; i < 3; i++) weights[i] = w[i]; } public ClassifiedNode[] getNodes(){ ClassifiedNode retval[] = new ClassifiedNode[count]; InputNode nextNode = first; int i; for(i = 0; i < count; i++){ retval[i] = new ClassifiedNode(nextNode.x, nextNode.y, nextNode.classification); nextNode = nextNode.next; } return retval; } } class InputNode{ int x,y; InputNode next; int classification; InputNode(InputNode old, int locx, int locy, int what){ next = old; x = locx; y = locy; classification = what; } } class ClassifiedNode{ double x,y; int classification; public ClassifiedNode(int a, int b, int c){ classification = c; // Map rectangle (1,1) -> (201,201) to (-1.0, 1.0) -> (1.0, -1.0) x = (double)(a-101)/100.0; y = (double)(101-b)/100.0; } }