import unittest
import numpy as np
from nnfwtbn.toydata import draw
SEED = 798088218969
[docs]class DrawTestCase(unittest.TestCase):
"""
Test the implementation of draw().
"""
@staticmethod
def _rng(seed=SEED):
"""
Returns a new random number generator.
"""
return np.random.Generator(np.random.PCG64(seed))
[docs] def setUp(self):
"""
Instantiate a random number generator.
"""
self.rng = DrawTestCase._rng()
@staticmethod
def _toy_pdf(x):
"""
A toy PDF for testing: Normalized parabola between 0 and 1
"""
return 3 * x**2
@staticmethod
def _toy_pdf2(x):
"""
A toy PDF for testing: Normalized parabola between 1 and 11
"""
return 3 * (x - 1)**2 / 1000
[docs] def test_draw_len(self):
"""
Check that draw returns the number of samples given by the size
parameter.
"""
self.assertEqual(len(draw(self.rng, DrawTestCase._toy_pdf)), 1)
self.assertEqual(len(draw(self.rng, DrawTestCase._toy_pdf, size=10)), 10)
self.assertEqual(draw(self.rng, DrawTestCase._toy_pdf, size=(2, 5)).shape,
(2, 5))
[docs] def test_draw_reproducible(self):
"""
Check that draw returns the same array when called with identical
arguments.
"""
rng1 = DrawTestCase._rng()
rng2 = DrawTestCase._rng()
sample1 = draw(rng1, DrawTestCase._toy_pdf, size=(10, 100))
sample2 = draw(rng2, DrawTestCase._toy_pdf, size=(10, 100))
self.assertTrue((sample1 == sample2).all())
[docs] def test_draw_seed(self):
"""
Check that different arrays are returned when different seeds are
given.
"""
rng1 = DrawTestCase._rng(42)
rng2 = DrawTestCase._rng(43)
sample1 = draw(rng1, DrawTestCase._toy_pdf, size=(10, 100))
sample2 = draw(rng2, DrawTestCase._toy_pdf, size=(10, 100))
self.assertFalse((sample1 == sample2).all())
[docs] def test_draw_limits(self):
"""
Check that the returned numbers are withing the limit.
"""
sample = draw(self.rng, DrawTestCase._toy_pdf, size=1000*1000)
self.assertGreaterEqual(sample.min(), 0)
self.assertLess(sample.min(), 0.01)
self.assertGreater(sample.max(), 0.9999)
self.assertLessEqual(sample.max(), 1)
[docs] def test_draw_limits_2(self):
"""
Check that the returned numbers are withing the limit.
"""
sample = draw(self.rng, DrawTestCase._toy_pdf2, lower=1, upper=11, size=1000*1000)
self.assertGreaterEqual(sample.min(), 1)
self.assertLess(sample.min(), 1.1)
self.assertGreater(sample.max(), 10.999)
self.assertLessEqual(sample.max(), 11)