Monday, February 09, 2015

Generating random numbers from a given set of numbers according to a distribution in java

Recently i am asked to implement following in Java based on the below interface:

Implement the method nextNum() and a minimal but effective set of unit tests.  As a quick check, given Random Numbers are [-1, 0, 1, 2, 3] and Probabilities are [0.01, 0.3, 0.58, 0.1, 0.01] if we call nextNum() 100 times we may get the following results. As the results are random, these particular results are unlikely.

-1: 1 times
0: 22 times
1: 57 times
2: 20 times
3: 0 times
public class RandomGen {
 // Values that may be returned by nextNum()
 private int[] randomNums;
 // Probability of the occurence of randomNums
 private float[] probabilities;

 /**
 Returns one of the randomNums. When this method is called
 multiple times over a long period, it should return the
 numbers roughly with the initialized probabilities.
 */
 public int nextNum() {

 }
}
Solution:
package com.denizstij.rand;

import java.util.Arrays;
import java.util.Random;
/**
 * @author Deniz Turan, http://denizstij.blogspot.co.uk denizstij AT gmail.com
 */
public class RandomGen {
	// Values that may be returned by nextNum()
	private int[] randomNums;
	// Probability of the occurence of randomNums
	private float[] probabilities;
	private float[] cumProbabilities;
	private int[] validRandomNums;
	private int totalNonZeroProbElement=0;
	private Random random;
	
	public RandomGen(int[] randomNums,float[] probabilities){
		validateProcessInputs(randomNums,probabilities);
		random= new Random();
	}
	
	public RandomGen(int[] randomNums,float[] probabilities, long seed){
		validateProcessInputs(randomNums,probabilities);
		random= new Random(seed);
	}
		
	private void validateProcessInputs(int[] randomNums,float[] probabilities){			
		if (randomNums==null || probabilities==null || randomNums.length!= probabilities.length || probabilities.length==0){
			throw new IllegalArgumentException("RandomNums and probabilities must be non empty and same size");
		}
		
		int len=probabilities.length;
		cumProbabilities= new float[len];
		validRandomNums= new int[len];
				
		for (int i=0;i < len;i++) {			
			float prob = probabilities[i];
			int randomNum = randomNums[i];
			if (MathUtil.isLess(prob,0.0f)){
				throw new IllegalArgumentException("Probability can not be negative");
			}
			
			if (MathUtil.isGreater(prob,1.0f)){
				throw new IllegalArgumentException("Probability can not be greater than 1");
			}
			
			if (MathUtil.checkAnyIsNanOrInfinite(prob)){
				throw new IllegalArgumentException("Nan or Infite prob is not accepted");
			}

			// not processing elements with zero probabilities as they can not occur
			if (MathUtil.isEqual(prob, 0.0f)){
				continue;
			}
			// store valid random number
			validRandomNums[totalNonZeroProbElement]=randomNum;
			if (totalNonZeroProbElement==0){
				cumProbabilities[totalNonZeroProbElement]=prob;				
			} else {
				cumProbabilities[totalNonZeroProbElement]=prob+cumProbabilities[totalNonZeroProbElement-1];				
			}			
			if (MathUtil.isGreater(cumProbabilities[totalNonZeroProbElement],1.0f)){
				throw new IllegalArgumentException("Cumilative total probability can not be greater than 1");
			}
			totalNonZeroProbElement++;
		}
		
		if (totalNonZeroProbElement==0){
			throw new IllegalArgumentException("All probabilities are zero :(");
		}
		
		if (!MathUtil.isEqual(cumProbabilities[totalNonZeroProbElement-1],1.0f)){
			throw new IllegalArgumentException("Total probability must be 1");
		}

		// let's shrink arrays to save memory
		this.validRandomNums=Arrays.copyOf(validRandomNums, totalNonZeroProbElement);
		this.cumProbabilities=Arrays.copyOf(cumProbabilities, totalNonZeroProbElement); 
		this.randomNums=randomNums;
		this.probabilities=probabilities;
		
		// some logging 
		System.out.println("RandomNums :"+Arrays.toString(this.randomNums));
		System.out.println("probabilities :"+Arrays.toString(this.probabilities));
		System.out.println("totalNonZeroProbElement:"+totalNonZeroProbElement);
		System.out.println("validRandomNums :"+Arrays.toString(validRandomNums));
		System.out.println("cumProbabilities :"+Arrays.toString(cumProbabilities));				
	}
	
	/**
	Returns one of the randomNums. When this method is called
	multiple times over a long period, it should return the
	numbers roughly with the initialized probabilities.
	*/
	public int nextNum() {
		float nextFloat = random.nextFloat();
		return nextNum(nextFloat);
	}
	
	/**
	Returns one of the randomNums. When this method is called
	multiple times over a long period, it should return the
	numbers roughly with the initialized probabilities.
	*/
	// For testing purposes 
	protected int nextNum(float nextFloat) {		
		int index = Arrays.binarySearch(cumProbabilities,nextFloat);
		// if not found in the cum probability array, we need to infer index from insertation index 
		if (index<0 code="" index="-1*index-1;" int="" nextrandomval="" return="">

package com.denizstij.rand;

/**
 * @author Deniz Turan, http://denizstij.blogspot.co.uk denizstij@gmail.com
 */
public class MathUtil {
	public static final float FLOAT_COMPARE_RESOLUTION=0.00001f;	
	
	public static boolean isEqual(float val1, float val2) {		
		if (Math.abs(val1-val2)<FLOAT_COMPARE_RESOLUTION){
			return true;
		}				
		return false;
	}
	
	public static boolean isLess(float val1, float val2) {
		if (isEqual(val1,val2)){
			return false;
		}
		return val1<val2;
	}
	
	public static boolean isLessEqual(float val1, float val2) {
		if (isEqual(val1,val2)){
			return true;
		}
		return val1<val2;
	}
	
	public static boolean isGreater(float val1, float val2) {
		if (isEqual(val1,val2)){
			return false;
		}
		return val1>val2;
	}
	
	public static boolean isGreaterEqual(float val1, float val2) {
		if (isEqual(val1,val2)){
			return true;
		}
		return val1>val2;
	}
	
	public static boolean checkAnyIsNanOrInfinite(float ...values){
		for (float val : values) {
			if (Float.isNaN(val) || Float.isInfinite(val)){
				return true;
			}
		}
		return false;
	}
}
Junit Test:

package com.denizstij.rand.test;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;

import org.junit.Assert;
import org.junit.Test;

import com.denizstij.rand.RandomGen;

/**
 * @author Deniz Turan, http://denizstij.blogspot.co.uk denizstij@gmail.com
 */
public class RandomGenTest {
	private static final float SIM_PROB_THRESHOLD = 0.001f;

	@Test(expected=IllegalArgumentException.class)
	public void testNullInputs() {
		int[] randomNums=null;
		float[] probabilities=null;
		new RandomGen(randomNums, probabilities);		
	}

	@Test(expected=IllegalArgumentException.class)
	public void testEmptyArray() {
		int[] randomNums= new int[0];
		float[] probabilities=new float[0];
		new RandomGen(randomNums, probabilities);		
	}

	@Test(expected=IllegalArgumentException.class)
	public void testDifferentSizeNumsAndProbs() {
		int[] randomNums= {1};
		float[] probabilities={0.3f,0.4f};
		new RandomGen(randomNums, probabilities);		
	}
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbsAndNumberArrayDifferentSize() {
		int[] randomNums= {1};
		float[] probabilities={0.3f,0.4f};
		new RandomGen(randomNums, probabilities);		
	}
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbsIsNotNegative() {
		int[] randomNums= {1,2};
		float[] probabilities={0.3f,-0.4f};
		new RandomGen(randomNums, probabilities);		
	}
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbsAllZero() {
		int[] randomNums= {1,2};
		float[] probabilities={0.0f,0.0f};
		new RandomGen(randomNums, probabilities);		
	}
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbArrayHasNaN() {
		int[] randomNums= {1,2};
		float[] probabilities={Float.NaN,0.4f};
		new RandomGen(randomNums, probabilities);		
	}	
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbsIsNotGreaterThanOne() {
		int[] randomNums= {1,2};
		float[] probabilities={0.3f,1.4f};
		new RandomGen(randomNums, probabilities);		
	}
	
	@Test(expected=IllegalArgumentException.class)
	public void testProbsTotalIsNotOne() {
		int[] randomNums= {1,2};
		float[] probabilities={0.3f,0.4f};
		new RandomGen(randomNums, probabilities);		
	}	
	
	@Test
	public void testCumulativeProbLogic() {
		int[] randomNums= {-1,-2,-4,3410,893,9,343,10};
		float[] probabilities={0.0f,0.0f,0.0f,0.3f,0.5f,0.0f,0.2f,0.0f};
		//validRandomNums :[3410, 893, 343]
		//cumProbabilities :[0.3, 0.8, 1.0]
		RandomGenWrapper randomGen = new RandomGenWrapper(randomNums, probabilities);
		// checking random number:3410 which has cum prob of 0<= and <=0.3
		int nextRandomNum = randomGen.nextNum(0.2f);		
		Assert.assertTrue(nextRandomNum==3410);
		nextRandomNum = randomGen.nextNum(0.299f);		
		Assert.assertTrue(nextRandomNum==3410);
		nextRandomNum = randomGen.nextNum(0.3f);		
		Assert.assertTrue(nextRandomNum==3410);

		// checking random number:893 which has cum prob of 0.3< and <=0.8		
		nextRandomNum = randomGen.nextNum(0.30001f);		
		Assert.assertTrue(nextRandomNum==893);
		nextRandomNum = randomGen.nextNum(0.51f);		
		Assert.assertTrue(nextRandomNum==893);
		nextRandomNum = randomGen.nextNum(0.799999f);		
		Assert.assertTrue(nextRandomNum==893);
		nextRandomNum = randomGen.nextNum(0.8f);		
		Assert.assertTrue(nextRandomNum==893);
		
		// checking random number:343 which has cum prob of 0.8< and <=1		
		nextRandomNum = randomGen.nextNum(0.800001f);
		Assert.assertTrue(nextRandomNum==343);
		nextRandomNum = randomGen.nextNum(0.9f);
		Assert.assertTrue(nextRandomNum==343);
		nextRandomNum = randomGen.nextNum(0.9999999f);
		Assert.assertTrue(nextRandomNum==343);
		nextRandomNum = randomGen.nextNum(1);
		Assert.assertTrue(nextRandomNum==343);
	}
	
	@Test
	public void testNonOccuringEventsWithProbZero() {
		// -1 event should never happen here	
		int[] randomNums= {5,-1,2,4};
		float[] probabilities={0.1f,0.0f,0.6f,0.3f};
		RandomGen randomGen = new RandomGen(randomNums, probabilities);
		//lets try several times if this true
		int maxTry=100000;
		for (int i=0;i<maxTry;i++){
			int nextNum = randomGen.nextNum();
			// make sure we never get -1 which has prob 0
			Assert.assertNotEquals(nextNum, -1);
		}		
	}
	
	@Test
	public void testAlwaysOccuringEventsWithProbOne() {
		// -1 event should always happen here	
		int[] randomNums= {5,-1,2,4};
		float[] probabilities={0.0f,1.0f,0.0f,0.0f};
		RandomGen randomGen = new RandomGen(randomNums, probabilities);
		//lets try several times if this true
		int maxTry=100000;
		for (int i=0;i<maxTry;i++){
			int nextNum = randomGen.nextNum();
			// make sure we always get -1 which has prob 1
			Assert.assertEquals(nextNum, -1);
		}		
	}
	
	/*
	 * Generate huge amount of random next number, given probability and random number arrays. 
	 * Then, frequency of each random number is estimated.
	 * Finally it is asserted that frequency is close to the input probability within a threshold (0.001)   
	 */
	@Test
	public void testProvidedQuickCheck() {
		int[] randomNums= {-1,0,1,2,3};
		float[] probabilities={0.01f, 0.3f, 0.58f, 0.1f, 0.01f};
		float[] simFrequency= new float[probabilities.length];
		
		Map <Integer,Long> counterMap= new LinkedHashMap<Integer,Long>();
		// reset counter map
		for (int randomNum : randomNums) {
			counterMap.put(randomNum, 0L);
		}
		
		RandomGen randomGen = new RandomGen(randomNums, probabilities,1234L);
		//lets generate many times random numbers
		int maxTry=10000000;
		for (int i=0;i<maxTry;i++){
			int nextNum = randomGen.nextNum();
			Long counter = counterMap.get(nextNum);			
			counter++;
			counterMap.put(nextNum, counter);			
		}
		
		// estimate frequency 
		int index=0;
		for (Entry<Integer, Long> entry : counterMap.entrySet()) {			
			Long counter = entry.getValue();
			// simulation frequencies
			simFrequency[index]=counter/(float)maxTry;
			float simProbDifference = Math.abs(simFrequency[index]-probabilities[index]);
			// 
			Assert.assertTrue(simProbDifference<SIM_PROB_THRESHOLD);
			index++;
		}

		System.out.println("Counter Map "+counterMap);
		System.out.println("Freq map:"+Arrays.toString(simFrequency));
		
	}
	private class RandomGenWrapper extends RandomGen{

		public RandomGenWrapper(int[] randomNums, float[] probabilities) {
			super(randomNums, probabilities);		
		}
		
		public int nextNum(float nextFloat){
			return super.nextNum(nextFloat);
		}
	}
}


1 comment:

Unknown said...

nextNum(float nextFloat) {} doesn't finish