Source code for nnfwtbn.tests.test_variable


import os
import tempfile
import unittest
import numpy as np
import pandas as pd

from nnfwtbn.variable import Variable, RangeBlindingStrategy

[docs]class VariableTestCase(unittest.TestCase): """ Test the implementation of the variable class. """
[docs] def test_init_store(self): """ Check that all arguments are stored in the object. """ blinding = RangeBlindingStrategy(100, 125) variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV", blinding) self.assertEqual(variable.name, "MMC") self.assertIsNotNone(variable.definition) self.assertEqual(variable.unit, "GeV") self.assertEqual(variable.blinding, blinding)
[docs] def test_init_definition_string(self): """ Check that a string used as the variable definition is wrapped into a lambda. """ variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV") self.assertTrue(callable(variable.definition))
[docs] def test_init_blinding_type(self): """ Check that an error is thrown if the blinding object is not an instance of the abstract blinding class. """ self.assertRaises(TypeError, "MMC", "ditau_mmc_mlm_m", "GeV", "blind")
[docs] def test_repr(self): """ Check that the string representation contains the name of the variable. """ variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV") self.assertEqual(repr(variable), "<Variable 'MMC' [GeV]>") variable = Variable(r"$\Delta \eta$", lambda df: df.jet_0_eta - df.jet_1_eta) self.assertEqual(repr(variable), r"<Variable '$\\Delta \\eta$'>")
[docs] def test_equal_same_values(self): """ Check the equal operator for variables which are created with the same values. """ variable1 = Variable("name", "branch", "unit") variable2 = Variable("name", "branch", "unit") self.assertTrue(variable1 == variable2) self.assertTrue(variable2 == variable1)
[docs] def test_equal_different_name(self): """ Check that variables with a different name are not equal. """ variable1 = Variable("name1", "branch", "unit") variable2 = Variable("name2", "branch", "unit") self.assertFalse(variable1 == variable2) self.assertFalse(variable2 == variable1)
[docs] def test_equal_different_definition(self): """ Check that variables with a different definition are not equal. """ variable1 = Variable("name", "branch1", "unit") variable2 = Variable("name", "branch2", "unit") self.assertFalse(variable1 == variable2) self.assertFalse(variable2 == variable1)
[docs] def test_equal_different_unit(self): """ Check that variables with a different unit are not equal. """ variable1 = Variable("name", "branch", "unit1") variable2 = Variable("name", "branch", "unit2") self.assertFalse(variable1 == variable2) self.assertFalse(variable2 == variable1)
# TODO implement test case for different blinding strategies # def test_equal_different_blinding(self): # pass
[docs] def generate_df(self): """ Generate a toy dataframe. """ return pd.DataFrame({ "x": np.arange(5), "y": np.arange(5)**2 })
[docs] def test_call_column(self): """ Check that calling the variable extracts the given column name. """ df = self.generate_df() variable = Variable("$y$", "y") y_col = variable(df) self.assertListEqual(list(y_col), [0, 1, 4, 9, 16])
[docs] def test_call_lambda(self): """ Check that calling the variable called the given lambda. """ df = self.generate_df() variable = Variable("$x + y$", lambda d: d.x + d.y) sum = variable(df) self.assertListEqual(list(sum), [0, 2, 6, 12, 20])
[docs] def test_saving_and_loading(self): """ Test that saving and loading a variable doesn't change the variable. """ variable1 = Variable("MMC", "ditau_mmc_mlm_m", "GeV") fd, path = tempfile.mkstemp() try: variable1.save_to_h5(path, "variable") variable2 = Variable.load_from_h5(path, "variable") finally: # close file descriptor and delete file os.close(fd) os.remove(path) self.assertTrue(variable1 == variable2)
[docs]class RangeBlindingTestCase(unittest.TestCase): """ Test the implementation of the RangeBlinding class. """
[docs] def generate_df(self): """ Returns a toy dataframe. """ return pd.DataFrame({ "ditau_mmc_mlm_m": np.linspace(0, 400, 400), "x": np.linspace(0, 1, 400), })
[docs] def test_init_store(self): """ Check that the constructor stores all arguments. """ blinding = RangeBlindingStrategy(100, 125) self.assertEqual(blinding.start, 100) self.assertEqual(blinding.end, 125)
[docs] def test_event_blinding(self): """ Check that events in the given region are removed. """ blinding_strategy = RangeBlindingStrategy(100, 125) variable = Variable("MMC", "ditau_mmc_mlm_m") df = self.generate_df() blinding = blinding_strategy(variable, bins=30, range=(50, 200)) blinded_df = blinding(df) # All events outside self.assertTrue(( (blinded_df.ditau_mmc_mlm_m < 100) | (blinded_df.ditau_mmc_mlm_m > 125)).all()) # No events inside self.assertFalse(( (blinded_df.ditau_mmc_mlm_m > 100) & (blinded_df.ditau_mmc_mlm_m < 125)).any()) # Boundary not enlarged self.assertTrue(( (blinded_df.ditau_mmc_mlm_m > 100) & (blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border(self): """ Check that the blind range is extended to match the bin borders. """ blinding_strategy = RangeBlindingStrategy(100, 125) variable = Variable("MMC", "ditau_mmc_mlm_m") df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) blinded_df = blinding(df) # All events outside self.assertTrue(( (blinded_df.ditau_mmc_mlm_m < 100) | (blinded_df.ditau_mmc_mlm_m > 130)).all()) # No events inside self.assertFalse(( (blinded_df.ditau_mmc_mlm_m > 100) & (blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border_left(self): """ Check that the blinding does not break if the blinding is left of the first bin. """ blinding_strategy = RangeBlindingStrategy(10, 125) variable = Variable("MMC", "ditau_mmc_mlm_m") df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) blinded_df = blinding(df) # All events outside self.assertTrue(( (blinded_df.ditau_mmc_mlm_m < 10) | (blinded_df.ditau_mmc_mlm_m > 130)).all()) # No events inside self.assertFalse(( (blinded_df.ditau_mmc_mlm_m > 10) & (blinded_df.ditau_mmc_mlm_m < 130)).any())
[docs] def test_bin_border_right(self): """ Check that the blinding does not break if the blinding is left of the first bin. """ blinding_strategy = RangeBlindingStrategy(100, 225) variable = Variable("MMC", "ditau_mmc_mlm_m") df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) blinded_df = blinding(df) # All events outside self.assertTrue(( (blinded_df.ditau_mmc_mlm_m < 100) | (blinded_df.ditau_mmc_mlm_m > 225)).all()) # No events inside self.assertFalse(( (blinded_df.ditau_mmc_mlm_m > 100) & (blinded_df.ditau_mmc_mlm_m < 225)).any())