From 9a1c6b5dd98d6bc0f5381c1f1531a9fe93113675 Mon Sep 17 00:00:00 2001 From: Patrick Guo Date: Thu, 5 Dec 2024 16:58:08 -0500 Subject: [PATCH] restructuring to use data driven approach with datatypes --- pyrad/__init__.py | 2 +- pyrad/datatypes/__init__.py | 0 pyrad/datatypes/base.py | 63 ++++ pyrad/datatypes/leaf.py | 566 ++++++++++++++++++++++++++++++++++ pyrad/datatypes/structural.py | 130 ++++++++ pyrad/dictionary.py | 55 +++- pyrad/packet.py | 67 ++-- pyrad/tools.py | 255 --------------- tests/data/full | 2 + tests/testDatatypes.py | 98 ++++++ tests/testDictionary.py | 62 ++-- tests/testPacket.py | 12 +- tests/testTools.py | 127 -------- 13 files changed, 997 insertions(+), 442 deletions(-) create mode 100644 pyrad/datatypes/__init__.py create mode 100644 pyrad/datatypes/base.py create mode 100644 pyrad/datatypes/leaf.py create mode 100644 pyrad/datatypes/structural.py delete mode 100644 pyrad/tools.py create mode 100644 tests/testDatatypes.py delete mode 100644 tests/testTools.py diff --git a/pyrad/__init__.py b/pyrad/__init__.py index 56f924e..c1ec1d3 100644 --- a/pyrad/__init__.py +++ b/pyrad/__init__.py @@ -43,4 +43,4 @@ __copyright__ = 'Copyright 2002-2023 Wichert Akkerman, Istvan Ruzman and Christian Giese. All rights reserved.' __version__ = '2.4' -__all__ = ['client', 'dictionary', 'packet', 'server', 'tools', 'dictfile'] +__all__ = ['client', 'dictionary', 'packet', 'server', 'datatypes', 'dictfile'] diff --git a/pyrad/datatypes/__init__.py b/pyrad/datatypes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyrad/datatypes/base.py b/pyrad/datatypes/base.py new file mode 100644 index 0000000..92e833d --- /dev/null +++ b/pyrad/datatypes/base.py @@ -0,0 +1,63 @@ +""" +base.py + +Contains base datatype +""" +from abc import ABC, abstractmethod + +class AbstractDatatype(ABC): + """ + Root of entire datatype class hierarchy + """ + def __init__(self, name): + self.name = name + + @abstractmethod + def encode(self, attribute, decoded, *args, **kwargs): + """ + turns python data structure into bytes + + :param *args: + :param **kwargs: + :param attribute: + :param decoded: python data structure to encode + :return: encoded bytes + """ + + @abstractmethod + def print(self, attribute, decoded, *args, **kwargs): + """ + returns string representation of decoding + + :param *args: + :param **kwargs: + :param attribute: attribute object + :param decoded: value pair + :return: string + """ + + @abstractmethod + def parse(self, dictionary, string, *args, **kwargs): + """ + returns python structure from ASCII string + + :param *args: + :param **kwargs: + :param dictionary: + :param string: ASCII string of attribute + :return: python structure for attribute + """ + + @abstractmethod + def get_value(self, dictionary, code, attribute, packet, offset): + """ + retrieves the encapsulated value + :param dictionary: + :param code: + :param *args: + :param **kwargs: + :param attribute: attribute value + :param packet: packet + :param offset: attribute starting position + :return: encapsulated value, and bytes read + """ diff --git a/pyrad/datatypes/leaf.py b/pyrad/datatypes/leaf.py new file mode 100644 index 0000000..b5f5ad2 --- /dev/null +++ b/pyrad/datatypes/leaf.py @@ -0,0 +1,566 @@ +""" +leaf.py + +Contains all leaf datatypes (ones that can be encoded and decoded directly) +""" +import binascii +import enum +import struct +from abc import ABC, abstractmethod +from datetime import datetime +from ipaddress import IPv4Address, IPv6Network, IPv6Address, IPv4Network, \ + AddressValueError + +from netaddr import EUI, core + +from pyrad.datatypes import base + + +class AbinaryKeystores(enum.Enum): + PYRAD = 1 + FREERADIUS = 2 + +class AbstractLeaf(base.AbstractDatatype, ABC): + """ + abstract class for leaf datatypes + """ + @abstractmethod + def decode(self, raw, *args, **kwargs): + """ + turns bytes into python data structure + + :param *args: + :param **kwargs: + :param raw: bytes + :return: python data structure + """ + + def get_value(self, dictionary, code, attribute, packet, offset): + _, attr_len = struct.unpack('!BB', packet[offset:offset + 2])[0:2] + return ((code, packet[offset + 2:offset + attr_len]),), attr_len + +class AscendBinary(AbstractLeaf): + """ + leaf datatype class for ascend binary + """ + def __init__(self): + super().__init__('abinary') + + def encode(self, attribute, decoded, *args, **kwargs): + # unless specified throw kwargs, use the pyrad keystore to avoid + # causing breakages + keystore = AbinaryKeystores.PYRAD + if 'keystore_abinary' in kwargs: + keystore = kwargs['keystore_abinary'] + + match keystore: + case AbinaryKeystores.PYRAD: + terms = { + 'family': b'\x01', + 'action': b'\x00', + 'direction': b'\x01', + 'src': b'\x00\x00\x00\x00', + 'dst': b'\x00\x00\x00\x00', + 'srcl': b'\x00', + 'dstl': b'\x00', + 'proto': b'\x00', + 'sport': b'\x00\x00', + 'dport': b'\x00\x00', + 'sportq': b'\x00', + 'dportq': b'\x00' + } + case AbinaryKeystores.FREERADIUS: + terms = { + 'srcip': b'=x00', + 'dstip': None, + 'srcmask': '', + 'dstmask': '', + 'proto': '', + 'established': '', + 'srcport': '', + 'dstport': '', + 'srcPortCmp': '', + 'dstPortCmp': '', + 'fill': '' + } + case _: + raise ValueError('Unexpected abinary keystore name') + + + family = 'ipv4' + for t in decoded.split(' '): + key, value = t.split('=') + if key == 'family' and value == 'ipv6': + family = 'ipv6' + terms[key] = b'\x03' + if terms['src'] == b'\x00\x00\x00\x00': + terms['src'] = 16 * b'\x00' + if terms['dst'] == b'\x00\x00\x00\x00': + terms['dst'] = 16 * b'\x00' + elif key == 'action' and value == 'accept': + terms[key] = b'\x01' + elif key == 'action' and value == 'redirect': + terms[key] = b'\x20' + elif key == 'direction' and value == 'out': + terms[key] = b'\x00' + elif key in ('src', 'dst'): + if family == 'ipv4': + ip = IPv4Network(value) + else: + ip = IPv6Network(value) + terms[key] = ip.network_address.packed + terms[key + 'l'] = struct.pack('B', ip.prefixlen) + elif key in ('sport', 'dport'): + terms[key] = struct.pack('!H', int(value)) + elif key in ('sportq', 'dportq', 'proto'): + terms[key] = struct.pack('B', int(value)) + + trailer = 8 * b'\x00' + + result = b''.join( + (terms['family'], terms['action'], terms['direction'], b'\x00', + terms['src'], terms['dst'], terms['srcl'], terms['dstl'], + terms['proto'], b'\x00', + terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], + b'\x00\x00', trailer)) + return result + + def decode(self, raw, *args, **kwargs): + # just return the raw binary string + return raw + + def print(self, attribute, decoded): + # the binary string is what we are looking for + return decoded + + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string + +class Byte(AbstractLeaf): + """ + leaf datatype class for bytes (1 byte unsigned int) + """ + def __init__(self): + super().__init__('byte') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as byte') from exc + return struct.pack('!B', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!B', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + # cast int to string before returning + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as byte') from e + else: + if num < 0: + raise ValueError('Parsed value too small for byte') + if num > 255: + raise ValueError('Parsed value too large for byte') + return num + +class Date(AbstractLeaf): + """ + leaf datatype class for dates + """ + def __init__(self): + super().__init__('date') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, int): + raise TypeError('Can not encode non-integer as date') + return struct.pack('!I', decoded) + + def decode(self, raw, *args, **kwargs): + # dates are stored as ints + return (struct.unpack('!I', raw))[0] + + def print(self, attribute, decoded, *args, **kwargs): + # turn seconds since epoch into timestamp with given format + return datetime.fromtimestamp(decoded).strftime('%Y-%m-%dT%H:%M:%S') + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + # parse string using given string, and return seconds since epoch + # as an int + return int(datetime.strptime(string, '%Y-%m-%dT%H:%M:%S') + .timestamp()) + except ValueError as e: + raise TypeError('Failed to parse date') from e + +class Ether(AbstractLeaf, ABC): + """ + leaf datatype class for ethernet addresses + """ + def __init__(self): + super().__init__('ether') + + def encode(self, attribute, decoded, *args, **kwargs): + return struct.pack('!6B', *map(lambda x: int(x, 16), decoded.split(':'))) + + def decode(self, raw, *args, **kwargs): + # return EUI object containing mac address + return EUI(':'.join(map('{0:02x}'.format, struct.unpack('!6B', raw)))) + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError('Can not encode non-string as ethernet address') + + try: + return EUI(string) + except core.AddrFormatError as e: + raise ValueError('Could not decode ethernet address') from e + +class Ifid(AbstractLeaf, ABC): + """ + leaf datatype class for IFID (IPV6 interface ID) + """ + def __init__(self): + super().__init__('ifid') + + def encode(self, attribute, decoded, *args, **kwargs): + struct.pack('!HHHH', *map(lambda x: int(x, 16), decoded.split(':'))) + + def decode(self, raw, *args, **kwargs): + ':'.join(map('{0:04x}'.format, struct.unpack('!HHHH', raw))) + + def print(self, attribute, decoded, *args, **kwargs): + # Following freeradius, IFIDs are displayed as a hex without any + # delimiters + return decoded.replace(':', '') + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + # adds a : delimiter after every second character + return ':'.join((string[i:i + 2] for i in range(0, len(string), 2))) + +class Integer(AbstractLeaf): + """ + leaf datatype class for integers + """ + def __init__(self): + super().__init__('integer') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as integer') from exc + return struct.pack('!I', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!I', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as int') from e + else: + if num < 0: + raise ValueError('Parsed value too small for int') + if num > 4294967295: + raise ValueError('Parsed value too large for int') + return num + +class Integer64(AbstractLeaf): + """ + leaf datatype class for 64bit integers + """ + def __init__(self): + super().__init__('integer64') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as 64bit integer') from exc + return struct.pack('!Q', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!Q', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as int64') from e + else: + if num < 0: + raise ValueError('Parsed value too small for int64') + if num > 18446744073709551615: + raise ValueError('Parsed value too large for int64') + return num + +class Ipaddr(AbstractLeaf): + """ + leaf datatype class for ipv4 addresses + """ + def __init__(self): + super().__init__('ipaddr') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('Address has to be a string') + return IPv4Address(decoded).packed + + def decode(self, raw, *args, **kwargs): + # stored as strings, not ipaddress objects + return '.'.join(map(str, struct.unpack('BBBB', raw))) + + def print(self, attribute, decoded, *args, **kwargs): + # since object is already stored as a string, just return it as is + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + # check if string is valid ipv4 address, but still returning the + # string representation + return IPv4Address(string).exploded + except AddressValueError as e: + raise TypeError('Parsing invalid IPv4 address') from e + +class Ipv6addr(AbstractLeaf): + """ + leaf datatype class for ipv6 addresses + """ + def __init__(self): + super().__init__('ipv6addr') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('IPv6 Address has to be a string') + return IPv6Address(decoded).packed + + def decode(self, raw, *args, **kwargs): + addr = raw + b'\x00' * (16 - len(raw)) + prefix = ':'.join( + map(lambda x: f'{0:x}', struct.unpack('!' + 'H' * 8, addr)) + ) + return str(IPv6Address(prefix)) + + def print(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError(f'Parsing expects a string, got {type(decoded)}') + + try: + # check if valid address, but return string representation + return IPv6Address(decoded).exploded + except AddressValueError as e: + raise TypeError('Parsing invalid IPv6 address') from e + + def parse(self, dictionary, string, *args, **kwargs): + return string + +class Ipv6prefix(AbstractLeaf): + """ + leaf datatype class for ipv6 prefixes + """ + def __init__(self): + super().__init__('ipv6prefix') + + def encode(self, attribute, decoded, *args, **kwargs): + if not isinstance(decoded, str): + raise TypeError('IPv6 Prefix has to be a string') + ip = IPv6Network(decoded) + return (struct.pack('2B', *[0, ip.prefixlen]) + + ip.network_address.packed) + + def decode(self, raw, *args, **kwargs): + addr = raw + b'\x00' * (18 - len(raw)) + _, length, prefix = ':'.join( + map(lambda x: f'{0:x}' , struct.unpack('!BB' + 'H' * 8, addr)) + ).split(":", 2) + # returns string representation in the form of / + return str(IPv6Network(f'{prefix}/{int(length, 16)}')) + + def print(self, attribute, decoded, *args, **kwargs): + # we already store this value as a string, so just return it as is + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + return str(IPv6Network(string)) + except AddressValueError as e: + raise TypeError('Parsing invalid IPv6 prefix') from e + +class Octets(AbstractLeaf): + """ + leaf datatype class for octets + """ + def __init__(self): + super().__init__('octets') + + def encode(self, attribute, decoded, *args, **kwargs): + # Check for max length of the hex encoded with 0x prefix, as a sanity check + if len(decoded) > 508: + raise ValueError('Can only encode strings of <= 253 characters') + + if isinstance(decoded, bytes) and decoded.startswith(b'0x'): + hexstring = decoded.split(b'0x')[1] + encoded_octets = binascii.unhexlify(hexstring) + elif isinstance(decoded, str) and decoded.startswith('0x'): + hexstring = decoded.split('0x')[1] + encoded_octets = binascii.unhexlify(hexstring) + elif isinstance(decoded, str) and decoded.isdecimal(): + encoded_octets = struct.pack('>L', int(decoded)).lstrip( + b'\x00') + else: + encoded_octets = decoded + + # Check for the encoded value being longer than 253 chars + if len(encoded_octets) > 253: + raise ValueError('Can only encode strings of <= 253 characters') + + return encoded_octets + + def decode(self, raw, *args, **kwargs): + return raw + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string + +class Short(AbstractLeaf): + """ + leaf datatype class for short integers + """ + def __init__(self): + super().__init__('short') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as integer') from exc + return struct.pack('!H', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!H', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as short') from e + else: + if num < 0: + raise ValueError('Parsed value too small for short') + if num > 65535: + raise ValueError('Parsed value too large for short') + return num + +class Signed(AbstractLeaf): + """ + leaf datatype class for signed integers + """ + def __init__(self): + super().__init__('signed') + + def encode(self, attribute, decoded, *args, **kwargs): + try: + num = int(decoded) + except Exception as exc: + raise TypeError('Can not encode non-integer as signed integer') from exc + return struct.pack('!i', num) + + def decode(self, raw, *args, **kwargs): + return struct.unpack('!i', raw)[0] + + def print(self, attribute, decoded, *args, **kwargs): + return str(decoded) + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + try: + num = int(string) + except ValueError as e: + raise TypeError('Can not parse non-integer as signed') from e + else: + if num < -2147483648: + raise ValueError('Parsed value too small for signed') + if num > 2147483647: + raise ValueError('Parsed value too large for signed') + return num + +class String(AbstractLeaf): + """ + leaf datatype class for strings + """ + def __init__(self): + super().__init__('string') + + def encode(self, attribute, decoded, *args, **kwargs): + if len(decoded) > 253: + raise ValueError('Can only encode strings of <= 253 characters') + if isinstance(decoded, str): + return decoded.encode('utf-8') + return decoded + + def decode(self, raw, *args, **kwargs): + return raw.decode('utf-8') + + def print(self, attribute, decoded, *args, **kwargs): + return decoded + + def parse(self, dictionary, string, *args, **kwargs): + if not isinstance(string, str): + raise TypeError(f'Parsing expects a string, got {type(string)}') + + return string diff --git a/pyrad/datatypes/structural.py b/pyrad/datatypes/structural.py new file mode 100644 index 0000000..e56ddb3 --- /dev/null +++ b/pyrad/datatypes/structural.py @@ -0,0 +1,130 @@ +""" +structural.py + +Contains all structural datatypes +""" +import struct + +from abc import ABC +from pyrad.datatypes import base +from pyrad.parser import ParserTLV +from pyrad.utility import tlv_name_to_codes, vsa_name_to_codes + +parser_tlv = ParserTLV() + +class AbstractStructural(base.AbstractDatatype, ABC): + """ + abstract class for structural datatypes + """ + +class Tlv(AbstractStructural): + """ + structural datatype class for TLV + """ + def __init__(self): + super().__init__('tlv') + + def encode(self, attribute, decoded, *args, **kwargs): + encoding = b'' + for key, value in decoded.items(): + encoding += attribute.sub_attributes[key].encode(value, ) + + if len(encoding) + 2 > 255: + raise ValueError('TLV length too long for one packet') + + return (struct.pack('!B', attribute.code) + + struct.pack('!B', len(encoding) + 2) + + encoding) + + def get_value(self, dictionary, code, attribute: 'Attribute', packet, offset): + sub_attrs = {} + + _, outer_len = struct.unpack('!BB', packet[offset:offset + 2])[0:2] + + if outer_len < 3: + raise ValueError('TLV length too short') + if offset + outer_len > len(packet): + raise ValueError('TLV length too long') + + # move cursor to TLV value + cursor = offset + 2 + while cursor < offset + outer_len: + (sub_type, sub_len) = struct.unpack( + '!BB', packet[cursor:cursor + 2] + ) + + if sub_len < 3: + raise ValueError('TLV length field too small') + + # future work will allow nested TLVs and structures. for now, TLVs + # must contain leaf attributes. As such, we can just extract the + # value from the packet + value = packet[cursor + 2:cursor + sub_len] + sub_attrs.setdefault(sub_type, []).append(value) + cursor += sub_len + return ((code, sub_attrs),), outer_len + + def print(self, attribute, decoded, *args, **kwargs): + sub_attr_strings = [sub_attr.print() + for sub_attr in attribute.sub_attributes] + return f"{attribute.name} = {{ {', '.join(sub_attr_strings)} }}" + + def parse(self, dictionary, string, *args, **kwargs): + return tlv_name_to_codes(dictionary, parser_tlv.parse(string)) + +class Vsa(AbstractStructural): + """ + structural datatype class for VSA + """ + def __init__(self): + super().__init__('vsa') + + # used for get_value() + self.tlv = Tlv() + + def encode(self, attribute, decoded, *args, **kwargs): + encoding = b'' + + for key, value in decoded.items(): + encoding += attribute.sub_attributes[key].encode(value, ) + + return (struct.pack('!B', attribute.code) + + struct.pack('!B', len(encoding) + 4) + + struct.pack('!L', attribute.vendor) + + encoding) + + def get_value(self, dictionary, code, attribute, packet, offset): + # currently, a list of (code, value) pair is returned. with the v4 + # update, a single (nested) object will be returned + values = [] + + (_, length) = struct.unpack('!BB', packet[offset:offset + 2]) + if length < 8: + return ((26, packet[offset + 2:offset + length]),), length + + vendor = struct.unpack('!L', packet[offset + 2:offset + 6]) + + cursor = offset + 6 + while cursor < offset + length: + (sub_type, _) = struct.unpack('!BB', packet[cursor:cursor + 2]) + + # first, using the vendor ID and sub attribute type, get the name + # of the sub attribute. then, using the name, get the Attribute + # object to call .get_value(...) + sub_attr_name = dictionary.attrindex.GetBackward(vendor + (sub_type,)) + sub_attr = dictionary.attributes[sub_attr_name] + + (sub_value, sub_offset) = sub_attr.get_value(dictionary, (vendor + (sub_type,)), packet, cursor) + + values += sub_value + cursor += sub_offset + + return values, length + + def print(self, attribute, decoded, *args, **kwargs): + sub_attr_strings = [sub_attr.print() + for sub_attr in attribute.sub_attributes] + return f"Vendor-Specific = {{ {attribute.vendor} = {{ {', '.join(sub_attr_strings)} }}" + + def parse(self, dictionary, string, *args, **kwargs): + return vsa_name_to_codes(dictionary, parser_tlv.parse(string)) diff --git a/pyrad/dictionary.py b/pyrad/dictionary.py index abe5263..7552561 100644 --- a/pyrad/dictionary.py +++ b/pyrad/dictionary.py @@ -72,18 +72,37 @@ +---------------+----------------------------------------------+ """ from pyrad import bidict -from pyrad import tools from pyrad import dictfile from copy import copy import logging -__docformat__ = 'epytext en' - +from pyrad.datatypes import leaf, structural -DATATYPES = frozenset(['string', 'ipaddr', 'integer', 'date', 'octets', - 'abinary', 'ipv6addr', 'ipv6prefix', 'short', 'byte', - 'signed', 'ifid', 'ether', 'tlv', 'integer64']) +__docformat__ = 'epytext en' +from pyrad.datatypes.structural import AbstractStructural + +DATATYPES = { + # leaf attributes + 'abinary': leaf.AscendBinary(), + 'byte': leaf.Byte(), + 'date': leaf.Date(), + 'ether': leaf.Ether(), + 'ifid': leaf.Ifid(), + 'integer': leaf.Integer(), + 'integer64': leaf.Integer64(), + 'ipaddr': leaf.Ipaddr(), + 'ipv6addr': leaf.Ipv6addr(), + 'ipv6prefix': leaf.Ipv6prefix(), + 'octets': leaf.Octets(), + 'short': leaf.Short(), + 'signed': leaf.Signed(), + 'string': leaf.String(), + + # structural attributes + 'tlv': structural.Tlv(), + 'vsa': structural.Vsa() +} class ParseError(Exception): """Dictionary parser exceptions. @@ -121,7 +140,7 @@ def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', valu raise ValueError('Invalid data type') self.name = name self.code = code - self.type = datatype + self.type = DATATYPES[datatype] self.vendor = vendor self.encrypt = encrypt self.has_tag = has_tag @@ -133,6 +152,26 @@ def __init__(self, name, code, datatype, is_sub_attribute=False, vendor='', valu for (key, value) in values.items(): self.values.Add(key, value) + def encode(self, decoded, *args, **kwargs): + return self.type.encode(self, decoded, args, kwargs) + + def decode(self, raw): + # Use datatype.decode to decode leaf attributes + if isinstance(raw, bytes): + # precautionary check to see if the raw data is truly being held + # by a leaf attribute + if isinstance(self.type, AbstractStructural): + raise ValueError('Structural datatype holding string!') + return self.type.decode(raw) + + # Recursively calls sub attribute's .decode() until a leaf attribute + # is reached + for sub_attr, value in raw.items(): + raw[sub_attr] = self.sub_attributes[sub_attr].decode(value) + return raw + + def get_value(self, dictionary, code, packet, offset): + return self.type.get_value(dictionary, code, self, packet, offset) class Dictionary(object): """RADIUS dictionary class. @@ -289,7 +328,7 @@ def __ParseValue(self, state, tokens, defer): if adef.type in ['integer', 'signed', 'short', 'byte', 'integer64']: value = int(value, 0) - value = tools.EncodeAttr(adef.type, value) + value = adef.encode(value) self.attributes[attr].values.Add(key, value) def __ParseVendor(self, state, tokens): diff --git a/pyrad/packet.py b/pyrad/packet.py index 4564f8f..b904d69 100644 --- a/pyrad/packet.py +++ b/pyrad/packet.py @@ -6,6 +6,11 @@ from collections import OrderedDict import struct + +from pyrad.datatypes.leaf import Integer, Octets +from pyrad.datatypes.structural import Tlv +from pyrad.dictionary import Attribute + try: import secrets random_generator = secrets.SystemRandom() @@ -27,7 +32,6 @@ # BBB for python 2.4 import md5 md5_constructor = md5.new -from pyrad import tools # Packet codes AccessRequest = 1 @@ -100,6 +104,13 @@ def __init__(self, code=0, id=None, secret=b'', authenticator=None, self.message_authenticator = None self.raw_packet = None + # the presence of some attributes require us to perform certain + # actions. this dict maps the attribute names to the functions to + # perform those actions + self.attr_actions = { + 'Message-Authenticator': self.__attr_action_message_authenticator + } + if 'dict' in attributes: self.dict = attributes['dict'] @@ -248,14 +259,14 @@ def _DecodeValue(self, attr, value): if attr.values.HasBackward(value): return attr.values.GetBackward(value) else: - return tools.DecodeAttr(attr.type, value) + return attr.decode(value) def _EncodeValue(self, attr, value): result = '' if attr.values.HasForward(value): result = attr.values.GetForward(value) else: - result = tools.EncodeAttr(attr.type, value) + result = attr.encode(value) if attr.encrypt == 2: # salt encrypt attribute @@ -275,7 +286,7 @@ def _EncodeKeyValues(self, key, values): key = self._EncodeKey(key) if tag: tag = struct.pack('B', int(tag)) - if attr.type == "integer": + if isinstance(attr.type, Integer): return (key, [tag + self._EncodeValue(attr, v)[1:] for v in values]) else: return (key, [tag + self._EncodeValue(attr, v) for v in values]) @@ -333,7 +344,7 @@ def __getitem__(self, key): values = OrderedDict.__getitem__(self, self._EncodeKey(key)) attr = self.dict.attributes[key] - if attr.type == 'tlv': # return map from sub attribute code to its values + if isinstance(attr.type, Tlv): # return map from sub attribute code to its values res = {} for (sub_attr_key, sub_attr_val) in values.items(): sub_attr_name = attr.sub_attributes[sub_attr_key] @@ -548,33 +559,37 @@ def DecodePacket(self, packet): self.clear() - packet = packet[20:] - while packet: + cursor = 20 + while cursor < len(packet): try: - (key, attrlen) = struct.unpack('!BB', packet[0:2]) + (key, length) = struct.unpack('!BB', packet[cursor:cursor + 2]) except struct.error: raise PacketError('Attribute header is corrupt') - if attrlen < 2: - raise PacketError( - 'Attribute length is too small (%d)' % attrlen) - - value = packet[2:attrlen] - attribute = self.dict.attributes.get(self._DecodeKey(key)) - if key == 26: - for (key, value) in self._PktDecodeVendorAttribute(value): - self.setdefault(key, []).append(value) - elif key == 80: - # POST: Message Authenticator AVP is present. - self.message_authenticator = True - self.setdefault(key, []).append(value) - elif attribute and attribute.type == 'tlv': - self._PktDecodeTlvAttribute(key,value) - else: + if length < 2: + raise PacketError(f'Attribute length is too small {length}') + + attribute: Attribute = self.dict.attributes.get(self._DecodeKey(key)) + if attribute is None: + raise PacketError(f'Unknown attribute key {key}') + + # perform attribute actions as needed + if attribute.name in self.attr_actions: + self.attr_actions[attribute.name](attribute, packet, cursor) + + raw, offset = attribute.get_value(self.dict, key, packet, cursor) + + # for each (key, value) pair from the raw values, add them to the + # packet's data + for key, value in raw: self.setdefault(key, []).append(value) - packet = packet[attrlen:] + cursor += offset + def __attr_action_message_authenticator(self, attribute, packet, offset): + # if the Message-Authenticator attribute is present, set the + # class attribute to True + self.message_authenticator = True def _salt_en_decrypt(self, data, salt): result = b'' @@ -796,7 +811,7 @@ def VerifyChapPasswd(self, userpwd): if isinstance(userpwd, str): userpwd = userpwd.strip().encode('utf-8') - chap_password = tools.DecodeOctets(self.get(3)[0]) + chap_password = Octets().decode(self.get(3))[0] if len(chap_password) != 17: return False diff --git a/pyrad/tools.py b/pyrad/tools.py deleted file mode 100644 index 303eb7a..0000000 --- a/pyrad/tools.py +++ /dev/null @@ -1,255 +0,0 @@ -# tools.py -# -# Utility functions -from ipaddress import IPv4Address, IPv6Address -from ipaddress import IPv4Network, IPv6Network -import struct -import binascii - - -def EncodeString(origstr): - if len(origstr) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - if isinstance(origstr, str): - return origstr.encode('utf-8') - else: - return origstr - - -def EncodeOctets(octetstring): - # Check for max length of the hex encoded with 0x prefix, as a sanity check - if len(octetstring) > 508: - raise ValueError('Can only encode strings of <= 253 characters') - - if isinstance(octetstring, bytes) and octetstring.startswith(b'0x'): - hexstring = octetstring.split(b'0x')[1] - encoded_octets = binascii.unhexlify(hexstring) - elif isinstance(octetstring, str) and octetstring.startswith('0x'): - hexstring = octetstring.split('0x')[1] - encoded_octets = binascii.unhexlify(hexstring) - elif isinstance(octetstring, str) and octetstring.isdecimal(): - encoded_octets = struct.pack('>L',int(octetstring)).lstrip((b'\x00')) - else: - encoded_octets = octetstring - - # Check for the encoded value being longer than 253 chars - if len(encoded_octets) > 253: - raise ValueError('Can only encode strings of <= 253 characters') - - return encoded_octets - - -def EncodeAddress(addr): - if not isinstance(addr, str): - raise TypeError('Address has to be a string') - return IPv4Address(addr).packed - - -def EncodeIPv6Prefix(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Prefix has to be a string') - ip = IPv6Network(addr) - return struct.pack('2B', *[0, ip.prefixlen]) + ip.ip.packed - - -def EncodeIPv6Address(addr): - if not isinstance(addr, str): - raise TypeError('IPv6 Address has to be a string') - return IPv6Address(addr).packed - - -def EncodeAscendBinary(orig_str): - """ - Format: List of type=value pairs separated by spaces. - - Example: 'family=ipv4 action=discard direction=in dst=10.10.255.254/32' - - Note: redirect(0x20) action is added for http-redirect (walled garden) use case - - Type: - family ipv4(default) or ipv6 - action discard(default) or accept or redirect - direction in(default) or out - src source prefix (default ignore) - dst destination prefix (default ignore) - proto protocol number / next-header number (default ignore) - sport source port (default ignore) - dport destination port (default ignore) - sportq source port qualifier (default 0) - dportq destination port qualifier (default 0) - - Source/Destination Port Qualifier: - 0 no compare - 1 less than - 2 equal to - 3 greater than - 4 not equal to - """ - - terms = { - 'family': b'\x01', - 'action': b'\x00', - 'direction': b'\x01', - 'src': b'\x00\x00\x00\x00', - 'dst': b'\x00\x00\x00\x00', - 'srcl': b'\x00', - 'dstl': b'\x00', - 'proto': b'\x00', - 'sport': b'\x00\x00', - 'dport': b'\x00\x00', - 'sportq': b'\x00', - 'dportq': b'\x00' - } - - family = 'ipv4' - for t in orig_str.split(' '): - key, value = t.split('=') - if key == 'family' and value == 'ipv6': - family = 'ipv6' - terms[key] = b'\x03' - if terms['src'] == b'\x00\x00\x00\x00': - terms['src'] = 16 * b'\x00' - if terms['dst'] == b'\x00\x00\x00\x00': - terms['dst'] = 16 * b'\x00' - elif key == 'action' and value == 'accept': - terms[key] = b'\x01' - elif key == 'action' and value == 'redirect': - terms[key] = b'\x20' - elif key == 'direction' and value == 'out': - terms[key] = b'\x00' - elif key == 'src' or key == 'dst': - if family == 'ipv4': - ip = IPv4Network(value) - else: - ip = IPv6Network(value) - terms[key] = ip.network_address.packed - terms[key+'l'] = struct.pack('B', ip.prefixlen) - elif key == 'sport' or key == 'dport': - terms[key] = struct.pack('!H', int(value)) - elif key == 'sportq' or key == 'dportq' or key == 'proto': - terms[key] = struct.pack('B', int(value)) - - trailer = 8 * b'\x00' - - result = b''.join((terms['family'], terms['action'], terms['direction'], b'\x00', - terms['src'], terms['dst'], terms['srcl'], terms['dstl'], terms['proto'], b'\x00', - terms['sport'], terms['dport'], terms['sportq'], terms['dportq'], b'\x00\x00', trailer)) - return result - - -def EncodeInteger(num, format='!I'): - try: - num = int(num) - except: - raise TypeError('Can not encode non-integer as integer') - return struct.pack(format, num) - - -def EncodeInteger64(num, format='!Q'): - try: - num = int(num) - except: - raise TypeError('Can not encode non-integer as integer64') - return struct.pack(format, num) - - -def EncodeDate(num): - if not isinstance(num, int): - raise TypeError('Can not encode non-integer as date') - return struct.pack('!I', num) - - -def DecodeString(orig_str): - return orig_str.decode('utf-8') - - -def DecodeOctets(orig_bytes): - return orig_bytes - - -def DecodeAddress(addr): - return '.'.join(map(str, struct.unpack('BBBB', addr))) - - -def DecodeIPv6Prefix(addr): - addr = addr + b'\x00' * (18-len(addr)) - _, length, prefix = ':'.join(map('{0:x}'.format, struct.unpack('!BB'+'H'*8, addr))).split(":", 2) - return str(IPv6Network("%s/%s" % (prefix, int(length, 16)))) - - -def DecodeIPv6Address(addr): - addr = addr + b'\x00' * (16-len(addr)) - prefix = ':'.join(map('{0:x}'.format, struct.unpack('!'+'H'*8, addr))) - return str(IPv6Address(prefix)) - - -def DecodeAscendBinary(orig_bytes): - return orig_bytes - - -def DecodeInteger(num, format='!I'): - return (struct.unpack(format, num))[0] - -def DecodeInteger64(num, format='!Q'): - return (struct.unpack(format, num))[0] - -def DecodeDate(num): - return (struct.unpack('!I', num))[0] - - -def EncodeAttr(datatype, value): - if datatype == 'string': - return EncodeString(value) - elif datatype == 'octets': - return EncodeOctets(value) - elif datatype == 'integer': - return EncodeInteger(value) - elif datatype == 'ipaddr': - return EncodeAddress(value) - elif datatype == 'ipv6prefix': - return EncodeIPv6Prefix(value) - elif datatype == 'ipv6addr': - return EncodeIPv6Address(value) - elif datatype == 'abinary': - return EncodeAscendBinary(value) - elif datatype == 'signed': - return EncodeInteger(value, '!i') - elif datatype == 'short': - return EncodeInteger(value, '!H') - elif datatype == 'byte': - return EncodeInteger(value, '!B') - elif datatype == 'date': - return EncodeDate(value) - elif datatype == 'integer64': - return EncodeInteger64(value) - else: - raise ValueError('Unknown attribute type %s' % datatype) - - -def DecodeAttr(datatype, value): - if datatype == 'string': - return DecodeString(value) - elif datatype == 'octets': - return DecodeOctets(value) - elif datatype == 'integer': - return DecodeInteger(value) - elif datatype == 'ipaddr': - return DecodeAddress(value) - elif datatype == 'ipv6prefix': - return DecodeIPv6Prefix(value) - elif datatype == 'ipv6addr': - return DecodeIPv6Address(value) - elif datatype == 'abinary': - return DecodeAscendBinary(value) - elif datatype == 'signed': - return DecodeInteger(value, '!i') - elif datatype == 'short': - return DecodeInteger(value, '!H') - elif datatype == 'byte': - return DecodeInteger(value, '!B') - elif datatype == 'date': - return DecodeDate(value) - elif datatype == 'integer64': - return DecodeInteger64(value) - else: - raise ValueError('Unknown attribute type %s' % datatype) diff --git a/tests/data/full b/tests/data/full index c0256b6..8c5aca9 100644 --- a/tests/data/full +++ b/tests/data/full @@ -20,6 +20,8 @@ ATTRIBUTE Test-Encrypted-String 5 string encrypt=2 ATTRIBUTE Test-Encrypted-Octets 6 octets encrypt=2 ATTRIBUTE Test-Encrypted-Integer 7 integer encrypt=2 +ATTRIBUTE Vendor-Specific 26 vsa + VENDOR Simplon 16 diff --git a/tests/testDatatypes.py b/tests/testDatatypes.py new file mode 100644 index 0000000..c8159fe --- /dev/null +++ b/tests/testDatatypes.py @@ -0,0 +1,98 @@ +from ipaddress import AddressValueError +from pyrad.datatypes.leaf import * +import unittest + + +class LeafEncodingTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.abinary = AscendBinary() + cls.byte = Byte() + cls.date = Date() + cls.ether = Ether() + cls.ifid = Ifid() + cls.integer = Integer() + cls.integer64 = Integer64() + cls.ipaddr = Ipaddr() + cls.ipv6addr = Ipv6addr() + cls.ipv6prefix = Ipv6prefix() + cls.octets = Octets() + cls.short = Short() + cls.signed = Signed() + cls.string = String() + + def testStringEncoding(self): + self.assertRaises(ValueError, self.string.encode, None, 'x' * 254) + self.assertEqual( + self.string.encode(None, '1234567890'), + b'1234567890') + + def testInvalidStringEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.string.encode, None, 1) + + def testAddressEncoding(self): + self.assertRaises(AddressValueError, self.ipaddr.encode, None,'TEST123') + self.assertEqual( + self.ipaddr.encode(None, '192.168.0.255'), + b'\xc0\xa8\x00\xff') + + def testInvalidAddressEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.ipaddr.encode, None, 1) + + def testIntegerEncoding(self): + self.assertEqual(self.integer.encode(None, 0x01020304), b'\x01\x02\x03\x04') + + def testInteger64Encoding(self): + self.assertEqual( + self.integer64.encode(None, 0xFFFFFFFFFFFFFFFF), b'\xff' * 8 + ) + + def testUnsignedIntegerEncoding(self): + self.assertEqual(self.integer.encode(None, 0xFFFFFFFF), b'\xff\xff\xff\xff') + + def testInvalidIntegerEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.integer.encode, None, 'ONE') + + def testDateEncoding(self): + self.assertEqual(self.date.encode(None, 0x01020304), b'\x01\x02\x03\x04') + + def testInvalidDataEncodingRaisesTypeError(self): + self.assertRaises(TypeError, self.date.encode, None, '1') + + def testEncodeAscendBinary(self): + self.assertEqual( + self.abinary.encode(None, 'family=ipv4 action=discard direction=in dst=10.10.255.254/32'), + b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') + + def testStringDecoding(self): + self.assertEqual( + self.string.decode(b'1234567890'), + '1234567890') + + def testAddressDecoding(self): + self.assertEqual( + self.ipaddr.decode(b'\xc0\xa8\x00\xff'), + '192.168.0.255') + + def testIntegerDecoding(self): + self.assertEqual( + self.integer.decode(b'\x01\x02\x03\x04'), + 0x01020304) + + def testInteger64Decoding(self): + self.assertEqual( + self.integer64.decode(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF + ) + + def testDateDecoding(self): + self.assertEqual( + self.date.decode(b'\x01\x02\x03\x04'), + 0x01020304) + + def testOctetsEncoding(self): + self.assertEqual(self.octets.encode(None, '0x01020304'), b'\x01\x02\x03\x04') + self.assertEqual(self.octets.encode(None, b'0x01020304'), b'\x01\x02\x03\x04') + self.assertEqual(self.octets.encode(None, '16909060'), b'\x01\x02\x03\x04') + # encodes to 253 bytes + self.assertEqual(self.octets.encode(None, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D'), b'\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r') + self.assertRaisesRegex(ValueError, 'Can only encode strings of <= 253 characters', self.octets.encode, None, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E') diff --git a/tests/testDictionary.py b/tests/testDictionary.py index 0d1fb99..fd55ee2 100644 --- a/tests/testDictionary.py +++ b/tests/testDictionary.py @@ -3,13 +3,15 @@ import os from io import StringIO +from pyrad.datatypes.leaf import Integer from . import home from pyrad.dictionary import Attribute from pyrad.dictionary import Dictionary from pyrad.dictionary import ParseError -from pyrad.tools import DecodeAttr from pyrad.dictfile import DictFile +from pyrad.datatypes import leaf, structural + class AttributeTests(unittest.TestCase): def testInvalidDataType(self): @@ -19,7 +21,7 @@ def testConstructionParameters(self): attr = Attribute('name', 'code', 'integer', False, 'vendor') self.assertEqual(attr.name, 'name') self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') + self.assertIsInstance(attr.type, Integer) self.assertEqual(attr.is_sub_attribute, False) self.assertEqual(attr.vendor, 'vendor') self.assertEqual(len(attr.values), 0) @@ -30,7 +32,7 @@ def testNamedConstructionParameters(self): vendor='vendor') self.assertEqual(attr.name, 'name') self.assertEqual(attr.code, 'code') - self.assertEqual(attr.type, 'integer') + self.assertIsInstance(attr.type, Integer) self.assertEqual(attr.vendor, 'vendor') self.assertEqual(len(attr.values), 0) @@ -83,6 +85,28 @@ class DictionaryParsingTests(unittest.TestCase): ('Test-Integer64-Oct', 10, 'integer64'), ] + @classmethod + def setUpClass(cls): + # leaf attributes + cls.abinary = leaf.AscendBinary() + cls.byte = leaf.Byte() + cls.date = leaf.Date() + cls.ether = leaf.Ether() + cls.ifid = leaf.Ifid() + cls.integer = leaf.Integer() + cls.integer64 = leaf.Integer64() + cls.ipaddr = leaf.Ipaddr() + cls.ipv6addr = leaf.Ipv6addr() + cls.ipv6prefix = leaf.Ipv6prefix() + cls.octets = leaf.Octets() + cls.short = leaf.Short() + cls.signed = leaf.Signed() + cls.string = leaf.String() + + # structural attributes + cls.tlv = structural.Tlv() + cls.vsa = structural.Vsa() + def setUp(self): self.path = os.path.join(home, 'data') self.dict = Dictionary(os.path.join(self.path, 'simple')) @@ -104,7 +128,7 @@ def testParseSimpleDictionary(self): for (attr, code, type) in self.simple_dict_values: attr = self.dict[attr] self.assertEqual(attr.code, code) - self.assertEqual(attr.type, type) + self.assertEqual(attr.type.name, type) def testAttributeTooFewColumnsError(self): try: @@ -168,18 +192,18 @@ def testIntegerValueParsing(self): self.dict.ReadDictionary(StringIO('VALUE Test-Integer Value-Six 5')) self.assertEqual(len(self.dict['Test-Integer'].values), 1) self.assertEqual( - DecodeAttr('integer', - self.dict['Test-Integer'].values['Value-Six']), - 5) + self.integer.decode( + self.dict['Test-Integer'].values['Value-Six'] + ), 5) def testInteger64ValueParsing(self): self.assertEqual(len(self.dict['Test-Integer64'].values), 0) self.dict.ReadDictionary(StringIO('VALUE Test-Integer64 Value-Six 5')) self.assertEqual(len(self.dict['Test-Integer64'].values), 1) self.assertEqual( - DecodeAttr('integer64', - self.dict['Test-Integer64'].values['Value-Six']), - 5) + self.integer64.decode( + self.dict['Test-Integer64'].values['Value-Six'] + ), 5) def testStringValueParsing(self): self.assertEqual(len(self.dict['Test-String'].values), 0) @@ -187,9 +211,9 @@ def testStringValueParsing(self): 'VALUE Test-String Value-Custard custardpie')) self.assertEqual(len(self.dict['Test-String'].values), 1) self.assertEqual( - DecodeAttr('string', - self.dict['Test-String'].values['Value-Custard']), - 'custardpie') + self.string.decode( + self.dict['Test-String'].values['Value-Custard'] + ), 'custardpie') def testOctetValueParsing(self): self.assertEqual(len(self.dict['Test-Octets'].values), 0) @@ -199,13 +223,13 @@ def testOctetValueParsing(self): 'VALUE Test-Octets Value-B 0x42\n')) # "B" self.assertEqual(len(self.dict['Test-Octets'].values), 2) self.assertEqual( - DecodeAttr('octets', - self.dict['Test-Octets'].values['Value-A']), - b'A') + self.octets.decode( + self.dict['Test-Octets'].values['Value-A'] + ), b'A') self.assertEqual( - DecodeAttr('octets', - self.dict['Test-Octets'].values['Value-B']), - b'B') + self.octets.decode( + self.dict['Test-Octets'].values['Value-B'] + ), b'B') def testTlvParsing(self): self.assertEqual(len(self.dict['Test-Tlv'].sub_attributes), 2) diff --git a/tests/testPacket.py b/tests/testPacket.py index f7649a0..61e7c14 100644 --- a/tests/testPacket.py +++ b/tests/testPacket.py @@ -124,7 +124,7 @@ def _create_reply_with_duplicate_attributes(self, request): def _get_attribute_bytes(self, attr_name, value): attr = self.dict.attributes[attr_name] attr_key = attr.code - attr_value = packet.tools.EncodeAttr(attr.type, value) + attr_value = attr.encode(value) attr_len = len(attr_value) + 2 return struct.pack('!BB', attr_key, attr_len) + attr_value @@ -437,22 +437,22 @@ def testDecodePacketWithAttribute(self): def testDecodePacketWithTlvAttribute(self): self.packet.DecodePacket( b'\x01\x02\x00\x1d1234567890123456\x04\x09\x01\x07value') - self.assertEqual(self.packet[4], {1:[b'value']}) + self.assertEqual(self.packet[4], [{1:[b'value']}]) def testDecodePacketWithVendorTlvAttribute(self): self.packet.DecodePacket( b'\x01\x02\x00\x231234567890123456\x1a\x0f\x00\x00\x00\x10\x03\x09\x01\x07value') - self.assertEqual(self.packet[(16,3)], {1:[b'value']}) + self.assertEqual(self.packet[(16,3)], [{1:[b'value']}]) def testDecodePacketWithTlvAttributeWith2SubAttributes(self): self.packet.DecodePacket( b'\x01\x02\x00\x231234567890123456\x04\x0f\x01\x07value\x02\x06\x00\x00\x00\x09') - self.assertEqual(self.packet[4], {1:[b'value'], 2:[b'\x00\x00\x00\x09']}) + self.assertEqual(self.packet[4], [{1:[b'value'], 2:[b'\x00\x00\x00\x09']}]) def testDecodePacketWithSplitTlvAttribute(self): self.packet.DecodePacket( - b'\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x09\x02\x06\x00\x00\x00\x09') - self.assertEqual(self.packet[4], {1:[b'value'], 2:[b'\x00\x00\x00\x09']}) + b'\x01\x02\x00\x251234567890123456\x04\x09\x01\x07value\x04\x08\x02\x06\x00\x00\x00\x09') + self.assertEqual(self.packet[4], [{1:[b'value']}, {2:[b'\x00\x00\x00\x09']}]) def testDecodePacketWithMultiValuedAttribute(self): self.packet.DecodePacket( diff --git a/tests/testTools.py b/tests/testTools.py deleted file mode 100644 index f220e7b..0000000 --- a/tests/testTools.py +++ /dev/null @@ -1,127 +0,0 @@ -from ipaddress import AddressValueError -from pyrad import tools -import unittest - - -class EncodingTests(unittest.TestCase): - def testStringEncoding(self): - self.assertRaises(ValueError, tools.EncodeString, 'x' * 254) - self.assertEqual( - tools.EncodeString('1234567890'), - b'1234567890') - - def testInvalidStringEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeString, 1) - - def testAddressEncoding(self): - self.assertRaises(AddressValueError, tools.EncodeAddress, 'TEST123') - self.assertEqual( - tools.EncodeAddress('192.168.0.255'), - b'\xc0\xa8\x00\xff') - - def testInvalidAddressEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeAddress, 1) - - def testIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0x01020304), b'\x01\x02\x03\x04') - - def testInteger64Encoding(self): - self.assertEqual( - tools.EncodeInteger64(0xFFFFFFFFFFFFFFFF), b'\xff' * 8 - ) - - def testUnsignedIntegerEncoding(self): - self.assertEqual(tools.EncodeInteger(0xFFFFFFFF), b'\xff\xff\xff\xff') - - def testInvalidIntegerEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeInteger, 'ONE') - - def testDateEncoding(self): - self.assertEqual(tools.EncodeDate(0x01020304), b'\x01\x02\x03\x04') - - def testInvalidDataEncodingRaisesTypeError(self): - self.assertRaises(TypeError, tools.EncodeDate, '1') - - def testEncodeAscendBinary(self): - self.assertEqual( - tools.EncodeAscendBinary('family=ipv4 action=discard direction=in dst=10.10.255.254/32'), - b'\x01\x00\x01\x00\x00\x00\x00\x00\n\n\xff\xfe\x00 \x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00') - - def testStringDecoding(self): - self.assertEqual( - tools.DecodeString(b'1234567890'), - '1234567890') - - def testAddressDecoding(self): - self.assertEqual( - tools.DecodeAddress(b'\xc0\xa8\x00\xff'), - '192.168.0.255') - - def testIntegerDecoding(self): - self.assertEqual( - tools.DecodeInteger(b'\x01\x02\x03\x04'), - 0x01020304) - - def testInteger64Decoding(self): - self.assertEqual( - tools.DecodeInteger64(b'\xff' * 8), 0xFFFFFFFFFFFFFFFF - ) - - def testDateDecoding(self): - self.assertEqual( - tools.DecodeDate(b'\x01\x02\x03\x04'), - 0x01020304) - - def testOctetsEncoding(self): - self.assertEqual(tools.EncodeOctets('0x01020304'), b'\x01\x02\x03\x04') - self.assertEqual(tools.EncodeOctets(b'0x01020304'), b'\x01\x02\x03\x04') - self.assertEqual(tools.EncodeOctets('16909060'), b'\x01\x02\x03\x04') - # encodes to 253 bytes - self.assertEqual(tools.EncodeOctets('0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D'), b'\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f\x10\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r') - self.assertRaisesRegex(ValueError, 'Can only encode strings of <= 253 characters', tools.EncodeOctets, '0x0102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E0F100102030405060708090A0B0C0D0E') - - def testUnknownTypeEncoding(self): - self.assertRaises(ValueError, tools.EncodeAttr, 'unknown', None) - - def testUnknownTypeDecoding(self): - self.assertRaises(ValueError, tools.DecodeAttr, 'unknown', None) - - def testEncodeFunction(self): - self.assertEqual( - tools.EncodeAttr('string', 'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.EncodeAttr('ipaddr', '192.168.0.255'), - b'\xc0\xa8\x00\xff') - self.assertEqual( - tools.EncodeAttr('integer', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('date', 0x01020304), - b'\x01\x02\x03\x04') - self.assertEqual( - tools.EncodeAttr('integer64', 0xFFFFFFFFFFFFFFFF), - b'\xff'*8) - - def testDecodeFunction(self): - self.assertEqual( - tools.DecodeAttr('string', b'string'), - 'string') - self.assertEqual( - tools.EncodeAttr('octets', b'string'), - b'string') - self.assertEqual( - tools.DecodeAttr('ipaddr', b'\xc0\xa8\x00\xff'), - '192.168.0.255') - self.assertEqual( - tools.DecodeAttr('integer', b'\x01\x02\x03\x04'), - 0x01020304) - self.assertEqual( - tools.DecodeAttr('integer64', b'\xff'*8), - 0xFFFFFFFFFFFFFFFF) - self.assertEqual( - tools.DecodeAttr('date', b'\x01\x02\x03\x04'), - 0x01020304)