diff --git a/test_secplus.py b/test_secplus.py index 78cafe7..fb21b2f 100755 --- a/test_secplus.py +++ b/test_secplus.py @@ -134,6 +134,15 @@ class TestSecplus(unittest.TestCase): 0xf085, 0x728c, 0x728c, 0x728c, 0x728c, 0xa191, 0x8081, 0x8092, 0x360e281, 0x260f281, 0x360e281, 0x260f281, 0x8193, 0x8193, 0x8193, 0x8193, ] + wireline_device_id_list = [fixed & 0xf0ffffffff for fixed in wireline_fixed_list] + wireline_command_list = [ + 0x285, 0x18c, 0x18c, 0x18c, 0x18c, 0x391, 0x181, 0x392, + 0x081, 0x081, 0x081, 0x081, 0x393, 0x393, 0x393, 0x393, + ] + wireline_payload_list = [ + 0x00000, 0x20000, 0x20000, 0x20000, 0x20000, 0x10000, 0x00000, 0x00000, + 0x26003, 0x26002, 0x26003, 0x26002, 0x10000, 0x10000, 0x10000, 0x10000, + ] @classmethod def setUpClass(cls): @@ -586,6 +595,136 @@ def test_decode_wireline(self): self.assertEqual(fixed, fixed_out) self.assertEqual(data, data_out) + def test_encode_wireline_command_decode_wireline_command(self): + for _ in range(self.test_cycles): + rolling = random.randrange(2**28) + device_id = random.randrange(2**40) & 0xf0ffffffff + command = random.randrange(2**12) + payload = random.randrange(2**20) + + code = secplus.encode_wireline_command(rolling, device_id, command, payload) + rolling_out, device_id_out, command_out, payload_out = secplus.decode_wireline_command(code) + + self.assertEqual(rolling, rolling_out) + self.assertEqual(device_id, device_id_out) + self.assertEqual(command, command_out) + self.assertEqual(payload, payload_out) + + def test_decode_wireline_command_robustness(self): + for _ in range(self.test_cycles): + random_code = bytes([0x55, 0x01, 0x00] + [random.randrange(256) for _ in range(16)]) + try: + rolling, device_id, command, payload = secplus.decode_wireline_command(random_code) + self.assertLessEqual(rolling, 2**28-1) + self.assertLessEqual(device_id, 2**40-1) + self.assertEqual(device_id & 0x0f00000000, 0) + self.assertLessEqual(command, 2**12-1) + self.assertLessEqual(payload, 2**20-1) + except ValueError: + pass + + for _ in range(self.test_cycles): + rolling = random.randrange(2**28) + device_id = random.randrange(2**40) & 0xf0ffffffff + command = random.randrange(2**12) + payload = random.randrange(2**20) + code = secplus.encode_wireline_command(rolling, device_id, command, payload) + random_code = bytes(b if random.randrange(19) > 0 else random.randrange(256) for b in code) + try: + rolling, device_id, command, payload = secplus.decode_wireline_command(random_code) + self.assertLessEqual(rolling, 2**28-1) + self.assertLessEqual(device_id, 2**40-1) + self.assertEqual(device_id & 0x0f00000000, 0) + self.assertLessEqual(command, 2**12-1) + self.assertLessEqual(payload, 2**20-1) + except ValueError: + pass + + def test_decode_wireline_command_input_validation(self): + with self.assertRaises(ValueError) as cm: + secplus.decode_wireline_command("foo") + self.assertEqual(str(cm.exception), "Input must be bytes") + with self.assertRaises(ValueError) as cm: + secplus.decode_wireline_command(b"foo") + self.assertEqual(str(cm.exception), "Input must be 19 bytes long") + with self.assertRaises(ValueError) as cm: + secplus.decode_wireline_command(b"foo bar foo bar foo") + self.assertIn(str(cm.exception), ["First three bytes must be 0x55, 0x01, 0x00", "Invalid input"]) + + def test_encode_wireline_command_rolling_limit(self): + rolling = 2**28 + device_id = 2**40 - 1 + command = 2**12 - 1 + payload = 2**20 - 1 + + with self.assertRaises(ValueError) as cm: + secplus.encode_wireline_command(rolling, device_id, command, payload) + self.assertIn(str(cm.exception), ["Rolling code must be less than 2^28", "Invalid input"]) + + def test_encode_wireline_command_device_id_limit(self): + rolling = 2**28 - 1 + device_id = 2**40 + command = 2**12 - 1 + payload = 2**20 - 1 + + with self.assertRaises(ValueError) as cm: + secplus.encode_wireline_command(rolling, device_id, command, payload) + self.assertIn(str(cm.exception), ["Device ID must be less than 2^40", "Invalid input"]) + + def test_encode_wireline_command_limit(self): + rolling = 2**28 - 1 + device_id = 2**40 - 1 + command = 2**12 + payload = 2**20 - 1 + + with self.assertRaises(ValueError) as cm: + secplus.encode_wireline_command(rolling, device_id, command, payload) + self.assertEqual(str(cm.exception), "Command must be less than 2^12") + + def test_encode_wireline_payload_limit(self): + rolling = 2**28 - 1 + device_id = 2**40 - 1 + command = 2**12 - 1 + payload = 2**20 + + with self.assertRaises(ValueError) as cm: + secplus.encode_wireline_command(rolling, device_id, command, payload) + self.assertEqual(str(cm.exception), "Payload value must be less than 2^20") + + def test_decode_wireline_command_bits_8_9(self): + for code in self.wireline_codes: + for byte_offset in (4, 12): + for bit_mask in (0x40, 0x80, 0xc0): + broken_code = code.copy() + broken_code[byte_offset] |= bit_mask + with self.assertRaises(ValueError) as cm: + secplus.decode_wireline_command(bytes(broken_code)) + self.assertIn(str(cm.exception), ["Unexpected values for bits 8 and 9", "Invalid input"]) + + def test_encode_wireline_command(self): + for code, rolling, device_id, command, payload in zip(self.wireline_codes, + self.wireline_rolling_list, + self.wireline_device_id_list, + self.wireline_command_list, + self.wireline_payload_list): + code = bytes(code) + code_out = secplus.encode_wireline_command(rolling, device_id, command, payload) + + self.assertEqual(code, code_out) + + def test_decode_wireline_command(self): + for code, rolling, device_id, command, payload in zip(self.wireline_codes, + self.wireline_rolling_list, + self.wireline_device_id_list, + self.wireline_command_list, + self.wireline_payload_list): + rolling_out, device_id_out, command_out, payload_out = secplus.decode_wireline_command(bytes(code)) + + self.assertEqual(rolling, rolling_out) + self.assertEqual(device_id, device_id_out) + self.assertEqual(command, command_out) + self.assertEqual(payload, payload_out) + def substitute_c(): import platform