Source code for nnfwtbn.tests.test_toydata


import unittest
import numpy as np
from nnfwtbn.toydata import draw

SEED = 798088218969

[docs]class DrawTestCase(unittest.TestCase): """ Test the implementation of draw(). """ @staticmethod def _rng(seed=SEED): """ Returns a new random number generator. """ return np.random.Generator(np.random.PCG64(seed))
[docs] def setUp(self): """ Instantiate a random number generator. """ self.rng = DrawTestCase._rng()
@staticmethod def _toy_pdf(x): """ A toy PDF for testing: Normalized parabola between 0 and 1 """ return 3 * x**2 @staticmethod def _toy_pdf2(x): """ A toy PDF for testing: Normalized parabola between 1 and 11 """ return 3 * (x - 1)**2 / 1000
[docs] def test_draw_len(self): """ Check that draw returns the number of samples given by the size parameter. """ self.assertEqual(len(draw(self.rng, DrawTestCase._toy_pdf)), 1) self.assertEqual(len(draw(self.rng, DrawTestCase._toy_pdf, size=10)), 10) self.assertEqual(draw(self.rng, DrawTestCase._toy_pdf, size=(2, 5)).shape, (2, 5))
[docs] def test_draw_reproducible(self): """ Check that draw returns the same array when called with identical arguments. """ rng1 = DrawTestCase._rng() rng2 = DrawTestCase._rng() sample1 = draw(rng1, DrawTestCase._toy_pdf, size=(10, 100)) sample2 = draw(rng2, DrawTestCase._toy_pdf, size=(10, 100)) self.assertTrue((sample1 == sample2).all())
[docs] def test_draw_seed(self): """ Check that different arrays are returned when different seeds are given. """ rng1 = DrawTestCase._rng(42) rng2 = DrawTestCase._rng(43) sample1 = draw(rng1, DrawTestCase._toy_pdf, size=(10, 100)) sample2 = draw(rng2, DrawTestCase._toy_pdf, size=(10, 100)) self.assertFalse((sample1 == sample2).all())
[docs] def test_draw_limits(self): """ Check that the returned numbers are withing the limit. """ sample = draw(self.rng, DrawTestCase._toy_pdf, size=1000*1000) self.assertGreaterEqual(sample.min(), 0) self.assertLess(sample.min(), 0.01) self.assertGreater(sample.max(), 0.9999) self.assertLessEqual(sample.max(), 1)
[docs] def test_draw_limits_2(self): """ Check that the returned numbers are withing the limit. """ sample = draw(self.rng, DrawTestCase._toy_pdf2, lower=1, upper=11, size=1000*1000) self.assertGreaterEqual(sample.min(), 1) self.assertLess(sample.min(), 1.1) self.assertGreater(sample.max(), 10.999) self.assertLessEqual(sample.max(), 11)