SpikeStream Library  0.2
WeightlessNeuron.cpp
Go to the documentation of this file.
00001 //SpikeStream includes
00002 #include "SpikeStreamException.h"
00003 #include "Util.h"
00004 #include "WeightlessNeuron.h"
00005 using namespace spikestream;
00006 
00007 //Qt includes
00008 #include <QDebug>
00009 
00010 //Other includes
00011 #include <algorithm>
00012 #include <iostream>
00013 using namespace std;
00014 
00021 byte lookup[] = {
00022         0,1,1,2, 1,2,2,3, 1,2,2,3, 2,3,3,4,
00023         1,2,2,3, 2,3,3,4, 2,3,3,4, 3,4,4,5,
00024         1,2,2,3, 2,3,3,4, 2,3,3,4, 3,4,4,5,
00025         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00026         1,2,2,3, 2,3,3,4, 2,3,3,4, 3,4,4,5,
00027         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00028         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00029         3,4,4,5, 4,5,5,6, 4,5,5,6, 5,6,6,7,
00030         1,2,2,3, 2,3,3,4, 2,3,3,4, 3,4,4,5,
00031         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00032         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00033         3,4,4,5, 4,5,5,6, 4,5,5,6, 5,6,6,7,
00034         2,3,3,4, 3,4,4,5, 3,4,4,5, 4,5,5,6,
00035         3,4,4,5, 4,5,5,6, 4,5,5,6, 5,6,6,7,
00036         3,4,4,5, 4,5,5,6, 4,5,5,6, 5,6,6,7,
00037         4,5,5,6, 5,6,6,7, 5,6,6,7, 6,7,7,8
00038 };
00039 
00040 
00042 WeightlessNeuron::WeightlessNeuron(QHash<unsigned int, QList<unsigned int> >& connectionMap, unsigned int id){
00043         //Store variables
00044         this->id = id;
00045         this->connectionMap = connectionMap;
00046 
00047         //Calculate number of connections to this neuron
00048         numberOfConnections = 0;
00049         for(QHash<unsigned int, QList<unsigned int> >::iterator iter = connectionMap.begin(); iter != connectionMap.end(); ++iter)
00050                 numberOfConnections += iter.value().size();
00051 
00052         /* Calculate training data length
00053            Each bit stores an input with a byte at the beginning for the result */
00054         if(numberOfConnections % 8 == 0)
00055                 trainingDataLength = numberOfConnections / 8 + 1;
00056         else
00057                 trainingDataLength = numberOfConnections / 8 + 2;
00058 
00059         //Default setting for hamming threshold
00060         hammingThreshold = 0;
00061 }
00062 
00063 
00065 WeightlessNeuron::~WeightlessNeuron(){
00066         //Delete the training data
00067         resetTraining();
00068 }
00069 
00070 
00071 /*-------------------------------------------------------------*/
00072 /*-------                  PUBLIC METHODS                ------*/
00073 /*-------------------------------------------------------------*/
00074 
00076 void WeightlessNeuron::addTraining(QByteArray& newData, unsigned int output){
00077         if( (newData.size() + 1) != trainingDataLength)
00078                 throw SpikeStreamException("New  training data length " + QString::number(newData.size() + 1) + " does not match current training data length " + QString::number(trainingDataLength));
00079 
00080         //Store the data
00081         byte* newTrainingArray = new byte[trainingDataLength];
00082         newTrainingArray[0] = output;
00083         for(int i=1; i<trainingDataLength; ++i){
00084                 newTrainingArray[i] = newData[i-1];
00085         }
00086         trainingData.append(newTrainingArray);
00087 }
00088 
00089 
00095 double WeightlessNeuron::getFiringStateProbability(byte inPatArr[], int inPatArrLen, int firingState){
00096         if(inPatArrLen != (trainingDataLength - 1) )
00097                 throw SpikeStreamException("Training data length " + QString::number(trainingDataLength-1) + " does not match pattern length " + QString::number(inPatArrLen));
00098 
00099         //Return 0.5 if there is no training data - there will be no matches with the incoming pattern.
00100         if(trainingData.size() == 0)
00101                 return 0.5;
00102 
00103         //Work through all of the training patterns
00104         bool firstTime = true;
00105         byte currentResponse;
00106         unsigned int minDist = hammingThreshold + 1;//Initialise to take account of case where there is no training data
00107         byte* trainingPattern;
00108         QHash<int, byte> minDistIndxMap;
00109         for(int listIndx=0; listIndx<trainingData.size(); ++listIndx){
00110                 trainingPattern = trainingData[listIndx];
00111 
00112                 /* Work through all of the bytes in the input pattern and calculate
00113                 the total hamming distance between the input pattern and the stored pattern */
00114                 unsigned int hamDist = 0;
00115                 for(int i=0; i<inPatArrLen; ++i){
00116                         hamDist += lookup[inPatArr[i] ^ trainingPattern[i + 1]];
00117                         if(hamDist > hammingThreshold)
00118                                 break;
00119                 }
00120                 currentResponse = trainingPattern[0];
00121 
00122                 if(firstTime){
00123                         minDist = hamDist;
00124                         minDistIndxMap[listIndx] = currentResponse;
00125                         firstTime = false;
00126                 }
00127                 else{
00128                         //Set this pattern to be the current minimum
00129                         if(hamDist < minDist){
00130                                 minDist = hamDist;
00131                                 minDistIndxMap.clear();
00132                                 minDistIndxMap[listIndx] = currentResponse;
00133                         }
00134                         //Add the index of this pattern to the map if it has the same Hamming distance
00135                         else if(hamDist == minDist){
00136                                 minDistIndxMap[listIndx] = currentResponse;
00137                         }
00138                 }
00139         }
00140         //Return 0.5 if there is  no match within the minimum hamming distance
00141         if(minDist > hammingThreshold)
00142                 return 0.5;
00143 
00144         //Return 0.5 if there are multiple contradicting matches
00145         bool zeroFound = false, oneFound = false;
00146         for(QHash<int, byte>::iterator iter = minDistIndxMap.begin(); iter != minDistIndxMap.end(); ++iter){
00147                 if(iter.value() == 0)
00148                         zeroFound = true;
00149                 else if(iter.value() == 1)
00150                         oneFound = true;
00151                 else
00152                         throw SpikeStreamException("Entry in training data not recognized. It should be 1 or 0: " + QString::number(iter.value()));
00153 
00154                 //Responses contradict one another
00155                 if(oneFound && zeroFound)
00156                         return 0.5;
00157         }
00158 
00159         /*One or more best matches have been found with the same output.
00160           Return 1.0 if the output matches the specified firing state or 0.0 otherwise */
00161         if(minDistIndxMap.begin().value() == firingState)
00162                 return 1.0;
00163         return 0.0;
00164 }
00165 
00166 
00168 double WeightlessNeuron::getTransitionProbability(const QList<unsigned int>& neurIDList, const QString& s0Pattern, int firingState){
00169         //Run checks on the data
00170         if(neurIDList.size() != s0Pattern.size())
00171                 throw SpikeStreamException("Neuron ID list size does not match the size of the s0 pattern.");
00172 
00173         /* Start pattern corresponds to a list of neuron ids, but need the pattern that corresponds
00174         to the connections to this neuron. FiringNeuronIndexMap links an index in the input pattern
00175         to the firing state of that neuron*/
00176         QHash<unsigned int, byte> firingNeuronIndexMap;
00177 
00178         //Set the state of the known neuron IDs in the input pattern
00179         for(int i=0; i<neurIDList.size(); ++i){
00180                 //Neuron in list is connected to this neuron
00181                 if(connectionMap.contains(neurIDList[i])){
00182                         /* Set the part of the input pattern to the state corresponding to the neuron's state
00183                    This may have to be done several times if there are several connections between two neurons */
00184                         foreach(unsigned int tmpNeurIndex, connectionMap[neurIDList[i]]){
00185                                 firingNeuronIndexMap[tmpNeurIndex] = Util::getUInt(s0Pattern[i]);
00186                         }
00187                 }
00188         }
00189         int missingNeuronCount = numberOfConnections - firingNeuronIndexMap.size();
00190         if(missingNeuronCount <0)
00191                 throw SpikeStreamException("Error in transition probability calculation. Missing neuron count is less than zero.");
00192 
00193         //Array used to select all combinations of missing neurons
00194         bool missingNeurSelectionArr[missingNeuronCount];
00195 
00196         //Create array used in the comparison with the stored data in the neuron
00197         int inPattArrSize = trainingDataLength - 1;
00198         byte inPattArr [inPattArrSize];
00199 
00200         //Variable to sum the output probabilities for each partially random input string
00201         double tmpProbOut = 0.0;
00202         int randomStringCount = 0;
00203 
00204         //Work through all firing combinations of the undefined input neurons
00205         int numOnes = 0;
00206         while(numOnes <= missingNeuronCount){
00207 
00208                 //Initialize selection array with first combination of 1s and 0s
00209                 Util::fillSelectionArray(missingNeurSelectionArr, missingNeuronCount, numOnes);
00210 
00211                 bool permutationsComplete = false;
00212                 while(!permutationsComplete){
00213                         //Build the input pattern
00214                         buildInputPattern(inPattArr, inPattArrSize, missingNeurSelectionArr, missingNeuronCount, firingNeuronIndexMap);
00215 
00216                         //Sum the probability that it has the specified output and keep track of the number of calculations
00217                         tmpProbOut += getFiringStateProbability(inPattArr, inPattArrSize, firingState);
00218                         ++randomStringCount;
00219 
00220                         //Change the selection array
00221                         permutationsComplete = !next_permutation(&missingNeurSelectionArr[0], &missingNeurSelectionArr[missingNeuronCount]);
00222                 }
00223                 ++numOnes;
00224         }
00225 
00226         /* Unknown neuron states are random, so each string has the same probability of occurring
00227            So probability of the firing state is the sum of the prob(state occurring) * prob(firing state for that string) */
00228         tmpProbOut /= randomStringCount;
00229         return tmpProbOut;
00230 }
00231 
00232 
00234 void WeightlessNeuron::resetTraining(){
00235         foreach(byte* array, trainingData)
00236                 delete array;
00237         trainingData.clear();
00238 }
00239 
00240 
00242 void WeightlessNeuron::setGeneralization(double generalization){
00243         hammingThreshold = Util::rUInt((double)numberOfConnections * (1.0 - generalization));
00244 }
00245 
00246 
00247 /*-------------------------------------------------------------*/
00248 /*-------                 PRIVATE METHODS                ------*/
00249 /*-------------------------------------------------------------*/
00250 
00253 void WeightlessNeuron::buildInputPattern(byte inPatArr[], int inPatArrSize, bool selArr[], int selArrSize, QHash<unsigned int, byte>& firingNeuronIndexMap){
00254         //Clear inPatArr
00255         for(int i=0; i<inPatArrSize; ++i)
00256                 inPatArr[i] = 0;
00257 
00258         /* Work through each connection and either populate the input array with the permuted random selection
00259         or with the actual firing state of the neuron, which is stored in the firingNeuronIndexMap */
00260         int selIndx = 0;
00261         for(int inputIndx=0; inputIndx < numberOfConnections; ++inputIndx){
00262                 if(inputIndx / 8 > inPatArrSize)
00263                         throw SpikeStreamException("Input pattern array index out of range.");
00264 
00265                 if(firingNeuronIndexMap.contains(inputIndx)){
00266                         //Set bit in the byte array to 1 or 0 depending on whether it is 1 or 0 in the map of relevant neuron firing states
00267                         if(firingNeuronIndexMap[inputIndx]){
00268                                 inPatArr[ inputIndx/8 ] |= 1<<(inputIndx % 8);
00269                         }
00270                 }
00271                 else{
00272                         if(selIndx >= selArrSize)
00273                                 throw SpikeStreamException("Selection index out of range. Actual=" + QString::number(selIndx) + "; maximum=" + QString::number(selArrSize));
00274 
00275                         //Set bit in the byte array to 1 if it is set to 1 in selection array
00276                         if(selArr[selIndx]){
00277                                 inPatArr[ inputIndx/8 ] |= 1<<(inputIndx % 8);
00278                         }
00279                         ++selIndx;
00280                 }
00281         }
00282 }
00283 
00284 
00286 void WeightlessNeuron::printConnectionMap(){
00287         for(QHash<unsigned int, QList<unsigned int> >::iterator iter = connectionMap.begin(); iter != connectionMap.end(); ++iter){
00288                 foreach(unsigned int tmpPattIndx, iter.value())
00289                         cout<<"Connection from "<<iter.key()<<" pattern index "<<tmpPattIndx<<endl;
00290         }
00291 }
00292 
00293 
00295 void WeightlessNeuron::printSelectionArray(bool selArr[], int arrSize){
00296         for(int i=0; i<arrSize; ++i){
00297                 if(selArr[i])
00298                         cout<<"1";
00299                 else
00300                         cout<<"0";
00301         }
00302         cout<<endl;
00303 }
00304 
00305 
00307 void WeightlessNeuron::printTraining(){
00308         foreach(byte* byteArr, trainingData)
00309                 Util::printByteArray(byteArr, trainingDataLength);
00310 }
00311 
 All Classes Files Functions Variables Typedefs Defines