Category orthogonal associative memory

edited August 2017 in Share Your Work

It's known that the end result of image processing in the brain is a small number of orthogonal (at right angles in higher dimensional space) categories. People's faces are recognized by something like 50 such categories.
I am just right at the beginning of trying the idea with associative memory to see if something similar is possible in software. There is a lot more experimenting to do like making the AM nonlinear etc. The idea is that you gradually refine (layer by layer) a very large number of weak associations from a very large number of examples down to a small number of distinct categories. At the moment you can criticize the code a lot. Anyway:

AM[] nets=new AM[3];
AMCat am;
float[] inVec=new float[1024];

void setup() {
  size(100, 100);
  frameRate(1000000);
  nets[0]=new AM(1024, 10, .01f, 12345);
  nets[1]=new AM(1024, 100, .01f, 23456);
  nets[2]=new AM(1024, 100, .01f, 34567);
  am=new AMCat(1024, nets, 95);
  fill(255);
  textSize(26);
}

void draw() {
  background(0);
  char i=(char)random(33, 128);  //95 categories
  text(i, 0, 21);
  stroke(255);
  for (int x=0; x<32; x++) {
    for (int y=0; y<32; y++) {
      int c=get(x, y)&0xff;
      inVec[x+32*y]=c;
    }
  }
  int layer=frameCount>>12;  // divide by 1024*4
  if (layer<3) {
    am.train(inVec, i-33, layer);
  } else {
    frameRate(1);
    text((char)(am.recall(inVec)+33), 0, 21+32);
  }
}

class AMCat {
  int vecLen;
  int categories;
  float[][] catVecs;
  float[] workA;
  float[] workB;
  float[] outVec;
  AM[] nets;
  RP rp=new RP();

  AMCat(int vecLen, AM[] nets, int categories) {
    this.vecLen=vecLen;
    this.nets=nets;
    this.categories=categories;
    outVec=new float[categories];
    workA=new float[vecLen];
    workB=new float[vecLen];
    catVecs=new float[categories][vecLen];
    Xor128 rnd=new Xor128();
    for (int i=0; i<categories; i++) {
      for (int j=0; j<vecLen; j++) {
        catVecs[i][j]=rnd.nextFloatSym();
      }
      rp.wht(catVecs[i]);  //random Gaussian orthonal category vectors
    }
  }

  void recall(float[] cats, float[] inVec) {
    nets[0].recallVec(workA, inVec);
    for (int i=1; i<nets.length; i++) {
      nets[i].recallVec(workA, workA);
    }
    rp.wht(workA);
    for (int i=0; i<categories; i++) {
      cats[i]=workA[i];
    }
  }

  int recall(float[] inVec) {
    recall(outVec, inVec);
    float max=outVec[0];
    int pos=0;
    for (int i=1; i<categories; i++) {      
      if (outVec[i]>max) {
        max=outVec[i];
        pos=i;
      }
    }
    return pos;
  }

  void train(float[] inVec, int cat, int level) {
    if (level==0) {
      nets[0].trainVec(catVecs[cat], inVec);
      return;
    } else {
      nets[0].recallVec(workA, inVec);
    }
    for (int i=1; i<level-1; i++) {
      nets[i].recallVec(workA, workA);
    }
    if (level<nets.length-1) {
      nets[level].trainVec(catVecs[cat], workA);
      return;
    } else {
      java.util.Arrays.fill(workB, 0f);
      workB[cat]=1f;
      rp.wht(workB);
      nets[level].trainVec(workB, workA);
    }
  }
} 

class AM {
  RP rp=new RP();
  int vecLen;
  int density;
  float rate;
  int hash;
  float[][] weights;
  float[][] surface;
  float[] workA;
  float[] workB;
  // vecLen must be 2,4,8,16,32.....
  AM(int vecLen, int density, float rate, int hash) {
    this.vecLen=vecLen;
    this.density=density;
    this.rate=rate/density;
    this.hash=hash;
    weights=new float[density][vecLen];
    surface=new float[density][vecLen];
    workA=new float[vecLen];
    workB=new float[vecLen];
  }

  void recallVec(float[] resultVec, float[] inVec) {
    rp.adjust(workA, inVec);
    java.util.Arrays.fill(resultVec, 0f);
    for (int i=0; i<density; i++) {
      rp.signFlip(workA, hash+i);
      rp.wht(workA);
      // rp.signOf(surface[i], workA);
      System.arraycopy(workA, 0, surface[i], 0, vecLen);
      for (int j=0; j<vecLen; j++) {
        resultVec[j]+=weights[i][j]*surface[i][j];
      }
    }
  }

  void trainVec(float[] targetVec, float[] inVec) {
    recallVec(workB, inVec);
    for (int i=0; i<vecLen; i++) {
      workB[i]=targetVec[i]-workB[i];
    }
    for (int i=0; i<density; i++) {
      for (int j=0; j<vecLen; j++) {
        weights[i][j]+=workB[j]*surface[i][j]*rate;
      }
    }
  }
}

final class RP {
  final Xor128 sfRnd=new Xor128();
  // Walsh Hadamard Transform  vec.length must be (2,4,8,16,32.....)
  void wht(float[] vec) {
    int i, j, hs=1, n=vec.length;
    float a, b, scale=1f/sqrt(n);
    while (hs<n) {
      i=0;
      while (i<n) {
        j=i+hs;
        while (i<j) {
          a=vec[i];
          b=vec[i+hs];
          vec[i]=a+b;
          vec[i+hs]=a-b;
          i+=1;
        }
        i+=hs;
      }
      hs+=hs;
    }
    for ( i=0; i<n; i++) {
      vec[i]*=scale;
    }
  }

  void signFlip(float[] vec, int h) {
    sfRnd.setSeed(h);
    for (int i=0; i<vec.length; i++) {
      int x=(int)sfRnd.nextLong()&0x80000000;
      // Faster than -  if(h<0) vec[i]=-vec[i];
      vec[i]=Float.intBitsToFloat(x^Float.floatToRawIntBits(vec[i]));
    }
  }

  // converts each element of x to +1 or -1 according to its sign.
  void signOf(float[] biVec, float[] x ) {
    int one=Float.floatToRawIntBits(1f);
    for (int i=0; i<biVec.length; i++) {
      biVec[i]=Float.intBitsToFloat(one|(Float.floatToRawIntBits(x[i])&0x80000000));
    }
  }

  void adjust(float[] resultVec, float[] x) {
    float sumsq=0f;
    for (int i=0; i<x.length; i++) {
      sumsq+=x[i]*x[i];
    }
    float v=sqrt(sumsq/x.length);
    if (v<1e-20f) {
      signOf(resultVec, x);
    } else {
      v=1f/v;
      for (int i=0; i<x.length; i++) {
        resultVec[i]=v*x[i];
      }
    }
  }
}

// Random number generator
final class Xor128 {

  private long s0;
  private long s1;

  Xor128() {
    setSeed(System.nanoTime());
  }

  public long nextLong() {
    final long s0 = this.s0;
    long s1 = this.s1;
    final long result = s0 + s1;
    s1 ^= s0;
    this.s0 = Long.rotateLeft(s0, 55) ^ s1 ^ s1 << 14;
    this.s1 = Long.rotateLeft(s1, 36);
    return result;
  }

  public float nextFloat() {
    return (nextLong()&0x7FFFFFFFFFFFFFFFL)*1.0842021e-19f;
  }

  public float nextFloatSym() {
    return nextLong()*1.0842021e-19f;
  }

  public boolean nextBoolean() {
    return nextLong() < 0;
  }

  public void setSeed(long seed) {
    s0 = seed*0xBB67AE8584CAA73BL;
    s1 = ~seed*0x9E3779B97F4A7C15L;
  }
}
Sign In or Register to comment.